Source code for syng_bts.helper_models

from __future__ import annotations

# %%
import torch
import torch.nn as nn

# %%


[docs] class AE(nn.Module):
[docs] def __init__(self, num_features): super().__init__() self.encoder = nn.Sequential( nn.Linear(int(num_features), 256), nn.ReLU(True), nn.Linear(256, 128), nn.ReLU(True), nn.Linear(128, 64), nn.ReLU(True), ) self.decoder = nn.Sequential( nn.Linear(64, 128), nn.ReLU(True), nn.Linear(128, 256), nn.ReLU(True), nn.Linear(256, int(num_features)), nn.ReLU(True), )
[docs] def forward(self, x): encoded = self.encoder(x) decoded = self.decoder(encoded) return encoded, decoded
# %%
[docs] class VAE(nn.Module):
[docs] def __init__(self, num_features): super().__init__() self.encoder = nn.Sequential( nn.Linear(int(num_features), 256), nn.BatchNorm1d(256), nn.ReLU(True), nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(True), nn.Linear(128, 64), nn.BatchNorm1d(64), nn.ReLU(True), ) self.z_mean = torch.nn.Linear(64, 32) self.z_log_var = torch.nn.Linear(64, 32) self.decoder = nn.Sequential( nn.Linear(32, 64), nn.BatchNorm1d(64), nn.ReLU(True), nn.Linear(64, 128), nn.BatchNorm1d(128), nn.ReLU(True), nn.Linear(128, 256), nn.BatchNorm1d(256), nn.ReLU(True), nn.Linear(256, int(num_features)), nn.ReLU(True), )
[docs] def reparameterize(self, z_mu, z_log_var, deterministic=False): if deterministic: return z_mu else: eps = torch.randn(z_mu.size(0), z_mu.size(1)) # .to(z_mu.get_device()) z = z_mu + eps * torch.exp(z_log_var / 2.0) return z
[docs] def encoding_fn(self, x, deterministic=False): x = self.encoder(x) z_mean, z_log_var = self.z_mean(x), self.z_log_var(x) encoded = self.reparameterize(z_mean, z_log_var, deterministic) return encoded
[docs] def forward(self, x, deterministic=False): x = self.encoder(x) z_mean, z_log_var = self.z_mean(x), self.z_log_var(x) encoded = self.reparameterize(z_mean, z_log_var, deterministic) decoded = self.decoder(encoded) return encoded, z_mean, z_log_var, decoded
# %%
[docs] class CVAE(nn.Module):
[docs] def __init__(self, num_features, num_classes, wide_network=False): super().__init__() self.num_features = num_features self.num_classes = num_classes self.wide_network = wide_network # encoder input: features + label(dim = 1) # wide_network=True: suitable for high-dimensional data like RNA, # use wider encoder (512 -> 256 -> 128 -> 64) # wide_network=False: suitable for miRNA data, # keep original structure (256 -> 128 -> 64) if wide_network: enc_hidden = (512, 256, 128, 64) else: enc_hidden = (256, 128, 64) enc_layers: list[nn.Module] = [] in_dim = num_features + 1 for h in enc_hidden: enc_layers.extend( [ nn.Linear(in_dim, h), nn.BatchNorm1d(h), nn.ReLU(True), ] ) in_dim = h self.encoder = nn.Sequential(*enc_layers) bottleneck_dim = enc_hidden[-1] self.z_mean = nn.Linear(bottleneck_dim, 32) self.z_log_var = nn.Linear(bottleneck_dim, 32) # decoder: symmetric to encoder dec_hidden = tuple(reversed(enc_hidden)) dec_layers: list[nn.Module] = [] in_dim = 32 + 1 for h in dec_hidden: dec_layers.extend( [ nn.Linear(in_dim, h), nn.BatchNorm1d(h), nn.ReLU(True), ] ) in_dim = h dec_layers.append(nn.Linear(in_dim, num_features)) dec_layers.append(nn.ReLU(True)) self.decoder = nn.Sequential(*dec_layers)
[docs] def reparameterize(self, z_mu, z_log_var, deterministic=False): if deterministic: return z_mu else: eps = torch.randn_like(z_mu) z = z_mu + eps * torch.exp(z_log_var / 2.0) return z
[docs] def encoding_fn(self, x, y, deterministic=False): x = self.encoder(torch.cat((x, y), dim=1)) z_mean, z_log_var = self.z_mean(x), self.z_log_var(x) encoded = self.reparameterize(z_mean, z_log_var, deterministic) return z_mean, z_log_var, encoded
[docs] def decoding_fn(self, encoded, y): encoded = torch.cat((encoded, y), dim=1) decoded = self.decoder(encoded) return decoded
[docs] def forward(self, x, y, deterministic=False): z_mean, z_log_var, encoded = self.encoding_fn(x, y, deterministic) decoded = self.decoding_fn(encoded, y) return encoded, z_mean, z_log_var, decoded
# %%
[docs] class GAN(torch.nn.Module):
[docs] def __init__(self, num_features, latent_dim=32): super().__init__() self.num_features = num_features self.generator = nn.Sequential( nn.Linear(latent_dim, 128), nn.ReLU(inplace=True), nn.Linear(128, 256), nn.ReLU(inplace=True), nn.Linear(256, num_features), nn.ReLU(inplace=True), ) self.discriminator = nn.Sequential( nn.Linear(num_features, 256), nn.ReLU(inplace=True), nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Linear(128, 1), )
[docs] def generator_forward(self, z): # z is input low dimension noise img = self.generator(z) return img
[docs] def discriminator_forward(self, img): logits = self.discriminator(img) return logits