Source code for syng_bts.synthesize

"""SyntheSize integration — sample-size evaluation via classifier learning curves.

This module provides classifier-based evaluation of synthetic data across
candidate sample sizes, using stratified cross-validation and inverse
power-law curve fitting.

Public API
----------
- :func:`evaluate_sample_sizes` — Evaluate classifiers across candidate sample
  sizes using stratified cross-validation.
- :func:`plot_sample_sizes` — Visualize IPLF learning curves from evaluation
  metrics.

References
----------
- SyntheSize (R): https://github.com/LXQin/SyntheSize
- SyntheSize (Python): https://github.com/LXQin/SyntheSize_py
"""

from __future__ import annotations

import inspect
from collections.abc import Callable
from numbers import Integral
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.optimize import approx_fprime, curve_fit
from scipy.stats import norm
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegressionCV
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from sklearn.model_selection import StratifiedKFold
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import scale
from sklearn.svm import SVC
from xgboost import DMatrix
from xgboost import train as xgb_train

from .helper_train import VerbosityLevel, _resolve_verbose

if TYPE_CHECKING:
    from .result import SyngResult

# ---------------------------------------------------------------------------
# Private classifier helpers
# ---------------------------------------------------------------------------


def _logis(
    train_data: np.ndarray,
    train_labels: np.ndarray,
    test_data: np.ndarray,
    test_labels: np.ndarray,
) -> dict[str, float]:
    """Ridge (L2-penalised) logistic regression classifier."""
    model_kwargs: dict[str, object] = {
        "Cs": 10,
        "cv": 5,
        "solver": "liblinear",
        "scoring": "accuracy",
        "random_state": 0,
        "max_iter": 1000,
    }

    lr_params = inspect.signature(LogisticRegressionCV).parameters
    if "l1_ratios" in lr_params:
        model_kwargs["l1_ratios"] = (0,)
    elif "penalty" in lr_params:
        model_kwargs["penalty"] = "l2"

    if "use_legacy_attributes" in lr_params:
        model_kwargs["use_legacy_attributes"] = False

    model = LogisticRegressionCV(**model_kwargs)
    model.fit(train_data, train_labels)

    predictions_proba = model.predict_proba(test_data)
    predictions = model.predict(test_data)

    if predictions_proba.shape[1] == 2:
        auc = roc_auc_score(test_labels, predictions_proba[:, 1])
    else:
        auc = roc_auc_score(
            test_labels, predictions_proba, multi_class="ovo", average="macro"
        )

    return {
        "f1": f1_score(test_labels, predictions, average="macro"),
        "accuracy": accuracy_score(test_labels, predictions),
        "auc": auc,
    }


def _svm(
    train_data: np.ndarray,
    train_labels: np.ndarray,
    test_data: np.ndarray,
    test_labels: np.ndarray,
) -> dict[str, float]:
    """Support Vector Machine classifier."""
    model = SVC(probability=True)
    model.fit(train_data, train_labels)

    predictions_proba = model.predict_proba(test_data)
    predictions = model.predict(test_data)

    if predictions_proba.shape[1] == 2:
        auc = roc_auc_score(test_labels, predictions_proba[:, 1])
    else:
        auc = roc_auc_score(
            test_labels, predictions_proba, multi_class="ovo", average="macro"
        )

    return {
        "f1": f1_score(test_labels, predictions, average="macro"),
        "accuracy": accuracy_score(test_labels, predictions),
        "auc": auc,
    }


def _knn(
    train_data: np.ndarray,
    train_labels: np.ndarray,
    test_data: np.ndarray,
    test_labels: np.ndarray,
) -> dict[str, float]:
    """K-Nearest Neighbors classifier."""
    model = KNeighborsClassifier(n_neighbors=5)
    model.fit(train_data, train_labels)

    predictions_proba = model.predict_proba(test_data)
    predictions = model.predict(test_data)

    if predictions_proba.shape[1] == 2:
        auc = roc_auc_score(test_labels, predictions_proba[:, 1])
    else:
        auc = roc_auc_score(
            test_labels, predictions_proba, multi_class="ovo", average="macro"
        )

    return {
        "f1": f1_score(test_labels, predictions, average="macro"),
        "accuracy": accuracy_score(test_labels, predictions),
        "auc": auc,
    }


