from __future__ import annotations
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import umap.umap_ as umap # noqa: F401
from matplotlib.figure import Figure
from umap import UMAP
from .data_utils import resolve_data
def _coerce_groups(
groups: pd.Series | np.ndarray | list | tuple | pd.Index | None,
*,
param_name: str,
expected_len: int,
) -> pd.Series | None:
"""Coerce optional group labels to a length-validated Series."""
if groups is None:
return None
if isinstance(groups, pd.Series):
coerced = groups.reset_index(drop=True)
elif isinstance(groups, (np.ndarray, list, tuple, pd.Index)):
coerced = pd.Series(groups)
else:
raise TypeError(
f"{param_name} must be a pandas Series, numpy array, list, tuple, "
f"or None; got {type(groups).__name__}."
)
if len(coerced) != expected_len:
raise ValueError(
f"{param_name} length ({len(coerced)}) does not match the number "
f"of rows in the corresponding dataset ({expected_len})."
)
return coerced.astype(str)
[docs]
def heatmap_eval(
real_data: pd.DataFrame,
generated_data: pd.DataFrame | None = None,
*,
apply_log: bool = True,
cmap: str = "YlGnBu",
) -> Figure:
r"""Create a heatmap visualization comparing real and generated data.
If only one dataset is provided, displays a single heatmap. If both
real and generated data are provided, displays them side by side.
Parameters
----------
real_data : pd.DataFrame
The original/real data.
generated_data : pd.DataFrame or None, optional
The generated/synthetic data. If ``None``, only *real_data* is plotted.
apply_log : bool, default True
Whether to apply ``log2(x + 1)`` transformation to both real
and generated data before visualization.
cmap : str, default ``"YlGnBu"``
Colormap passed to :func:`seaborn.heatmap`.
Returns
-------
Figure
The matplotlib Figure containing the heatmap(s).
"""
# Select only numeric columns.
real_data_plot = real_data.select_dtypes(include=["number"]).copy()
generated_data_plot = (
generated_data.select_dtypes(include=["number"]).copy()
if generated_data is not None
else None
)
# Apply log2 transformation if requested
if apply_log:
real_data_plot = np.log2(real_data_plot + 1)
if generated_data_plot is not None:
generated_data_plot = np.log2(generated_data_plot + 1)
if generated_data_plot is None:
fig = plt.figure(figsize=(6, 6))
ax = sns.heatmap(real_data_plot, cbar=True, cmap=cmap)
ax.set_title("Real Data")
ax.set_xlabel("Features")
ax.set_ylabel("Samples")
else:
fig, axs = plt.subplots(
ncols=2, figsize=(12, 6), gridspec_kw={"width_ratios": [0.5, 0.55]}
)
sns.heatmap(generated_data_plot, ax=axs[0], cbar=False, cmap=cmap)
axs[0].set_title("Generated Data")
axs[0].set_xlabel("Features")
axs[0].set_ylabel("Samples")
sns.heatmap(real_data_plot, ax=axs[1], cbar=True, cmap=cmap)
axs[1].set_title("Real Data")
axs[1].set_xlabel("Features")
axs[1].set_ylabel("Samples")
fig.tight_layout()
return fig
[docs]
def UMAP_eval(
real_data: pd.DataFrame,
generated_data: pd.DataFrame | None = None,
*,
apply_log: bool = True,
groups_real: pd.Series | None = None,
groups_generated: pd.Series | None = None,
random_seed: int = 42,
legend_pos: str = "best",
) -> Figure:
r"""Create a UMAP visualization comparing real and generated data.
Uses UMAP dimensionality reduction to visualize high-dimensional
data in 2D, with optional group colouring.
Parameters
----------
real_data : pd.DataFrame
The original/real data.
generated_data : pd.DataFrame or None, optional
The generated/synthetic data. If ``None``, only *real_data* is
visualised.
apply_log : bool, default True
Whether to apply ``log2(x + 1)`` transformation to both real
and generated data before dimensionality reduction.
groups_real : pd.Series or None, optional
Group labels for real samples. Used for styling.
groups_generated : pd.Series or None, optional
Group labels for generated samples. Used for styling.
random_seed : int, default 42
Random seed for UMAP reproducibility.
legend_pos : str, default ``"best"``
Legend position (``"best"``, ``"upper right"``, ``"lower left"``, …).
Returns
-------
Figure
The matplotlib Figure containing the UMAP scatter plot.
"""
# Select numeric columns and apply log transformation if requested
real_data_processed = real_data.select_dtypes(include=[np.number]).copy()
if apply_log:
real_data_processed = np.log2(real_data_processed + 1)
if generated_data is None:
reducer = UMAP(random_state=random_seed)
embedding = reducer.fit_transform(real_data_processed.values)
umap_df = pd.DataFrame(embedding, columns=["UMAP1", "UMAP2"])
fig, ax = plt.subplots(figsize=(10, 8))
if groups_real is not None:
umap_df["Group"] = groups_real.astype(str).values
sns.scatterplot(
data=umap_df,
x="UMAP1",
y="UMAP2",
style="Group",
palette="bright",
ax=ax,
)
ax.legend(title="Group", loc=legend_pos)
ax.set_title("UMAP Projection of Real Data with Groups")
else:
ax.scatter(umap_df["UMAP1"], umap_df["UMAP2"], alpha=0.7)
ax.set_title("UMAP Projection of Real Data")
return fig
# Process generated data and filter to match real_data columns
gen_data_processed = generated_data.iloc[:, : real_data_processed.shape[1]].copy()
gen_data_processed.columns = real_data_processed.columns
if apply_log:
gen_data_processed = np.log2(gen_data_processed + 1)
# Filter out features with zero variance in generated data
non_zero_var_cols = gen_data_processed.var(axis=0) != 0
real_filtered = real_data_processed.loc[:, non_zero_var_cols]
gen_filtered = gen_data_processed.loc[:, non_zero_var_cols]
# Combine datasets
combined_data = np.vstack((real_filtered.values, gen_filtered.values))
combined_labels = np.array(
["Real"] * real_filtered.shape[0] + ["Generated"] * gen_filtered.shape[0]
)
reducer = UMAP(random_state=random_seed)
embedding = reducer.fit_transform(combined_data)
umap_df = pd.DataFrame(embedding, columns=["UMAP1", "UMAP2"])
umap_df["Data Type"] = combined_labels
fig, ax = plt.subplots(figsize=(10, 8))
if groups_real is not None and groups_generated is not None:
combined_groups = [
str(g) for g in np.concatenate((groups_real, groups_generated))
]
umap_df["Group"] = combined_groups
sns.scatterplot(
data=umap_df,
x="UMAP1",
y="UMAP2",
hue="Data Type",
style="Group",
palette="bright",
ax=ax,
)
ax.legend(title="Data Type / Group", loc=legend_pos)
ax.set_title("UMAP Projection of Real and Generated Data with Groups")
else:
sns.scatterplot(
data=umap_df,
x="UMAP1",
y="UMAP2",
hue="Data Type",
palette="bright",
ax=ax,
)
ax.legend(title="Data Type", loc=legend_pos)
ax.set_title("UMAP Projection of Real and Generated Data")
return fig
[docs]
def evaluation(
real_data: pd.DataFrame | str | Path,
generated_data: pd.DataFrame | str | Path,
*,
real_groups: pd.Series | np.ndarray | list | tuple | pd.Index | None = None,
generated_groups: pd.Series | np.ndarray | list | tuple | pd.Index | None = None,
n_samples: int | None = 200,
apply_log: bool = True,
random_seed: int = 42,
) -> dict[str, Figure]:
r"""Preprocessing and visualization of generated vs real data.
Loads and preprocesses the input data, then creates heatmap and UMAP
visualizations comparing generated and real datasets.
Parameters
----------
real_data : pd.DataFrame, str, or Path
The original/real dataset. Accepts a DataFrame, a file path, or
a bundled dataset name (resolved via :func:`resolve_data`).
generated_data : pd.DataFrame, str, or Path
The generated/synthetic dataset. Same input types as *real_data*.
real_groups : pd.Series, np.ndarray, list, tuple, pd.Index, or None, optional
Group labels for the real samples. When provided, takes
precedence over any bundled groups resolved from *real_data*.
Values are used as-is for plot labels (converted to ``str``).
generated_groups : pd.Series, np.ndarray, list, tuple, pd.Index, or None, optional
Group labels for the generated samples. When provided, takes
precedence over any bundled groups resolved from *generated_data*.
Values are used as-is for plot labels (converted to ``str``).
n_samples : int or None, default 200
Number of samples from each end of the dataset to use for
visualization (to keep UMAP fast). If ``None``, all samples are
used.
apply_log : bool, default True
Whether to apply ``log2(x + 1)`` transformation to both real
and generated data before comparison.
random_seed : int, default 42
Random seed for UMAP reproducibility.
Returns
-------
dict[str, Figure]
``{"heatmap": <Figure>, "umap": <Figure>}`` — the two evaluation
figures. Neither figure has been displayed; the caller decides
when to call ``plt.show()`` or ``fig.savefig()``.
"""
real_df, bundled_groups_real = resolve_data(real_data)
gen_df, bundled_groups_gen = resolve_data(generated_data)
# --- Resolve group labels -----------------------------------------------
# Precedence: explicit parameter > bundled groups > None
groups_real = _coerce_groups(
real_groups,
param_name="real_groups",
expected_len=len(real_df),
)
groups_generated = _coerce_groups(
generated_groups,
param_name="generated_groups",
expected_len=len(gen_df),
)
if groups_real is None and bundled_groups_real is not None:
groups_real = bundled_groups_real.reset_index(drop=True).astype(str)
if groups_generated is None and bundled_groups_gen is not None:
groups_generated = bundled_groups_gen.reset_index(drop=True).astype(str)
# --- Prepare numeric matrices -------------------------------------------
real_numeric = real_df.select_dtypes(include=[np.number])
gen_numeric = gen_df.iloc[:, : real_numeric.shape[1]].copy()
gen_numeric.columns = real_numeric.columns
# When apply_log is True, log-transform both real and generated data so
# they are compared in the same (log2) scale. Generated data is now
# returned in count scale by the experiment API.
if apply_log:
real_numeric = np.log2(real_numeric + 1)
gen_numeric = np.log2(gen_numeric + 1)
# --- Sub-sample for speed -----------------------------------------------
if n_samples is not None and n_samples < len(real_numeric):
n = min(n_samples, len(real_numeric) // 2)
real_idx = list(range(n)) + list(
range(len(real_numeric) - n, len(real_numeric))
)
else:
real_idx = list(range(len(real_numeric)))
if n_samples is not None and n_samples < len(gen_numeric):
n = min(n_samples, len(gen_numeric) // 2)
gen_idx = list(range(n)) + list(range(len(gen_numeric) - n, len(gen_numeric)))
else:
gen_idx = list(range(len(gen_numeric)))
real_sub = real_numeric.iloc[real_idx]
gen_sub = gen_numeric.iloc[gen_idx]
groups_real_sub = groups_real.iloc[real_idx] if groups_real is not None else None
groups_gen_sub = (
groups_generated.iloc[gen_idx] if groups_generated is not None else None
)
# --- Produce figures ----------------------------------------------------
fig_heatmap = heatmap_eval(
real_data=real_sub, generated_data=gen_sub, apply_log=False
)
fig_umap = UMAP_eval(
real_data=real_sub,
generated_data=gen_sub,
apply_log=False,
groups_real=groups_real_sub,
groups_generated=groups_gen_sub,
random_seed=random_seed,
)
return {"heatmap": fig_heatmap, "umap": fig_umap}