From 8bffef4d0fc9ebc3b318f3d9489653b9bea0b52d Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 25 Feb 2026 12:36:24 -0600 Subject: [PATCH 1/9] T1: add MedGAN source files and generators directory --- examples/generate_synthetic_mimic3_medgan.py | 163 +++++++ ...synthetic_data_generation_mimic3_medgan.py | 390 +++++++++++++++ pyhealth/models/generators/__init__.py | 0 pyhealth/models/generators/medgan.py | 447 ++++++++++++++++++ 4 files changed, 1000 insertions(+) create mode 100644 examples/generate_synthetic_mimic3_medgan.py create mode 100644 examples/synthetic_data_generation_mimic3_medgan.py create mode 100644 pyhealth/models/generators/__init__.py create mode 100644 pyhealth/models/generators/medgan.py diff --git a/examples/generate_synthetic_mimic3_medgan.py b/examples/generate_synthetic_mimic3_medgan.py new file mode 100644 index 000000000..3c163926a --- /dev/null +++ b/examples/generate_synthetic_mimic3_medgan.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +""" +Generate synthetic MIMIC-III patients using a trained MedGAN checkpoint. +Uses simple 0.5 threshold - MedGAN doesn't require post-processing. +""" + +import os +import argparse +import torch +import numpy as np +import pandas as pd +from pyhealth.models.generators.medgan import MedGAN + + +def main(): + parser = argparse.ArgumentParser(description="Generate synthetic patients using trained MedGAN") + parser.add_argument("--checkpoint", required=True, help="Path to trained MedGAN checkpoint (.pth)") + parser.add_argument("--vocab", required=True, help="Path to ICD-9 vocabulary file (.txt)") + parser.add_argument("--data_matrix", required=True, help="Path to training data matrix (.npy)") + parser.add_argument("--output", required=True, help="Path to output CSV file") + parser.add_argument("--n_samples", type=int, default=10000, help="Number of synthetic patients to generate") + parser.add_argument("--threshold", type=float, default=0.5, help="Threshold for binarization (binary mode only)") + + # Mode parameters + parser.add_argument("--data_mode", type=str, default="binary", choices=["binary", "count"], + help="Data mode: 'binary' (default) or 'count'") + parser.add_argument("--count_activation", type=str, default="relu", choices=["relu", "softplus"], + help="Activation for count mode: 'relu' (default) or 'softplus'") + parser.add_argument("--count_loss", type=str, default="mse", choices=["mse", "poisson"], + help="Loss function for count mode: 'mse' (default) or 'poisson'") + + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + print(f"Data mode: {args.data_mode}") + + # Load vocabulary + print(f"Loading vocabulary from {args.vocab}") + with open(args.vocab, 'r') as f: + code_vocab = [line.strip() for line in f] + print(f"Loaded {len(code_vocab)} ICD-9 codes") + + # Load data matrix to get architecture dimensions + print(f"Loading data matrix from {args.data_matrix}") + data_matrix = np.load(args.data_matrix) + n_codes = data_matrix.shape[1] + print(f"Data matrix shape: {data_matrix.shape}") + if args.data_mode == "binary": + print(f"Real data avg codes/patient: {data_matrix.sum(axis=1).mean():.2f}") + else: + print(f"Real data avg code occurrences/patient: {data_matrix.sum(axis=1).mean():.2f}") + print(f"Real data max count: {data_matrix.max():.0f}") + + # Load checkpoint + print(f"\nLoading checkpoint from {args.checkpoint}") + checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False) + + # Initialize MedGAN with same architecture + print("Initializing MedGAN model...") + if args.data_mode == "binary": + model = MedGAN.from_binary_matrix( + binary_matrix=data_matrix, + latent_dim=128, + autoencoder_hidden_dim=128, + discriminator_hidden_dim=256, + minibatch_averaging=True, + data_mode=args.data_mode + ).to(device) + else: # count mode + model = MedGAN.from_count_matrix( + count_matrix=data_matrix, + latent_dim=128, + autoencoder_hidden_dim=128, + discriminator_hidden_dim=256, + minibatch_averaging=True, + count_activation=args.count_activation, + count_loss=args.count_loss + ).to(device) + + # Load trained weights + model.autoencoder.load_state_dict(checkpoint['autoencoder_state_dict']) + model.generator.load_state_dict(checkpoint['generator_state_dict']) + model.discriminator.load_state_dict(checkpoint['discriminator_state_dict']) + + model.eval() + print("Model loaded successfully") + + # Generate synthetic patients + print(f"\nGenerating {args.n_samples} synthetic patients...") + + with torch.no_grad(): + # Generate data + synthetic_data = model.generate(args.n_samples, device) + + # Apply transform (binary threshold or count round+clip) + if args.data_mode == "binary": + discrete_data = model.sample_transform(synthetic_data, threshold=args.threshold) + else: + discrete_data = model.sample_transform(synthetic_data) + + data_matrix_synthetic = discrete_data.cpu().numpy() + + # Calculate statistics + avg_codes = data_matrix_synthetic.sum(axis=1).mean() + std_codes = data_matrix_synthetic.sum(axis=1).std() + min_codes = data_matrix_synthetic.sum(axis=1).min() + max_codes = data_matrix_synthetic.sum(axis=1).max() + sparsity = (data_matrix_synthetic == 0).mean() + + print(f"\nSynthetic data statistics:") + if args.data_mode == "binary": + print(f" Avg codes per patient: {avg_codes:.2f} ± {std_codes:.2f}") + else: + print(f" Avg code occurrences per patient: {avg_codes:.2f} ± {std_codes:.2f}") + print(f" Max count: {data_matrix_synthetic.max():.0f}") + print(f" Range: [{min_codes:.0f}, {max_codes:.0f}]") + print(f" Sparsity: {sparsity:.4f}") + + # Check heterogeneity + unique_profiles = len(set(tuple(row) for row in data_matrix_synthetic)) + print(f" Unique patient profiles: {unique_profiles}/{args.n_samples} ({unique_profiles/args.n_samples*100:.1f}%)") + + # Convert to CSV format (SUBJECT_ID, ICD9_CODE) + print(f"\nConverting to CSV format...") + records = [] + for patient_idx in range(args.n_samples): + patient_id = f"SYNTHETIC_{patient_idx+1:06d}" + + if args.data_mode == "binary": + # Binary mode: include codes where value == 1 + code_indices = np.where(data_matrix_synthetic[patient_idx] == 1)[0] + for code_idx in code_indices: + records.append({ + 'SUBJECT_ID': patient_id, + 'ICD9_CODE': code_vocab[code_idx] + }) + else: # count mode + # Count mode: repeat codes based on their counts + for code_idx in range(n_codes): + count = int(data_matrix_synthetic[patient_idx, code_idx]) + for _ in range(count): + records.append({ + 'SUBJECT_ID': patient_id, + 'ICD9_CODE': code_vocab[code_idx] + }) + + df = pd.DataFrame(records) + print(f"Created {len(df)} diagnosis records for {args.n_samples} patients") + + # Save to CSV + print(f"\nSaving to {args.output}") + df.to_csv(args.output, index=False) + + file_size_mb = os.path.getsize(args.output) / (1024 * 1024) + print(f"Saved {file_size_mb:.1f} MB") + + print("\n✓ Generation complete!") + print(f"Output: {args.output}") + + +if __name__ == '__main__': + main() diff --git a/examples/synthetic_data_generation_mimic3_medgan.py b/examples/synthetic_data_generation_mimic3_medgan.py new file mode 100644 index 000000000..c834d912c --- /dev/null +++ b/examples/synthetic_data_generation_mimic3_medgan.py @@ -0,0 +1,390 @@ +""" +Synthetic data generation using MedGAN on MIMIC-III data. + +This example demonstrates how to train MedGAN to generate synthetic ICD-9 matrices +from MIMIC-III data, following PyHealth conventions. +""" + +import os +import torch +import numpy as np +import argparse +from torch.utils.data import DataLoader +import pickle +import json +from tqdm import tqdm +import pandas as pd + +from pyhealth.datasets import MIMIC3Dataset +from pyhealth.datasets.icd9_matrix import create_icd9_matrix, ICD9MatrixDataset +from pyhealth.models.generators.medgan import MedGAN + +""" +python examples/synthetic_data_generation_mimic3_medgan.py --autoencoder_epochs 5 --gan_epochs 10 --batch_size 16 +""" +def train_medgan(model, dataloader, n_epochs, device, save_dir, lr=0.001, weight_decay=0.0001, b1=0.5, b2=0.9): + """ + Train MedGAN model using the original synthEHRella approach. + + Args: + model: MedGAN model + dataloader: DataLoader for training data + n_epochs: Number of training epochs + device: Device to train on + save_dir: Directory to save checkpoints + lr: Learning rate + weight_decay: Weight decay for regularization + b1: Beta1 for Adam optimizer + b2: Beta2 for Adam optimizer + + Returns: + loss_history: Dictionary containing loss history + """ + + def generator_loss(y_fake): + """ + Original synthEHRella generator loss + """ + # standard GAN generator loss - want fake samples to be classified as real + return -torch.mean(torch.log(y_fake + 1e-12)) + + def discriminator_loss(outputs, labels): + """ + Original synthEHRella discriminator loss + """ + loss = -torch.mean(labels * torch.log(outputs + 1e-12)) - torch.mean((1 - labels) * torch.log(1. - outputs + 1e-12)) + return loss + + optimizer_g = torch.optim.Adam([ + {'params': model.generator.parameters()}, + {'params': model.autoencoder.decoder.parameters(), 'lr': lr * 0.1} + ], lr=lr, betas=(b1, b2), weight_decay=weight_decay) + + optimizer_d = torch.optim.Adam(model.discriminator.parameters(), + lr=lr * 0.1, betas=(b1, b2), weight_decay=weight_decay) + + g_losses = [] + d_losses = [] + + print("="*60) + print("Epoch | D_loss | G_loss | Progress") + print("="*60) + + for epoch in range(n_epochs): + epoch_g_loss = 0.0 + epoch_d_loss = 0.0 + num_batches = 0 + + for i, real_data in enumerate(dataloader): + real_data = real_data.to(device) + batch_size = real_data.size(0) + + valid = torch.ones(batch_size).to(device) # 1D tensor + fake = torch.zeros(batch_size).to(device) # 1D tensor + + z = torch.randn(batch_size, model.latent_dim).to(device) + + # Disable discriminator gradients for generator training to prevent discriminator from being updated + for p in model.discriminator.parameters(): + p.requires_grad = False + + # generate fake samples + fake_samples = model.generator(z) + fake_samples = model.autoencoder.decode(fake_samples) + + # generator loss using original medgan loss function + fake_output = model.discriminator(fake_samples).view(-1) + g_loss = generator_loss(fake_output) + + optimizer_g.zero_grad() + g_loss.backward() + optimizer_g.step() + + # --------------------- + # Train Discriminator + # --------------------- + + # Enable discriminator gradients + for p in model.discriminator.parameters(): + p.requires_grad = True + + optimizer_d.zero_grad() + + # Real samples + real_output = model.discriminator(real_data).view(-1) + real_loss = discriminator_loss(real_output, valid) + real_loss.backward() + + # Fake samples (detached) + fake_output = model.discriminator(fake_samples.detach()).view(-1) + fake_loss = discriminator_loss(fake_output, fake) + fake_loss.backward() + + # Total discriminator loss + d_loss = (real_loss + fake_loss) / 2 + + optimizer_d.step() + + # Track losses + epoch_g_loss += g_loss.item() + epoch_d_loss += d_loss.item() + num_batches += 1 + + # calculate average losses + avg_g_loss = epoch_g_loss / num_batches + avg_d_loss = epoch_d_loss / num_batches + + # store losses for trackin + g_losses.append(avg_g_loss) + d_losses.append(avg_d_loss) + + progress = (epoch + 1) / n_epochs * 100 + print(f"{epoch+1:5d} | {avg_d_loss:.4f} | {avg_g_loss:.4f} | {progress:5.1f}%") + + # save every 50 epochs + if (epoch + 1) % 50 == 0: + checkpoint_path = os.path.join(save_dir, f"medgan_epoch_{epoch+1}.pth") + torch.save({ + 'epoch': epoch + 1, + 'generator_state_dict': model.generator.state_dict(), + 'discriminator_state_dict': model.discriminator.state_dict(), + 'autoencoder_state_dict': model.autoencoder.state_dict(), + 'optimizer_g_state_dict': optimizer_g.state_dict(), + 'optimizer_d_state_dict': optimizer_d.state_dict(), + 'g_losses': g_losses, + 'd_losses': d_losses, + }, checkpoint_path) + print(f"Checkpoint saved to {checkpoint_path}") + + print("="*60) + print("GAN Training Completed!") + print(f"Final G_loss: {g_losses[-1]:.4f}") + print(f"Final D_loss: {d_losses[-1]:.4f}") + + # save loss history + loss_history = { + 'g_losses': g_losses, + 'd_losses': d_losses, + } + np.save(os.path.join(save_dir, "loss_history.npy"), loss_history) + + return loss_history + + + + +def main(): + parser = argparse.ArgumentParser(description="Train MedGAN for synthetic data generation") + parser.add_argument("--data_path", type=str, default="./data_files", help="path to MIMIC-III data") + parser.add_argument("--output_path", type=str, default="./medgan_results", help="Output directory") + parser.add_argument("--autoencoder_epochs", type=int, default=100, help="Autoencoder pretraining epochs") + parser.add_argument("--gan_epochs", type=int, default=1000, help="GAN training epochs") + parser.add_argument("--latent_dim", type=int, default=128, help="Latent dimension") + parser.add_argument("--hidden_dim", type=int, default=128, help="Hidden dimension") + parser.add_argument("--batch_size", type=int, default=128, help="Batch size") + parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate") + parser.add_argument("--weight_decay", type=float, default=0.0001, help="l2 regularization") + parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") + parser.add_argument("--b2", type=float, default=0.9, help="adam: decay of second order momentum of gradient") + parser.add_argument("--save_dir", type=str, default="medgan_results", help="directory to save results") + args = parser.parse_args() + + # setup + os.makedirs(args.output_path, exist_ok=True) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # load MIMIC-III data + print("Loading MIMIC-III data") + dataset = MIMIC3Dataset(root=args.data_path, tables=["DIAGNOSES_ICD"]) + + # create ICD-9 matrix using utility function + print("Creating ICD-9 matrix") + icd9_matrix, icd9_types = create_icd9_matrix(dataset, args.output_path) + print(f"ICD-9 matrix shape: {icd9_matrix.shape}") + + + # initialize MedGAN model + print("Initializing MedGAN model...") + model = MedGAN.from_binary_matrix( + binary_matrix=icd9_matrix, + latent_dim=args.latent_dim, + autoencoder_hidden_dim=args.hidden_dim, + discriminator_hidden_dim=args.hidden_dim, + minibatch_averaging=True + ) + + # device stuff + model = model.to(device) + model.autoencoder = model.autoencoder.to(device) + model.generator = model.generator.to(device) + model.discriminator = model.discriminator.to(device) + + # make a dataloader + print("Creating dataloader...") + icd9_matrix_dataset = ICD9MatrixDataset(icd9_matrix) + dataloader = DataLoader( + icd9_matrix_dataset, + batch_size=args.batch_size, + shuffle=True + ) + + # autoencoder pretraining + print("Pretraining autoencoder...") + autoencoder_losses = model.pretrain_autoencoder( + dataloader=dataloader, + epochs=args.autoencoder_epochs, + lr=args.lr, + device=device + ) + + # train GAN + print("Training GAN...") + gan_loss_history = train_medgan( + model=model, + dataloader=dataloader, + n_epochs=args.gan_epochs, + device=device, + save_dir=args.save_dir, + lr=args.lr, + weight_decay=args.weight_decay, + b1=args.b1, + b2=args.b2 + ) + + # generate synthetic data + print("Generating synthetic data...") + with torch.no_grad(): + synthetic_data = model.generate(1000, device) + binary_data = model.sample_transform(synthetic_data, threshold=0.5) + + synthetic_matrix = binary_data.cpu().numpy() + + # save + print("Saving results...") + torch.save({ + 'model_config': { + 'latent_dim': args.latent_dim, + 'hidden_dim': args.hidden_dim, + 'autoencoder_hidden_dim': args.hidden_dim, + 'discriminator_hidden_dim': args.hidden_dim, + 'input_dim': icd9_matrix.shape[1], + }, + 'generator_state_dict': model.generator.state_dict(), + 'discriminator_state_dict': model.discriminator.state_dict(), + 'autoencoder_state_dict': model.autoencoder.state_dict(), + }, os.path.join(args.output_path, "medgan_final.pth")) + + np.save(os.path.join(args.output_path, "synthetic_binary_matrix.npy"), synthetic_matrix) + + # save loss histories + loss_history = { + 'autoencoder_losses': autoencoder_losses, + 'gan_losses': gan_loss_history, + } + np.save(os.path.join(args.output_path, "loss_history.npy"), loss_history) + + # print final stats + print("\n" + "="*50) + print("TRAINING COMPLETED") + print("="*50) + print(f"Real data shape: {icd9_matrix.shape}") + print(f"Real data mean activation: {icd9_matrix.mean():.4f}") + print(f"Real data sparsity: {(icd9_matrix == 0).mean():.4f}") + print(f"Synthetic data shape: {synthetic_matrix.shape}") + print(f"Synthetic data mean activation: {synthetic_matrix.mean():.4f}") + print(f"Synthetic data sparsity: {(synthetic_matrix == 0).mean():.4f}") + print(f"Results saved to: {args.output_path}") + print("="*50) + + print("\nGenerated synthetic data in original MIMIC3 ICD-9 format.") + + +if __name__ == "__main__": + main() + +""" +Slurm script example: + +#!/bin/bash +#SBATCH --account=jalenj4-ic +#SBATCH --job-name=medgan_pyhealth +#SBATCH --output=logs/medgan_pyhealth_%j.out +#SBATCH --error=logs/medgan_pyhealth_%j.err +#SBATCH --partition=IllinoisComputes-GPU # Change to appropriate partition +#SBATCH --gres=gpu:1 # Request 1 GPU +#SBATCH --cpus-per-task=4 +#SBATCH --mem=32G +#SBATCH --time=12:00:00 + +# Change to the directory where you submitted the job +cd "$SLURM_SUBMIT_DIR" + +# Print useful Slurm environment variables for debugging +echo "SLURM_JOB_ID: $SLURM_JOB_ID" +echo "SLURM_JOB_NODELIST: $SLURM_JOB_NODELIST" +echo "SLURM_NTASKS: $SLURM_NTASKS" +echo "SLURM_CPUS_ON_NODE: $SLURM_CPUS_ON_NODE" +echo "SLURM_GPUS_ON_NODE: $SLURM_GPUS_ON_NODE" +echo "SLURM_GPUS: $SLURM_GPUS" +echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES" + +# Optional: check what GPU(s) is/are actually visible +echo "Running nvidia-smi to confirm GPU availability:" +nvidia-smi + +# Load modules or activate environment +# module load python/3.10 +# module load cuda/11.7 +# conda activate pyhealth + +# Create output directories +mkdir -p logs +mkdir -p medgan_results + +# Set parameters (matching original synthEHRella defaults) +export AUTOENCODER_EPOCHS=100 +export GAN_EPOCHS=1000 +export BATCH_SIZE=128 +export LATENT_DIM=128 +export HIDDEN_DIM=128 +export NUM_SAMPLES=1000 +export LEARNING_RATE=0.001 +export WEIGHT_DECAY=0.0001 +export BETA1=0.5 +export BETA2=0.9 + +echo "Starting PyHealth MedGAN training with parameters:" +echo " Autoencoder epochs: $AUTOENCODER_EPOCHS" +echo " GAN epochs: $GAN_EPOCHS" +echo " Batch size: $BATCH_SIZE" +echo " Latent dimension: $LATENT_DIM" +echo " Hidden dimension: $HIDDEN_DIM" +echo " Number of synthetic samples: $NUM_SAMPLES" +echo " Learning rate: $LEARNING_RATE" +echo " Weight decay: $WEIGHT_DECAY" +echo " Beta1: $BETA1" +echo " Beta2: $BETA2" + +# Run the comprehensive PyHealth MedGAN script +python examples/synthetic_data_generation_mimic3_medgan.py \ + --data_path ./data_files \ + --output_path ./medgan_results \ + --autoencoder_epochs $AUTOENCODER_EPOCHS \ + --gan_epochs $GAN_EPOCHS \ + --batch_size $BATCH_SIZE \ + --latent_dim $LATENT_DIM \ + --hidden_dim $HIDDEN_DIM \ + --lr $LEARNING_RATE \ + --weight_decay $WEIGHT_DECAY \ + --b1 $BETA1 \ + --b2 $BETA2 \ + --postprocess + +echo "PyHealth MedGAN training completed!" +echo "Results saved to: ./medgan_results/" +echo "Check the following files:" +echo " - synthetic_binary_matrix.npy: Raw synthetic data" +echo " - medgan_final.pth: Trained model" +echo " - loss_history.npy: Training loss history" +""" \ No newline at end of file diff --git a/pyhealth/models/generators/__init__.py b/pyhealth/models/generators/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pyhealth/models/generators/medgan.py b/pyhealth/models/generators/medgan.py new file mode 100644 index 000000000..5fb06ca67 --- /dev/null +++ b/pyhealth/models/generators/medgan.py @@ -0,0 +1,447 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from typing import Dict, List, Optional, Tuple, Union +from torch.utils.data import DataLoader + +from pyhealth.models import BaseModel + + +class MedGANAutoencoder(nn.Module): + """simple autoencoder for pretraining""" + + def __init__(self, input_dim: int, hidden_dim: int = 128, + data_mode: str = "binary", count_activation: str = "relu"): + super().__init__() + self.data_mode = data_mode + self.encoder = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.Tanh() + ) + + # Conditional decoder activation based on data mode + if data_mode == "binary": + self.decoder = nn.Sequential( + nn.Linear(hidden_dim, input_dim), + nn.Sigmoid() + ) + else: # count mode + activation = nn.ReLU() if count_activation == "relu" else nn.Softplus() + self.decoder = nn.Sequential( + nn.Linear(hidden_dim, input_dim), + activation + ) + + def forward(self, x): + encoded = self.encoder(x) + decoded = self.decoder(encoded) + return decoded + + def encode(self, x): + return self.encoder(x) + + def decode(self, x): + return self.decoder(x) + +# ONLY USE ADMISSIONS AND DIAGNOSES FOR EVERYTHING + +class MedGANGenerator(nn.Module): + """generator with residual connections""" + + def __init__(self, latent_dim: int = 128, hidden_dim: int = 128): + super().__init__() + self.linear1 = nn.Linear(latent_dim, hidden_dim) + self.bn1 = nn.BatchNorm1d(hidden_dim, eps=0.001, momentum=0.01) + self.activation1 = nn.ReLU() + + self.linear2 = nn.Linear(hidden_dim, hidden_dim) + self.bn2 = nn.BatchNorm1d(hidden_dim, eps=0.001, momentum=0.01) + self.activation2 = nn.Tanh() + + def forward(self, x): + # residual block 1 + residual = x + out = self.activation1(self.bn1(self.linear1(x))) + out1 = out + residual + + # residual block 2 + residual = out1 + out = self.activation2(self.bn2(self.linear2(out1))) + out2 = out + residual + + return out2 + + +class MedGANDiscriminator(nn.Module): + """discriminator with minibatch averaging""" + + def __init__(self, input_dim: int, hidden_dim: int = 256, minibatch_averaging: bool = True): + super().__init__() + self.minibatch_averaging = minibatch_averaging + model_input_dim = input_dim * 2 if minibatch_averaging else input_dim + + self.model = nn.Sequential( + nn.Linear(model_input_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim // 2), + nn.ReLU(), + nn.Linear(hidden_dim // 2, 1), + nn.Sigmoid() + ) + + def forward(self, x): + if self.minibatch_averaging: + x_mean = torch.mean(x, dim=0).repeat(x.shape[0], 1) + x = torch.cat((x, x_mean), dim=1) + return self.model(x) + + +class MedGAN(BaseModel): + """MedGAN for binary matrix generation""" + + def __init__( + self, + dataset, + feature_keys: List[str], + label_key: str, + mode: str = "generation", + data_mode: str = "binary", + count_activation: str = "relu", + count_loss: str = "mse", + latent_dim: int = 128, + hidden_dim: int = 128, + autoencoder_hidden_dim: int = 128, + discriminator_hidden_dim: int = 256, + minibatch_averaging: bool = True, + **kwargs + ): + # dummy wrapper for BaseModel compatibility + class DummyWrapper: + def __init__(self, dataset, feature_keys, label_key): + self.dataset = dataset + self.input_schema = {key: "multilabel" for key in feature_keys} + self.output_schema = {label_key: "multilabel"} + self.input_processors = {} + self.output_processors = {} + + wrapped_dataset = DummyWrapper(dataset, feature_keys, label_key) + super().__init__(dataset=wrapped_dataset) + + self.data_mode = data_mode + self.count_activation = count_activation + self.count_loss = count_loss + self.latent_dim = latent_dim + self.hidden_dim = hidden_dim + self.minibatch_averaging = minibatch_averaging + + # build vocab (simplified) + self.global_vocab = self._build_global_vocab(dataset, feature_keys) + self.input_dim = len(self.global_vocab) + + # init components + self.autoencoder = MedGANAutoencoder( + input_dim=self.input_dim, + hidden_dim=autoencoder_hidden_dim, + data_mode=data_mode, + count_activation=count_activation + ) + self.generator = MedGANGenerator(latent_dim=latent_dim, hidden_dim=autoencoder_hidden_dim) + self.discriminator = MedGANDiscriminator( + input_dim=self.input_dim, + hidden_dim=discriminator_hidden_dim, + minibatch_averaging=minibatch_averaging + ) + + self._init_weights() + + @classmethod + def from_binary_matrix( + cls, + binary_matrix: np.ndarray, + latent_dim: int = 128, + hidden_dim: int = 128, + autoencoder_hidden_dim: int = 128, + discriminator_hidden_dim: int = 256, + minibatch_averaging: bool = True, + data_mode: str = "binary", + count_activation: str = "relu", + count_loss: str = "mse", + **kwargs + ): + """create MedGAN model from binary matrix (ICD-9, etc.)""" + class MatrixWrapper: + def __init__(self, matrix): + self.matrix = matrix + self.input_processors = {} + self.output_processors = {} + + def __len__(self): + return self.matrix.shape[0] + + def __getitem__(self, idx): + return {"binary_vector": torch.tensor(self.matrix[idx], dtype=torch.float32)} + + def iter_patients(self): + """iterate over patients""" + for i in range(len(self)): + yield type('Patient', (), { + 'binary_vector': self.matrix[i], + 'patient_id': f'patient_{i}' + })() + + dummy_dataset = MatrixWrapper(binary_matrix) + + model = cls( + dataset=dummy_dataset, + feature_keys=["binary_vector"], + label_key="binary_vector", + data_mode=data_mode, + count_activation=count_activation, + count_loss=count_loss, + latent_dim=latent_dim, + hidden_dim=hidden_dim, + autoencoder_hidden_dim=autoencoder_hidden_dim, + discriminator_hidden_dim=discriminator_hidden_dim, + minibatch_averaging=minibatch_averaging, + **kwargs + ) + + # override input dimension + model.input_dim = binary_matrix.shape[1] + + # reinitialize components with correct dimensions + model.autoencoder = MedGANAutoencoder( + input_dim=model.input_dim, + hidden_dim=autoencoder_hidden_dim, + data_mode=data_mode, + count_activation=count_activation + ) + model.generator = MedGANGenerator(latent_dim=latent_dim, hidden_dim=autoencoder_hidden_dim) + model.discriminator = MedGANDiscriminator( + input_dim=model.input_dim, + hidden_dim=discriminator_hidden_dim, + minibatch_averaging=minibatch_averaging + ) + + # Move all components to the same device as the model + device = next(model.parameters()).device if hasattr(model, 'parameters') else torch.device('cpu') + model.autoencoder = model.autoencoder.to(device) + model.generator = model.generator.to(device) + model.discriminator = model.discriminator.to(device) + + # override feature extraction + def extract_features(batch_data, device): + return batch_data["binary_vector"].to(device) + + model._extract_features_from_batch = extract_features + + return model + + @classmethod + def from_count_matrix( + cls, + count_matrix: np.ndarray, + latent_dim: int = 128, + hidden_dim: int = 128, + autoencoder_hidden_dim: int = 128, + discriminator_hidden_dim: int = 256, + minibatch_averaging: bool = True, + count_activation: str = "relu", + count_loss: str = "mse", + **kwargs + ): + """Create MedGAN model from count matrix (integers >= 0) + + This is a convenience method that calls from_binary_matrix with data_mode="count". + The name is kept as count_matrix to be explicit about expected input format. + + Args: + count_matrix: numpy array with shape (n_patients, n_features) containing counts (integers >= 0) + latent_dim: dimension of latent space for generator + hidden_dim: hidden dimension for generator + autoencoder_hidden_dim: hidden dimension for autoencoder + discriminator_hidden_dim: hidden dimension for discriminator + minibatch_averaging: whether to use minibatch averaging in discriminator + count_activation: activation function for count mode ("relu" or "softplus") + count_loss: loss function for count mode ("mse" or "poisson") + **kwargs: additional arguments + + Returns: + MedGAN model configured for count mode + """ + return cls.from_binary_matrix( + binary_matrix=count_matrix, + latent_dim=latent_dim, + hidden_dim=hidden_dim, + autoencoder_hidden_dim=autoencoder_hidden_dim, + discriminator_hidden_dim=discriminator_hidden_dim, + minibatch_averaging=minibatch_averaging, + data_mode="count", + count_activation=count_activation, + count_loss=count_loss, + **kwargs + ) + + def _build_global_vocab(self, dataset, feature_keys: List[str]) -> List[str]: + """build vocab from dataset (simplified)""" + vocab = set() + for patient in dataset.iter_patients(): + for feature_key in feature_keys: + if hasattr(patient, feature_key): + feature_values = getattr(patient, feature_key) + if isinstance(feature_values, list): + vocab.update(feature_values) + elif isinstance(feature_values, str): + vocab.add(feature_values) + return sorted(list(vocab)) + + def _init_weights(self): + """init weights""" + def weights_init(m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + self.autoencoder.apply(weights_init) + self.generator.apply(weights_init) + self.discriminator.apply(weights_init) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """forward pass""" + features = self._extract_features_from_batch(kwargs, self.device) + noise = torch.randn(features.shape[0], self.latent_dim, device=self.device) + fake_samples = self.generator(noise) + return {"real_features": features, "fake_samples": fake_samples} + + def generate(self, n_samples: int, device: torch.device = None) -> torch.Tensor: + """generate synthetic samples""" + if device is None: + device = self.device + + self.generator.eval() + self.autoencoder.eval() + with torch.no_grad(): + noise = torch.randn(n_samples, self.latent_dim, device=device) + generated = self.generator(noise) + # use autoencoder decoder to get final output + generated = self.autoencoder.decode(generated) + + return generated + + def discriminate(self, x: torch.Tensor) -> torch.Tensor: + """discriminate real vs fake""" + return self.discriminator(x) + + def pretrain_autoencoder(self, dataloader: DataLoader, epochs: int = 100, lr: float = 0.001, device: torch.device = None): + """pretrain autoencoder with detailed loss tracking""" + if device is None: + device = self.device + + # Ensure autoencoder is on the correct device + self.autoencoder = self.autoencoder.to(device) + + print("Pretraining Autoencoder...") + print("="*50) + print("Epoch | A_loss | Progress") + print("="*50) + + optimizer = torch.optim.Adam(self.autoencoder.parameters(), lr=lr) + + # Conditional loss function based on data mode + if self.data_mode == "binary": + criterion = nn.BCELoss() + elif self.count_loss == "mse": + criterion = nn.MSELoss() + else: # poisson + criterion = nn.PoissonNLLLoss(log_input=False) + + # Track losses for plotting + a_losses = [] + + self.autoencoder.train() + + for epoch in range(epochs): + total_loss = 0 + num_batches = 0 + + for batch in dataloader: + # handle both tensor and dict inputs + if isinstance(batch, torch.Tensor): + features = batch.to(device) + else: + features = self._extract_features_from_batch(batch, device) + + reconstructed = self.autoencoder(features) + loss = criterion(reconstructed, features) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + num_batches += 1 + + avg_loss = total_loss / num_batches + a_losses.append(avg_loss) + + # Print progress every epoch for shorter training runs, every 10 for longer runs + print_freq = 1 if epochs <= 50 else 10 + if (epoch + 1) % print_freq == 0 or epoch == 0 or epoch == epochs - 1: + progress = (epoch + 1) / epochs * 100 + print(f"{epoch+1:5d} | {avg_loss:.4f} | {progress:5.1f}%") + + print("="*50) + print("Autoencoder Pretraining Completed!") + print(f"Final A_loss: {a_losses[-1]:.4f}") + + return a_losses + + def _extract_features_from_batch(self, batch_data, device: torch.device) -> torch.Tensor: + """extract features from batch""" + features = [] + for feature_key in self.feature_keys: + if feature_key in batch_data: + features.append(batch_data[feature_key]) + + if len(features) == 1: + return features[0].to(device) + else: + return torch.cat(features, dim=1).to(device) + + def sample_transform(self, samples: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: + """Convert to discrete values based on data mode""" + if self.data_mode == "binary": + return (samples > threshold).float() + else: # count mode + return torch.clamp(torch.round(samples), min=0) + + def train_step(self, batch, optimizer_g, optimizer_d, optimizer_ae=None): + """single training step""" + real_features = self._extract_features_from_batch(batch, self.device) + + # train discriminator + optimizer_d.zero_grad() + noise = torch.randn(real_features.shape[0], self.latent_dim, device=self.device) + fake_samples = self.generator(noise) + + real_predictions = self.discriminator(real_features) + fake_predictions = self.discriminator(fake_samples.detach()) + + d_loss = F.binary_cross_entropy(real_predictions, torch.ones_like(real_predictions)) + \ + F.binary_cross_entropy(fake_predictions, torch.zeros_like(fake_predictions)) + d_loss.backward() + optimizer_d.step() + + # train generator + optimizer_g.zero_grad() + fake_predictions = self.discriminator(fake_samples) + g_loss = F.binary_cross_entropy(fake_predictions, torch.ones_like(fake_predictions)) + g_loss.backward() + optimizer_g.step() + + return {"d_loss": d_loss.item(), "g_loss": g_loss.item()} \ No newline at end of file From f993b896a1424e58286e4d15c456dbfdc9571ea6 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 25 Feb 2026 12:48:47 -0600 Subject: [PATCH 2/9] T2: add MedGANGenerationMIMIC3 BaseTask function --- pyhealth/tasks/__init__.py | 6 +++ pyhealth/tasks/medgan_generation.py | 66 +++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 pyhealth/tasks/medgan_generation.py diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 2f4294a19..e496806e6 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -68,3 +68,9 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from pyhealth.tasks.medgan_generation import ( + MedGANGenerationMIMIC3, + MedGANGenerationMIMIC4, + medgan_generation_mimic3_fn, + medgan_generation_mimic4_fn, +) diff --git a/pyhealth/tasks/medgan_generation.py b/pyhealth/tasks/medgan_generation.py new file mode 100644 index 000000000..f739ec90b --- /dev/null +++ b/pyhealth/tasks/medgan_generation.py @@ -0,0 +1,66 @@ +import polars as pl +from typing import Dict, List + +from pyhealth.tasks.base_task import BaseTask + + +class MedGANGenerationMIMIC3(BaseTask): + """MedGAN generation task for MIMIC-III. + + Aggregates all ICD-9 diagnosis codes across all admissions into a + single flat list per patient, matching the ``multi_hot`` input schema + expected by :class:`~pyhealth.models.MedGAN`. + + Args: + None + + Examples: + >>> task = MedGANGenerationMIMIC3() + >>> task.task_name + 'MedGANGenerationMIMIC3' + """ + + task_name = "MedGANGenerationMIMIC3" + input_schema = {"visits": "multi_hot"} + output_schema = {} + _icd_col = "diagnoses_icd/icd9_code" + + def __call__(self, patient) -> List[Dict]: + admissions = list(patient.get_events(event_type="admissions")) + codes = [] + for adm in admissions: + visit_codes = ( + patient.get_events( + event_type="diagnoses_icd", + filters=[("hadm_id", "==", adm.hadm_id)], + return_df=True, + ) + .select(pl.col(self._icd_col)) + .to_series() + .drop_nulls() + .to_list() + ) + codes.extend(visit_codes) + if not codes: + return [] + return [{"patient_id": patient.patient_id, "visits": codes}] + + +class MedGANGenerationMIMIC4(MedGANGenerationMIMIC3): + """MedGAN generation task for MIMIC-IV. + + Identical to :class:`MedGANGenerationMIMIC3` but uses the MIMIC-IV + ICD column name ``diagnoses_icd/icd_code``. + + Examples: + >>> task = MedGANGenerationMIMIC4() + >>> task.task_name + 'MedGANGenerationMIMIC4' + """ + + task_name = "MedGANGenerationMIMIC4" + _icd_col = "diagnoses_icd/icd_code" + + +medgan_generation_mimic3_fn = MedGANGenerationMIMIC3() +medgan_generation_mimic4_fn = MedGANGenerationMIMIC4() From 202fac8e46e396376392e4df5b890b21e08b4971 Mon Sep 17 00:00:00 2001 From: jalengg Date: Sun, 8 Mar 2026 21:31:40 -0500 Subject: [PATCH 3/9] T3: refactor MedGAN to proper BaseModel with train_model/synthesize_dataset --- pyhealth/models/generators/medgan.py | 724 ++++++++++++++------------- 1 file changed, 369 insertions(+), 355 deletions(-) diff --git a/pyhealth/models/generators/medgan.py b/pyhealth/models/generators/medgan.py index 5fb06ca67..c4ff8fa56 100644 --- a/pyhealth/models/generators/medgan.py +++ b/pyhealth/models/generators/medgan.py @@ -1,95 +1,119 @@ +"""MedGAN: Medical Generative Adversarial Network for synthetic EHR generation. + +Reference: + Choi et al., "Generating Multi-label Discrete Patient Records using + Generative Adversarial Networks", MLHC 2017. +""" + +import os +import time +from typing import Dict, List, Optional + +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -import numpy as np -from typing import Dict, List, Optional, Tuple, Union -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset from pyhealth.models import BaseModel +class MedGANDataset(Dataset): + """Dataset wrapper for MedGAN training from a numpy binary matrix.""" + + def __init__(self, data): + self.data = data.astype(np.float32) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return torch.from_numpy(self.data[idx]) + + class MedGANAutoencoder(nn.Module): - """simple autoencoder for pretraining""" + """Linear autoencoder for MedGAN pretraining. - def __init__(self, input_dim: int, hidden_dim: int = 128, - data_mode: str = "binary", count_activation: str = "relu"): + Args: + input_dim (int): Dimensionality of the input (vocabulary size). + hidden_dim (int): Dimensionality of the latent space. Default: 128. + """ + + def __init__(self, input_dim: int, hidden_dim: int = 128): super().__init__() - self.data_mode = data_mode self.encoder = nn.Sequential( nn.Linear(input_dim, hidden_dim), - nn.Tanh() + nn.Tanh(), + ) + self.decoder = nn.Sequential( + nn.Linear(hidden_dim, input_dim), + nn.Sigmoid(), ) - # Conditional decoder activation based on data mode - if data_mode == "binary": - self.decoder = nn.Sequential( - nn.Linear(hidden_dim, input_dim), - nn.Sigmoid() - ) - else: # count mode - activation = nn.ReLU() if count_activation == "relu" else nn.Softplus() - self.decoder = nn.Sequential( - nn.Linear(hidden_dim, input_dim), - activation - ) - def forward(self, x): - encoded = self.encoder(x) - decoded = self.decoder(encoded) - return decoded - + return self.decoder(self.encoder(x)) + def encode(self, x): return self.encoder(x) - + def decode(self, x): return self.decoder(x) -# ONLY USE ADMISSIONS AND DIAGNOSES FOR EVERYTHING class MedGANGenerator(nn.Module): - """generator with residual connections""" - + """Generator with residual connections. + + Args: + latent_dim (int): Dimensionality of the noise input. Default: 128. + hidden_dim (int): Width of hidden layers. Default: 128. + """ + def __init__(self, latent_dim: int = 128, hidden_dim: int = 128): super().__init__() self.linear1 = nn.Linear(latent_dim, hidden_dim) self.bn1 = nn.BatchNorm1d(hidden_dim, eps=0.001, momentum=0.01) self.activation1 = nn.ReLU() - + self.linear2 = nn.Linear(hidden_dim, hidden_dim) self.bn2 = nn.BatchNorm1d(hidden_dim, eps=0.001, momentum=0.01) self.activation2 = nn.Tanh() - + def forward(self, x): - # residual block 1 residual = x out = self.activation1(self.bn1(self.linear1(x))) out1 = out + residual - - # residual block 2 + residual = out1 out = self.activation2(self.bn2(self.linear2(out1))) out2 = out + residual - + return out2 class MedGANDiscriminator(nn.Module): - """discriminator with minibatch averaging""" - - def __init__(self, input_dim: int, hidden_dim: int = 256, minibatch_averaging: bool = True): + """Discriminator with minibatch averaging. + + Args: + input_dim (int): Dimensionality of the input. + hidden_dim (int): Width of hidden layers. Default: 256. + minibatch_averaging (bool): Concatenate batch mean to each sample. Default: True. + """ + + def __init__(self, input_dim: int, hidden_dim: int = 256, + minibatch_averaging: bool = True): super().__init__() self.minibatch_averaging = minibatch_averaging model_input_dim = input_dim * 2 if minibatch_averaging else input_dim - + self.model = nn.Sequential( nn.Linear(model_input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim // 2), nn.ReLU(), nn.Linear(hidden_dim // 2, 1), - nn.Sigmoid() + nn.Sigmoid(), ) - + def forward(self, x): if self.minibatch_averaging: x_mean = torch.mean(x, dim=0).repeat(x.shape[0], 1) @@ -97,351 +121,341 @@ def forward(self, x): return self.model(x) +def _weights_init(m): + """Xavier uniform initialization for linear layers.""" + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + class MedGAN(BaseModel): - """MedGAN for binary matrix generation""" - + """MedGAN: Medical Generative Adversarial Network. + + Generates synthetic binary EHR records via a two-phase training process: + (1) pre-train a linear autoencoder, then (2) adversarial training with + standard BCE loss. The generator maps noise to the autoencoder's latent + space, and the decoder projects back to binary medical codes. + + Reference: + Choi et al., "Generating Multi-label Discrete Patient Records using + Generative Adversarial Networks", MLHC 2017. + + Args: + dataset (SampleDataset): A fitted SampleDataset with + ``input_schema = {"visits": "multi_hot"}``. + latent_dim (int): Dimensionality of the generator latent space. Default: 128. + hidden_dim (int): Hidden layer width for the generator. Default: 128. + autoencoder_hidden_dim (int): Autoencoder latent dimension. Default: 128. + discriminator_hidden_dim (int): Discriminator hidden layer width. Default: 256. + minibatch_averaging (bool): Use minibatch averaging in discriminator. Default: True. + batch_size (int): Training batch size. Default: 512. + ae_epochs (int): Autoencoder pre-training epochs. Default: 100. + gan_epochs (int): Adversarial training epochs. Default: 200. + ae_lr (float): Autoencoder learning rate. Default: 0.001. + gan_lr (float): GAN learning rate. Default: 0.001. + save_dir (str): Checkpoint save directory. Default: ``"./medgan_checkpoints"``. + **kwargs: Additional arguments passed to ``BaseModel``. + + Examples: + >>> from pyhealth.datasets.sample_dataset import InMemorySampleDataset + >>> dataset = InMemorySampleDataset( + ... samples=[ + ... {"patient_id": "p1", "visits": ["A", "B", "C"]}, + ... {"patient_id": "p2", "visits": ["A", "C", "D"]}, + ... ], + ... input_schema={"visits": "multi_hot"}, + ... output_schema={}, + ... ) + >>> model = MedGAN(dataset, latent_dim=32, hidden_dim=32) + >>> isinstance(model, MedGAN) + True + """ + def __init__( self, dataset, - feature_keys: List[str], - label_key: str, - mode: str = "generation", - data_mode: str = "binary", - count_activation: str = "relu", - count_loss: str = "mse", latent_dim: int = 128, hidden_dim: int = 128, autoencoder_hidden_dim: int = 128, discriminator_hidden_dim: int = 256, minibatch_averaging: bool = True, - **kwargs + batch_size: int = 512, + ae_epochs: int = 100, + gan_epochs: int = 200, + ae_lr: float = 0.001, + gan_lr: float = 0.001, + save_dir: str = "./medgan_checkpoints", + **kwargs, ): - # dummy wrapper for BaseModel compatibility - class DummyWrapper: - def __init__(self, dataset, feature_keys, label_key): - self.dataset = dataset - self.input_schema = {key: "multilabel" for key in feature_keys} - self.output_schema = {label_key: "multilabel"} - self.input_processors = {} - self.output_processors = {} - - wrapped_dataset = DummyWrapper(dataset, feature_keys, label_key) - super().__init__(dataset=wrapped_dataset) - - self.data_mode = data_mode - self.count_activation = count_activation - self.count_loss = count_loss + super().__init__(dataset=dataset) + self.latent_dim = latent_dim self.hidden_dim = hidden_dim - self.minibatch_averaging = minibatch_averaging - - # build vocab (simplified) - self.global_vocab = self._build_global_vocab(dataset, feature_keys) - self.input_dim = len(self.global_vocab) - - # init components + self.batch_size = batch_size + self.ae_epochs = ae_epochs + self.gan_epochs = gan_epochs + self.ae_lr = ae_lr + self.gan_lr = gan_lr + self.save_dir = save_dir + + # Derive vocabulary size from processor + processor = dataset.input_processors["visits"] + self.input_dim = processor.size() + + # Build reverse lookup: index -> code string + self._idx_to_code: List[Optional[str]] = [None] * self.input_dim + for code, idx in processor.label_vocab.items(): + self._idx_to_code[idx] = code + + # Initialize components self.autoencoder = MedGANAutoencoder( input_dim=self.input_dim, hidden_dim=autoencoder_hidden_dim, - data_mode=data_mode, - count_activation=count_activation ) - self.generator = MedGANGenerator(latent_dim=latent_dim, hidden_dim=autoencoder_hidden_dim) + self.generator = MedGANGenerator( + latent_dim=latent_dim, + hidden_dim=autoencoder_hidden_dim, + ) self.discriminator = MedGANDiscriminator( input_dim=self.input_dim, hidden_dim=discriminator_hidden_dim, - minibatch_averaging=minibatch_averaging - ) - - self._init_weights() - - @classmethod - def from_binary_matrix( - cls, - binary_matrix: np.ndarray, - latent_dim: int = 128, - hidden_dim: int = 128, - autoencoder_hidden_dim: int = 128, - discriminator_hidden_dim: int = 256, - minibatch_averaging: bool = True, - data_mode: str = "binary", - count_activation: str = "relu", - count_loss: str = "mse", - **kwargs - ): - """create MedGAN model from binary matrix (ICD-9, etc.)""" - class MatrixWrapper: - def __init__(self, matrix): - self.matrix = matrix - self.input_processors = {} - self.output_processors = {} - - def __len__(self): - return self.matrix.shape[0] - - def __getitem__(self, idx): - return {"binary_vector": torch.tensor(self.matrix[idx], dtype=torch.float32)} - - def iter_patients(self): - """iterate over patients""" - for i in range(len(self)): - yield type('Patient', (), { - 'binary_vector': self.matrix[i], - 'patient_id': f'patient_{i}' - })() - - dummy_dataset = MatrixWrapper(binary_matrix) - - model = cls( - dataset=dummy_dataset, - feature_keys=["binary_vector"], - label_key="binary_vector", - data_mode=data_mode, - count_activation=count_activation, - count_loss=count_loss, - latent_dim=latent_dim, - hidden_dim=hidden_dim, - autoencoder_hidden_dim=autoencoder_hidden_dim, - discriminator_hidden_dim=discriminator_hidden_dim, minibatch_averaging=minibatch_averaging, - **kwargs - ) - - # override input dimension - model.input_dim = binary_matrix.shape[1] - - # reinitialize components with correct dimensions - model.autoencoder = MedGANAutoencoder( - input_dim=model.input_dim, - hidden_dim=autoencoder_hidden_dim, - data_mode=data_mode, - count_activation=count_activation ) - model.generator = MedGANGenerator(latent_dim=latent_dim, hidden_dim=autoencoder_hidden_dim) - model.discriminator = MedGANDiscriminator( - input_dim=model.input_dim, - hidden_dim=discriminator_hidden_dim, - minibatch_averaging=minibatch_averaging + + self.autoencoder.apply(_weights_init) + self.generator.apply(_weights_init) + self.discriminator.apply(_weights_init) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Not used in GAN context.""" + raise NotImplementedError( + "Use train_model() for training and synthesize_dataset() for generation." ) - - # Move all components to the same device as the model - device = next(model.parameters()).device if hasattr(model, 'parameters') else torch.device('cpu') - model.autoencoder = model.autoencoder.to(device) - model.generator = model.generator.to(device) - model.discriminator = model.discriminator.to(device) - - # override feature extraction - def extract_features(batch_data, device): - return batch_data["binary_vector"].to(device) - - model._extract_features_from_batch = extract_features - - return model - - @classmethod - def from_count_matrix( - cls, - count_matrix: np.ndarray, - latent_dim: int = 128, - hidden_dim: int = 128, - autoencoder_hidden_dim: int = 128, - discriminator_hidden_dim: int = 256, - minibatch_averaging: bool = True, - count_activation: str = "relu", - count_loss: str = "mse", - **kwargs - ): - """Create MedGAN model from count matrix (integers >= 0) - This is a convenience method that calls from_binary_matrix with data_mode="count". - The name is kept as count_matrix to be explicit about expected input format. + def train_model(self, train_dataset, val_dataset=None): + """Train MedGAN on a SampleDataset. + + Phase 1: pre-train the autoencoder with BCE reconstruction loss. + Phase 2: adversarial training with standard BCE GAN loss (not WGAN). Args: - count_matrix: numpy array with shape (n_patients, n_features) containing counts (integers >= 0) - latent_dim: dimension of latent space for generator - hidden_dim: hidden dimension for generator - autoencoder_hidden_dim: hidden dimension for autoencoder - discriminator_hidden_dim: hidden dimension for discriminator - minibatch_averaging: whether to use minibatch averaging in discriminator - count_activation: activation function for count mode ("relu" or "softplus") - count_loss: loss function for count mode ("mse" or "poisson") - **kwargs: additional arguments + train_dataset: A fitted SampleDataset with + ``input_schema = {"visits": "multi_hot"}``. + val_dataset: Unused. Accepted for API compatibility. Returns: - MedGAN model configured for count mode + None """ - return cls.from_binary_matrix( - binary_matrix=count_matrix, - latent_dim=latent_dim, - hidden_dim=hidden_dim, - autoencoder_hidden_dim=autoencoder_hidden_dim, - discriminator_hidden_dim=discriminator_hidden_dim, - minibatch_averaging=minibatch_averaging, - data_mode="count", - count_activation=count_activation, - count_loss=count_loss, - **kwargs + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.to(device) + print(f"Training MedGAN on: {device}") + + # Build multi-hot matrix from pre-encoded tensors + tensors = [train_dataset[i]["visits"] for i in range(len(train_dataset))] + data_matrix = torch.stack(tensors).numpy() + + medgan_ds = MedGANDataset(data=data_matrix) + sampler = torch.utils.data.sampler.RandomSampler( + data_source=medgan_ds, replacement=True, + ) + dataloader = DataLoader( + medgan_ds, + batch_size=self.batch_size, + shuffle=False, + num_workers=0, + drop_last=True, + sampler=sampler, ) - def _build_global_vocab(self, dataset, feature_keys: List[str]) -> List[str]: - """build vocab from dataset (simplified)""" - vocab = set() - for patient in dataset.iter_patients(): - for feature_key in feature_keys: - if hasattr(patient, feature_key): - feature_values = getattr(patient, feature_key) - if isinstance(feature_values, list): - vocab.update(feature_values) - elif isinstance(feature_values, str): - vocab.add(feature_values) - return sorted(list(vocab)) - - def _init_weights(self): - """init weights""" - def weights_init(m): - if isinstance(m, nn.Linear): - nn.init.xavier_uniform_(m.weight) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.BatchNorm1d): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - self.autoencoder.apply(weights_init) - self.generator.apply(weights_init) - self.discriminator.apply(weights_init) - - def forward(self, **kwargs) -> Dict[str, torch.Tensor]: - """forward pass""" - features = self._extract_features_from_batch(kwargs, self.device) - noise = torch.randn(features.shape[0], self.latent_dim, device=self.device) - fake_samples = self.generator(noise) - return {"real_features": features, "fake_samples": fake_samples} - - def generate(self, n_samples: int, device: torch.device = None) -> torch.Tensor: - """generate synthetic samples""" - if device is None: - device = self.device - - self.generator.eval() - self.autoencoder.eval() - with torch.no_grad(): - noise = torch.randn(n_samples, self.latent_dim, device=device) - generated = self.generator(noise) - # use autoencoder decoder to get final output - generated = self.autoencoder.decode(generated) - - return generated - - def discriminate(self, x: torch.Tensor) -> torch.Tensor: - """discriminate real vs fake""" - return self.discriminator(x) - - def pretrain_autoencoder(self, dataloader: DataLoader, epochs: int = 100, lr: float = 0.001, device: torch.device = None): - """pretrain autoencoder with detailed loss tracking""" - if device is None: - device = self.device - - # Ensure autoencoder is on the correct device - self.autoencoder = self.autoencoder.to(device) - - print("Pretraining Autoencoder...") - print("="*50) - print("Epoch | A_loss | Progress") - print("="*50) - - optimizer = torch.optim.Adam(self.autoencoder.parameters(), lr=lr) - - # Conditional loss function based on data mode - if self.data_mode == "binary": - criterion = nn.BCELoss() - elif self.count_loss == "mse": - criterion = nn.MSELoss() - else: # poisson - criterion = nn.PoissonNLLLoss(log_input=False) - - # Track losses for plotting - a_losses = [] - + os.makedirs(self.save_dir, exist_ok=True) + + # ---- Phase 1: Autoencoder pretraining ---- + print(f"Phase 1: Pretraining autoencoder for {self.ae_epochs} epochs...") + optimizer_ae = torch.optim.Adam( + self.autoencoder.parameters(), lr=self.ae_lr, + ) + criterion_ae = nn.BCELoss() + self.autoencoder.train() - - for epoch in range(epochs): - total_loss = 0 - num_batches = 0 - + for epoch in range(self.ae_epochs): + total_loss = 0.0 + n_batches = 0 for batch in dataloader: - # handle both tensor and dict inputs - if isinstance(batch, torch.Tensor): - features = batch.to(device) - else: - features = self._extract_features_from_batch(batch, device) - - reconstructed = self.autoencoder(features) - loss = criterion(reconstructed, features) - - optimizer.zero_grad() + real = batch.to(device) + recon = self.autoencoder(real) + loss = criterion_ae(recon, real) + + optimizer_ae.zero_grad() loss.backward() - optimizer.step() - + optimizer_ae.step() + total_loss += loss.item() - num_batches += 1 - - avg_loss = total_loss / num_batches - a_losses.append(avg_loss) - - # Print progress every epoch for shorter training runs, every 10 for longer runs - print_freq = 1 if epochs <= 50 else 10 - if (epoch + 1) % print_freq == 0 or epoch == 0 or epoch == epochs - 1: - progress = (epoch + 1) / epochs * 100 - print(f"{epoch+1:5d} | {avg_loss:.4f} | {progress:5.1f}%") - - print("="*50) - print("Autoencoder Pretraining Completed!") - print(f"Final A_loss: {a_losses[-1]:.4f}") - - return a_losses - - def _extract_features_from_batch(self, batch_data, device: torch.device) -> torch.Tensor: - """extract features from batch""" - features = [] - for feature_key in self.feature_keys: - if feature_key in batch_data: - features.append(batch_data[feature_key]) - - if len(features) == 1: - return features[0].to(device) - else: - return torch.cat(features, dim=1).to(device) - - def sample_transform(self, samples: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: - """Convert to discrete values based on data mode""" - if self.data_mode == "binary": - return (samples > threshold).float() - else: # count mode - return torch.clamp(torch.round(samples), min=0) - - def train_step(self, batch, optimizer_g, optimizer_d, optimizer_ae=None): - """single training step""" - real_features = self._extract_features_from_batch(batch, self.device) - - # train discriminator - optimizer_d.zero_grad() - noise = torch.randn(real_features.shape[0], self.latent_dim, device=self.device) - fake_samples = self.generator(noise) - - real_predictions = self.discriminator(real_features) - fake_predictions = self.discriminator(fake_samples.detach()) - - d_loss = F.binary_cross_entropy(real_predictions, torch.ones_like(real_predictions)) + \ - F.binary_cross_entropy(fake_predictions, torch.zeros_like(fake_predictions)) - d_loss.backward() - optimizer_d.step() - - # train generator - optimizer_g.zero_grad() - fake_predictions = self.discriminator(fake_samples) - g_loss = F.binary_cross_entropy(fake_predictions, torch.ones_like(fake_predictions)) - g_loss.backward() - optimizer_g.step() - - return {"d_loss": d_loss.item(), "g_loss": g_loss.item()} \ No newline at end of file + n_batches += 1 + + if (epoch + 1) % max(1, self.ae_epochs // 10) == 0 or epoch == 0: + avg = total_loss / n_batches + print(f" AE epoch {epoch + 1}/{self.ae_epochs} loss={avg:.4f}") + + # ---- Phase 2: Adversarial training ---- + print(f"Phase 2: Adversarial training for {self.gan_epochs} epochs...") + optimizer_g = torch.optim.Adam( + list(self.generator.parameters()) + + list(self.autoencoder.decoder.parameters()), + lr=self.gan_lr, + ) + optimizer_d = torch.optim.Adam( + self.discriminator.parameters(), lr=self.gan_lr, + ) + + best_d_loss = float("inf") + + for epoch in range(self.gan_epochs): + epoch_d_loss = 0.0 + epoch_g_loss = 0.0 + n_batches = 0 + + self.generator.train() + self.discriminator.train() + self.autoencoder.eval() + self.autoencoder.decoder.train() + + for batch in dataloader: + real = batch.to(device) + bs = real.size(0) + + # --- Train Discriminator --- + optimizer_d.zero_grad() + noise = torch.randn(bs, self.latent_dim, device=device) + fake_hidden = self.generator(noise) + fake = self.autoencoder.decode(fake_hidden) + + real_pred = self.discriminator(real) + fake_pred = self.discriminator(fake.detach()) + + d_loss = ( + F.binary_cross_entropy(real_pred, torch.ones_like(real_pred)) + + F.binary_cross_entropy(fake_pred, torch.zeros_like(fake_pred)) + ) + d_loss.backward() + optimizer_d.step() + + # --- Train Generator --- + optimizer_g.zero_grad() + fake_pred = self.discriminator(fake) + g_loss = F.binary_cross_entropy( + fake_pred, torch.ones_like(fake_pred), + ) + g_loss.backward() + optimizer_g.step() + + epoch_d_loss += d_loss.item() + epoch_g_loss += g_loss.item() + n_batches += 1 + + avg_d = epoch_d_loss / n_batches + avg_g = epoch_g_loss / n_batches + + if (epoch + 1) % max(1, self.gan_epochs // 10) == 0 or epoch == 0: + print( + f" GAN epoch {epoch + 1}/{self.gan_epochs} " + f"D_loss={avg_d:.4f} G_loss={avg_g:.4f}" + ) + + # Save best checkpoint + if avg_d < best_d_loss: + best_d_loss = avg_d + self.save_model(os.path.join(self.save_dir, "best.pt")) + + # Save final checkpoint + self.save_model(os.path.join(self.save_dir, "final.pt")) + print("Training complete.") + + def synthesize_dataset( + self, num_samples: int, random_sampling: bool = True, + ) -> List[Dict]: + """Generate synthetic patient records. + + Each synthetic patient is a flat list of ICD code strings decoded from + a generated binary vector, matching the ``multi_hot`` input schema. + + Args: + num_samples (int): Number of synthetic patients to generate. + random_sampling (bool): Unused; accepted for API compatibility. + + Returns: + list of dict: Synthetic patient records. Each dict has: + ``"patient_id"`` (str): e.g. ``"synthetic_0"``. + ``"visits"`` (list of str): flat list of decoded ICD code strings. + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.to(device) + + self.generator.eval() + self.autoencoder.eval() + + gen_samples = np.zeros((num_samples, self.input_dim), dtype=np.float32) + n_full = num_samples // self.batch_size + + with torch.no_grad(): + for i in range(n_full): + z = torch.randn(self.batch_size, self.latent_dim, device=device) + fake = self.autoencoder.decode(self.generator(z)) + gen_samples[i * self.batch_size : (i + 1) * self.batch_size] = ( + fake.cpu().numpy() + ) + + remaining = num_samples % self.batch_size + if remaining > 0: + z = torch.randn(remaining, self.latent_dim, device=device) + fake = self.autoencoder.decode(self.generator(z)) + gen_samples[n_full * self.batch_size :] = fake.cpu().numpy() + + # Binarize at threshold 0.5 + gen_samples = (gen_samples >= 0.5).astype(np.float32) + + # Decode to code strings + results: List[Dict] = [] + for i in range(num_samples): + codes = [ + self._idx_to_code[idx] + for idx in np.where(gen_samples[i] == 1.0)[0] + if self._idx_to_code[idx] not in (None, "", "") + ] + results.append({ + "patient_id": f"synthetic_{i}", + "visits": codes, + }) + return results + + def save_model(self, path: str): + """Save model weights to a checkpoint file. + + Args: + path (str): File path to write the checkpoint. + """ + torch.save( + { + "autoencoder": self.autoencoder.state_dict(), + "generator": self.generator.state_dict(), + "discriminator": self.discriminator.state_dict(), + "input_dim": self.input_dim, + "latent_dim": self.latent_dim, + "idx_to_code": self._idx_to_code, + }, + path, + ) + + def load_model(self, path: str): + """Load model weights from a checkpoint file. + + Args: + path (str): File path to read the checkpoint. + """ + ckpt = torch.load(path, map_location=self.device) + self.autoencoder.load_state_dict(ckpt["autoencoder"]) + self.generator.load_state_dict(ckpt["generator"]) + self.discriminator.load_state_dict(ckpt["discriminator"]) From cfc29e241e8b5754c5edf00fcdea4c2225f11a76 Mon Sep 17 00:00:00 2001 From: jalengg Date: Sun, 8 Mar 2026 21:41:36 -0500 Subject: [PATCH 4/9] T4-T7: add PyHealth 2.0 examples, register MedGAN in models/__init__, remove old example --- examples/generate_synthetic_mimic3_medgan.py | 197 ++------- examples/medgan_mimic3_training.py | 34 ++ ...synthetic_data_generation_mimic3_medgan.py | 390 ------------------ pyhealth/models/__init__.py | 1 + 4 files changed, 72 insertions(+), 550 deletions(-) create mode 100644 examples/medgan_mimic3_training.py delete mode 100644 examples/synthetic_data_generation_mimic3_medgan.py diff --git a/examples/generate_synthetic_mimic3_medgan.py b/examples/generate_synthetic_mimic3_medgan.py index 3c163926a..b1e1bda10 100644 --- a/examples/generate_synthetic_mimic3_medgan.py +++ b/examples/generate_synthetic_mimic3_medgan.py @@ -1,163 +1,40 @@ -#!/usr/bin/env python3 -""" -Generate synthetic MIMIC-III patients using a trained MedGAN checkpoint. -Uses simple 0.5 threshold - MedGAN doesn't require post-processing. -""" +"""Generate synthetic MIMIC-III patient records using a trained MedGAN checkpoint.""" +import json -import os -import argparse -import torch -import numpy as np -import pandas as pd +from pyhealth.datasets import MIMIC3Dataset +from pyhealth.tasks import medgan_generation_mimic3_fn from pyhealth.models.generators.medgan import MedGAN - -def main(): - parser = argparse.ArgumentParser(description="Generate synthetic patients using trained MedGAN") - parser.add_argument("--checkpoint", required=True, help="Path to trained MedGAN checkpoint (.pth)") - parser.add_argument("--vocab", required=True, help="Path to ICD-9 vocabulary file (.txt)") - parser.add_argument("--data_matrix", required=True, help="Path to training data matrix (.npy)") - parser.add_argument("--output", required=True, help="Path to output CSV file") - parser.add_argument("--n_samples", type=int, default=10000, help="Number of synthetic patients to generate") - parser.add_argument("--threshold", type=float, default=0.5, help="Threshold for binarization (binary mode only)") - - # Mode parameters - parser.add_argument("--data_mode", type=str, default="binary", choices=["binary", "count"], - help="Data mode: 'binary' (default) or 'count'") - parser.add_argument("--count_activation", type=str, default="relu", choices=["relu", "softplus"], - help="Activation for count mode: 'relu' (default) or 'softplus'") - parser.add_argument("--count_loss", type=str, default="mse", choices=["mse", "poisson"], - help="Loss function for count mode: 'mse' (default) or 'poisson'") - - args = parser.parse_args() - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"Using device: {device}") - print(f"Data mode: {args.data_mode}") - - # Load vocabulary - print(f"Loading vocabulary from {args.vocab}") - with open(args.vocab, 'r') as f: - code_vocab = [line.strip() for line in f] - print(f"Loaded {len(code_vocab)} ICD-9 codes") - - # Load data matrix to get architecture dimensions - print(f"Loading data matrix from {args.data_matrix}") - data_matrix = np.load(args.data_matrix) - n_codes = data_matrix.shape[1] - print(f"Data matrix shape: {data_matrix.shape}") - if args.data_mode == "binary": - print(f"Real data avg codes/patient: {data_matrix.sum(axis=1).mean():.2f}") - else: - print(f"Real data avg code occurrences/patient: {data_matrix.sum(axis=1).mean():.2f}") - print(f"Real data max count: {data_matrix.max():.0f}") - - # Load checkpoint - print(f"\nLoading checkpoint from {args.checkpoint}") - checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False) - - # Initialize MedGAN with same architecture - print("Initializing MedGAN model...") - if args.data_mode == "binary": - model = MedGAN.from_binary_matrix( - binary_matrix=data_matrix, - latent_dim=128, - autoencoder_hidden_dim=128, - discriminator_hidden_dim=256, - minibatch_averaging=True, - data_mode=args.data_mode - ).to(device) - else: # count mode - model = MedGAN.from_count_matrix( - count_matrix=data_matrix, - latent_dim=128, - autoencoder_hidden_dim=128, - discriminator_hidden_dim=256, - minibatch_averaging=True, - count_activation=args.count_activation, - count_loss=args.count_loss - ).to(device) - - # Load trained weights - model.autoencoder.load_state_dict(checkpoint['autoencoder_state_dict']) - model.generator.load_state_dict(checkpoint['generator_state_dict']) - model.discriminator.load_state_dict(checkpoint['discriminator_state_dict']) - - model.eval() - print("Model loaded successfully") - - # Generate synthetic patients - print(f"\nGenerating {args.n_samples} synthetic patients...") - - with torch.no_grad(): - # Generate data - synthetic_data = model.generate(args.n_samples, device) - - # Apply transform (binary threshold or count round+clip) - if args.data_mode == "binary": - discrete_data = model.sample_transform(synthetic_data, threshold=args.threshold) - else: - discrete_data = model.sample_transform(synthetic_data) - - data_matrix_synthetic = discrete_data.cpu().numpy() - - # Calculate statistics - avg_codes = data_matrix_synthetic.sum(axis=1).mean() - std_codes = data_matrix_synthetic.sum(axis=1).std() - min_codes = data_matrix_synthetic.sum(axis=1).min() - max_codes = data_matrix_synthetic.sum(axis=1).max() - sparsity = (data_matrix_synthetic == 0).mean() - - print(f"\nSynthetic data statistics:") - if args.data_mode == "binary": - print(f" Avg codes per patient: {avg_codes:.2f} ± {std_codes:.2f}") - else: - print(f" Avg code occurrences per patient: {avg_codes:.2f} ± {std_codes:.2f}") - print(f" Max count: {data_matrix_synthetic.max():.0f}") - print(f" Range: [{min_codes:.0f}, {max_codes:.0f}]") - print(f" Sparsity: {sparsity:.4f}") - - # Check heterogeneity - unique_profiles = len(set(tuple(row) for row in data_matrix_synthetic)) - print(f" Unique patient profiles: {unique_profiles}/{args.n_samples} ({unique_profiles/args.n_samples*100:.1f}%)") - - # Convert to CSV format (SUBJECT_ID, ICD9_CODE) - print(f"\nConverting to CSV format...") - records = [] - for patient_idx in range(args.n_samples): - patient_id = f"SYNTHETIC_{patient_idx+1:06d}" - - if args.data_mode == "binary": - # Binary mode: include codes where value == 1 - code_indices = np.where(data_matrix_synthetic[patient_idx] == 1)[0] - for code_idx in code_indices: - records.append({ - 'SUBJECT_ID': patient_id, - 'ICD9_CODE': code_vocab[code_idx] - }) - else: # count mode - # Count mode: repeat codes based on their counts - for code_idx in range(n_codes): - count = int(data_matrix_synthetic[patient_idx, code_idx]) - for _ in range(count): - records.append({ - 'SUBJECT_ID': patient_id, - 'ICD9_CODE': code_vocab[code_idx] - }) - - df = pd.DataFrame(records) - print(f"Created {len(df)} diagnosis records for {args.n_samples} patients") - - # Save to CSV - print(f"\nSaving to {args.output}") - df.to_csv(args.output, index=False) - - file_size_mb = os.path.getsize(args.output) / (1024 * 1024) - print(f"Saved {file_size_mb:.1f} MB") - - print("\n✓ Generation complete!") - print(f"Output: {args.output}") - - -if __name__ == '__main__': - main() +# Update this to your local MIMIC-III path before running +MIMIC3_ROOT = "/path/to/mimic3" + +# 1. Reconstruct dataset — required to initialise MedGAN's vocabulary from the processor. +base_dataset = MIMIC3Dataset( + root=MIMIC3_ROOT, + tables=["diagnoses_icd"], +) +sample_dataset = base_dataset.set_task(medgan_generation_mimic3_fn) + +# 2. Instantiate model (training params are unused during generation; +# they must match your training configuration for checkpoint compatibility). +model = MedGAN( + dataset=sample_dataset, + latent_dim=128, + hidden_dim=128, + batch_size=128, + save_dir="./medgan_checkpoints/", +) + +# 3. Load trained checkpoint +model.load_model("./medgan_checkpoints/best.pt") + +# 4. Generate synthetic patients — each patient is a flat bag-of-codes (no visit structure) +synthetic = model.synthesize_dataset(num_samples=10000) +print(f"Generated {len(synthetic)} synthetic patients") +print(f"Example record: {synthetic[0]}") + +# 5. Save to JSON +output_path = "synthetic_medgan_10k.json" +with open(output_path, "w") as f: + json.dump(synthetic, f, indent=2) +print(f"Saved to {output_path}") diff --git a/examples/medgan_mimic3_training.py b/examples/medgan_mimic3_training.py new file mode 100644 index 000000000..cb5ca5b97 --- /dev/null +++ b/examples/medgan_mimic3_training.py @@ -0,0 +1,34 @@ +"""Train MedGAN on MIMIC-III diagnosis codes and save a checkpoint.""" + +# 1. Load MIMIC-III dataset +from pyhealth.datasets import MIMIC3Dataset +from pyhealth.datasets import split_by_patient +from pyhealth.tasks import medgan_generation_mimic3_fn +from pyhealth.models.generators.medgan import MedGAN + +base_dataset = MIMIC3Dataset( + root="/path/to/mimic3", + tables=["diagnoses_icd"], +) + +# 2. Apply generation task — flattens all ICD codes per patient into a bag-of-codes +sample_dataset = base_dataset.set_task(medgan_generation_mimic3_fn) +print(f"{len(sample_dataset)} patients after filtering") + +# 3. Patient-level split — required for generative models to prevent data leakage across splits +train_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1]) + +# 4. Instantiate and train — reduce epochs for testing; 100+ recommended for quality +model = MedGAN( + dataset=sample_dataset, + latent_dim=128, + hidden_dim=128, + batch_size=128, + ae_epochs=100, + gan_epochs=200, + save_dir="./medgan_checkpoints/", +) +model.train_model(train_dataset, val_dataset) + +# 5. Checkpoint is saved automatically to save_dir by train_model +print("Training complete. Checkpoint saved to ./medgan_checkpoints/") diff --git a/examples/synthetic_data_generation_mimic3_medgan.py b/examples/synthetic_data_generation_mimic3_medgan.py deleted file mode 100644 index c834d912c..000000000 --- a/examples/synthetic_data_generation_mimic3_medgan.py +++ /dev/null @@ -1,390 +0,0 @@ -""" -Synthetic data generation using MedGAN on MIMIC-III data. - -This example demonstrates how to train MedGAN to generate synthetic ICD-9 matrices -from MIMIC-III data, following PyHealth conventions. -""" - -import os -import torch -import numpy as np -import argparse -from torch.utils.data import DataLoader -import pickle -import json -from tqdm import tqdm -import pandas as pd - -from pyhealth.datasets import MIMIC3Dataset -from pyhealth.datasets.icd9_matrix import create_icd9_matrix, ICD9MatrixDataset -from pyhealth.models.generators.medgan import MedGAN - -""" -python examples/synthetic_data_generation_mimic3_medgan.py --autoencoder_epochs 5 --gan_epochs 10 --batch_size 16 -""" -def train_medgan(model, dataloader, n_epochs, device, save_dir, lr=0.001, weight_decay=0.0001, b1=0.5, b2=0.9): - """ - Train MedGAN model using the original synthEHRella approach. - - Args: - model: MedGAN model - dataloader: DataLoader for training data - n_epochs: Number of training epochs - device: Device to train on - save_dir: Directory to save checkpoints - lr: Learning rate - weight_decay: Weight decay for regularization - b1: Beta1 for Adam optimizer - b2: Beta2 for Adam optimizer - - Returns: - loss_history: Dictionary containing loss history - """ - - def generator_loss(y_fake): - """ - Original synthEHRella generator loss - """ - # standard GAN generator loss - want fake samples to be classified as real - return -torch.mean(torch.log(y_fake + 1e-12)) - - def discriminator_loss(outputs, labels): - """ - Original synthEHRella discriminator loss - """ - loss = -torch.mean(labels * torch.log(outputs + 1e-12)) - torch.mean((1 - labels) * torch.log(1. - outputs + 1e-12)) - return loss - - optimizer_g = torch.optim.Adam([ - {'params': model.generator.parameters()}, - {'params': model.autoencoder.decoder.parameters(), 'lr': lr * 0.1} - ], lr=lr, betas=(b1, b2), weight_decay=weight_decay) - - optimizer_d = torch.optim.Adam(model.discriminator.parameters(), - lr=lr * 0.1, betas=(b1, b2), weight_decay=weight_decay) - - g_losses = [] - d_losses = [] - - print("="*60) - print("Epoch | D_loss | G_loss | Progress") - print("="*60) - - for epoch in range(n_epochs): - epoch_g_loss = 0.0 - epoch_d_loss = 0.0 - num_batches = 0 - - for i, real_data in enumerate(dataloader): - real_data = real_data.to(device) - batch_size = real_data.size(0) - - valid = torch.ones(batch_size).to(device) # 1D tensor - fake = torch.zeros(batch_size).to(device) # 1D tensor - - z = torch.randn(batch_size, model.latent_dim).to(device) - - # Disable discriminator gradients for generator training to prevent discriminator from being updated - for p in model.discriminator.parameters(): - p.requires_grad = False - - # generate fake samples - fake_samples = model.generator(z) - fake_samples = model.autoencoder.decode(fake_samples) - - # generator loss using original medgan loss function - fake_output = model.discriminator(fake_samples).view(-1) - g_loss = generator_loss(fake_output) - - optimizer_g.zero_grad() - g_loss.backward() - optimizer_g.step() - - # --------------------- - # Train Discriminator - # --------------------- - - # Enable discriminator gradients - for p in model.discriminator.parameters(): - p.requires_grad = True - - optimizer_d.zero_grad() - - # Real samples - real_output = model.discriminator(real_data).view(-1) - real_loss = discriminator_loss(real_output, valid) - real_loss.backward() - - # Fake samples (detached) - fake_output = model.discriminator(fake_samples.detach()).view(-1) - fake_loss = discriminator_loss(fake_output, fake) - fake_loss.backward() - - # Total discriminator loss - d_loss = (real_loss + fake_loss) / 2 - - optimizer_d.step() - - # Track losses - epoch_g_loss += g_loss.item() - epoch_d_loss += d_loss.item() - num_batches += 1 - - # calculate average losses - avg_g_loss = epoch_g_loss / num_batches - avg_d_loss = epoch_d_loss / num_batches - - # store losses for trackin - g_losses.append(avg_g_loss) - d_losses.append(avg_d_loss) - - progress = (epoch + 1) / n_epochs * 100 - print(f"{epoch+1:5d} | {avg_d_loss:.4f} | {avg_g_loss:.4f} | {progress:5.1f}%") - - # save every 50 epochs - if (epoch + 1) % 50 == 0: - checkpoint_path = os.path.join(save_dir, f"medgan_epoch_{epoch+1}.pth") - torch.save({ - 'epoch': epoch + 1, - 'generator_state_dict': model.generator.state_dict(), - 'discriminator_state_dict': model.discriminator.state_dict(), - 'autoencoder_state_dict': model.autoencoder.state_dict(), - 'optimizer_g_state_dict': optimizer_g.state_dict(), - 'optimizer_d_state_dict': optimizer_d.state_dict(), - 'g_losses': g_losses, - 'd_losses': d_losses, - }, checkpoint_path) - print(f"Checkpoint saved to {checkpoint_path}") - - print("="*60) - print("GAN Training Completed!") - print(f"Final G_loss: {g_losses[-1]:.4f}") - print(f"Final D_loss: {d_losses[-1]:.4f}") - - # save loss history - loss_history = { - 'g_losses': g_losses, - 'd_losses': d_losses, - } - np.save(os.path.join(save_dir, "loss_history.npy"), loss_history) - - return loss_history - - - - -def main(): - parser = argparse.ArgumentParser(description="Train MedGAN for synthetic data generation") - parser.add_argument("--data_path", type=str, default="./data_files", help="path to MIMIC-III data") - parser.add_argument("--output_path", type=str, default="./medgan_results", help="Output directory") - parser.add_argument("--autoencoder_epochs", type=int, default=100, help="Autoencoder pretraining epochs") - parser.add_argument("--gan_epochs", type=int, default=1000, help="GAN training epochs") - parser.add_argument("--latent_dim", type=int, default=128, help="Latent dimension") - parser.add_argument("--hidden_dim", type=int, default=128, help="Hidden dimension") - parser.add_argument("--batch_size", type=int, default=128, help="Batch size") - parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate") - parser.add_argument("--weight_decay", type=float, default=0.0001, help="l2 regularization") - parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") - parser.add_argument("--b2", type=float, default=0.9, help="adam: decay of second order momentum of gradient") - parser.add_argument("--save_dir", type=str, default="medgan_results", help="directory to save results") - args = parser.parse_args() - - # setup - os.makedirs(args.output_path, exist_ok=True) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"Using device: {device}") - - # load MIMIC-III data - print("Loading MIMIC-III data") - dataset = MIMIC3Dataset(root=args.data_path, tables=["DIAGNOSES_ICD"]) - - # create ICD-9 matrix using utility function - print("Creating ICD-9 matrix") - icd9_matrix, icd9_types = create_icd9_matrix(dataset, args.output_path) - print(f"ICD-9 matrix shape: {icd9_matrix.shape}") - - - # initialize MedGAN model - print("Initializing MedGAN model...") - model = MedGAN.from_binary_matrix( - binary_matrix=icd9_matrix, - latent_dim=args.latent_dim, - autoencoder_hidden_dim=args.hidden_dim, - discriminator_hidden_dim=args.hidden_dim, - minibatch_averaging=True - ) - - # device stuff - model = model.to(device) - model.autoencoder = model.autoencoder.to(device) - model.generator = model.generator.to(device) - model.discriminator = model.discriminator.to(device) - - # make a dataloader - print("Creating dataloader...") - icd9_matrix_dataset = ICD9MatrixDataset(icd9_matrix) - dataloader = DataLoader( - icd9_matrix_dataset, - batch_size=args.batch_size, - shuffle=True - ) - - # autoencoder pretraining - print("Pretraining autoencoder...") - autoencoder_losses = model.pretrain_autoencoder( - dataloader=dataloader, - epochs=args.autoencoder_epochs, - lr=args.lr, - device=device - ) - - # train GAN - print("Training GAN...") - gan_loss_history = train_medgan( - model=model, - dataloader=dataloader, - n_epochs=args.gan_epochs, - device=device, - save_dir=args.save_dir, - lr=args.lr, - weight_decay=args.weight_decay, - b1=args.b1, - b2=args.b2 - ) - - # generate synthetic data - print("Generating synthetic data...") - with torch.no_grad(): - synthetic_data = model.generate(1000, device) - binary_data = model.sample_transform(synthetic_data, threshold=0.5) - - synthetic_matrix = binary_data.cpu().numpy() - - # save - print("Saving results...") - torch.save({ - 'model_config': { - 'latent_dim': args.latent_dim, - 'hidden_dim': args.hidden_dim, - 'autoencoder_hidden_dim': args.hidden_dim, - 'discriminator_hidden_dim': args.hidden_dim, - 'input_dim': icd9_matrix.shape[1], - }, - 'generator_state_dict': model.generator.state_dict(), - 'discriminator_state_dict': model.discriminator.state_dict(), - 'autoencoder_state_dict': model.autoencoder.state_dict(), - }, os.path.join(args.output_path, "medgan_final.pth")) - - np.save(os.path.join(args.output_path, "synthetic_binary_matrix.npy"), synthetic_matrix) - - # save loss histories - loss_history = { - 'autoencoder_losses': autoencoder_losses, - 'gan_losses': gan_loss_history, - } - np.save(os.path.join(args.output_path, "loss_history.npy"), loss_history) - - # print final stats - print("\n" + "="*50) - print("TRAINING COMPLETED") - print("="*50) - print(f"Real data shape: {icd9_matrix.shape}") - print(f"Real data mean activation: {icd9_matrix.mean():.4f}") - print(f"Real data sparsity: {(icd9_matrix == 0).mean():.4f}") - print(f"Synthetic data shape: {synthetic_matrix.shape}") - print(f"Synthetic data mean activation: {synthetic_matrix.mean():.4f}") - print(f"Synthetic data sparsity: {(synthetic_matrix == 0).mean():.4f}") - print(f"Results saved to: {args.output_path}") - print("="*50) - - print("\nGenerated synthetic data in original MIMIC3 ICD-9 format.") - - -if __name__ == "__main__": - main() - -""" -Slurm script example: - -#!/bin/bash -#SBATCH --account=jalenj4-ic -#SBATCH --job-name=medgan_pyhealth -#SBATCH --output=logs/medgan_pyhealth_%j.out -#SBATCH --error=logs/medgan_pyhealth_%j.err -#SBATCH --partition=IllinoisComputes-GPU # Change to appropriate partition -#SBATCH --gres=gpu:1 # Request 1 GPU -#SBATCH --cpus-per-task=4 -#SBATCH --mem=32G -#SBATCH --time=12:00:00 - -# Change to the directory where you submitted the job -cd "$SLURM_SUBMIT_DIR" - -# Print useful Slurm environment variables for debugging -echo "SLURM_JOB_ID: $SLURM_JOB_ID" -echo "SLURM_JOB_NODELIST: $SLURM_JOB_NODELIST" -echo "SLURM_NTASKS: $SLURM_NTASKS" -echo "SLURM_CPUS_ON_NODE: $SLURM_CPUS_ON_NODE" -echo "SLURM_GPUS_ON_NODE: $SLURM_GPUS_ON_NODE" -echo "SLURM_GPUS: $SLURM_GPUS" -echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES" - -# Optional: check what GPU(s) is/are actually visible -echo "Running nvidia-smi to confirm GPU availability:" -nvidia-smi - -# Load modules or activate environment -# module load python/3.10 -# module load cuda/11.7 -# conda activate pyhealth - -# Create output directories -mkdir -p logs -mkdir -p medgan_results - -# Set parameters (matching original synthEHRella defaults) -export AUTOENCODER_EPOCHS=100 -export GAN_EPOCHS=1000 -export BATCH_SIZE=128 -export LATENT_DIM=128 -export HIDDEN_DIM=128 -export NUM_SAMPLES=1000 -export LEARNING_RATE=0.001 -export WEIGHT_DECAY=0.0001 -export BETA1=0.5 -export BETA2=0.9 - -echo "Starting PyHealth MedGAN training with parameters:" -echo " Autoencoder epochs: $AUTOENCODER_EPOCHS" -echo " GAN epochs: $GAN_EPOCHS" -echo " Batch size: $BATCH_SIZE" -echo " Latent dimension: $LATENT_DIM" -echo " Hidden dimension: $HIDDEN_DIM" -echo " Number of synthetic samples: $NUM_SAMPLES" -echo " Learning rate: $LEARNING_RATE" -echo " Weight decay: $WEIGHT_DECAY" -echo " Beta1: $BETA1" -echo " Beta2: $BETA2" - -# Run the comprehensive PyHealth MedGAN script -python examples/synthetic_data_generation_mimic3_medgan.py \ - --data_path ./data_files \ - --output_path ./medgan_results \ - --autoencoder_epochs $AUTOENCODER_EPOCHS \ - --gan_epochs $GAN_EPOCHS \ - --batch_size $BATCH_SIZE \ - --latent_dim $LATENT_DIM \ - --hidden_dim $HIDDEN_DIM \ - --lr $LEARNING_RATE \ - --weight_decay $WEIGHT_DECAY \ - --b1 $BETA1 \ - --b2 $BETA2 \ - --postprocess - -echo "PyHealth MedGAN training completed!" -echo "Results saved to: ./medgan_results/" -echo "Check the following files:" -echo " - synthetic_binary_matrix.npy: Raw synthetic data" -echo " - medgan_final.pth: Trained model" -echo " - loss_history.npy: Training loss history" -""" \ No newline at end of file diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index a13b18a51..32bfa9338 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -11,6 +11,7 @@ from .jamba_ehr import JambaEHR, JambaLayer from .logistic_regression import LogisticRegression from .gan import GAN +from .generators.medgan import MedGAN from .gnn import GAT, GCN from .graph_torchvision_model import Graph_TorchvisionModel from .grasp import GRASP, GRASPLayer From c0209faae4019bb5865064f3d48b40d3282b97db Mon Sep 17 00:00:00 2001 From: jalengg Date: Sun, 8 Mar 2026 21:43:51 -0500 Subject: [PATCH 5/9] T8: add integration tests (8 pass, 4 skip without MIMIC-III) --- pyhealth/models/generators/medgan.py | 8 +- tests/integration/test_medgan_end_to_end.py | 358 ++++++++++++++++++++ 2 files changed, 363 insertions(+), 3 deletions(-) create mode 100644 tests/integration/test_medgan_end_to_end.py diff --git a/pyhealth/models/generators/medgan.py b/pyhealth/models/generators/medgan.py index c4ff8fa56..5b689734e 100644 --- a/pyhealth/models/generators/medgan.py +++ b/pyhealth/models/generators/medgan.py @@ -271,7 +271,8 @@ def train_model(self, train_dataset, val_dataset=None): sampler=sampler, ) - os.makedirs(self.save_dir, exist_ok=True) + if self.save_dir: + os.makedirs(self.save_dir, exist_ok=True) # ---- Phase 1: Autoencoder pretraining ---- print(f"Phase 1: Pretraining autoencoder for {self.ae_epochs} epochs...") @@ -366,12 +367,13 @@ def train_model(self, train_dataset, val_dataset=None): ) # Save best checkpoint - if avg_d < best_d_loss: + if self.save_dir and avg_d < best_d_loss: best_d_loss = avg_d self.save_model(os.path.join(self.save_dir, "best.pt")) # Save final checkpoint - self.save_model(os.path.join(self.save_dir, "final.pt")) + if self.save_dir: + self.save_model(os.path.join(self.save_dir, "final.pt")) print("Training complete.") def synthesize_dataset( diff --git a/tests/integration/test_medgan_end_to_end.py b/tests/integration/test_medgan_end_to_end.py new file mode 100644 index 000000000..cb5e90491 --- /dev/null +++ b/tests/integration/test_medgan_end_to_end.py @@ -0,0 +1,358 @@ +"""End-to-end integration tests for the MedGAN synthetic EHR generation pipeline. + +Category A tests use InMemorySampleDataset with synthetic data — no external +data required and must always pass. + +Category B tests require actual MIMIC-III data and are skipped gracefully when +the data is unavailable. + +The bootstrap pattern mirrors test_corgan_end_to_end.py: load MedGAN and +InMemorySampleDataset via importlib while stubbing out heavy optional +dependencies (einops, litdata, etc.) that are not yet in the venv. +""" + +import importlib.util +import os +import sys +import tempfile +import unittest +from unittest.mock import MagicMock + + +# --------------------------------------------------------------------------- +# Bootstrap: load MedGAN, BaseModel, and InMemorySampleDataset without +# triggering pyhealth.models.__init__ or pyhealth.datasets.__init__. +# --------------------------------------------------------------------------- + + +def _bootstrap(): + """Load MedGAN, BaseModel, and InMemorySampleDataset via importlib. + + Returns: + (BaseModel, MedGAN, InMemorySampleDataset) + """ + import pyhealth # noqa: F401 — top-level __init__ has no heavy deps + + # Stub pyhealth.datasets so that base_model.py's + # "from ..datasets import SampleDataset" resolves cleanly. + if "pyhealth.datasets" not in sys.modules: + ds_stub = MagicMock() + + class _FakeSampleDataset: + pass + + ds_stub.SampleDataset = _FakeSampleDataset + sys.modules["pyhealth.datasets"] = ds_stub + + # Stub pyhealth.models so we can control loading without the real __init__. + if "pyhealth.models" not in sys.modules or isinstance( + sys.modules["pyhealth.models"], MagicMock + ): + models_stub = MagicMock() + sys.modules["pyhealth.models"] = models_stub + else: + models_stub = sys.modules["pyhealth.models"] + + # Processors are safe to import normally. + from pyhealth.processors import PROCESSOR_REGISTRY # noqa: F401 + + def _load_file(mod_name, filepath): + spec = importlib.util.spec_from_file_location(mod_name, filepath) + mod = importlib.util.module_from_spec(spec) + sys.modules[mod_name] = mod + spec.loader.exec_module(mod) + return mod + + root = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + base = os.path.join(root, "pyhealth", "models") + + # Load base_model and expose via stub. + bm_mod = _load_file( + "pyhealth.models.base_model", os.path.join(base, "base_model.py") + ) + BaseModel = bm_mod.BaseModel + models_stub.BaseModel = BaseModel + + gen_stub = MagicMock() + sys.modules.setdefault("pyhealth.models.generators", gen_stub) + + # Load MedGAN directly. + medgan_mod = _load_file( + "pyhealth.models.generators.medgan", + os.path.join(base, "generators", "medgan.py"), + ) + MedGAN = medgan_mod.MedGAN + + # Stub litdata so sample_dataset.py can be loaded. + if "litdata" not in sys.modules: + litdata_pkg = MagicMock() + litdata_pkg.StreamingDataset = type( + "StreamingDataset", (), {"__init__": lambda self, *a, **kw: None} + ) + litdata_utilities = MagicMock() + litdata_utilities_train_test = MagicMock() + litdata_utilities_train_test.deepcopy_dataset = lambda x: x + litdata_utilities.train_test_split = litdata_utilities_train_test + litdata_pkg.utilities = litdata_utilities + sys.modules["litdata"] = litdata_pkg + sys.modules["litdata.utilities"] = litdata_utilities + sys.modules["litdata.utilities.train_test_split"] = ( + litdata_utilities_train_test + ) + + # Load sample_dataset.py directly (bypasses datasets/__init__.py). + ds_file_mod = _load_file( + "pyhealth.datasets.sample_dataset", + os.path.join(root, "pyhealth", "datasets", "sample_dataset.py"), + ) + InMemorySampleDataset = ds_file_mod.InMemorySampleDataset + + return BaseModel, MedGAN, InMemorySampleDataset + + +BaseModel, MedGAN, InMemorySampleDataset = _bootstrap() + +import torch # noqa: E402 + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + +_SMALL_SAMPLES = [ + {"patient_id": "p1", "visits": ["A", "B", "C"]}, + {"patient_id": "p2", "visits": ["A", "C", "D"]}, + {"patient_id": "p3", "visits": ["B", "D", "E"]}, + {"patient_id": "p4", "visits": ["A", "B", "C", "D"]}, + {"patient_id": "p5", "visits": ["C", "E"]}, + {"patient_id": "p6", "visits": ["A", "D", "E"]}, + {"patient_id": "p7", "visits": ["B", "C", "D"]}, + {"patient_id": "p8", "visits": ["A", "E"]}, +] + +_SMALL_MODEL_KWARGS = dict( + latent_dim=4, + hidden_dim=4, + autoencoder_hidden_dim=4, + discriminator_hidden_dim=8, + batch_size=4, + ae_epochs=1, + gan_epochs=1, + save_dir=None, +) + + +def _make_dataset(samples=None): + if samples is None: + samples = _SMALL_SAMPLES + return InMemorySampleDataset( + samples=samples, + input_schema={"visits": "multi_hot"}, + output_schema={}, + ) + + +# --------------------------------------------------------------------------- +# Category A: In-Memory Integration Tests (must always pass) +# --------------------------------------------------------------------------- + + +class TestMedGANIsBaseModelInstance(unittest.TestCase): + """MedGAN model is an instance of BaseModel.""" + + def test_model_is_basemodel_instance(self): + dataset = _make_dataset() + model = MedGAN(dataset, **_SMALL_MODEL_KWARGS) + self.assertIsInstance(model, BaseModel) + + +class TestMedGANFeatureKeys(unittest.TestCase): + """model.feature_keys equals ['visits'].""" + + def test_feature_keys(self): + dataset = _make_dataset() + model = MedGAN(dataset, **_SMALL_MODEL_KWARGS) + self.assertEqual(model.feature_keys, ["visits"]) + + +class TestMedGANVocabSize(unittest.TestCase): + """MedGAN.input_dim matches processor.size().""" + + def test_vocab_size_matches_processor(self): + dataset = _make_dataset() + expected = dataset.input_processors["visits"].size() + model = MedGAN(dataset, **_SMALL_MODEL_KWARGS) + self.assertEqual(model.input_dim, expected) + + +class TestMedGANForwardRaisesNotImplementedError(unittest.TestCase): + """Calling forward() raises NotImplementedError.""" + + def test_forward_not_implemented(self): + dataset = _make_dataset() + model = MedGAN(dataset, **_SMALL_MODEL_KWARGS) + with self.assertRaises(NotImplementedError): + model.forward() + + +class TestMedGANTrainModelRuns(unittest.TestCase): + """train_model completes one epoch without error.""" + + def test_train_model_runs_one_epoch(self): + dataset = _make_dataset() + model = MedGAN(dataset, **_SMALL_MODEL_KWARGS) + try: + model.train_model(dataset, val_dataset=None) + except Exception as exc: + self.fail(f"train_model raised an unexpected exception: {exc}") + + +class TestMedGANSynthesizeCount(unittest.TestCase): + """synthesize_dataset(num_samples=5) returns exactly 5 dicts.""" + + def setUp(self): + dataset = _make_dataset() + self.model = MedGAN(dataset, **_SMALL_MODEL_KWARGS) + + def test_synthesize_returns_correct_count(self): + result = self.model.synthesize_dataset(num_samples=5) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 5) + + +class TestMedGANSynthesizeOutputStructure(unittest.TestCase): + """Each synthesized dict has patient_id (str) and visits (flat list of str).""" + + def setUp(self): + dataset = _make_dataset() + self.model = MedGAN(dataset, **_SMALL_MODEL_KWARGS) + + def test_synthesize_output_structure(self): + result = self.model.synthesize_dataset(num_samples=3) + for i, item in enumerate(result): + self.assertIsInstance(item, dict, f"Item {i} is not a dict") + self.assertIn("patient_id", item, f"Item {i} missing 'patient_id'") + self.assertIn("visits", item, f"Item {i} missing 'visits'") + self.assertIsInstance( + item["patient_id"], str, f"patient_id in item {i} is not a str" + ) + self.assertIsInstance( + item["visits"], list, f"visits in item {i} is not a list" + ) + for code in item["visits"]: + self.assertIsInstance( + code, str, f"code '{code}' in item {i} is not a str" + ) + + +class TestMedGANSaveLoadRoundtrip(unittest.TestCase): + """save_model then load_model; synthesize_dataset still returns correct count.""" + + def test_save_load_roundtrip(self): + dataset = _make_dataset() + model = MedGAN(dataset, **_SMALL_MODEL_KWARGS) + with tempfile.TemporaryDirectory() as tmpdir: + ckpt_path = os.path.join(tmpdir, "medgan_test.pt") + model.save_model(ckpt_path) + self.assertTrue( + os.path.exists(ckpt_path), + f"Expected checkpoint at {ckpt_path}", + ) + model.load_model(ckpt_path) + result = model.synthesize_dataset(num_samples=3) + self.assertEqual(len(result), 3) + + +# --------------------------------------------------------------------------- +# Category B: MIMIC-III Integration Tests (skipped if data unavailable) +# --------------------------------------------------------------------------- + +_MIMIC3_PATH = os.environ.get( + "PYHEALTH_MIMIC3_PATH", + "/srv/local/data/physionet.org/files/mimiciii/1.4", +) + + +class TestMedGANMIMIC3Integration(unittest.TestCase): + """End-to-end pipeline test with actual MIMIC-III data.""" + + @classmethod + def setUpClass(cls): + cls.skip_integration = False + cls.skip_reason = "" + try: + _saved_stub = sys.modules.pop("pyhealth.datasets", None) + try: + import importlib as _il + + _il.invalidate_caches() + from pyhealth.datasets import MIMIC3Dataset as _MIMIC3Dataset + from pyhealth.tasks.medgan_generation import ( + MedGANGenerationMIMIC3, + ) + except (ImportError, ModuleNotFoundError) as exc: + if _saved_stub is not None: + sys.modules["pyhealth.datasets"] = _saved_stub + raise ImportError(str(exc)) from exc + + cls.dataset = _MIMIC3Dataset( + root=_MIMIC3_PATH, + tables=["diagnoses_icd"], + ) + task = MedGANGenerationMIMIC3() + cls.sample_dataset = cls.dataset.set_task(task) + except (FileNotFoundError, OSError, ImportError, ValueError) as exc: + cls.skip_integration = True + cls.skip_reason = str(exc) + + def setUp(self): + if self.skip_integration: + self.skipTest( + f"MIMIC-III integration test skipped: {self.skip_reason}" + ) + + def test_mimic3_set_task_returns_nonempty_dataset(self): + self.assertGreater(len(self.sample_dataset), 0) + + def test_mimic3_sample_keys(self): + for sample in self.sample_dataset: + self.assertIn("patient_id", sample) + self.assertIn("visits", sample) + + def test_mimic3_visits_are_flat_multihot_tensors(self): + processor = self.sample_dataset.input_processors["visits"] + vocab_size = processor.size() + for sample in self.sample_dataset: + visits = sample["visits"] + self.assertIsInstance(visits, torch.Tensor) + self.assertEqual(visits.shape, (vocab_size,)) + self.assertEqual(visits.dtype, torch.float32) + self.assertTrue( + torch.all((visits == 0.0) | (visits == 1.0)), + "visits tensor contains values outside {0, 1}", + ) + + def test_mimic3_full_pipeline_train_and_synthesize(self): + with tempfile.TemporaryDirectory() as tmpdir: + model = MedGAN( + self.sample_dataset, + latent_dim=64, + hidden_dim=64, + batch_size=32, + ae_epochs=1, + gan_epochs=1, + save_dir=tmpdir, + ) + model.train_model(self.sample_dataset, val_dataset=None) + synthetic = model.synthesize_dataset(num_samples=10) + self.assertEqual(len(synthetic), 10) + for item in synthetic: + self.assertIn("patient_id", item) + self.assertIn("visits", item) + self.assertIsInstance(item["visits"], list) + + +if __name__ == "__main__": + unittest.main() From 1f5286e91ade4930da0a717a3fae558046bd26c0 Mon Sep 17 00:00:00 2001 From: jalengg Date: Sun, 8 Mar 2026 21:56:29 -0500 Subject: [PATCH 6/9] Add MedGAN Colab notebook for MIMIC-III synthetic data generation --- examples/medgan_mimic3_colab.ipynb | 584 +++++++++++++++++++++++++++++ 1 file changed, 584 insertions(+) create mode 100644 examples/medgan_mimic3_colab.ipynb diff --git a/examples/medgan_mimic3_colab.ipynb b/examples/medgan_mimic3_colab.ipynb new file mode 100644 index 000000000..339028ca4 --- /dev/null +++ b/examples/medgan_mimic3_colab.ipynb @@ -0,0 +1,584 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "cell-0", + "metadata": {}, + "source": [ + "# MedGAN Synthetic Data Generation for MIMIC-III\n", + "\n", + "_Last updated: 2026-03-08_\n", + "\n", + "This notebook trains MedGAN on your MIMIC-III data and generates synthetic patients.\n", + "\n", + "## What You'll Need\n", + "\n", + "1. **MIMIC-III Access**: Download these files from PhysioNet:\n", + " - `PATIENTS.csv` — patient demographics\n", + " - `ADMISSIONS.csv` — hospital admission records\n", + " - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n", + "\n", + "2. **Google Colab**: Free tier works. GPU recommended but not required (MedGAN is lightweight).\n", + "\n", + "## How It Works\n", + "\n", + "MedGAN uses a two-phase training process:\n", + "1. **Phase 1 — Autoencoder pretraining**: A linear autoencoder learns a compressed representation of binary diagnosis vectors.\n", + "2. **Phase 2 — Adversarial training**: A generator maps random noise to the autoencoder's latent space, and the decoder projects back to binary codes. A discriminator (with minibatch averaging) distinguishes real vs. synthetic records.\n", + "\n", + "Each patient is represented as a flat **bag-of-codes** (binary vector indicating which ICD-9 codes appear across all visits). MedGAN does not model visit structure.\n", + "\n", + "## References\n", + "\n", + "- [MedGAN Paper](https://arxiv.org/abs/1703.06490) — Choi et al., MLHC 2017\n", + "- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n", + "- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" + ] + }, + { + "cell_type": "markdown", + "id": "cell-1", + "metadata": {}, + "source": [ + "---\n", + "# 1. Setup & Installation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-2", + "metadata": {}, + "outputs": [], + "source": [ + "import subprocess\n", + "import sys\n", + "\n", + "FORK = 'jalengg'\n", + "BRANCH = 'medgan-pr-integration'\n", + "install_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\n", + "\n", + "subprocess.run(\n", + " [sys.executable, \"-m\", \"pip\", \"uninstall\", \"pyhealth\", \"-y\"],\n", + " capture_output=True, text=True,\n", + ")\n", + "result = subprocess.run(\n", + " [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n", + " \"--quiet\", \"--no-cache-dir\"],\n", + " capture_output=True, text=True,\n", + ")\n", + "if result.returncode != 0:\n", + " print(result.stderr)\n", + " raise RuntimeError(\"PyHealth installation failed.\")\n", + "print(f\"PyHealth installed from {FORK}/{BRANCH}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-3", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import shutil\n", + "import torch\n", + "import numpy as np\n", + "import pandas as pd\n", + "from IPython.display import display\n", + "from google.colab import drive, files\n", + "\n", + "print(f\"PyTorch: {torch.__version__}\")\n", + "if torch.cuda.is_available():\n", + " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n", + "else:\n", + " print(\"No GPU detected. MedGAN is lightweight and runs fine on CPU.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-4", + "metadata": {}, + "outputs": [], + "source": [ + "# Mount Google Drive for persistent storage\n", + "print(\"Mounting Google Drive...\")\n", + "if not os.path.ismount('/content/drive'):\n", + " drive.mount('/content/drive', force_remount=True)\n", + "else:\n", + " print(\"Drive already mounted\")\n", + "print(\"Google Drive mounted\")\n", + "\n", + "BASE_DIR = '/content/drive/MyDrive/MedGAN_Training'\n", + "DATA_DIR = f'{BASE_DIR}/data'\n", + "CHECKPOINT_DIR = f'{BASE_DIR}/checkpoints'\n", + "OUTPUT_DIR = f'{BASE_DIR}/output'\n", + "\n", + "for d in [DATA_DIR, CHECKPOINT_DIR, OUTPUT_DIR]:\n", + " os.makedirs(d, exist_ok=True)\n", + "\n", + "print(f\"\\nDirectory structure:\")\n", + "print(f\" Base: {BASE_DIR}\")\n", + "print(f\" Data: {DATA_DIR}\")\n", + "print(f\" Checkpoints: {CHECKPOINT_DIR}\")\n", + "print(f\" Output: {OUTPUT_DIR}\")" + ] + }, + { + "cell_type": "markdown", + "id": "cell-5", + "metadata": {}, + "source": [ + "---\n", + "# 2. Configuration" + ] + }, + { + "cell_type": "markdown", + "id": "cell-6", + "metadata": {}, + "source": [ + "Configure your training and generation parameters below.\n", + "\n", + "**For Quick Demo (recommended first time):**\n", + "- Leave defaults (10 AE epochs, 20 GAN epochs, 50 samples)\n", + "\n", + "**For Production Quality:**\n", + "- Set `AE_EPOCHS = 100`, `GAN_EPOCHS = 200`\n", + "- Set `N_SYNTHETIC_SAMPLES = 10000`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-7", + "metadata": {}, + "outputs": [], + "source": [ + "# ============================================================\n", + "# CONFIGURATION\n", + "# ============================================================\n", + "\n", + "# Training\n", + "AE_EPOCHS = 10 # Demo: 10, Production: 100\n", + "GAN_EPOCHS = 20 # Demo: 20, Production: 200\n", + "BATCH_SIZE = 128 # Larger batches stabilize GAN training\n", + "AE_LR = 0.001 # Autoencoder learning rate\n", + "GAN_LR = 0.001 # GAN learning rate\n", + "\n", + "# Model architecture\n", + "LATENT_DIM = 128 # Generator noise / AE latent dimension\n", + "HIDDEN_DIM = 128 # Generator hidden width\n", + "AE_HIDDEN_DIM = 128 # Autoencoder hidden width\n", + "DISC_HIDDEN_DIM = 256 # Discriminator hidden width\n", + "\n", + "# Generation\n", + "N_SYNTHETIC_SAMPLES = 50 # Demo: 50, Production: 10000\n", + "\n", + "# Display\n", + "print(\"=\" * 60)\n", + "print(\"MEDGAN CONFIGURATION\")\n", + "print(\"=\" * 60)\n", + "print(f\"Training:\")\n", + "print(f\" AE epochs: {AE_EPOCHS} | GAN epochs: {GAN_EPOCHS}\")\n", + "print(f\" Batch size: {BATCH_SIZE}\")\n", + "print(f\" AE LR: {AE_LR} | GAN LR: {GAN_LR}\")\n", + "print(f\"\\nGeneration:\")\n", + "print(f\" Synthetic samples: {N_SYNTHETIC_SAMPLES}\")\n", + "print(\"=\" * 60)" + ] + }, + { + "cell_type": "markdown", + "id": "cell-8", + "metadata": {}, + "source": [ + "---\n", + "# 3. Data Upload" + ] + }, + { + "cell_type": "markdown", + "id": "cell-9", + "metadata": {}, + "source": [ + "Upload your MIMIC-III CSV files:\n", + "\n", + "1. `PATIENTS.csv` — patient demographics\n", + "2. `ADMISSIONS.csv` — admission records\n", + "3. `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n", + "\n", + "Files persist across Colab sessions when saved to Google Drive." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-10", + "metadata": {}, + "outputs": [], + "source": [ + "required_files = {\n", + " 'PATIENTS.csv': 'Patient demographics',\n", + " 'ADMISSIONS.csv': 'Admission records',\n", + " 'DIAGNOSES_ICD.csv': 'ICD-9 diagnosis codes',\n", + "}\n", + "existing = {f: os.path.exists(f'{DATA_DIR}/{f}') for f in required_files}\n", + "missing = [f for f, ok in existing.items() if not ok]\n", + "\n", + "if not missing:\n", + " print(\"All MIMIC-III files found in Drive (no upload needed):\")\n", + " for fname in required_files:\n", + " size_mb = os.path.getsize(f'{DATA_DIR}/{fname}') / 1024 / 1024\n", + " print(f\" {fname} ({size_mb:.1f} MB)\")\n", + " print(f\"\\nFiles reused from: {DATA_DIR}\")\n", + "else:\n", + " print(\"MIMIC-III file status:\")\n", + " for fname, desc in required_files.items():\n", + " mark = \"OK\" if existing[fname] else \"MISSING\"\n", + " print(f\" [{mark}] {fname} - {desc}\")\n", + "\n", + " print(f\"\\nUploading {len(missing)} missing file(s)...\")\n", + " uploaded = files.upload()\n", + "\n", + " for uploaded_name, data in uploaded.items():\n", + " matched = None\n", + " for req in required_files:\n", + " base = req.replace('.csv', '')\n", + " if base in uploaded_name and uploaded_name.endswith('.csv'):\n", + " matched = req\n", + " break\n", + " if matched:\n", + " tmp = f'/content/{uploaded_name}'\n", + " with open(tmp, 'wb') as f:\n", + " f.write(data)\n", + " dest = f'{DATA_DIR}/{matched}'\n", + " shutil.copy(tmp, dest)\n", + " size_mb = os.path.getsize(dest) / 1024 / 1024\n", + " print(f\" Saved {matched} ({size_mb:.1f} MB)\")\n", + " else:\n", + " print(f\" Unrecognised file: {uploaded_name} (skipped)\")\n", + "\n", + " missing = [f for f in required_files if not os.path.exists(f'{DATA_DIR}/{f}')]\n", + " if missing:\n", + " raise FileNotFoundError(\n", + " f\"Still missing: {missing}. Please re-run this cell.\"\n", + " )\n", + " print(\"\\nAll 3 MIMIC-III files present.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-11", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Validating MIMIC-III files...\")\n", + "\n", + "_patients = pd.read_csv(f'{DATA_DIR}/PATIENTS.csv', nrows=5)\n", + "assert 'SUBJECT_ID' in _patients.columns\n", + "print(f\" PATIENTS.csv: {len(_patients.columns)} columns\")\n", + "\n", + "_admissions = pd.read_csv(f'{DATA_DIR}/ADMISSIONS.csv', nrows=5)\n", + "assert 'HADM_ID' in _admissions.columns\n", + "print(f\" ADMISSIONS.csv: {len(_admissions.columns)} columns\")\n", + "\n", + "_diagnoses = pd.read_csv(f'{DATA_DIR}/DIAGNOSES_ICD.csv', nrows=5)\n", + "assert 'ICD9_CODE' in _diagnoses.columns\n", + "print(f\" DIAGNOSES_ICD.csv: {len(_diagnoses.columns)} columns\")\n", + "\n", + "del _patients, _admissions, _diagnoses\n", + "print(\"\\nAll files validated.\")" + ] + }, + { + "cell_type": "markdown", + "id": "cell-12", + "metadata": {}, + "source": [ + "---\n", + "# 4. Training" + ] + }, + { + "cell_type": "markdown", + "id": "cell-13", + "metadata": {}, + "source": [ + "**What happens during training:**\n", + "\n", + "1. **Dataset loading**: PyHealth reads MIMIC-III and creates one sample per patient (flat list of all ICD-9 codes across all visits).\n", + "2. **Multi-hot encoding**: The `MultiHotProcessor` converts each patient's code list into a binary vector of shape `(vocab_size,)`.\n", + "3. **Phase 1 — AE pretraining**: A linear autoencoder learns to reconstruct the binary vectors using BCE loss.\n", + "4. **Phase 2 — GAN training**: The generator maps noise to the AE latent space. The AE decoder projects to binary codes. The discriminator (with minibatch averaging) provides the adversarial signal via standard BCE loss.\n", + "5. **Checkpoint**: Best and final checkpoints saved to `CHECKPOINT_DIR`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-14", + "metadata": {}, + "outputs": [], + "source": [ + "from pyhealth.datasets import MIMIC3Dataset, split_by_patient\n", + "from pyhealth.tasks import medgan_generation_mimic3_fn\n", + "from pyhealth.models.generators.medgan import MedGAN\n", + "\n", + "print(\"Loading MIMIC-III dataset...\")\n", + "base_dataset = MIMIC3Dataset(\n", + " root=DATA_DIR,\n", + " tables=[\"diagnoses_icd\"],\n", + " dev=False,\n", + ")\n", + "print(f\"Loaded {len(base_dataset.unique_patient_ids)} patients\")\n", + "\n", + "print(\"Applying MedGAN generation task...\")\n", + "sample_dataset = base_dataset.set_task(medgan_generation_mimic3_fn)\n", + "print(f\"Eligible patients: {len(sample_dataset)}\")\n", + "\n", + "train_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\n", + "print(f\"Split: {len(train_dataset)} train / {len(val_dataset)} val\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-15", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize and train MedGAN\n", + "print(\"Initializing MedGAN model...\")\n", + "model = MedGAN(\n", + " dataset=sample_dataset,\n", + " latent_dim=LATENT_DIM,\n", + " hidden_dim=HIDDEN_DIM,\n", + " autoencoder_hidden_dim=AE_HIDDEN_DIM,\n", + " discriminator_hidden_dim=DISC_HIDDEN_DIM,\n", + " batch_size=BATCH_SIZE,\n", + " ae_epochs=AE_EPOCHS,\n", + " gan_epochs=GAN_EPOCHS,\n", + " ae_lr=AE_LR,\n", + " gan_lr=GAN_LR,\n", + " save_dir=CHECKPOINT_DIR,\n", + ")\n", + "print(f\"Vocabulary size: {model.input_dim} ICD-9 codes\")\n", + "total_params = sum(p.numel() for p in model.parameters())\n", + "print(f\"Total parameters: {total_params:,}\")\n", + "\n", + "print(\"\\nStarting training...\")\n", + "print(\"=\" * 60)\n", + "model.train_model(train_dataset, val_dataset)\n", + "print(\"=\" * 60)\n", + "print(f\"Training complete! Checkpoints saved to: {CHECKPOINT_DIR}\")" + ] + }, + { + "cell_type": "markdown", + "id": "cell-16", + "metadata": {}, + "source": [ + "---\n", + "# 5. Generation" + ] + }, + { + "cell_type": "markdown", + "id": "cell-17", + "metadata": {}, + "source": [ + "**How generation works:**\n", + "\n", + "1. Sample random noise vectors from a standard normal distribution\n", + "2. Generator maps noise to the autoencoder's latent space\n", + "3. AE decoder projects latent vectors to binary code space (sigmoid output)\n", + "4. Threshold at 0.5 to produce binary vectors\n", + "5. Map active indices back to ICD-9 code strings\n", + "\n", + "Each synthetic patient is a flat **bag-of-codes** (no visit structure)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-18", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"Generating {N_SYNTHETIC_SAMPLES} synthetic patients...\")\n", + "synthetic = model.synthesize_dataset(num_samples=N_SYNTHETIC_SAMPLES)\n", + "print(f\"Generated {len(synthetic)} synthetic patients\")\n", + "\n", + "# Summary statistics\n", + "codes_per_patient = [len(p['visits']) for p in synthetic]\n", + "avg_codes = np.mean(codes_per_patient)\n", + "non_empty = sum(1 for c in codes_per_patient if c > 0)\n", + "\n", + "print(f\"\\nStatistics:\")\n", + "print(f\" Non-empty patients: {non_empty}/{len(synthetic)}\")\n", + "print(f\" Avg codes per patient: {avg_codes:.2f}\")\n", + "print(f\" Min codes: {min(codes_per_patient)}\")\n", + "print(f\" Max codes: {max(codes_per_patient)}\")\n", + "\n", + "# Preview\n", + "preview = []\n", + "for p in synthetic[:10]:\n", + " sample_codes = ', '.join(p['visits'][:5]) + ('...' if len(p['visits']) > 5 else '')\n", + " preview.append({\n", + " 'patient_id': p['patient_id'],\n", + " 'n_codes': len(p['visits']),\n", + " 'sample_codes': sample_codes or '(empty)',\n", + " })\n", + "display(pd.DataFrame(preview))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-19", + "metadata": {}, + "outputs": [], + "source": "# Save as CSV (flat SUBJECT_ID, ICD9_CODE — one row per code per patient)\nrows = []\nfor p in synthetic:\n for code in p['visits']:\n rows.append({'SUBJECT_ID': p['patient_id'], 'ICD9_CODE': code})\n\ndf_synthetic = pd.DataFrame(rows)\ncsv_path = f'{OUTPUT_DIR}/medgan_synthetic_data.csv'\ndf_synthetic.to_csv(csv_path, index=False)\n\nprint(f\"{len(df_synthetic):,} records saved to: {csv_path}\")\nprint(f\"Columns: SUBJECT_ID, ICD9_CODE\")\nprint(f\"\\nSample rows:\")\ndisplay(df_synthetic.head(8))" + }, + { + "cell_type": "markdown", + "id": "cell-20", + "metadata": {}, + "source": [ + "---\n", + "# 6. Results & Download" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-21", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"=\" * 60)\n", + "print(\"DATA QUALITY CHECKS\")\n", + "print(\"=\" * 60)\n", + "\n", + "unique_patients = df_synthetic['SUBJECT_ID'].nunique()\n", + "print(f\"\\nUnique patients: {unique_patients} out of {N_SYNTHETIC_SAMPLES} requested\")\n", + "\n", + "# Check for empty values\n", + "empty_subjects = df_synthetic['SUBJECT_ID'].isna().sum()\n", + "empty_codes = df_synthetic['ICD9_CODE'].isna().sum()\n", + "print(f\"\\nEmpty values:\")\n", + "print(f\" Subject IDs: {empty_subjects} (should be 0)\")\n", + "print(f\" ICD9 codes: {empty_codes} (should be 0)\")\n", + "\n", + "# Code distribution\n", + "codes_per_patient = df_synthetic.groupby('SUBJECT_ID').size()\n", + "print(f\"\\nCodes per patient:\")\n", + "print(f\" Min: {codes_per_patient.min()}\")\n", + "print(f\" Max: {codes_per_patient.max()}\")\n", + "print(f\" Mean: {codes_per_patient.mean():.2f}\")\n", + "print(f\" Median: {codes_per_patient.median():.2f}\")\n", + "\n", + "# Unique codes used\n", + "unique_codes = df_synthetic['ICD9_CODE'].nunique()\n", + "print(f\"\\nUnique ICD-9 codes in synthetic data: {unique_codes}\")\n", + "\n", + "# Heterogeneity\n", + "unique_profiles = len(set(\n", + " tuple(sorted(df_synthetic[df_synthetic['SUBJECT_ID'] == pid]['ICD9_CODE'].tolist()))\n", + " for pid in df_synthetic['SUBJECT_ID'].unique()\n", + "))\n", + "print(f\"Unique patient profiles: {unique_profiles}/{unique_patients} \"\n", + " f\"({unique_profiles/max(unique_patients,1)*100:.1f}%)\")\n", + "\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(\"QUALITY CHECKS COMPLETE\")\n", + "print(\"=\" * 60)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-22", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"=\" * 60)\n", + "print(\"DOWNLOAD SYNTHETIC DATA\")\n", + "print(\"=\" * 60)\n", + "\n", + "print(f\"\\nYour synthetic data is ready:\")\n", + "print(f\" File: medgan_synthetic_data.csv\")\n", + "print(f\" Patients: {unique_patients:,}\")\n", + "print(f\" Total records: {len(df_synthetic):,}\")\n", + "print(f\" Size: {os.path.getsize(csv_path) / (1024*1024):.2f} MB\")\n", + "\n", + "print(f\"\\nDownloading...\")\n", + "files.download(csv_path)\n", + "\n", + "print(f\"\\nFile also saved in Google Drive:\")\n", + "print(f\" {csv_path}\")" + ] + }, + { + "cell_type": "markdown", + "id": "cell-23", + "metadata": {}, + "source": [ + "---\n", + "## Congratulations!\n", + "\n", + "You've successfully:\n", + "1. Trained a MedGAN model on your MIMIC-III data\n", + "2. Generated synthetic patients\n", + "3. Validated the synthetic data quality\n", + "4. Downloaded the CSV file\n", + "\n", + "## Next Steps\n", + "\n", + "**Use your synthetic data:**\n", + "- Train predictive models (readmission, mortality, etc.)\n", + "- Evaluate data utility via Train-on-Synthetic, Test-on-Real (TSTR)\n", + "- Share data without privacy concerns\n", + "\n", + "**Generate more samples:**\n", + "- Change `N_SYNTHETIC_SAMPLES` and re-run Section 5\n", + "- No need to retrain if the model is still in memory\n", + "\n", + "**Production training:**\n", + "- Set `AE_EPOCHS = 100`, `GAN_EPOCHS = 200`\n", + "- Set `N_SYNTHETIC_SAMPLES = 10000`\n", + "\n", + "## Troubleshooting\n", + "\n", + "| Symptom | Fix |\n", + "|---------|-----|\n", + "| All synthetic patients empty | Increase `GAN_EPOCHS` |\n", + "| All patients have identical codes | Increase `GAN_EPOCHS`, check `AE_EPOCHS` reduced loss |\n", + "| Training loss not decreasing | Try `GAN_LR = 0.0002` |\n", + "| Out of memory | Reduce `BATCH_SIZE` to 64 or 32 |\n", + "\n", + "## References\n", + "\n", + "- [MedGAN Paper](https://arxiv.org/abs/1703.06490) — Choi et al., MLHC 2017\n", + "- [PyHealth Docs](https://pyhealth.readthedocs.io/)\n", + "- [GitHub Issues](https://github.com/sunlabuiuc/PyHealth/issues)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file From 730dc1b08d3140754c7f027b7e63d0c3b372ede9 Mon Sep 17 00:00:00 2001 From: jalengg Date: Sun, 8 Mar 2026 22:30:13 -0500 Subject: [PATCH 7/9] Fix Colab compatibility: pin numpy, add try/except import guards --- examples/medgan_mimic3_colab.ipynb | 23 +-------- pyhealth/datasets/__init__.py | 25 ++++++++-- pyhealth/models/__init__.py | 75 ++++++++++++++++++++++-------- 3 files changed, 76 insertions(+), 47 deletions(-) diff --git a/examples/medgan_mimic3_colab.ipynb b/examples/medgan_mimic3_colab.ipynb index 339028ca4..96ade2091 100644 --- a/examples/medgan_mimic3_colab.ipynb +++ b/examples/medgan_mimic3_colab.ipynb @@ -50,28 +50,7 @@ "id": "cell-2", "metadata": {}, "outputs": [], - "source": [ - "import subprocess\n", - "import sys\n", - "\n", - "FORK = 'jalengg'\n", - "BRANCH = 'medgan-pr-integration'\n", - "install_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\n", - "\n", - "subprocess.run(\n", - " [sys.executable, \"-m\", \"pip\", \"uninstall\", \"pyhealth\", \"-y\"],\n", - " capture_output=True, text=True,\n", - ")\n", - "result = subprocess.run(\n", - " [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n", - " \"--quiet\", \"--no-cache-dir\"],\n", - " capture_output=True, text=True,\n", - ")\n", - "if result.returncode != 0:\n", - " print(result.stderr)\n", - " raise RuntimeError(\"PyHealth installation failed.\")\n", - "print(f\"PyHealth installed from {FORK}/{BRANCH}\")" - ] + "source": "import subprocess\nimport sys\n\n# Record Colab's pre-installed numpy version BEFORE installing PyHealth.\n# PyHealth's transitive deps may try to upgrade numpy, which breaks the\n# already-loaded C extensions (causes \"cannot import name '_center'\" etc.).\n_np_ver = subprocess.run(\n [sys.executable, \"-c\", \"import numpy; print(numpy.__version__)\"],\n capture_output=True, text=True,\n).stdout.strip()\nprint(f\"Colab numpy version: {_np_ver}\")\n\nFORK = 'jalengg'\nBRANCH = 'medgan-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\n\n# Uninstall old pyhealth (if any), then install from branch.\n# Pin numpy to Colab's version to prevent upgrade-induced breakage.\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"uninstall\", \"pyhealth\", \"-y\"],\n capture_output=True, text=True,\n)\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n f\"numpy=={_np_ver}\", \"--quiet\", \"--no-cache-dir\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed.\")\n\n# Verify numpy wasn't silently upgraded\n_np_after = subprocess.run(\n [sys.executable, \"-c\", \"import numpy; print(numpy.__version__)\"],\n capture_output=True, text=True,\n).stdout.strip()\nif _np_after != _np_ver:\n print(f\"WARNING: numpy changed {_np_ver} -> {_np_after}, restarting kernel...\")\n import os\n os.kill(os.getpid(), 9) # force kernel restart\n\nprint(f\"PyHealth installed from {FORK}/{BRANCH}\")\nprint(f\"numpy: {_np_after}\")" }, { "cell_type": "code", diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index effb47133..6933723a9 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -48,10 +48,16 @@ def __init__(self, *args, **kwargs): from .base_dataset import BaseDataset from .cardiology import CardiologyDataset -from .chestxray14 import ChestXray14Dataset +try: + from .chestxray14 import ChestXray14Dataset +except ImportError: + pass # PIL/torchvision unavailable from .clinvar import ClinVarDataset from .cosmic import COSMICDataset -from .covid19_cxr import COVID19CXRDataset +try: + from .covid19_cxr import COVID19CXRDataset +except ImportError: + pass # PIL/torchvision unavailable from .dreamt import DREAMTDataset from .ehrshot import EHRShotDataset from .eicu import eICUDataset @@ -63,7 +69,10 @@ def __init__(self, *args, **kwargs): from .omop import OMOPDataset from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset from .shhs import SHHSDataset -from .sleepedf import SleepEDFDataset +try: + from .sleepedf import SleepEDFDataset +except ImportError: + pass # mne unavailable from .bmd_hs import BMDHSDataset from .support2 import Support2Dataset from .tcga_prad import TCGAPRADDataset @@ -76,8 +85,14 @@ def __init__(self, *args, **kwargs): split_by_visit, split_by_visit_conformal, ) -from .tuab import TUABDataset -from .tuev import TUEVDataset +try: + from .tuab import TUABDataset +except ImportError: + pass # mne unavailable; TUABDataset not registered +try: + from .tuev import TUEVDataset +except ImportError: + pass # mne unavailable; TUEVDataset not registered from .utils import ( collate_fn_dict, collate_fn_dict_with_padding, diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 32bfa9338..447c3b913 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -1,8 +1,14 @@ from .adacare import AdaCare, AdaCareLayer from .agent import Agent, AgentLayer from .base_model import BaseModel -from .biot import BIOT -from .cnn import CNN, CNNLayer +try: + from .biot import BIOT +except ImportError: + pass # einops unavailable +try: + from .cnn import CNN, CNNLayer +except ImportError: + pass # PIL/torchvision unavailable from .concare import ConCare, ConCareLayer from .contrawr import ContraWR, ResBlock2D from .deepr import Deepr, DeeprLayer @@ -13,34 +19,63 @@ from .gan import GAN from .generators.medgan import MedGAN from .gnn import GAT, GCN -from .graph_torchvision_model import Graph_TorchvisionModel -from .grasp import GRASP, GRASPLayer -from .medlink import MedLink +try: + from .graph_torchvision_model import Graph_TorchvisionModel +except ImportError: + pass # torchvision unavailable +try: + from .grasp import GRASP, GRASPLayer +except ImportError: + pass # sklearn unavailable from .micron import MICRON, MICRONLayer from .mlp import MLP -from .molerec import MoleRec, MoleRecLayer +try: + from .molerec import MoleRec, MoleRecLayer +except ImportError: + pass # rdkit unavailable from .retain import RETAIN, RETAINLayer from .rnn import MultimodalRNN, RNN, RNNLayer -from .safedrug import SafeDrug, SafeDrugLayer +try: + from .safedrug import SafeDrug, SafeDrugLayer +except ImportError: + pass # rdkit unavailable from .sparcnet import DenseBlock, DenseLayer, SparcNet, TransitionLayer from .stagenet import StageNet, StageNetLayer from .stagenet_mha import StageAttentionNet, StageNetAttentionLayer from .tcn import TCN, TCNLayer -from .tfm_tokenizer import ( - TFMTokenizer, - TFM_VQVAE2_deep, - TFM_TOKEN_Classifier, - get_tfm_tokenizer_2x2x8, - get_tfm_token_classifier_64x4, - load_embedding_weights, -) -from .torchvision_model import TorchvisionModel +try: + from .tfm_tokenizer import ( + TFMTokenizer, + TFM_VQVAE2_deep, + TFM_TOKEN_Classifier, + get_tfm_tokenizer_2x2x8, + get_tfm_token_classifier_64x4, + load_embedding_weights, + ) +except ImportError: + pass # einops unavailable +try: + from .torchvision_model import TorchvisionModel +except ImportError: + pass # torchvision unavailable from .transformer import Transformer, TransformerLayer -from .transformers_model import TransformersModel +try: + from .transformers_model import TransformersModel +except ImportError: + pass # transformers unavailable from .ehrmamba import EHRMamba, MambaBlock from .vae import VAE -from .vision_embedding import VisionEmbeddingModel -from .text_embedding import TextEmbedding -from .sdoh import SdohClassifier +try: + from .vision_embedding import VisionEmbeddingModel +except ImportError: + pass # PIL/torchvision unavailable +try: + from .text_embedding import TextEmbedding +except ImportError: + pass # transformers unavailable +try: + from .sdoh import SdohClassifier +except ImportError: + pass # transformers/peft unavailable from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding From 40fd2eeaf4b8b76b91985765c93a9d118a71c3f4 Mon Sep 17 00:00:00 2001 From: jalengg Date: Sun, 8 Mar 2026 22:43:28 -0500 Subject: [PATCH 8/9] Fix Colab install: let numpy upgrade then auto-restart kernel --- examples/medgan_mimic3_colab.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/medgan_mimic3_colab.ipynb b/examples/medgan_mimic3_colab.ipynb index 96ade2091..5d433f427 100644 --- a/examples/medgan_mimic3_colab.ipynb +++ b/examples/medgan_mimic3_colab.ipynb @@ -50,7 +50,7 @@ "id": "cell-2", "metadata": {}, "outputs": [], - "source": "import subprocess\nimport sys\n\n# Record Colab's pre-installed numpy version BEFORE installing PyHealth.\n# PyHealth's transitive deps may try to upgrade numpy, which breaks the\n# already-loaded C extensions (causes \"cannot import name '_center'\" etc.).\n_np_ver = subprocess.run(\n [sys.executable, \"-c\", \"import numpy; print(numpy.__version__)\"],\n capture_output=True, text=True,\n).stdout.strip()\nprint(f\"Colab numpy version: {_np_ver}\")\n\nFORK = 'jalengg'\nBRANCH = 'medgan-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\n\n# Uninstall old pyhealth (if any), then install from branch.\n# Pin numpy to Colab's version to prevent upgrade-induced breakage.\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"uninstall\", \"pyhealth\", \"-y\"],\n capture_output=True, text=True,\n)\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n f\"numpy=={_np_ver}\", \"--quiet\", \"--no-cache-dir\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed.\")\n\n# Verify numpy wasn't silently upgraded\n_np_after = subprocess.run(\n [sys.executable, \"-c\", \"import numpy; print(numpy.__version__)\"],\n capture_output=True, text=True,\n).stdout.strip()\nif _np_after != _np_ver:\n print(f\"WARNING: numpy changed {_np_ver} -> {_np_after}, restarting kernel...\")\n import os\n os.kill(os.getpid(), 9) # force kernel restart\n\nprint(f\"PyHealth installed from {FORK}/{BRANCH}\")\nprint(f\"numpy: {_np_after}\")" + "source": "import subprocess\nimport sys\nimport os\n\n# Record Colab's pre-installed numpy version so we can detect if it changed.\n_np_before = subprocess.run(\n [sys.executable, \"-c\", \"import numpy; print(numpy.__version__)\"],\n capture_output=True, text=True,\n).stdout.strip()\n\nFORK = 'jalengg'\nBRANCH = 'medgan-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\n\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"uninstall\", \"pyhealth\", \"-y\"],\n capture_output=True, text=True,\n)\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url, \"--quiet\", \"--no-cache-dir\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed.\")\n\n# Check if numpy was upgraded. If so, the already-loaded C extensions are\n# stale and will cause \"cannot import name '_center'\" errors. The only\n# reliable fix is to restart the kernel so Python loads the new .so files.\n_np_after = subprocess.run(\n [sys.executable, \"-c\", \"import numpy; print(numpy.__version__)\"],\n capture_output=True, text=True,\n).stdout.strip()\n\nif _np_after != _np_before:\n print(f\"numpy upgraded {_np_before} -> {_np_after}. Restarting kernel...\")\n print(\">>> After restart, re-run this cell. It will skip the restart on the second run. <<<\")\n os.kill(os.getpid(), 9) # force kernel restart\n\nprint(f\"PyHealth installed from {FORK}/{BRANCH}\")\nprint(f\"numpy: {_np_after}\")" }, { "cell_type": "code", From 786625d6fd18ebdb4688d34862690d270eda2bdf Mon Sep 17 00:00:00 2001 From: jalengg Date: Mon, 16 Mar 2026 15:55:52 -0500 Subject: [PATCH 9/9] Fix: remove icustays from MIMIC3Dataset default tables --- pyhealth/datasets/mimic3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyhealth/datasets/mimic3.py b/pyhealth/datasets/mimic3.py index 7e569d2f3..3ff38d5ca 100644 --- a/pyhealth/datasets/mimic3.py +++ b/pyhealth/datasets/mimic3.py @@ -15,7 +15,7 @@ class MIMIC3Dataset(BaseDataset): A dataset class for handling MIMIC-III data. This class is responsible for loading and managing the MIMIC-III dataset, - which includes tables such as patients, admissions, and icustays. + which includes tables such as patients, admissions, diagnoses_icd, etc. Attributes: root (str): The root directory where the dataset is stored. @@ -53,7 +53,7 @@ def __init__( if config_path is None: logger.info("No config path provided, using default config") config_path = Path(__file__).parent / "configs" / "mimic3.yaml" - default_tables = ["patients", "admissions", "icustays"] + default_tables = ["patients", "admissions"] tables = default_tables + tables if "prescriptions" in tables: warnings.warn(