def _rf(
    train_data: np.ndarray,
    train_labels: np.ndarray,
    test_data: np.ndarray,
    test_labels: np.ndarray,
) -> dict[str, float]:
    """Random Forest classifier."""
    model = RandomForestClassifier(n_estimators=100)
    model.fit(train_data, train_labels)

    predictions_proba = model.predict_proba(test_data)
    predictions = model.predict(test_data)

    if predictions_proba.shape[1] == 2:
        auc = roc_auc_score(test_labels, predictions_proba[:, 1])
    else:
        auc = roc_auc_score(
            test_labels, predictions_proba, multi_class="ovo", average="macro"
        )

    return {
        "f1": f1_score(test_labels, predictions, average="macro"),
        "accuracy": accuracy_score(test_labels, predictions),
        "auc": auc,
    }


def _xgb(
    train_data: np.ndarray,
    train_labels: np.ndarray,
    test_data: np.ndarray,
    test_labels: np.ndarray,
) -> dict[str, float]:
    """XGBoost classifier."""
    num_class = len(np.unique(train_labels))
    dtrain = DMatrix(train_data, label=train_labels)
    dtest = DMatrix(test_data, label=test_labels)

    if num_class == 2:
        params = {"objective": "binary:logistic", "eval_metric": "auc"}
    else:
        params = {
            "objective": "multi:softprob",
            "num_class": num_class,
            "eval_metric": "mlogloss",
        }

    bst = xgb_train(params, dtrain, num_boost_round=10)
    predictions_proba = bst.predict(dtest)

    if predictions_proba.ndim == 1:
        predictions = (predictions_proba > 0.5).astype(int)
        auc = roc_auc_score(test_labels, predictions_proba)
    else:
        predictions = np.argmax(predictions_proba, axis=1)
        auc = roc_auc_score(
            test_labels, predictions_proba, multi_class="ovo", average="macro"
        )

    return {
        "f1": f1_score(test_labels, predictions, average="macro"),
        "accuracy": accuracy_score(test_labels, predictions),
        "auc": auc,
    }


# Map canonical method names to private classifier callables
_CLASSIFIER_MAP: dict[
    str, Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray], dict[str, float]]
] = {
    "LOGIS": _logis,
    "SVM": _svm,
    "KNN": _knn,
    "RF": _rf,
    "XGB": _xgb,
}

# Common aliases (case-insensitive lookup via upper())
_METHOD_ALIASES: dict[str, str] = {
    "LOGIS": "LOGIS",
    "LOGISTIC": "LOGIS",
    "LR": "LOGIS",
    "SVM": "SVM",
    "KNN": "KNN",
    "RF": "RF",
    "RANDOM_FOREST": "RF",
    "XGB": "XGB",
    "XGBOOST": "XGB",
}


def _print_eval_progress(
    step: int,
    total_steps: int,
    size_index: int,
    n_sizes: int,
    n: int,
    draw: int,
    method: str,
) -> None:
    """Print a single ``\\r``-overwritten progress line (MINIMAL verbosity).

    Format::

        Progress |████░░░░░░░░░░░░░░░░| 3/10 size=1/3 (n=50), draw=1, method=RF
    """
    pct = step / total_steps
    bar_len = 20
    filled = int(bar_len * pct)
    bar = "\u2588" * filled + "\u2591" * (bar_len - filled)
    print(
        f"\rProgress |{bar}| {step}/{total_steps} "
        f"size={size_index + 1}/{n_sizes} (n={n}), "
        f"draw={draw}, method={method}",
        end="",
        flush=True,
    )


