Source code for syng_bts.core

"""Core experiment functions for SyNG-BTS.

Public API
----------
- ``generate()`` — train a model and produce synthetic samples.
- ``pilot_study()`` — sweep over pilot sizes with replicated draws.
- ``transfer()`` — pre-train on source, fine-tune on target.
"""

from __future__ import annotations

import re
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import pandas as pd
import torch

from .data_utils import _derive_dataname, _validate_feature_data, resolve_data
from .helper_train import VerbosityLevel, _resolve_verbose
from .helper_training import (
    TrainedModel,
    training_AEs,
    training_flows,
    training_GANs,
    training_iter,
)
from .helper_utils import (
    Gaussian_aug,
    create_labels,
    draw_pilot,
    inverse_log2,
    preprocessinglog2,
    set_all_seeds,
)
from .inference import run_generation, run_reconstruction
from .result import PilotResult, SyngResult

# =========================================================================
# Private helpers
# =========================================================================


@dataclass
class PreparedData:
    """Container for pre-processed data shared across public API functions.

    Holds everything produced by :func:`_prepare_data` so that
    ``generate()``, ``pilot_study()``, and ``transfer()`` can
    consume a single, validated object instead of duplicating the
    resolve → validate → convert → label pipeline.
    """

    df: pd.DataFrame
    """Original DataFrame (unmodified)."""

    colnames: list[str]
    """Column names from the original DataFrame."""

    oridata: torch.Tensor
    """Numeric data as a float32 tensor, optionally log-transformed."""

    n_samples: int
    """Number of samples (rows) in *oridata*."""

    orilabels: torch.Tensor
    """Label tensor (one-hot or single-column)."""

    oriblurlabels: torch.Tensor
    """Blurred-label tensor."""

    dataname: str
    """Short dataset name for metadata / filenames."""

    effective_groups: np.ndarray | None
    """Resolved group array (explicit > bundled > ``None``)."""

    group_mapping: dict[int, object] | None
    """Maps numeric label (0, 1) back to original group values.
    Built from the same logic as :func:`create_labels` so round-trip
    is consistent.  ``None`` when no groups are present."""

    apply_log: bool
    """Whether ``log2(x + 1)`` was applied to *oridata*."""


@dataclass
class TrainingContext:
    """Sidecar training context returned by :func:`orchestrate_training`.

    Contains non-architectural training runtime information needed for
    reconstruction parity and metadata assembly.  This avoids leaking
    private ``_train_*`` keys into ``TrainedModel.arch_params``.
    """

    random_seed: int
    """Random seed used for the training split."""

    val_ratio: float
    """Validation split ratio used during training."""

    batch_size: int
    """Computed batch size (from ``batch_frac``)."""

    num_epochs: int
    """Resolved maximum epoch count."""

    early_stop: bool
    """Whether early stopping was enabled."""

    early_stop_num: int
    """Early stopping patience value."""

    rawdata: torch.Tensor
    """Training data after blur-label appending and augmentation."""

    rawlabels: torch.Tensor
    """Training labels after augmentation."""


def _prepare_data(
    *,
    data: pd.DataFrame | str | Path,
    name: str | None,
    groups: pd.Series | np.ndarray | None,
    apply_log: bool,
) -> PreparedData:
    """Shared data-preparation pipeline for public API functions.

    Resolves the input data, validates it, derives the dataset name,
    converts to a float32 tensor (with optional log2 transform),
    resolves groups, and creates labels.

    Parameters
    ----------
    data : DataFrame, str, or Path
        Input data — a pandas DataFrame, a path to a CSV file, or the
        name of a bundled dataset.
    name : str or None
        Short name override.  Derived automatically when ``None``.
    groups : pd.Series, np.ndarray, or None
        Optional binary group labels.
    apply_log : bool
        Apply ``log2(x + 1)`` preprocessing.

    Returns
    -------
    PreparedData
    """
    df, bundled_groups = resolve_data(data)
    _validate_feature_data(df)

    # Light sanity checks when user requests automatic log2 preprocessing.
    # - If negatives are present, that's invalid for log2 and we raise.
    # - If data contains non-integer values, warn the user because the
    #   input may already be transformed (double-logging risk).
    if apply_log:
        arr = df.to_numpy()
        if (arr < 0).any():
            raise ValueError("Input contains negative values; cannot apply log2.")
        # Treat non-integer values as suspicious (non-fatal).
        if not (arr.shape[0] == 0 or np.allclose(arr, np.round(arr))):
            import warnings

            warnings.warn(
                "apply_log=True but input contains non-integer values — the data may already "
                "be log-transformed. To avoid double-logging, pass apply_log=False.",
                UserWarning,
                stacklevel=2,
            )

    dataname = _derive_dataname(data, name)

    colnames = list(df.columns)
    oridata = torch.from_numpy(df.to_numpy().copy()).to(torch.float32)

    if apply_log:
        oridata = preprocessinglog2(oridata)

    n_samples = oridata.shape[0]

    effective_groups = _resolve_effective_groups(
        groups,
        bundled_groups,
        n_samples=n_samples,
        param_name="groups",
    )

    orilabels, oriblurlabels = create_labels(
        n_samples=n_samples,
        groups=effective_groups,
    )

    # Build group_mapping: maps numeric label → original group value.
    # Uses the same convention as create_labels (base = groups[0] → 0).
    group_mapping: dict[int, object] | None = None
    if effective_groups is not None:
        unique_vals = list(dict.fromkeys(effective_groups))  # order-preserving unique
        base = effective_groups[0]
        group_mapping = {0: base}
        for v in unique_vals:
            if v != base:
                group_mapping[1] = v
                break

    return PreparedData(
        df=df,
        colnames=colnames,
        oridata=oridata,
        n_samples=n_samples,
        orilabels=orilabels,
        oriblurlabels=oriblurlabels,
        dataname=dataname,
        effective_groups=effective_groups,
        group_mapping=group_mapping,
        apply_log=apply_log,
    )


