Source code for syng_bts.result

"""
Result objects for SyNG-BTS experiment outputs.

This module defines result classes that experiment functions return
instead of writing directly to disk. Results carry generated data, loss logs,
reconstructed data, and trained model state, which are all accessible as attributes.
"""

from __future__ import annotations

import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn


def _json_serializable(obj: Any) -> Any:
    """Convert non-JSON-serializable objects to JSON-safe equivalents."""
    if isinstance(obj, tuple):
        return list(obj)
    if isinstance(obj, np.integer):
        return int(obj)
    if isinstance(obj, np.floating):
        return float(obj)
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    if isinstance(obj, Path):
        return str(obj)
    if isinstance(obj, set):
        return sorted(obj)
    return str(obj)


[docs] @dataclass class SyngResult: """Result of a single SyNG-BTS model training and generation run. Attributes ---------- generated_data : pd.DataFrame Synthetic samples with the original column names preserved. loss : pd.DataFrame Training loss log (columns depend on the model family). reconstructed_data : pd.DataFrame or None Reconstructions of the input data (AE/VAE/CVAE only). original_data : pd.DataFrame or None The full original input data. model_state : dict or None The ``state_dict()`` of the trained model, suitable for ``torch.save()`` / ``torch.load()``. metadata : dict Run parameters and summary statistics, e.g. model name, kl_weight, seed, epoch count, input data dimensions. original_groups : pd.Series or None Group labels for the original input data. Populated when groups were provided or bundled with the dataset. generated_groups : pd.Series or None Group labels for the generated data, derived from the label column produced during generation and mapped back to the original group values. reconstructed_groups : pd.Series or None Group labels for the reconstructed data (AE/VAE/CVAE only), derived from the label column and mapped back to original group values. Examples -------- >>> result = generate(data="SKCMPositive_4", model="VAE1-10", epoch=5) >>> result.generated_data.head() >>> result.save("./my_output/") >>> figs = result.plot_loss() # dict[str, Figure] """ generated_data: pd.DataFrame loss: pd.DataFrame reconstructed_data: pd.DataFrame | None = None original_data: pd.DataFrame | None = None model_state: dict[str, Any] | None = None metadata: dict[str, Any] = field(default_factory=dict) original_groups: pd.Series | None = None generated_groups: pd.Series | None = None reconstructed_groups: pd.Series | None = None # Non-serialized lazy model cache (excluded from dataclass init) _cached_model: nn.Module | None = field( default=None, init=False, repr=False, compare=False ) # ------------------------------------------------------------------ # Lazy model resolver # ------------------------------------------------------------------ def _resolve_model(self) -> nn.Module: """Return the cached model, rebuilding from state if needed. Uses ``model_state`` and ``metadata["arch_params"]`` to rebuild the model via :func:`model_factory.rebuild_model`. The rebuilt model is cached for subsequent calls. Returns ------- nn.Module The trained model in ``eval()`` mode. Raises ------ ValueError If ``model_state`` or ``metadata["arch_params"]`` is missing. """ if self._cached_model is not None: return self._cached_model if self.model_state is None: raise ValueError( "Cannot resolve model: 'model_state' is None. " "Ensure the SyngResult was created with a model_state, " "or loaded from a directory that contains a .pt file." ) arch_params = self.metadata.get("arch_params") if arch_params is None: raise ValueError( "Cannot resolve model: 'arch_params' is missing from metadata. " "Ensure the SyngResult was created by generate() v3.1+ or " "loaded from a directory with the metadata JSON." ) from .model_factory import rebuild_model self._cached_model = rebuild_model(arch_params, self.model_state) return self._cached_model # ------------------------------------------------------------------ # Post-training generation # ------------------------------------------------------------------
[docs] def generate_new_samples( self, n: int, *, mode: str = "new", ) -> SyngResult: """Generate new synthetic samples from the trained model. This method reuses the same generation and post-processing path as :func:`generate`, applying the same inverse-log transform and column naming. Parameters ---------- n : int Number of new samples to generate. mode : str How to incorporate the new samples: - ``"new"`` (default): return a **new** ``SyngResult`` whose ``generated_data`` contains only the newly generated samples. All other fields (loss, metadata, model_state, etc.) are copied from ``self``. - ``"overwrite"``: **replace** ``self.generated_data`` with the new samples and return ``self``. - ``"append"``: **append** the new samples to ``self.generated_data`` and return ``self``. Returns ------- SyngResult The result containing the new samples (see *mode*). Raises ------ ValueError If ``model_state`` is ``None``, ``arch_params`` is missing from metadata, or *mode* is not one of the accepted values. Examples -------- >>> result = generate(data="SKCMPositive_4", model="VAE1-10", epoch=5) >>> new_result = result.generate_new_samples(200) >>> new_result.generated_data.shape[0] 200 >>> # After save/load round-trip: >>> loaded = SyngResult.load("output/") >>> more = loaded.generate_new_samples(100, mode="append") >>> more.generated_data.shape[0] # original + 100 """ if isinstance(n, bool) or not isinstance(n, int) or n <= 0: raise ValueError(f"n must be a positive integer, got {n!r}") if mode not in ("new", "overwrite", "append"): raise ValueError( f"mode must be 'new', 'overwrite', or 'append', got {mode!r}" ) from .helper_training import TrainedModel from .inference import run_generation # Resolve the trained model (lazy rebuild from state_dict) model = self._resolve_model() arch_params = self.metadata["arch_params"] trained = TrainedModel( model=model, model_state=self.model_state, arch_params=arch_params, log_dict={}, epochs_trained=self.metadata.get("epochs_trained", 0), ) # Generate raw samples via the unified inference dispatcher gen_tensor = run_generation(trained, num_samples=n) # Post-processing: same as generate() in core.py gen_np = gen_tensor.detach().numpy() colnames = list(self.generated_data.columns) modelname = arch_params.get("modelname", "") gen_labels: pd.Series | None = None # Strip the appended label/blur-label column when the generated # tensor has more columns than the original feature set. if gen_np.shape[1] > len(colnames): gen_labels = pd.Series(gen_np[:, -1], name="label") gen_np = gen_np[:, :-1] gen_df = pd.DataFrame(gen_np, columns=colnames) if gen_df.columns.tolist() != colnames: raise RuntimeError( "Column order mismatch in generated_data. " "This is an internal error; please report it." ) # Inverse log transform if the original run used apply_log apply_log = self.metadata.get("apply_log", False) if apply_log: from .helper_utils import inverse_log2 gen_df = inverse_log2(gen_df) # Derive generated_groups from labels + group_mapping group_mapping = self.metadata.get("group_mapping") new_gen_groups: pd.Series | None = None if group_mapping is not None and gen_labels is not None: from .core import _labels_to_groups new_gen_groups = _labels_to_groups( gen_labels, group_mapping, modelname=modelname ) def _as_series_or_none(value: Any) -> pd.Series | None: if value is None: return None if isinstance(value, pd.Series): return value.reset_index(drop=True) if isinstance(value, np.ndarray): return pd.Series(value, name="label") if isinstance(value, list): return pd.Series(value, name="label") return None # Apply mode if mode == "new": new_metadata = self.metadata.copy() new_metadata["generated_labels"] = gen_labels return SyngResult( generated_data=gen_df, loss=self.loss.copy(), reconstructed_data=( self.reconstructed_data.copy() if self.reconstructed_data is not None else None ), original_data=( self.original_data.copy() if self.original_data is not None else None ), model_state=self.model_state, metadata=new_metadata, original_groups=( self.original_groups.copy() if self.original_groups is not None else None ), generated_groups=new_gen_groups, reconstructed_groups=( self.reconstructed_groups.copy() if self.reconstructed_groups is not None else None ), ) elif mode == "overwrite": self.generated_data = gen_df self.metadata["generated_labels"] = gen_labels self.generated_groups = new_gen_groups return self else: # append self.generated_data = pd.concat( [self.generated_data, gen_df], ignore_index=True ) if gen_labels is None: self.metadata["generated_labels"] = None else: old_labels = _as_series_or_none(self.metadata.get("generated_labels")) if old_labels is None: self.metadata["generated_labels"] = gen_labels else: self.metadata["generated_labels"] = pd.concat( [old_labels, gen_labels], ignore_index=True ) # Append generated groups if new_gen_groups is None: self.generated_groups = None elif self.generated_groups is None: self.generated_groups = new_gen_groups else: self.generated_groups = pd.concat( [self.generated_groups, new_gen_groups], ignore_index=True ) return self
# ------------------------------------------------------------------ # Convenience methods # ------------------------------------------------------------------
[docs] def save( self, output_dir: str | Path, prefix: str | None = None, ) -> dict[str, Path]: """Save all non-None results to *output_dir*. Files are written into a single flat directory. CSVs include column headers. Model state is saved as a ``.pt`` file. Metadata is written as a human-readable JSON file. Parameters ---------- output_dir : str or Path Directory to write files into (created if it does not exist). prefix : str or None Optional filename prefix. When ``None``, uses ``metadata["dataname"]`` if available, otherwise ``"syng"``. Returns ------- dict[str, Path] Mapping of output type (``"generated"``, ``"loss"``, ``"reconstructed"``, ``"model"``, ``"metadata"``) to the written file path. """ out = Path(output_dir) out.mkdir(parents=True, exist_ok=True) if prefix is None: prefix = self.metadata.get("dataname", "syng") model_tag = self.metadata.get("model", "") if model_tag: stem = f"{prefix}_{model_tag}" else: stem = prefix paths: dict[str, Path] = {} # Generated data gen_path = out / f"{stem}_generated.csv" self.generated_data.to_csv(gen_path, index=False) paths["generated"] = gen_path # Loss log loss_path = out / f"{stem}_loss.csv" self.loss.to_csv(loss_path, index=False) paths["loss"] = loss_path # Reconstructed data if self.reconstructed_data is not None: recon_path = out / f"{stem}_reconstructed.csv" self.reconstructed_data.to_csv(recon_path, index=False) paths["reconstructed"] = recon_path # Original data if self.original_data is not None: orig_path = out / f"{stem}_original.csv" self.original_data.to_csv(orig_path, index=True) paths["original"] = orig_path # Group attributes if self.original_groups is not None: og_path = out / f"{stem}_original_groups.csv" self.original_groups.to_frame().to_csv(og_path, index=False) paths["original_groups"] = og_path if self.generated_groups is not None: gg_path = out / f"{stem}_generated_groups.csv" self.generated_groups.to_frame().to_csv(gg_path, index=False) paths["generated_groups"] = gg_path if self.reconstructed_groups is not None: rg_path = out / f"{stem}_reconstructed_groups.csv" self.reconstructed_groups.to_frame().to_csv(rg_path, index=False) paths["reconstructed_groups"] = rg_path # Model state dict if self.model_state is not None: model_path = out / f"{stem}_model.pt" torch.save(self.model_state, model_path) paths["model"] = model_path # Metadata as human-readable JSON if self.metadata: meta_path = out / f"{stem}_metadata.json" meta_path.write_text( json.dumps(self.metadata, indent=2, default=_json_serializable), encoding="utf-8", ) paths["metadata"] = meta_path return paths
[docs] def plot_loss( self, running_average_window: int = 25, x_axis: str = "epochs", ) -> dict[str, plt.Figure]: """Plot the training loss curve(s), one figure per loss column. Each returned figure shows the raw loss series (``alpha=0.4``) and a running-average overlay. Parameters ---------- running_average_window : int Window size for the running-average overlay. Must be > 0. Default: 25. x_axis : str ``"epochs"`` (default) maps the x-axis to epoch space using ``metadata["epochs_trained"]`` (must be present and > 0). ``"iterations"`` numbers data points 0…N-1. Returns ------- dict[str, matplotlib.figure.Figure] ``{loss_column_name: figure}`` for every column in ``self.loss``. Raises ------ ValueError If *running_average_window* ≤ 0, if *x_axis* is not ``"iterations"`` or ``"epochs"``, if ``x_axis="epochs"`` but ``metadata["epochs_trained"]`` is missing or ≤ 0, or if the window is larger than a loss series. """ if running_average_window <= 0: raise ValueError( f"running_average_window must be > 0, got {running_average_window}" ) if x_axis not in ("iterations", "epochs"): raise ValueError(f"x_axis must be 'iterations' or 'epochs', got {x_axis!r}") num_epochs: float | None = None if x_axis == "epochs": raw_num_epochs = self.metadata.get("epochs_trained") if raw_num_epochs is None or isinstance(raw_num_epochs, bool): raise ValueError( "x_axis='epochs' requires metadata['epochs_trained'] > 0, " f"got {raw_num_epochs!r}" ) try: num_epochs = float(raw_num_epochs) except (TypeError, ValueError) as exc: raise ValueError( "x_axis='epochs' requires metadata['epochs_trained'] > 0, " f"got {raw_num_epochs!r}" ) from exc if num_epochs <= 0: raise ValueError( "x_axis='epochs' requires metadata['epochs_trained'] > 0, " f"got {raw_num_epochs!r}" ) figures: dict[str, plt.Figure] = {} for col in self.loss.columns: values = self.loss[col].to_numpy() if running_average_window > len(values): raise ValueError( f"running_average_window ({running_average_window}) is larger " f"than the '{col}' series length ({len(values)})" ) fig, ax = plt.subplots() # --- Build x-coordinates --- if x_axis == "epochs": assert num_epochs is not None x = np.linspace(0, num_epochs, len(values)) ax.set_xlabel("Epochs") else: x = np.arange(len(values)) ax.set_xlabel("Iterations") ax.plot(x, values, alpha=0.4, label=f"{col} (raw)") # --- Running average --- kernel = np.ones(running_average_window) / running_average_window smoothed = np.convolve(values, kernel, mode="valid") offset = running_average_window - 1 ax.plot( x[offset:], smoothed, label=f"{col} (avg, w={running_average_window})", ) ax.set_ylabel("Loss") ax.set_title(f"{col} loss") # --- Y-axis scaling: ignore the initial spike --- n = len(values) skip = n // 2 if n < 1001 else 1000 if n > skip: later_max = float(np.max(values[skip:])) if later_max > 0: ax.set_ylim([0, later_max * 1.5]) ax.legend() fig.tight_layout() figures[col] = fig return figures
[docs] def plot_heatmap( self, which: str = "generated", log_scale: bool = True ) -> plt.Figure: """Render a seaborn heatmap of generated or reconstructed data. Parameters ---------- which : str ``"generated"``, ``"reconstructed"``, or ``"original"``. log_scale : bool If ``True`` (default), apply ``log2(x + 1)`` scaling to the data before plotting. This compresses wide-ranging values and often produces more readable heatmaps. Returns ------- matplotlib.figure.Figure The heatmap figure (not shown; caller decides when to display). Raises ------ ValueError If *which* is ``"reconstructed"`` but no reconstructed data exists, or if *which* is not a recognised value. """ if which == "generated": df = self.generated_data elif which == "reconstructed": if self.reconstructed_data is None: raise ValueError( "No reconstructed data available in this result. " "Reconstructed data is only produced by AE/VAE/CVAE models." ) df = self.reconstructed_data elif which == "original": if self.original_data is None: raise ValueError( "No original data available in this result. " "Pass original_data when constructing the result." ) df = self.original_data else: raise ValueError( f"Unknown value which={which!r}; " f"expected 'generated', 'reconstructed', or 'original'." ) data = df.to_numpy() if log_scale: data = np.log2(data + 1) fig, ax = plt.subplots() sns.heatmap(data, cmap="YlGnBu", ax=ax) title = f"{which.capitalize()} data" if log_scale: title += " (log2)" ax.set_title(title) fig.tight_layout() return fig
[docs] def summary(self) -> str: """Return a short textual summary of this result. Returns ------- str A paragraph describing the run dimensions, epoch count, and final loss values. """ meta = self.metadata model = meta.get("model", "unknown") n_gen, n_feat = self.generated_data.shape epochs = meta.get("epochs_trained", "?") # Summarise final loss values final_losses = { col: f"{self.loss[col].iloc[-1]:.4f}" for col in self.loss.columns } loss_str = ", ".join(f"{k}={v}" for k, v in final_losses.items()) parts = [ f"Model: {model}", f"Generated data: {n_gen} samples × {n_feat} features", f"Epochs trained: {epochs}", f"Final loss: {loss_str}", ] if self.reconstructed_data is not None: r, c = self.reconstructed_data.shape parts.append(f"Reconstructed data: {r} rows × {c} cols") if self.original_data is not None: r, c = self.original_data.shape parts.append(f"Original data: {r} rows × {c} cols") if self.original_groups is not None: n_classes = self.original_groups.nunique() parts.append(f"Groups: {n_classes} classes") if "seed" in meta: parts.append(f"Random seed: {meta['seed']}") return " | ".join(parts)
def __repr__(self) -> str: n_gen, n_feat = self.generated_data.shape model = self.metadata.get("model", "?") has_recon = self.reconstructed_data is not None has_original = self.original_data is not None has_model = self.model_state is not None has_groups = self.original_groups is not None return ( f"SyngResult(model={model!r}, " f"generated={n_gen}×{n_feat}, " f"loss_cols={list(self.loss.columns)}, " f"has_reconstructed={has_recon}, " f"has_original={has_original}, " f"has_model_state={has_model}, " f"has_groups={has_groups})" ) # ------------------------------------------------------------------ # Loader # ------------------------------------------------------------------
[docs] @classmethod def load( cls, directory: str | Path, prefix: str | None = None, ) -> SyngResult: """Load a previously saved ``SyngResult`` from disk. Parameters ---------- directory : str or Path Directory that contains the saved files. prefix : str or None The filename stem (everything before ``_generated.csv``). When ``None``, auto-detected from ``*_generated.csv`` files in the directory; exactly one match is required. Returns ------- SyngResult Reconstructed result with all available artifacts. Raises ------ FileNotFoundError If the required ``*_generated.csv`` or ``*_loss.csv`` file is missing. ValueError If *prefix* is ``None`` and zero or more than one ``*_generated.csv`` file is found (ambiguous). """ d = Path(directory) if prefix is None: candidates = sorted(d.glob("*_generated.csv")) if len(candidates) == 0: raise FileNotFoundError(f"No *_generated.csv files found in {d}") if len(candidates) > 1: stems = [c.name.removesuffix("_generated.csv") for c in candidates] looks_like_pilot_dir = any( "_pilot" in c.name and "_draw" in c.name for c in candidates ) if looks_like_pilot_dir: raise ValueError( "Multiple generated files found and directory appears to " f"contain PilotResult outputs: {stems}. " "SyngResult.load() loads one run at a time; pass prefix " "for a specific run stem, e.g. '<dataname>_pilot50_draw1_<model>'." ) raise ValueError( f"Multiple generated files found in {d}: {stems}. " "Specify 'prefix' to disambiguate." ) stem = candidates[0].name.removesuffix("_generated.csv") else: stem = prefix # --- Required files --- gen_path = d / f"{stem}_generated.csv" if not gen_path.exists(): raise FileNotFoundError(f"Required file not found: {gen_path}") generated_data = pd.read_csv(gen_path) loss_path = d / f"{stem}_loss.csv" if not loss_path.exists(): raise FileNotFoundError(f"Required file not found: {loss_path}") loss = pd.read_csv(loss_path) # --- Optional files --- recon_path = d / f"{stem}_reconstructed.csv" reconstructed_data = pd.read_csv(recon_path) if recon_path.exists() else None orig_path = d / f"{stem}_original.csv" original_data = ( pd.read_csv(orig_path, index_col=0) if orig_path.exists() else None ) model_path = d / f"{stem}_model.pt" model_state = ( torch.load(model_path, weights_only=False) if model_path.exists() else None ) # --- Group sidecar files --- def _load_groups(suffix: str) -> pd.Series | None: path = d / f"{stem}_{suffix}.csv" if not path.exists(): return None df = pd.read_csv(path) return df.iloc[:, 0].rename("group") original_groups = _load_groups("original_groups") generated_groups = _load_groups("generated_groups") reconstructed_groups = _load_groups("reconstructed_groups") meta_path = d / f"{stem}_metadata.json" if meta_path.exists(): metadata = json.loads(meta_path.read_text(encoding="utf-8")) # Restore tuples that were serialised as lists if "input_shape" in metadata and isinstance(metadata["input_shape"], list): metadata["input_shape"] = tuple(metadata["input_shape"]) # Restore group_mapping keys from JSON strings to ints if "group_mapping" in metadata and isinstance( metadata["group_mapping"], dict ): metadata["group_mapping"] = { int(k): v for k, v in metadata["group_mapping"].items() } else: metadata = {} return cls( generated_data=generated_data, loss=loss, reconstructed_data=reconstructed_data, original_data=original_data, model_state=model_state, metadata=metadata, original_groups=original_groups, generated_groups=generated_groups, reconstructed_groups=reconstructed_groups, )
[docs] @dataclass class PilotResult: """Result of a pilot study run across multiple pilot sizes and draws. Attributes ---------- runs : dict[tuple[int, int], SyngResult] Mapping of ``(pilot_size, draw_index)`` → individual run result. ``draw_index`` is 1-based (1 through 5). original_data : pd.DataFrame or None The full original input data (before subsetting). metadata : dict Shared metadata across all runs (model, data dimensions, etc.). Examples -------- >>> result = pilot_study(data="SKCMPositive_4", pilot_size=[50, 100], ...) >>> result.runs[(50, 1)].generated_data.head() >>> result.save("./pilot_output/") """ runs: dict[tuple[int, int], SyngResult] = field(default_factory=dict) original_data: pd.DataFrame | None = None metadata: dict[str, Any] = field(default_factory=dict) # ------------------------------------------------------------------ # Convenience methods # ------------------------------------------------------------------
[docs] def save( self, output_dir: str | Path, prefix: str | None = None, ) -> dict[tuple[int, int], dict[str, Path]]: """Save all individual run results to *output_dir*. Each run is saved with a filename that encodes the pilot size and draw index. Parameters ---------- output_dir : str or Path Directory to write files into (created if it does not exist). prefix : str or None Optional filename prefix. Falls back to ``metadata["dataname"]`` or ``"syng"``. Returns ------- dict[tuple[int, int], dict[str, Path]] Nested mapping: ``(pilot_size, draw) → {output_type → path}``. """ out = Path(output_dir) out.mkdir(parents=True, exist_ok=True) if prefix is None: prefix = self.metadata.get("dataname", "syng") all_paths: dict[tuple[int, int], dict[str, Path]] = {} for (pilot_size, draw), result in sorted(self.runs.items()): run_prefix = f"{prefix}_pilot{pilot_size}_draw{draw}" all_paths[(pilot_size, draw)] = result.save(out, prefix=run_prefix) # Save top-level original data if self.original_data is not None: orig_path = out / f"{prefix}_original.csv" self.original_data.to_csv(orig_path, index=True) # Save top-level metadata if self.metadata: meta_path = out / f"{prefix}_pilot_metadata.json" meta_path.write_text( json.dumps(self.metadata, indent=2, default=_json_serializable), encoding="utf-8", ) return all_paths
[docs] def plot_loss( self, style: str = "overlay_runs", running_average_window: int = 25, x_axis: str = "epochs", truncate: bool = True, ) -> dict[tuple[int, int], dict[str, plt.Figure]] | dict[str, plt.Figure]: """Plot loss curves for every run in the pilot study. Parameters ---------- style : str Plotting style for loss trajectories. - ``"per_run"`` (default): one figure per run per loss column, delegating to :meth:`SyngResult.plot_loss`. - ``"overlay_runs"``: overlay all runs on the same plot for each loss column. Only the running-average line is drawn per run (no raw trace) to keep the plot readable. - ``"mean_band"``: plot the mean loss trajectory across all runs for each loss column, with a shaded ±1 std band. Mean and std are computed on raw loss values; the mean line is then optionally smoothed with a running average. For all styles, y-axis scaling is applied to reduce the effect of large initial spikes (analogous to :meth:`SyngResult.plot_loss`). running_average_window : int Window size for the running-average overlay. Must be > 0. Default: 25. x_axis : str ``"epochs"`` (default) maps the x-axis to epoch space using each run's ``metadata["epochs_trained"]``. ``"iterations"`` numbers data points 0…N-1. truncate : bool Only relevant for ``style="mean_band"`` and ``style="overlay_runs"``. When ``True`` (default), only epochs/iterations common to **all** runs are plotted (truncated to the shortest run). When ``False``, all epochs/iterations are plotted; statistics are computed from whichever runs still have data at each point. Returns ------- dict[tuple[int, int], dict[str, Figure]] or dict[str, Figure] ``style="per_run"``: nested dict keyed by ``(pilot_size, draw)`` → ``{column: Figure}``. ``style="overlay_runs"`` or ``style="mean_band"``: flat dict ``{column: Figure}``. Raises ------ ValueError If *style* is not one of the accepted values, if *running_average_window* ≤ 0, or if *x_axis* is invalid. Examples -------- >>> figs = pilot_result.plot_loss(style="overlay_runs") >>> figs = pilot_result.plot_loss(style="mean_band", truncate=False) """ # --- Input validation --- valid_styles = ("per_run", "overlay_runs", "mean_band") if style not in valid_styles: raise ValueError(f"style must be one of {valid_styles!r}, got {style!r}") if running_average_window <= 0: raise ValueError( f"running_average_window must be > 0, got {running_average_window}" ) if x_axis not in ("iterations", "epochs"): raise ValueError(f"x_axis must be 'iterations' or 'epochs', got {x_axis!r}") # --- style="per_run": delegate to SyngResult.plot_loss() --- if style == "per_run": return { key: result.plot_loss( running_average_window=running_average_window, x_axis=x_axis, ) for key, result in sorted(self.runs.items()) } # --- Shared helpers for "overlay_runs" and "mean_band" --- def _collect_loss_columns() -> list[str]: """Return the ordered union of loss column names across runs.""" cols: list[str] = [] seen: set[str] = set() for result in self.runs.values(): for col in result.loss.columns: if col not in seen: cols.append(col) seen.add(col) return cols def _resolve_epochs(result: SyngResult, key: tuple[int, int]) -> float: """Validate and return ``epochs_trained`` for a single run.""" raw = result.metadata.get("epochs_trained") if raw is None or isinstance(raw, bool): raise ValueError( f"x_axis='epochs' requires metadata['epochs_trained'] > 0 " f"for run {key}, got {raw!r}" ) try: val = float(raw) except (TypeError, ValueError) as exc: raise ValueError( f"x_axis='epochs' requires metadata['epochs_trained'] > 0 " f"for run {key}, got {raw!r}" ) from exc if val <= 0: raise ValueError( f"x_axis='epochs' requires metadata['epochs_trained'] > 0 " f"for run {key}, got {raw!r}" ) return val def _build_x(length: int, num_epochs: float | None) -> np.ndarray: """Build x-coordinate array.""" if x_axis == "epochs" and num_epochs is not None: return np.linspace(0, num_epochs, length) return np.arange(length) def _apply_ylim_scaling(ax, values: np.ndarray) -> None: """Set y-axis limits to suppress the initial spike.""" n = len(values) skip = n // 2 if n < 1001 else 1000 if n > skip: later_max = float(np.nanmax(values[skip:])) if later_max > 0: ax.set_ylim([0, later_max * 1.5]) def _later_max_for_scaling(values: np.ndarray) -> float | None: """Return post-spike max used for y-axis scaling, if available.""" n = len(values) skip = n // 2 if n < 1001 else 1000 if n <= skip: return None later_max = float(np.nanmax(values[skip:])) if later_max <= 0: return None return later_max all_columns = _collect_loss_columns() # --- style="overlay_runs": overlay running-average per run --- if style == "overlay_runs": figures: dict[str, plt.Figure] = {} cmap = plt.colormaps["tab10"] kernel = np.ones(running_average_window) / running_average_window for col in all_columns: fig, ax = plt.subplots() colour_idx = 0 later_max_values: list[float] = [] eligible_runs: list[tuple[tuple[int, int], SyngResult]] = [ (key, result) for key, result in sorted(self.runs.items()) if col in result.loss.columns ] if not eligible_runs: ax.set_xlabel("Epochs" if x_axis == "epochs" else "Iterations") ax.set_ylabel("Loss") ax.set_title(f"{col} loss (overlay_runs)") fig.tight_layout() figures[col] = fig continue # If truncating, use the common prefix length across runs. truncate_len: int | None = None common_epochs: float | None = None if truncate: truncate_len = min( len(result.loss[col]) for _key, result in eligible_runs ) if x_axis == "epochs": common_epochs = min( _resolve_epochs(result, key) for key, result in eligible_runs ) for (ps, draw), result in eligible_runs: values = result.loss[col].to_numpy() if truncate_len is not None: values = values[:truncate_len] num_epochs: float | None = None if x_axis == "epochs": if truncate and common_epochs is not None: num_epochs = common_epochs else: num_epochs = _resolve_epochs(result, (ps, draw)) x_full = _build_x(len(values), num_epochs) colour = cmap(colour_idx % 10) colour_idx += 1 # Only plot running-average values when possible; otherwise # fall back to raw values (short series). if running_average_window <= len(values): smoothed = np.convolve(values, kernel, mode="valid") offset = running_average_window - 1 ax.plot( x_full[offset:], smoothed, alpha=0.7, color=colour, label=f"pilot={ps} draw={draw}", ) later_max = _later_max_for_scaling(smoothed) if later_max is not None: later_max_values.append(later_max) else: ax.plot( x_full, values, alpha=0.7, color=colour, label=f"pilot={ps} draw={draw} (raw)", ) later_max = _later_max_for_scaling(values) if later_max is not None: later_max_values.append(later_max) ax.set_xlabel("Epochs" if x_axis == "epochs" else "Iterations") ax.set_ylabel("Loss") ax.set_title(f"{col} loss (overlay_runs)") ax.legend(fontsize="x-small") # Y-axis scaling: ignore each run's initial spike, then combine if later_max_values: ax.set_ylim([0, max(later_max_values) * 1.5]) fig.tight_layout() figures[col] = fig return figures # --- style="mean_band": mean ± std across runs --- figures = {} kernel = np.ones(running_average_window) / running_average_window for col in all_columns: # Gather raw loss arrays for this column arrays: list[np.ndarray] = [] epoch_values: list[float] = [] for (ps, draw), result in sorted(self.runs.items()): if col not in result.loss.columns: continue arrays.append(result.loss[col].to_numpy()) if x_axis == "epochs": epoch_values.append(_resolve_epochs(result, (ps, draw))) if not arrays: continue # Determine target length and stack lengths = [len(a) for a in arrays] if truncate: target_len = min(lengths) stacked = np.array([a[:target_len] for a in arrays]) else: target_len = max(lengths) stacked = np.full((len(arrays), target_len), np.nan) for i, a in enumerate(arrays): stacked[i, : len(a)] = a mean_vals = np.nanmean(stacked, axis=0) std_vals = np.nanstd(stacked, axis=0) # Build x-coordinates if x_axis == "epochs": # Use the epoch range corresponding to the target length if truncate: ref_epochs = min(epoch_values) else: ref_epochs = max(epoch_values) x_full = np.linspace(0, ref_epochs, target_len) else: x_full = np.arange(target_len) fig, ax = plt.subplots() # Shaded std band (on raw values) ax.fill_between( x_full, mean_vals - std_vals, mean_vals + std_vals, alpha=0.3, color="tab:blue", label="±1 std", ) # Smoothed mean line if running_average_window <= target_len: smoothed_mean = np.convolve(mean_vals, kernel, mode="valid") offset = running_average_window - 1 ax.plot( x_full[offset:], smoothed_mean, color="tab:blue", label=f"mean (avg w={running_average_window})", ) else: # Window too large for smoothing; plot raw mean ax.plot( x_full, mean_vals, color="tab:blue", label="mean", ) ax.set_xlabel("Epochs" if x_axis == "epochs" else "Iterations") ax.set_ylabel("Loss") n_runs = stacked.shape[0] ax.set_title(f"{col} loss (mean_band, n={n_runs})") ax.legend() # Y-axis scaling from mean values _apply_ylim_scaling(ax, mean_vals) fig.tight_layout() figures[col] = fig return figures
[docs] def summary(self) -> str: """Return an aggregate summary of all pilot runs. Returns ------- str Multi-line summary with one line per run. """ lines = [f"PilotResult: {len(self.runs)} runs"] model = self.metadata.get("model", "?") lines.append(f"Model: {model}") pilot_sizes = sorted({ps for ps, _ in self.runs}) lines.append(f"Pilot sizes: {pilot_sizes}") if self.original_data is not None: r, c = self.original_data.shape lines.append(f"Original data: {r} rows × {c} cols") for key in sorted(self.runs): r = self.runs[key] n_gen = r.generated_data.shape[0] final_losses = { col: f"{r.loss[col].iloc[-1]:.4f}" for col in r.loss.columns } loss_str = ", ".join(f"{k}={v}" for k, v in final_losses.items()) lines.append( f" pilot={key[0]}, draw={key[1]}: " f"{n_gen} generated, final loss: {loss_str}" ) return "\n".join(lines)
def __repr__(self) -> str: n_runs = len(self.runs) pilot_sizes = sorted({ps for ps, _ in self.runs}) model = self.metadata.get("model", "?") has_original = self.original_data is not None return ( f"PilotResult(model={model!r}, n_runs={n_runs}, " f"pilot_sizes={pilot_sizes}, has_original={has_original})" )