def _resolve_methods(methods: list[str] | None) -> list[str]:
    """Resolve and validate classifier method names, accepting aliases."""
    if methods is None:
        return ["LOGIS", "SVM", "KNN", "RF", "XGB"]
    resolved: list[str] = []
    for m in methods:
        canonical = _METHOD_ALIASES.get(m.upper())
        if canonical is None:
            raise ValueError(
                f"Unknown classifier method: {m!r}. "
                f"Valid options: {sorted(set(_METHOD_ALIASES.values()))}"
            )
        resolved.append(canonical)
    return resolved


def _resolve_data_and_groups(
    data: pd.DataFrame | SyngResult,
    groups: np.ndarray | pd.Series | list | None,
    which: str,
) -> tuple[pd.DataFrame, np.ndarray | pd.Series]:
    """Resolve data and groups from a DataFrame or SyngResult.

    Parameters
    ----------
    data : pd.DataFrame or SyngResult
        Input data source.
    groups : array-like or None
        Explicit group labels. Required when *data* is a DataFrame.
        When provided alongside a SyngResult, overrides auto-resolved groups.
    which : str
        Selector for SyngResult fields: ``"generated"``, ``"original"``,
        or ``"reconstructed"``.

    Returns
    -------
    tuple[pd.DataFrame, np.ndarray | pd.Series]
        Resolved (features, group_labels) pair.
    """
    from .result import SyngResult

    if isinstance(data, SyngResult):
        valid_which = ("generated", "original", "reconstructed")
        if which not in valid_which:
            raise ValueError(
                f"Invalid 'which' value: {which!r}. Must be one of {valid_which}."
            )
        if which == "generated":
            resolved_data = data.generated_data
            resolved_groups = data.generated_groups
        elif which == "original":
            if data.original_data is None:
                raise ValueError("SyngResult has no original_data.")
            resolved_data = data.original_data
            resolved_groups = data.original_groups
        else:  # reconstructed
            if data.reconstructed_data is None:
                raise ValueError("SyngResult has no reconstructed_data.")
            resolved_data = data.reconstructed_data
            resolved_groups = data.reconstructed_groups

        # Allow explicit groups to override auto-resolved groups
        if groups is not None:
            resolved_groups = groups

        if resolved_groups is None:
            raise ValueError(
                f"SyngResult has no {which}_groups and no explicit 'groups' provided."
            )
        return resolved_data, resolved_groups

    if isinstance(data, pd.DataFrame):
        if groups is None:
            raise ValueError("'groups' is required when 'data' is a DataFrame.")
        return data, groups

    raise TypeError(
        f"'data' must be a pd.DataFrame or SyngResult, got {type(data).__name__}"
    )


def _allocate_stratified_counts(
    total_size: int,
    group_counts: dict[str, int],
) -> dict[str, int]:
    """Allocate per-group sample counts with largest-remainder rounding.

    Produces integer counts that sum to *total_size* and do not exceed each
    group's available count.
    """
    total_available = sum(group_counts.values())
    if total_size > total_available:
        raise ValueError(
            f"Requested sample size {total_size} exceeds available rows "
            f"({total_available})."
        )

    groups = list(group_counts.keys())
    raw = {
        group: (total_size * group_counts[group] / total_available) for group in groups
    }
    allocated = {
        group: min(int(np.floor(raw[group])), group_counts[group]) for group in groups
    }

    remaining = total_size - sum(allocated.values())
    if remaining > 0:
        order = sorted(
            groups,
            key=lambda group: raw[group] - allocated[group],
            reverse=True,
        )
        while remaining > 0:
            progressed = False
            for group in order:
                if allocated[group] < group_counts[group]:
                    allocated[group] += 1
                    remaining -= 1
                    progressed = True
                    if remaining == 0:
                        break
            if not progressed:
                break

    if sum(allocated.values()) != total_size:
        raise ValueError(
            "Could not allocate stratified sample counts that sum to the "
            f"requested size {total_size}."
        )

    return allocated


# ---------------------------------------------------------------------------
# Curve fitting helpers
# ---------------------------------------------------------------------------


def _power_law(x: float, a: float, b: float, c: float) -> float:
    """Inverse power-law function: ``(1 - a) - b * x^c``."""
    return (1 - a) - (b * (x**c))