def _parse_model_spec(model: str) -> tuple[str, int, int]:
    """Parse a model string like ``'VAE1-10'`` into components.

    The model string format is ``<NAME><recon_weight>-<kl_weight>``
    where ``recon_weight`` and ``kl_weight`` specify the ratio of
    reconstruction loss to KL divergence (e.g. ``'VAE1-10'`` means
    ``recon_weight=1, kl_weight=10``, i.e. a 1:10 ratio).

    Returns
    -------
    tuple[str, int, int]
        ``(modelname, reconstruction_term_weight, kl_weight)``
        — e.g. ``("VAE", 1, 10)``.
    """
    parts = re.split(r"([A-Z]+)(\d)([-+])(\d+)", model)
    if len(parts) > 1:
        return parts[1], int(parts[2]), int(parts[4])
    return model, 1, 1


def _build_loss_df(log_dict: dict, modelname: str) -> pd.DataFrame:
    """Convert a raw log_dict from a training helper into a tidy DataFrame.

    Parameters
    ----------
    log_dict : dict
        Raw loss series as returned by ``TrainedModel.log_dict``.
    modelname : str
        The short model name (``"AE"``, ``"VAE"``, ``"CVAE"``, ``"GAN"``,
        ``"WGAN"``, ``"WGANGP"``, ``"maf"``, ``"glow"``, ``"realnvp"``, etc.).

    Returns
    -------
    pd.DataFrame
    """
    if modelname == "AE":
        # AE logs total loss only (no KL/reconstruction split)
        return pd.DataFrame(
            {
                "train_loss": log_dict.get("train_loss_per_batch", []),
                "val_loss": log_dict.get("val_loss_per_batch", []),
            }
        )
    if modelname in ("VAE", "CVAE"):
        return pd.DataFrame(
            {
                "kl": log_dict.get(
                    "val_kl_loss_per_batch",
                    log_dict.get("train_kl_loss_per_batch", []),
                ),
                "recons": log_dict.get(
                    "val_reconstruction_loss_per_batch",
                    log_dict.get("train_reconstruction_loss_per_batch", []),
                ),
            }
        )
    if "GAN" in modelname:
        return pd.DataFrame(
            {
                "discriminator": log_dict["train_discriminator_loss_per_batch"],
                "generator": log_dict["train_generator_loss_per_batch"],
            }
        )
    # Flows — per-epoch loss
    return pd.DataFrame({"train_loss": log_dict["train_loss_per_epoch"]})


# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Training orchestration
# ---------------------------------------------------------------------------


def orchestrate_training(
    *,
    rawdata: torch.Tensor,
    rawlabels: torch.Tensor,
    oriblurlabels: torch.Tensor,
    modelname: str,
    kl_weight: int = 1,
    reconstruction_term_weight: int = 1,
    batch_frac: float = 0.1,
    random_seed: int = 123,
    epoch: int | None = None,
    early_stop_patience: int | None = None,
    learning_rate: float = 0.0005,
    val_ratio: float = 0.2,
    off_aug: str | None = None,
    AE_head_num: int = 2,
    Gaussian_head_num: int = 9,
    model_state: dict | None = None,
    cap: bool = False,
    loss_fn: str = "MSE",
    use_scheduler: bool = False,
    step_size: int = 10,
    gamma: float = 0.5,
    CVAE_wide_network: bool = False,
    verbose: int = 1,
) -> tuple[TrainedModel, TrainingContext]:
    """Centralized training orchestrator.

    Resolves early-stopping configuration, appends blur-labels for
    non-CVAE multi-group data, applies offline augmentation, computes
    batch size, and dispatches to the appropriate model-family
    training wrapper.

    Model parsing is **external** — the caller passes ``modelname``
    ``kl_weight``, and ``reconstruction_term_weight``
    (see :func:`_parse_model_spec`).

    Parameters
    ----------
    rawdata : torch.Tensor
        Input data tensor (pre-blur-label, pre-augmentation).
    rawlabels : torch.Tensor
        Label tensor.
    oriblurlabels : torch.Tensor
        Blurred-label tensor for two-group training.
    modelname : str
        Short model name (``"AE"``, ``"VAE"``, ``"GAN"``, etc.).
    kl_weight : int
        KL divergence weight (VAE/CVAE only).
    reconstruction_term_weight : int
        Reconstruction loss weight (VAE/CVAE only).
    batch_frac : float
        Batch size as a fraction of sample count.
    random_seed : int
        Random seed for reproducibility.
    epoch : int or None
        Fixed epoch count, or ``None`` for early stopping.
    early_stop_patience : int or None
        Early stopping patience, or ``None``.
    learning_rate : float
        Optimizer learning rate.
    val_ratio : float
        Validation split ratio (AE family only).
    off_aug : str or None
        Offline augmentation: ``"AE_head"``, ``"Gaussian_head"``,
        or ``None``.
    AE_head_num : int
        Fold multiplier for AE-head augmentation.
    Gaussian_head_num : int
        Fold multiplier for Gaussian-head augmentation.
    model_state : dict or None
        Pre-trained model state for transfer learning.
    cap : bool
        Cap generated values (AE family training).
    loss_fn : str
        Loss function name (AE family).
    use_scheduler : bool
        Enable learning-rate scheduler (AE family).
    step_size : int
        Scheduler step size.
    gamma : float
        Scheduler gamma.
    CVAE_wide_network : bool
        Use wider encoder/decoder for CVAE (512→256→128→64
        instead of 256→128→64).  Ignored for non-CVAE models.
    verbose : int
        Verbosity level.

    Returns
    -------
    tuple[TrainedModel, TrainingContext]
        The trained model and a sidecar context with runtime info
        needed for inference and metadata assembly.
    """
    # --- 1. Resolve early-stopping config --------------------------------
    num_epochs, early_stop, early_stop_num = _resolve_early_stopping_config(
        epoch=epoch,
        early_stop_patience=early_stop_patience,
        default_max_epochs=1000,
        default_patience=30,
    )

    # --- 2. Append blur-labels for non-CVAE two-group data ---------------
    if (modelname != "CVAE") and (torch.unique(rawlabels).shape[0] > 1):
        rawdata = torch.cat((rawdata, oriblurlabels), dim=1)

    # --- 3. Offline augmentation -----------------------------------------
    if off_aug == "Gaussian_head":
        rawdata, rawlabels = Gaussian_aug(
            rawdata, rawlabels, multiplier=[Gaussian_head_num]
        )
    elif off_aug == "AE_head":
        # TODO Change hardcoded training config for AE head augmentation
        # to be more flexible
        feed_data, feed_labels = training_iter(
            iter_times=AE_head_num,
            rawdata=rawdata,
            rawlabels=rawlabels,
            random_seed=random_seed,
            modelname="AE",
            num_epochs=1000,
            batch_size=round(rawdata.shape[0] * 0.1),
            learning_rate=0.0005,
            early_stop=False,
            early_stop_num=30,
            kl_weight=1,
            loss_fn="MSE",
            replace=True,
            verbose=verbose,
        )
        rawdata = feed_data
        rawlabels = feed_labels

    # --- 4. Compute batch size -------------------------------------------
    batch_size = max(1, round(rawdata.shape[0] * batch_frac))

    # --- 5. Dispatch to model-family training ----------------------------
    if "GAN" in modelname:
        trained = training_GANs(
            rawdata=rawdata,
            rawlabels=rawlabels,
            batch_size=batch_size,
            random_seed=random_seed,
            modelname=modelname,
            num_epochs=num_epochs,
            learning_rate=learning_rate,
            early_stop=early_stop,
            early_stop_num=early_stop_num,
            model_state=model_state,
            verbose=verbose,
        )
    elif "AE" in modelname:
        trained = training_AEs(
            rawdata=rawdata,
            rawlabels=rawlabels,
            batch_size=batch_size,
            random_seed=random_seed,
            modelname=modelname,
            num_epochs=num_epochs,
            learning_rate=learning_rate,
            val_ratio=val_ratio,
            kl_weight=kl_weight,
            reconstruction_term_weight=reconstruction_term_weight,
            early_stop=early_stop,
            early_stop_num=early_stop_num,
            model_state=model_state,
            cap=cap,
            loss_fn=loss_fn,
            use_scheduler=use_scheduler,
            step_size=step_size,
            gamma=gamma,
            wide_network=CVAE_wide_network,
            verbose=verbose,
        )
    elif modelname in ("maf", "realnvp", "glow", "maf-split", "maf-split-glow"):
        trained = training_flows(
            rawdata=rawdata,
            batch_frac=batch_frac,
            valid_batch_frac=0.3,
            random_seed=random_seed,
            modelname=modelname,
            num_blocks=5,
            num_epochs=num_epochs,
            learning_rate=learning_rate,
            num_hidden=226,
            early_stop=early_stop,
            early_stop_num=early_stop_num,
            model_state=model_state,
            verbose=verbose,
        )
    else:
        raise ValueError(f"Unsupported model: {modelname!r}")

    ctx = TrainingContext(
        random_seed=random_seed,
        val_ratio=val_ratio,
        batch_size=batch_size,
        num_epochs=num_epochs,
        early_stop=early_stop,
        early_stop_num=early_stop_num,
        rawdata=rawdata,
        rawlabels=rawlabels,
    )

    return trained, ctx