def _fit_curve(
    acc_table: pd.DataFrame,
    metric_name: str,
    n_target: int | list | None = None,
    plot: bool = True,
    ax: plt.Axes | None = None,
    annotation: str = "",
) -> plt.Axes | None:
    """Fit an inverse power-law curve to evaluation metrics.

    Parameters
    ----------
    acc_table : pd.DataFrame
        Must contain columns ``"n"`` and *metric_name*.
    metric_name : str
        Column in *acc_table* to fit against.
    n_target : int, list, or None
        Unused in this implementation (reserved for future extrapolation).
    plot : bool
        Whether to create a plot.
    ax : matplotlib Axes or None
        Axes to draw on; a new figure is created when ``None``.
    annotation : str
        Subplot title.

    Returns
    -------
    matplotlib Axes or None
    """
    acc_table = acc_table.copy()
    initial_params = [0, 1, -0.5]
    max_iterations = 50000
    fit_ok = False

    try:
        popt, pcov = curve_fit(
            _power_law,
            acc_table["n"],
            acc_table[metric_name],
            p0=initial_params,
            maxfev=max_iterations,
        )

        acc_table["predicted"] = _power_law(acc_table["n"], *popt)

        # Confidence intervals via delta method
        epsilon = np.sqrt(np.finfo(float).eps)
        jacobian = np.empty((len(acc_table["n"]), len(popt)))
        for i, x in enumerate(acc_table["n"]):
            jacobian[i] = approx_fprime(
                [x], lambda x_: _power_law(x_[0], *popt), epsilon
            )
        pred_var = np.sum((jacobian @ pcov) * jacobian, axis=1)
        pred_std = np.sqrt(pred_var)
        t = norm.ppf(0.975)
        acc_table["ci_low"] = acc_table["predicted"] - t * pred_std
        acc_table["ci_high"] = acc_table["predicted"] + t * pred_std
        fit_ok = True
    except (RuntimeError, ValueError):
        fit_ok = False

    if plot:
        if ax is None:
            _, ax = plt.subplots(figsize=(10, 6))

        ax.scatter(
            acc_table["n"],
            acc_table[metric_name],
            label="Actual Data",
            color="red",
        )
        if fit_ok:
            ax.plot(
                acc_table["n"],
                acc_table["predicted"],
                label="Fitted",
                color="blue",
                linestyle="--",
            )
            ax.fill_between(
                acc_table["n"],
                acc_table["ci_low"],
                acc_table["ci_high"],
                color="blue",
                alpha=0.2,
                label="95% CI",
            )
        ax.set_xlabel("Sample Size")
        ax.legend(loc="best")
        ax.set_title(annotation)
        ax.set_ylim(0.4, 1)
        return ax

    return None


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------