def _infer_from_trained(
    trained: TrainedModel,
    *,
    new_size: int | list[int],
    ctx: TrainingContext,
    cap: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
    """Run generation and reconstruction via the unified inference dispatcher.

    This delegates to :func:`inference.run_generation` and
    :func:`inference.run_reconstruction`, keeping post-training
    inference cleanly separated from training orchestrators.

    Parameters
    ----------
    trained : TrainedModel
        Output from the training dispatch layer.
    new_size : int or list[int]
        Number of synthetic samples to generate.
    ctx : TrainingContext
        Sidecar context from :func:`orchestrate_training` containing
        the training data, batch size, and split parameters needed
        for reconstruction parity.
    cap : bool
        Cap generated values to observed range.

    Returns
    -------
    tuple[torch.Tensor, torch.Tensor | None]
        ``(generated_data, reconstructed_data)``
    """
    from torch.utils.data import DataLoader, TensorDataset, random_split

    family = trained.arch_params["family"]

    # --- Determine capping values ---
    if cap:
        col_max, _ = torch.max(ctx.rawdata, dim=0)
        col_sd = torch.std(ctx.rawdata, dim=0, unbiased=True)
    else:
        col_max = None
        col_sd = None

    # --- Generation via unified dispatcher ---
    generated = run_generation(
        trained,
        num_samples=new_size,
        col_max=col_max,
        col_sd=col_sd,
    )

    # --- Reconstruction via unified dispatcher (AE family only) ---
    reconstructed = None
    if family == "ae":
        num_features = trained.arch_params["num_features"]
        data = TensorDataset(ctx.rawdata, ctx.rawlabels)

        # Reproduce legacy reconstruction path as closely as possible:
        # training_AEs reconstructs over the train split DataLoader
        # (shuffle=True, drop_last=True).
        set_all_seeds(ctx.random_seed)
        val_size = int(len(data) * ctx.val_ratio)
        train_size = len(data) - val_size
        train_dataset, _val_dataset = random_split(data, [train_size, val_size])
        recon_loader = DataLoader(
            train_dataset,
            batch_size=ctx.batch_size,
            shuffle=True,
            drop_last=True,
        )

        reconstructed, _ = run_reconstruction(
            trained,
            data_loader=recon_loader,
            n_features=num_features,
        )

    return generated, reconstructed


def _compute_new_size(
    orilabels: torch.Tensor,
    n_samples: int,
    new_size: int | list[int],
) -> int | list[int]:
    """Compute the generation size, honouring group balance.

    - If *new_size* is a ``list``, it is returned as-is after validation
      (must match the number of groups).
    - If *new_size* is an ``int`` and the data has groups, returns
      ``[n_group_0, n_group_1]`` preserving the original group ratio and
      summing to *new_size*.
    - If *new_size* is an ``int`` and the data has no groups, returns
      *new_size* unchanged.

    Notes
    -----
    For grouped data, ``group_0`` is the base group used by
    :func:`create_labels` (the first group value encountered in the input),
    and ``group_1`` is the other group.
    """
    if isinstance(new_size, bool):
        raise TypeError("new_size must be an int or list[int], got bool.")

    n_groups = len(torch.unique(orilabels))
    has_groups = n_groups > 1

    if isinstance(new_size, list):
        if not has_groups:
            raise ValueError(
                "new_size as a list requires grouped data, but the dataset "
                "has only a single group."
            )
        if len(new_size) != n_groups:
            raise ValueError(
                f"new_size list length ({len(new_size)}) must match the "
                f"number of groups ({n_groups})."
            )
        if any(isinstance(v, bool) or not isinstance(v, int) for v in new_size):
            raise TypeError("new_size list values must be integers.")
        if any(v < 0 for v in new_size):
            raise ValueError("new_size list values must be non-negative.")
        return new_size

    if not isinstance(new_size, int):
        raise TypeError(
            f"new_size must be an int or list[int], got {type(new_size).__name__}."
        )
    if new_size < 0:
        raise ValueError("new_size must be non-negative.")

    if has_groups:
        n0 = int((orilabels == 0).sum())
        n1 = int((orilabels == 1).sum())
        total = n0 + n1
        new_n0 = round(new_size * n0 / total)
        new_n1 = new_size - new_n0

        if n0 != n1:
            import warnings

            warnings.warn(
                "Grouped generation with integer new_size preserves the original "
                "class ratio and may remain imbalanced. "
                f"Input groups are ({n0}, {n1}), so new_size={new_size} "
                f"becomes [{new_n0}, {new_n1}]. "
                "Pass an explicit list (for example [500, 500]) to control "
                "per-group counts.",
                UserWarning,
                stacklevel=3,
            )
        return [new_n0, new_n1]

    return new_size


def _labels_to_groups(
    labels: pd.Series,
    group_mapping: dict[int, object],
    *,
    modelname: str,
) -> pd.Series:
    """Map numeric labels back to original group values.

    Parameters
    ----------
    labels : pd.Series
        Raw label column stripped from generated or reconstructed data.
        For CVAE these are integer class indices (0, 1, …).
        For non-CVAE models they are blur-labels (0–1 for group 0,
        9–10 for group 1).
    group_mapping : dict[int, object]
        ``{0: base_group_value, 1: other_group_value}`` produced by
        :func:`_prepare_data`.
    modelname : str
        Short model name — determines the conversion strategy.

    Returns
    -------
    pd.Series
        Group values with ``name="group"``.
    """
    if modelname == "CVAE":
        int_labels = labels.round().astype(int)
    else:
        # Blur-label ranges: [0, 1] → group 0, [9, 10] → group 1.
        # Threshold at 5 cleanly separates the two ranges.
        int_labels = (labels >= 5).astype(int)
    return int_labels.map(group_mapping).rename("group")


def _assemble_result(
    *,
    gen_data: torch.Tensor,
    recon_data: torch.Tensor | None,
    trained: TrainedModel,
    colnames: list[str],
    modelname: str,
    model: str,
    dataname: str,
    n_samples: int,
    num_epochs: int,
    random_seed: int,
    kl_weight: int,
    reconstruction_term_weight: int,
    early_stop: bool,
    early_stop_num: int,
    apply_log: bool,
    original_data: pd.DataFrame,
    original_groups: pd.Series | None = None,
    group_mapping: dict[int, object] | None = None,
    extra_metadata: dict | None = None,
) -> SyngResult:
    """Assemble a :class:`SyngResult` from training/inference outputs.

    Centralises label stripping (for CVAE conditioning labels and
    non-CVAE blur-labels), DataFrame construction, inverse log
    transform, column-order validation, group derivation, loss
    assembly, metadata assembly, and ``SyngResult`` construction.

    Parameters
    ----------
    gen_data : torch.Tensor
        Raw generated data tensor.
    recon_data : torch.Tensor or None
        Raw reconstructed data tensor (AE-family only).
    trained : TrainedModel
        Training output from the dispatch layer.
    colnames : list[str]
        Original column names.
    modelname : str
        Short model name (e.g. ``"VAE"``, ``"CVAE"``).
    model : str
        Full model specification string (e.g. ``"VAE1-10"``).
    dataname : str
        Dataset name for metadata.
    n_samples : int
        Number of original samples (for ``input_shape``).
    num_epochs : int
        Maximum epoch count configured.
    random_seed : int
        Random seed used.
    kl_weight : int
        KL weight used.
    reconstruction_term_weight : int
        Reconstruction loss weight used.
    early_stop : bool
        Whether early stopping was enabled.
    early_stop_num : int
        Early stopping patience value.
    apply_log : bool
        Whether ``log2(x + 1)`` preprocessing was applied.
    original_data : pd.DataFrame
        Original data subset to attach to the result.
    original_groups : pd.Series or None
        Group labels for the original data.
    group_mapping : dict[int, object] or None
        Maps numeric labels (0, 1) back to original group values.
        When provided, generated and reconstructed groups are derived
        from the stripped label columns.
    extra_metadata : dict or None
        Additional metadata entries (e.g. ``pilot_size``, ``draw``,
        ``pilot_indices``).

    Returns
    -------
    SyngResult
    """
    gen_np = gen_data.detach().numpy()

    # Strip the appended label/blur-label column when the generated
    # tensor has more columns than the original feature set.  This
    # handles both CVAE conditioning labels and non-CVAE blur-labels
    # appended during training for two-group data.
    gen_labels = None
    if gen_np.shape[1] > len(colnames):
        gen_labels = pd.Series(gen_np[:, -1], name="label")
        gen_np = gen_np[:, :-1]

    gen_df = pd.DataFrame(gen_np, columns=list(colnames))

    recon_df = None
    recon_labels = None
    if recon_data is not None:
        recon_np = recon_data.detach().numpy()

        if recon_np.shape[1] > len(colnames):
            recon_labels = pd.Series(recon_np[:, -1], name="label")
            recon_np = recon_np[:, :-1]

        recon_df = pd.DataFrame(recon_np, columns=list(colnames))

    # Inverse log transform to count scale
    if apply_log:
        gen_df = inverse_log2(gen_df)
        if recon_df is not None:
            recon_df = inverse_log2(recon_df)

    # Validate column order consistency
    if not gen_df.columns.tolist() == colnames:
        raise RuntimeError(
            "Column order mismatch in generated_data. "
            "This is an internal error; please report it."
        )
    if recon_df is not None and not recon_df.columns.tolist() == colnames:
        raise RuntimeError(
            "Column order mismatch in reconstructed_data. "
            "This is an internal error; please report it."
        )

    # --- Derive group Series from labels + mapping -----------------------
    generated_groups: pd.Series | None = None
    reconstructed_groups: pd.Series | None = None
    if group_mapping is not None:
        if gen_labels is not None:
            generated_groups = _labels_to_groups(
                gen_labels, group_mapping, modelname=modelname
            )
        if recon_labels is not None:
            reconstructed_groups = _labels_to_groups(
                recon_labels, group_mapping, modelname=modelname
            )

    loss_df = _build_loss_df(trained.log_dict, modelname)

    metadata: dict = {
        "model": model,
        "modelname": modelname,
        "dataname": dataname,
        "num_epochs": num_epochs,
        "epochs_trained": trained.epochs_trained,
        "seed": random_seed,
        "kl_weight": kl_weight,
        "reconstruction_term_weight": reconstruction_term_weight,
        "input_shape": (n_samples, len(colnames)),
        "early_stop": early_stop,
        "early_stop_patience": early_stop_num,
        "generated_labels": gen_labels,
        "reconstructed_labels": recon_labels,
        "apply_log": apply_log,
        "group_mapping": group_mapping,
        "arch_params": {
            k: v for k, v in trained.arch_params.items() if not k.startswith("_")
        },
    }

    if extra_metadata is not None:
        overlapping = set(extra_metadata).intersection(metadata)
        if overlapping:
            overlap_str = ", ".join(sorted(overlapping))
            raise ValueError(
                f"extra_metadata contains reserved metadata keys: {overlap_str}"
            )
        metadata.update(extra_metadata)

    return SyngResult(
        generated_data=gen_df,
        loss=loss_df,
        reconstructed_data=recon_df,
        original_data=original_data,
        model_state=trained.model_state,
        metadata=metadata,
        original_groups=original_groups,
        generated_groups=generated_groups,
        reconstructed_groups=reconstructed_groups,
    )


def _resolve_early_stopping_config(
    epoch: int | None,
    early_stop_patience: int | None,
    default_max_epochs: int = 1000,
    default_patience: int = 30,
) -> tuple[int, bool, int]:
    """Resolve epoch count and early stopping configuration.

    Parameters
    ----------
    epoch : int or None
        User-specified fixed epoch count.
    early_stop_patience : int or None
        User-specified early stopping patience.
    default_max_epochs : int
        Default maximum epochs when early stopping is enabled.
    default_patience : int
        Default patience when early stopping is enabled but patience not specified.

    Returns
    -------
    tuple[int, bool, int]
        ``(num_epochs, early_stop, early_stop_num)`` where:
        - ``num_epochs`` is the maximum epoch count to run
        - ``early_stop`` is whether early stopping is enabled
        - ``early_stop_num`` is the patience value to use
    """
    if epoch is not None:
        if isinstance(epoch, bool) or not isinstance(epoch, int) or epoch <= 0:
            raise ValueError(f"epoch must be a positive integer or None, got {epoch!r}")

    if early_stop_patience is not None:
        if (
            isinstance(early_stop_patience, bool)
            or not isinstance(early_stop_patience, int)
            or early_stop_patience <= 0
        ):
            raise ValueError(
                "early_stop_patience must be a positive integer or None, "
                f"got {early_stop_patience!r}"
            )

    if epoch is not None and early_stop_patience is not None:
        # Both provided: run up to `epoch` epochs with early stopping
        return epoch, True, early_stop_patience
    elif epoch is not None:
        # Only epoch: run exactly that many epochs, no early stopping
        return epoch, False, default_patience
    elif early_stop_patience is not None:
        # Only patience: early stop with default max epochs
        return default_max_epochs, True, early_stop_patience
    else:
        # Neither: default early stopping with default patience
        return default_max_epochs, True, default_patience


def _validate_n_draws(n_draws: int, *, param_name: str = "n_draws") -> int:
    """Validate replicated draw count parameters.

    Parameters
    ----------
    n_draws : int
        Number of replicated random draws.
    param_name : str
        Parameter name used in error messages.

    Returns
    -------
    int
        Validated draw count.
    """
    if isinstance(n_draws, bool) or not isinstance(n_draws, int) or n_draws <= 0:
        raise ValueError(f"'{param_name}' must be a positive integer, got {n_draws!r}")
    return n_draws


def _coerce_groups_array(
    groups: pd.Series | np.ndarray,
    *,
    n_samples: int,
    param_name: str,
) -> np.ndarray:
    """Normalize user/bundled groups to a 1D numpy array."""
    if isinstance(groups, pd.Series):
        arr = groups.to_numpy()
    elif isinstance(groups, np.ndarray):
        arr = groups
    else:
        raise TypeError(
            f"'{param_name}' must be a pandas Series, numpy ndarray, or None, "
            f"got {type(groups).__name__}"
        )

    if arr.ndim != 1:
        raise ValueError(
            f"'{param_name}' must be one-dimensional, got shape {arr.shape}"
        )

    if arr.shape[0] != n_samples:
        raise ValueError(
            f"'{param_name}' length ({arr.shape[0]}) must match number of "
            f"samples ({n_samples})"
        )

    return arr


def _validate_binary_groups(groups: np.ndarray, *, param_name: str) -> None:
    """Enforce v3.1 binary-group scope (<= 2 distinct classes)."""
    unique = pd.Series(groups).dropna().unique()
    if len(unique) > 2:
        raise ValueError(
            f"'{param_name}' has {len(unique)} classes, but SyNG-BTS supports only "
            "binary groups (at most 2 classes)."
        )


def _resolve_effective_groups(
    explicit_groups: pd.Series | np.ndarray | None,
    bundled_groups: pd.Series | None,
    *,
    n_samples: int,
    param_name: str,
) -> np.ndarray | None:
    """Resolve groups using explicit precedence: argument > bundled."""
    candidate = explicit_groups if explicit_groups is not None else bundled_groups
    if candidate is None:
        return None

    group_array = _coerce_groups_array(
        candidate,
        n_samples=n_samples,
        param_name=param_name,
    )
    _validate_binary_groups(group_array, param_name=param_name)
    return group_array


# =========================================================================
# Public API
# =========================================================================


[docs] def generate( data: pd.DataFrame | str | Path, *, name: str | None = None, groups: pd.Series | np.ndarray | None = None, new_size: int | list[int] = 500, model: str = "VAE1-10", apply_log: bool = True, batch_frac: float = 0.1, learning_rate: float = 0.0005, epoch: int | None = None, val_ratio: float = 0.2, early_stop_patience: int | None = None, off_aug: str | None = None, AE_head_num: int = 2, Gaussian_head_num: int = 9, use_scheduler: bool = False, step_size: int = 10, gamma: float = 0.5, cap: bool = False, random_seed: int = 123, CVAE_wide_network: bool = False, output_dir: str | Path | None = None, verbose: int | str = "minimal", ) -> SyngResult: """Train a deep generative model and generate synthetic data. This is the primary entry point for training a single model and generating synthetic samples. It replaces the legacy ``ApplyExperiment`` function. Parameters ---------- data : DataFrame, str, or Path Input data — a pandas DataFrame, a path to a CSV file, or the name of a bundled dataset (e.g. ``"SKCMPositive_4"``). name : str or None Short name for output filenames. Derived automatically when ``None``. groups : pd.Series, np.ndarray, or None Optional binary group labels. When provided, these labels take precedence over bundled dataset groups. new_size : int or list[int] Generation size. - If ``int``: generate exactly ``new_size`` samples. For grouped data, counts are split by the input group ratio and rounded to integers. - If ``list[int]``: explicit grouped counts ``[n_group_0, n_group_1]``. For grouped data, ``group_0`` is the base group used by :func:`create_labels` (first encountered group value) and ``group_1`` is the other group. model : str Model specification, e.g. ``"VAE1-10"`` (parsed into model type and kl_weight). apply_log : bool Apply ``log2(x + 1)`` preprocessing. batch_frac : float Batch size as a fraction of sample count. learning_rate : float Optimizer learning rate. epoch : int or None Fixed epoch count, or ``None`` for early stopping. The interaction between *epoch* and *early_stop_patience*: =========== ======================= ================================================== ``epoch`` ``early_stop_patience`` Behaviour =========== ======================= ================================================== ``None`` ``None`` Early stopping ON, patience=30, max 1 000 epochs ``None`` ``30`` Early stopping ON, patience=30, max 1 000 epochs ``500`` ``None`` Early stopping OFF, run exactly 500 epochs ``500`` ``30`` Early stopping ON, patience=30, max 500 epochs =========== ======================= ================================================== val_ratio : float Validation split ratio (AE family only). early_stop_patience : int or None Stop if loss does not improve for this many epochs. When ``None`` and ``epoch`` is also ``None``, defaults to ``30``. off_aug : str or None Offline augmentation: ``"AE_head"``, ``"Gaussian_head"``, or ``None``. AE_head_num : int Fold multiplier for AE-head augmentation. Gaussian_head_num : int Fold multiplier for Gaussian-head augmentation. use_scheduler : bool Enable learning-rate scheduler (AE family). step_size : int Scheduler step size. gamma : float Scheduler gamma. cap : bool Cap generated values to observed range. random_seed : int Random seed for reproducibility. CVAE_wide_network : bool Use wider encoder/decoder for CVAE (512→256→128→64 instead of 256→128→64). Suitable for high-dimensional data like RNA. Ignored for non-CVAE models (default: ``False``). output_dir : str, Path, or None If set, automatically save results to this directory. verbose : int or str Verbosity level for training output. - ``"silent"`` or ``0`` — no output during training. - ``"minimal"`` or ``1`` (default) — print only training summaries and early-stopping messages. - ``"detailed"`` or ``2`` — print per-epoch progress (epoch number, loss values, elapsed time, learning rate). Returns ------- SyngResult Rich result object containing generated data, loss log, reconstructed data (AE/VAE/CVAE), model state, and metadata. """ # --- 0. Resolve verbose level ---------------------------------------- verbose_level = _resolve_verbose(verbose) # --- 1. Prepare data (resolve, validate, convert, label) ------------- prep = _prepare_data(data=data, name=name, groups=groups, apply_log=apply_log) # --- 2. Parse model spec --------------------------------------------- modelname, recon_weight, kl_weight = _parse_model_spec(model) # --- 3. Compute new_size (group-balanced if needed) ------------------ effective_new_size = _compute_new_size(prep.orilabels, prep.n_samples, new_size) # --- 4. Train (orchestrate: early-stop, blur-label, aug, dispatch) --- trained, ctx = orchestrate_training( rawdata=prep.oridata, rawlabels=prep.orilabels, oriblurlabels=prep.oriblurlabels, modelname=modelname, kl_weight=kl_weight, reconstruction_term_weight=recon_weight, batch_frac=batch_frac, random_seed=random_seed, epoch=epoch, early_stop_patience=early_stop_patience, learning_rate=learning_rate, val_ratio=val_ratio, off_aug=off_aug, AE_head_num=AE_head_num, Gaussian_head_num=Gaussian_head_num, cap=cap, loss_fn="MSE", use_scheduler=use_scheduler, step_size=step_size, gamma=gamma, CVAE_wide_network=CVAE_wide_network, verbose=verbose_level, ) # --- 5. Infer -------------------------------------------------------- gen_data, recon_data = _infer_from_trained( trained, new_size=effective_new_size, ctx=ctx, cap=cap, ) # --- 6. Assemble SyngResult ------------------------------------------ original_groups: pd.Series | None = None if prep.effective_groups is not None: original_groups = pd.Series(prep.effective_groups, name="group").reset_index( drop=True ) result = _assemble_result( gen_data=gen_data, recon_data=recon_data, trained=trained, colnames=prep.colnames, modelname=modelname, model=model, dataname=prep.dataname, n_samples=prep.n_samples, num_epochs=ctx.num_epochs, random_seed=random_seed, kl_weight=kl_weight, reconstruction_term_weight=recon_weight, early_stop=ctx.early_stop, early_stop_num=ctx.early_stop_num, apply_log=prep.apply_log, original_data=prep.df.copy(), original_groups=original_groups, group_mapping=prep.group_mapping, ) if output_dir is not None: result.save(output_dir) return result
[docs] def pilot_study( data: pd.DataFrame | str | Path, pilot_size: list[int], *, name: str | None = None, groups: pd.Series | np.ndarray | None = None, n_draws: int = 5, model: str = "VAE1-10", apply_log: bool = True, batch_frac: float = 0.1, learning_rate: float = 0.0005, epoch: int | None = None, early_stop_patience: int | None = None, off_aug: str | None = None, AE_head_num: int = 2, Gaussian_head_num: int = 9, random_seed: int = 123, CVAE_wide_network: bool = False, output_dir: str | Path | None = None, verbose: int | str = "minimal", ) -> PilotResult: """Sweep over pilot sizes with replicated random draws. For each pilot size, *n_draws* random sub-samples are drawn from the original data. A model is trained on each sub-sample and synthetic data equal to *n_draws* times the sub-sample size is generated. This replaces the legacy ``PilotExperiment`` function. Parameters ---------- data : DataFrame, str, or Path Input data. pilot_size : list[int] List of pilot sizes to evaluate. name : str or None Short name for output filenames. groups : pd.Series, np.ndarray, or None Optional binary group labels. When provided, these labels take precedence over bundled dataset groups. n_draws : int Number of replicated random draws per pilot size (default: 5). Must be a positive integer. model : str Model specification (e.g. ``"VAE1-10"``). apply_log : bool Apply ``log2(x + 1)`` preprocessing. batch_frac : float Batch size as a fraction of sample count. learning_rate : float Optimizer learning rate. epoch : int or None Fixed epoch count or ``None`` for early stopping. See :func:`generate` for the full interaction table. early_stop_patience : int or None Stop if loss does not improve for this many epochs. When ``None`` and ``epoch`` is also ``None``, defaults to ``30``. See :func:`generate` for the full interaction table. off_aug : str or None Offline augmentation mode. AE_head_num : int Fold multiplier for AE-head augmentation. Gaussian_head_num : int Fold multiplier for Gaussian-head augmentation. random_seed : int Base random seed for reproducibility. CVAE_wide_network : bool Use wider encoder/decoder for CVAE — see :func:`generate`. output_dir : str, Path, or None If set, automatically save results to this directory. verbose : int or str Verbosity level — see :func:`generate` for details. Returns ------- PilotResult Wrapper containing one ``SyngResult`` per (pilot_size, draw). """ n_draws = _validate_n_draws(n_draws, param_name="n_draws") # --- 0. Resolve verbose level ---------------------------------------- verbose_level = _resolve_verbose(verbose) # --- 1. Prepare data (resolve, validate, convert, label) ------------- prep = _prepare_data(data=data, name=name, groups=groups, apply_log=apply_log) # --- 2. Parse model spec --------------------------------------------- modelname, recon_weight, kl_weight = _parse_model_spec(model) # --- 3. Pilot loop --------------------------------------------------- runs: dict[tuple[int, int], SyngResult] = {} last_ctx: TrainingContext | None = None # Calculate total number of runs for progress logging total_runs = len(pilot_size) * n_draws current_run = 0 for n_pilot in pilot_size: for rand_pilot in range(1, n_draws + 1): current_run += 1 # Log progress before training (if verbosity >= MINIMAL) if verbose_level >= VerbosityLevel.MINIMAL: print( f"[Pilot size {n_pilot}] Draw {rand_pilot}/{n_draws} " f"(training no. {current_run}/{total_runs})" ) # Draw pilot sub-sample rawdata, rawlabels, rawblurlabels, pilot_indices = draw_pilot( dataset=prep.oridata, labels=prep.orilabels, blurlabels=prep.oriblurlabels, n_pilot=n_pilot, seednum=rand_pilot, ) # new_size for this pilot (group-balanced if needed) effective_new_size = _compute_new_size( rawlabels, int(rawdata.shape[0]), n_draws * n_pilot, ) # Train (orchestrate: early-stop, blur-label, aug, dispatch) trained, ctx = orchestrate_training( rawdata=rawdata, rawlabels=rawlabels, oriblurlabels=rawblurlabels, modelname=modelname, kl_weight=kl_weight, reconstruction_term_weight=recon_weight, batch_frac=batch_frac, random_seed=random_seed, epoch=epoch, early_stop_patience=early_stop_patience, learning_rate=learning_rate, off_aug=off_aug, AE_head_num=AE_head_num, Gaussian_head_num=Gaussian_head_num, CVAE_wide_network=CVAE_wide_network, verbose=verbose_level, ) last_ctx = ctx gen_data, recon_data = _infer_from_trained( trained, new_size=effective_new_size, ctx=ctx, ) # -- Assemble SyngResult for this run ------------------------- pilot_original = prep.df.iloc[pilot_indices.numpy()].copy() pilot_groups: pd.Series | None = None if prep.effective_groups is not None: pilot_groups = pd.Series( prep.effective_groups[pilot_indices.numpy()], name="group" ).reset_index(drop=True) runs[(n_pilot, rand_pilot)] = _assemble_result( gen_data=gen_data, recon_data=recon_data, trained=trained, colnames=prep.colnames, modelname=modelname, model=model, dataname=prep.dataname, n_samples=prep.n_samples, num_epochs=ctx.num_epochs, random_seed=random_seed, kl_weight=kl_weight, reconstruction_term_weight=recon_weight, early_stop=ctx.early_stop, early_stop_num=ctx.early_stop_num, apply_log=prep.apply_log, original_data=pilot_original, original_groups=pilot_groups, group_mapping=prep.group_mapping, extra_metadata={ "pilot_size": n_pilot, "draw": rand_pilot, "pilot_indices": pilot_indices.tolist(), }, ) # Newline after each run for readability if verbose_level >= VerbosityLevel.MINIMAL: print() # Resolve num_epochs for PilotResult metadata — use last training # context if available, otherwise resolve defaults directly. if last_ctx is not None: resolved_num_epochs = last_ctx.num_epochs else: resolved_num_epochs, _, _ = _resolve_early_stopping_config( epoch=epoch, early_stop_patience=early_stop_patience, ) # --- 4. Assemble PilotResult ----------------------------------------- pilot_result = PilotResult( runs=runs, original_data=prep.df.copy(), metadata={ "model": model, "modelname": modelname, "dataname": prep.dataname, "pilot_sizes": pilot_size, "num_epochs": resolved_num_epochs, "seed": random_seed, }, ) if output_dir is not None: pilot_result.save(output_dir) return pilot_result
[docs] def transfer( source_data: pd.DataFrame | str | Path, target_data: pd.DataFrame | str | Path, *, source_name: str | None = None, target_name: str | None = None, source_groups: pd.Series | np.ndarray | None = None, target_groups: pd.Series | np.ndarray | None = None, new_size: int | list[int] = 500, model: str = "VAE1-10", apply_log: bool = True, batch_frac: float = 0.1, learning_rate: float = 0.0005, epoch: int | None = None, early_stop_patience: int | None = None, off_aug: str | None = None, AE_head_num: int = 2, Gaussian_head_num: int = 9, random_seed: int = 123, CVAE_wide_network: bool = False, output_dir: str | Path | None = None, verbose: int | str = "minimal", ) -> SyngResult: """Train on source data, then fine-tune and generate on target data. The model is first trained on *source_data* and its learned state is kept in-memory, then fine-tuned on *target_data*. This is a single-run operation returning a :class:`SyngResult`. This replaces the legacy ``TransferExperiment`` function. Parameters ---------- source_data : DataFrame, str, or Path Pre-training dataset. target_data : DataFrame, str, or Path Fine-tuning / target dataset. source_name : str or None Short name for the source dataset. target_name : str or None Short name for the target dataset. source_groups : pd.Series, np.ndarray, or None Optional binary groups for the source dataset. target_groups : pd.Series, np.ndarray, or None Optional binary groups for the target dataset. new_size : int or list[int] Generation size for the fine-tuned target model. - If ``int``: generate exactly ``new_size`` samples. For grouped data, counts are split by the target input group ratio and rounded to integers. - If ``list[int]``: explicit grouped counts ``[n_group_0, n_group_1]``. For grouped data, ``group_0`` is the base group used by :func:`create_labels` (first encountered group value) and ``group_1`` is the other group. model : str Model specification. apply_log : bool Apply log2 preprocessing. batch_frac : float Batch fraction. learning_rate : float Learning rate. epoch : int or None Fixed epoch count, or ``None`` for early stopping. See :func:`generate` for the full interaction table. early_stop_patience : int or None Stop if loss does not improve for this many epochs. When ``None`` and ``epoch`` is also ``None``, defaults to ``30``. See :func:`generate` for the full interaction table. off_aug : str or None Offline augmentation mode. AE_head_num : int Fold multiplier for AE-head augmentation. Gaussian_head_num : int Fold multiplier for Gaussian-head augmentation. random_seed : int Random seed. CVAE_wide_network : bool Use wider encoder/decoder for CVAE — see :func:`generate`. output_dir : str, Path, or None If set, save results here. verbose : int or str Verbosity level — see :func:`generate` for details. Returns ------- SyngResult Result from the fine-tuned target-phase model. """ # --- 0. Resolve verbose level ---------------------------------------- verbose_level = _resolve_verbose(verbose) # --- 1. Prepare source and target data ------------------------------- source_prep = _prepare_data( data=source_data, name=source_name, groups=source_groups, apply_log=apply_log, ) target_prep = _prepare_data( data=target_data, name=target_name, groups=target_groups, apply_log=apply_log, ) # --- 2. Parse model spec --------------------------------------------- modelname, recon_weight, kl_weight = _parse_model_spec(model) # --- 3. Pre-train on source ------------------------------------------ source_trained, _source_ctx = orchestrate_training( rawdata=source_prep.oridata, rawlabels=source_prep.orilabels, oriblurlabels=source_prep.oriblurlabels, modelname=modelname, kl_weight=kl_weight, reconstruction_term_weight=recon_weight, batch_frac=batch_frac, random_seed=random_seed, epoch=epoch, early_stop_patience=early_stop_patience, learning_rate=learning_rate, off_aug=off_aug, AE_head_num=AE_head_num, Gaussian_head_num=Gaussian_head_num, CVAE_wide_network=CVAE_wide_network, verbose=verbose_level, ) source_model_state = source_trained.model_state # --- 4. Fine-tune on target ------------------------------------------ effective_new_size = _compute_new_size( target_prep.orilabels, target_prep.n_samples, new_size, ) target_trained, target_ctx = orchestrate_training( rawdata=target_prep.oridata, rawlabels=target_prep.orilabels, oriblurlabels=target_prep.oriblurlabels, modelname=modelname, kl_weight=kl_weight, reconstruction_term_weight=recon_weight, batch_frac=batch_frac, random_seed=random_seed, epoch=epoch, early_stop_patience=early_stop_patience, learning_rate=learning_rate, off_aug=off_aug, AE_head_num=AE_head_num, Gaussian_head_num=Gaussian_head_num, CVAE_wide_network=CVAE_wide_network, model_state=source_model_state, verbose=verbose_level, ) # --- 5. Infer -------------------------------------------------------- gen_data, recon_data = _infer_from_trained( target_trained, new_size=effective_new_size, ctx=target_ctx, ) # --- 6. Assemble SyngResult ------------------------------------------ target_original_groups: pd.Series | None = None if target_prep.effective_groups is not None: target_original_groups = pd.Series( target_prep.effective_groups, name="group" ).reset_index(drop=True) result = _assemble_result( gen_data=gen_data, recon_data=recon_data, trained=target_trained, colnames=target_prep.colnames, modelname=modelname, model=model, dataname=target_prep.dataname, n_samples=target_prep.n_samples, num_epochs=target_ctx.num_epochs, random_seed=random_seed, kl_weight=kl_weight, reconstruction_term_weight=recon_weight, early_stop=target_ctx.early_stop, early_stop_num=target_ctx.early_stop_num, apply_log=target_prep.apply_log, original_data=target_prep.df.copy(), original_groups=target_original_groups, group_mapping=target_prep.group_mapping, ) if output_dir is not None: result.save(output_dir) return result