[docs] def evaluate_sample_sizes( data: pd.DataFrame | SyngResult, sample_sizes: list[int] | np.ndarray | pd.Series | int, groups: np.ndarray | pd.Series | list | None = None, which: str = "generated", n_draws: int = 5, apply_log: bool = True, methods: list[str] | None = None, verbose: int | str = "minimal", ) -> pd.DataFrame: r"""Evaluate classifiers across candidate sample sizes. For each classifier and each candidate sample size, performs *n_draws* rounds of stratified sampling (proportional to class distribution), applies 5-fold cross-validation, and averages metrics across folds. Parameters ---------- data : pd.DataFrame or SyngResult The dataset to evaluate. When a :class:`~syng_bts.result.SyngResult` is provided, the *which* parameter selects the data attribute and groups are auto-resolved from the corresponding ``*_groups`` field. sample_sizes : list[int], np.ndarray, pd.Series, or int Candidate sample sizes to evaluate. Accepts a list, numpy array, or pandas Series of positive integers. When a **single int** is provided it is interpreted as the *number* of equidistant sizes to create — the maximum equals the number of data rows. For example, ``sample_sizes=3`` with 15-row data produces ``[5, 10, 15]``. groups : array-like or None Class labels corresponding to the rows of *data*. **Required** when *data* is a ``pd.DataFrame``. When provided alongside a ``SyngResult``, overrides the auto-resolved groups. which : str, default ``"generated"`` Selector when *data* is a ``SyngResult``: ``"generated"``, ``"original"``, or ``"reconstructed"``. n_draws : int, default 5 Number of resampling repetitions for each sample size. apply_log : bool, default True When ``True``, a ``log2(x + 1)`` transform is applied to the data before evaluation. methods : list[str] or None Classifier names to evaluate. Accepts canonical names (``'LOGIS'``, ``'SVM'``, ``'KNN'``, ``'RF'``, ``'XGB'``) and common aliases (``'LOGISTIC'``, ``'LR'``, ``'RANDOM_FOREST'``, ``'XGBOOST'``). Defaults to all five classifiers. verbose : int or str, default "minimal" Controls output verbosity. Accepts ``0`` / ``"silent"`` (no output), ``1`` / ``"minimal"`` (one dynamic overall progress bar across all sample sizes, draws, and methods), or ``2`` / ``"detailed"`` (per-draw/method metric lines). Returns ------- pd.DataFrame Columns: ``total_size``, ``draw``, ``method``, ``f1_score``, ``accuracy``, ``auc``. Raises ------ TypeError If *data* is not a ``pd.DataFrame`` or ``SyngResult``. ValueError If *groups* is missing when required, *which* is invalid, *methods* contains unknown names, *sample_sizes* is empty or contains non-positive values, or any sample size exceeds the number of available rows. Examples -------- Using a DataFrame: >>> df = pd.read_csv("mydata.csv") >>> groups = df.pop("group") >>> result = evaluate_sample_sizes(df, sample_sizes=[50, 100], groups=groups) Using a SyngResult: >>> from syng_bts import generate >>> sr = generate(data="BRCASubtypeSel_test", model="CVAE1-20", epoch=10) >>> result = evaluate_sample_sizes(sr, sample_sizes=[50], which="generated") """ # --- Resolve verbose level --- verbose_level = _resolve_verbose(verbose) # --- Resolve data and groups --- resolved_data, resolved_groups = _resolve_data_and_groups(data, groups, which) # --- Validate data shape/content --- if resolved_data.shape[0] == 0 or resolved_data.shape[1] == 0: raise ValueError("'data' must have at least 1 row and 1 column.") non_numeric_cols = [ col for col in resolved_data.columns if not pd.api.types.is_numeric_dtype(resolved_data[col]) ] if non_numeric_cols: raise ValueError( "'data' must contain only numeric columns; non-numeric columns: " f"{non_numeric_cols}" ) group_arr = np.asarray(resolved_groups) if group_arr.ndim != 1: raise ValueError("'groups' must be one-dimensional.") if len(group_arr) != len(resolved_data): raise ValueError( "Length mismatch: 'groups' must have one label per data row " f"(groups={len(group_arr)}, rows={len(resolved_data)})." ) if len(group_arr) == 0: raise ValueError("'groups' must be non-empty.") unique_labels = np.unique(group_arr.astype(str)) if len(unique_labels) < 2: raise ValueError("At least two unique groups are required for evaluation.") # --- Resolve and validate methods --- resolved_methods = _resolve_methods(methods) # --- Normalise sample_sizes to list[int] --- n_rows = len(resolved_data) if isinstance(sample_sizes, (np.ndarray, pd.Series)): sample_sizes = sample_sizes.tolist() # type: ignore[assignment] if isinstance(sample_sizes, (int, np.integer)) and not isinstance( sample_sizes, bool ): k = int(sample_sizes) if k <= 0: raise ValueError(f"'sample_sizes' as int must be positive, got {k}.") sample_sizes = np.round(np.linspace(n_rows / k, n_rows, k)).astype(int).tolist() if not sample_sizes: raise ValueError("'sample_sizes' must be a non-empty list of integers.") normalized_sample_sizes: list[int] = [] for s in sample_sizes: if isinstance(s, bool) or not isinstance(s, Integral) or int(s) <= 0: raise ValueError(f"All sample sizes must be positive integers, got {s!r}.") normalized_sample_sizes.append(int(s)) for s in normalized_sample_sizes: if s > n_rows: raise ValueError(f"Sample size {s} exceeds available rows ({n_rows}).") # --- Validate n_draws --- if not isinstance(n_draws, int) or n_draws < 1: raise ValueError(f"'n_draws' must be a positive integer, got {n_draws!r}.") n_splits = 5 # --- Apply log transform if requested --- if apply_log: resolved_data = np.log2(resolved_data + 1) # Ensure float64 before sklearn scaling to avoid float32 numerical-warning # spam on high-range expression data. if (resolved_data.dtypes == np.float32).any(): resolved_data = resolved_data.astype(np.float64) # Encode groups as integer labels group_arr = np.array([str(item) for item in group_arr]) unique_groups = np.unique(group_arr) group_dict = {g: i for i, g in enumerate(unique_groups)} labels = np.array([group_dict[g] for g in group_arr]) # Compute class proportions and per-group indices group_counts = {g: int(np.sum(group_arr == g)) for g in unique_groups} group_indices_dict = {g: np.where(group_arr == g)[0] for g in unique_groups} # Feasibility checks per requested sample size for stratified 5-fold CV for s in normalized_sample_sizes: if s < n_splits * len(unique_groups): raise ValueError( "Sample size is too small for 5-fold stratified CV across all " f"classes: n={s}, classes={len(unique_groups)}, minimum=" f"{n_splits * len(unique_groups)}." ) counts = _allocate_stratified_counts(s, group_counts) too_small_groups = [group for group, c in counts.items() if c < n_splits] if too_small_groups: raise ValueError( "Sample size yields too few samples per class for 5-fold " "stratified CV. Increase sample size or reduce class imbalance. " f"n={s}, groups={too_small_groups}." ) results: list[dict] = [] total_steps_overall = len(normalized_sample_sizes) * n_draws * len(resolved_methods) overall_step_counter = 0 for n_index, n in enumerate(normalized_sample_sizes): if verbose_level >= VerbosityLevel.DETAILED: print( f"\nRunning sample size index " f"{n_index + 1}/{len(normalized_sample_sizes)} (n = {n})\n" ) for draw in range(n_draws): # Stratified subsample indices: list[int] = [] allocation = _allocate_stratified_counts(n, group_counts) for g in unique_groups: n_g = allocation[g] selected = np.random.choice(group_indices_dict[g], n_g, replace=False) indices.extend(selected) idx = np.array(indices) dat_candidate = resolved_data.iloc[idx].values labels_candidate = labels[idx] skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42) # Accumulate per-fold metrics per classifier metrics: dict[str, dict[str, list]] = { method: {"f1": [], "accuracy": [], "auc": []} for method in resolved_methods } for train_index, test_index in skf.split(dat_candidate, labels_candidate): train_data = dat_candidate[train_index] test_data = dat_candidate[test_index] train_labels = labels_candidate[train_index] test_labels = labels_candidate[test_index] # Scale non-zero-variance features non_zero_std = train_data.std(axis=0) != 0 train_data[:, non_zero_std] = scale(train_data[:, non_zero_std]) test_data[:, non_zero_std] = scale(test_data[:, non_zero_std]) for method in resolved_methods: clf_func = _CLASSIFIER_MAP[method] res = clf_func(train_data, train_labels, test_data, test_labels) metrics[method]["f1"].append(res["f1"]) metrics[method]["accuracy"].append(res["accuracy"]) metrics[method]["auc"].append(res["auc"]) for method in resolved_methods: mean_f1 = float(np.mean(metrics[method]["f1"])) mean_acc = float(np.mean(metrics[method]["accuracy"])) mean_auc = float(np.mean(metrics[method]["auc"])) overall_step_counter += 1 if verbose_level == VerbosityLevel.MINIMAL: _print_eval_progress( step=overall_step_counter, total_steps=total_steps_overall, size_index=n_index, n_sizes=len(normalized_sample_sizes), n=n, draw=draw, method=method, ) elif verbose_level >= VerbosityLevel.DETAILED: print( f"[n={n}, draw={draw}, method={method}] " f"F1: {mean_f1:.4f}, Acc: {mean_acc:.4f}, " f"AUC: {mean_auc:.4f}" ) results.append( { "total_size": n, "draw": draw, "method": method, "f1_score": mean_f1, "accuracy": mean_acc, "auc": mean_auc, } ) if verbose_level == VerbosityLevel.MINIMAL: print() # move past final \r line return pd.DataFrame(results)
[docs] def plot_sample_sizes( metric_real: pd.DataFrame, n_target: int | list, metric_generated: pd.DataFrame | None = None, metric_name: str = "f1_score", ) -> plt.Figure: r"""Visualize IPLF learning curves fitted from evaluation metrics. Fits inverse power-law curves to the evaluation metrics produced by :func:`evaluate_sample_sizes` and plots observed values, fitted curves, and 95% confidence intervals. The returned figure is never displayed automatically — call ``fig.savefig(...)`` or ``plt.show()`` explicitly to display or save. Parameters ---------- metric_real : pd.DataFrame Metrics from :func:`evaluate_sample_sizes` on real data. n_target : int or list Target sample sizes for extrapolation reference. metric_generated : pd.DataFrame or None Metrics from :func:`evaluate_sample_sizes` on generated data. When provided, a second column of panels is added. metric_name : str, default ``"f1_score"`` Metric to visualize (``"f1_score"``, ``"accuracy"``, or ``"auc"``). Returns ------- matplotlib.figure.Figure The figure containing the learning-curve panels. Examples -------- >>> metrics = evaluate_sample_sizes(df, [50, 100, 200], groups=g) >>> fig = plot_sample_sizes(metrics, n_target=300) >>> fig.savefig("learning_curves.png") """ valid_metric_names = {"f1_score", "accuracy", "auc"} if metric_name not in valid_metric_names: raise ValueError( f"Invalid metric_name {metric_name!r}. " f"Valid options: {sorted(valid_metric_names)}" ) required_cols = {"total_size", "draw", "method", metric_name} missing_real = required_cols - set(metric_real.columns) if missing_real: raise ValueError( f"metric_real is missing required columns: {sorted(missing_real)}" ) if metric_real.empty: raise ValueError("metric_real must be non-empty.") if metric_generated is not None: missing_generated = required_cols - set(metric_generated.columns) if missing_generated: raise ValueError( "metric_generated is missing required columns: " f"{sorted(missing_generated)}" ) methods = metric_real["method"].unique() num_methods = len(methods) cols = 2 if metric_generated is not None else 1 fig, axs = plt.subplots(num_methods, cols, figsize=(15, 5 * num_methods)) # Normalise axes array for uniform indexing if num_methods == 1 and cols == 1: axs = np.array([[axs]]) elif num_methods == 1: axs = np.array([axs]) elif cols == 1: axs = axs.reshape(-1, 1) def _mean_metrics(df: pd.DataFrame, metric: str) -> pd.DataFrame: return ( df.groupby(["total_size", "method"]) .agg({metric: "mean"}) .reset_index() .rename(columns={"total_size": "n"}) ) for i, method in enumerate(methods): df_real = metric_real[metric_real["method"] == method] mean_real = _mean_metrics(df_real, metric_name) _fit_curve( mean_real, metric_name, n_target=n_target, plot=True, ax=axs[i, 0], annotation=f"{method}: Real ({metric_name})", ) if metric_generated is not None: df_gen = metric_generated[metric_generated["method"] == method] if df_gen.empty: raise ValueError( "metric_generated must include rows for every method in " f"metric_real. Missing method: {method!r}." ) mean_gen = _mean_metrics(df_gen, metric_name) _fit_curve( mean_gen, metric_name, n_target=n_target, plot=True, ax=axs[i, 1], annotation=f"{method}: Generated ({metric_name})", ) fig.tight_layout() return fig