diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 8d9a59d21..15862a6a8 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -244,5 +244,6 @@ Available Datasets datasets/pyhealth.datasets.ClinVarDataset datasets/pyhealth.datasets.COSMICDataset datasets/pyhealth.datasets.TCGAPRADDataset + datasets/pyhealth.datasets.TCGACRCkDataset datasets/pyhealth.datasets.splitter datasets/pyhealth.datasets.utils diff --git a/docs/api/datasets/pyhealth.datasets.TCGACRCkDataset.rst b/docs/api/datasets/pyhealth.datasets.TCGACRCkDataset.rst new file mode 100644 index 000000000..5d24c4433 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.TCGACRCkDataset.rst @@ -0,0 +1,11 @@ +pyhealth.datasets.TCGACRCkDataset +================================= + +The Cancer Genome Atlas Colorectal Carcinoma (TCGA-CRC) dataset is a comprehensive molecular and histopathology tile dataset for slide-level MSI prediction. + +This dataset normalizes the public TCGA-CRCk metadata into PyHealth's patient / visit / event representation. Each slide is represented as one patient/visit and each tile is represented as one event. + +.. autoclass:: pyhealth.datasets.TCGACRCkDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/models.rst b/docs/api/models.rst index 7c3ac7c4b..4599243c2 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -206,3 +206,4 @@ API Reference models/pyhealth.models.BIOT models/pyhealth.models.unified_multimodal_embedding_docs models/pyhealth.models.califorest + models/pyhealth.models.TissueAwareSimCLR diff --git a/docs/api/models/pyhealth.models.TissueAwareSimCLR.rst b/docs/api/models/pyhealth.models.TissueAwareSimCLR.rst new file mode 100644 index 000000000..ab5d72528 --- /dev/null +++ b/docs/api/models/pyhealth.models.TissueAwareSimCLR.rst @@ -0,0 +1,9 @@ +pyhealth.models.TissueAwareSimCLR +=================================== + +Tissue-aware SimCLR ResNet-18 MIL classifier with for slide-level MSI prediction [TCGA-CRCk]. + +.. autoclass:: pyhealth.models.TissueAwareSimCLR + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 23a4e06e5..69cca36f5 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -230,3 +230,4 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + Slide-level MSI Classification (TCGA-CRCk) \ No newline at end of file diff --git a/docs/api/tasks/pyhealth.tasks.TCGACRCkMSIClassification.rst b/docs/api/tasks/pyhealth.tasks.TCGACRCkMSIClassification.rst new file mode 100644 index 000000000..8fbf186e5 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.TCGACRCkMSIClassification.rst @@ -0,0 +1,11 @@ +pyhealth.tasks.TCGACRCkMSIClassification +======================================== + +Slide-level MSI classification task for TCGA-CRCk. + +This task groups all tile events from the same slide into a single bag for binary MSI classification. It is designed for the TCGA-CRCk dataset and uses PyHealth's 'time_image' input type so each sample can be represented as a bag of image paths with simple monotonic timestamps. + +.. autoclass:: pyhealth.tasks.TCGACRCkMSIClassification + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/examples/tcga_crck_simclr/tcga_crck_msi_classification_tissue_aware_simclr.py b/examples/tcga_crck_simclr/tcga_crck_msi_classification_tissue_aware_simclr.py new file mode 100644 index 000000000..00b96e097 --- /dev/null +++ b/examples/tcga_crck_simclr/tcga_crck_msi_classification_tissue_aware_simclr.py @@ -0,0 +1,255 @@ +from __future__ import annotations + +import argparse +from pathlib import Path +from urllib.parse import urlsplit +from urllib.request import urlretrieve + +import numpy as np +import torch +from sklearn.model_selection import train_test_split + +from pyhealth.datasets import TCGACRCkDataset, get_dataloader +from pyhealth.models import TissueAwareSimCLR +from pyhealth.tasks import TCGACRCkMSIClassification +from pyhealth.trainer import Trainer + + +DATA_ROOT = Path.home() / "TCGA_CRCk" +CACHE_DIR = Path("/home/ubuntu/.cache/pyhealth_local") +CHECKPOINT_CACHE_DIR = Path.home() / ".cache" / "pyhealth_checkpoints" + +# Keep downstream params close to the paper +BATCH_SIZE = 32 +MAX_TILES = 1000 +MAX_EPOCHS = 100 +HIDDEN_DIM = 128 +DROPOUT = 0.25 +LR = 5e-3 +MOMENTUM = 0.6 +WEIGHT_DECAY = 1e-4 +# PATIENCE = 50 +POOLING = "attention" + +# Runtime optimization only +TILE_CHUNK_SIZE = 1024 +SEED = 42 + +# Ablation commands: +# +# 1) Main experiment: pretrained encoder + fine-tuning +# python /examples/tcga_crck_simclr/tcga_crck_msi_classification_tissue_aware_simclr.py \ +# --pretrain-from-checkpoint /path/to/checkpoint.ckpt +# +# 2) Ablation 1: no pretraining (random initialization) +# python /examples/tcga_crck_simclr/tcga_crck_msi_classification_tissue_aware_simclr.py +# +# 3) Ablation 2: pretrained encoder + frozen encoder +# python /examples/tcga_crck_simclr/tcga_crck_msi_classification_tissue_aware_simclr.py \ +# --pretrain-from-checkpoint /path/to/checkpoint.ckpt \ +# --freeze-encoder + + +def parse_args() -> argparse.Namespace: + """Parses command-line arguments for downstream MSI classification. + + Returns: + argparse.Namespace: Parsed arguments containing the optional + pretrained checkpoint path and encoder freezing flag. + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--pretrain-from-checkpoint", + type=str, + default=None, + help="Local path or HTTP(S) URL for a pretrained encoder checkpoint.", + ) + parser.add_argument( + "--freeze-encoder", + action="store_true", + help="Freeze the encoder during downstream training.", + ) + return parser.parse_args() + + +def resolve_checkpoint_path(checkpoint_spec: str | None) -> str | None: + """Resolves a checkpoint spec into a usable local checkpoint path. + + If the input is an HTTP(S) URL, the checkpoint is downloaded into the + local cache directory. If it is a local path, the file must already exist. + + Args: + checkpoint_spec: Local filesystem path or HTTP(S) URL to a pretrained + encoder checkpoint. If None, no checkpoint is used. + + Returns: + str | None: Local path to the checkpoint, or None if no checkpoint + was provided. + + Raises: + FileNotFoundError: If a provided local checkpoint path does not exist. + """ + if checkpoint_spec is None: + return None + + parts = urlsplit(checkpoint_spec) + if parts.scheme in {"http", "https"}: + filename = Path(parts.path).name or "checkpoint.ckpt" + CHECKPOINT_CACHE_DIR.mkdir(parents=True, exist_ok=True) + local_path = CHECKPOINT_CACHE_DIR / filename + if not local_path.exists(): + print(f"Downloading checkpoint to {local_path}", flush=True) + urlretrieve(checkpoint_spec, local_path) + else: + print(f"Using cached checkpoint at {local_path}", flush=True) + return str(local_path) + + local_path = Path(checkpoint_spec).expanduser() + if not local_path.exists(): + raise FileNotFoundError(f"Checkpoint does not exist: {local_path}") + return str(local_path) + + +def build_splits(sample_dataset): + """Builds train, validation, and test splits from the task dataset. + + The function reads the `data_split` field from each sample, separates + train and test data accordingly, and then creates a stratified validation + split from the training partition. + + Args: + sample_dataset: Task-specific PyHealth dataset produced by + `set_task(...)`. + + Returns: + tuple: A tuple of `(train_dataset, val_dataset, test_dataset)`. + + Raises: + ValueError: If an unknown split label is encountered or if train/test + samples are missing. + """ + train_indices = [] + test_indices = [] + + for i in range(len(sample_dataset)): + split = str(sample_dataset[i]["data_split"]).strip().lower() + if split in {"train", "training", "tr"}: + train_indices.append(i) + elif split in {"test", "testing", "te"}: + test_indices.append(i) + else: + raise ValueError(f"Unknown data_split: {split}") + + if not train_indices or not test_indices: + raise ValueError( + f"Expected both train and test samples, got train={len(train_indices)}, " + f"test={len(test_indices)}" + ) + + train_labels = [int(sample_dataset[i]["label"]) for i in train_indices] + train_indices, val_indices = train_test_split( + train_indices, + test_size=0.2, + random_state=SEED, + stratify=train_labels, + ) + + train_dataset = sample_dataset.subset(train_indices) + val_dataset = sample_dataset.subset(val_indices) + test_dataset = sample_dataset.subset(test_indices) + return train_dataset, val_dataset, test_dataset + + +def main() -> None: + """Runs downstream MSI classification with a TissueAwareSimCLR encoder. + + This function: + 1. Parses command-line arguments. + 2. Sets random seeds. + 3. Checks for CUDA availability. + 4. Builds the TCGA-CRCk dataset and downstream MSI task. + 5. Splits the dataset into train/validation/test sets. + 6. Creates dataloaders, model, and trainer. + 7. Trains the model and evaluates on validation and test splits. + + Raises: + RuntimeError: If CUDA is unavailable. + """ + args = parse_args() + + torch.manual_seed(SEED) + np.random.seed(SEED) + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA GPU is required for this run, but no GPU was found.") + + device = torch.device("cuda") + + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + + checkpoint_path = resolve_checkpoint_path(args.pretrain_from_checkpoint) + print(f"Resolved checkpoint: {checkpoint_path}", flush=True) + + base_dataset = TCGACRCkDataset( + root=str(DATA_ROOT), + cache_dir=str(CACHE_DIR), + ) + + sample_dataset = base_dataset.set_task( + TCGACRCkMSIClassification(max_tiles=MAX_TILES) + ) + print(f"Task dataset size: {len(sample_dataset)}", flush=True) + + train_dataset, val_dataset, test_dataset = build_splits(sample_dataset) + print( + f"Split sizes | train={len(train_dataset)} val={len(val_dataset)} test={len(test_dataset)}", + flush=True, + ) + + train_loader = get_dataloader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=BATCH_SIZE, shuffle=False) + test_loader = get_dataloader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) + + model = TissueAwareSimCLR( + dataset=train_dataset, + checkpoint_path=checkpoint_path, + hidden_dim=HIDDEN_DIM, + dropout=DROPOUT, + freeze_encoder=args.freeze_encoder, + pooling=POOLING, + tile_chunk_size=TILE_CHUNK_SIZE, + use_bf16=(device.type == "cuda"), + ).to(device) + + trainer = Trainer( + model=model, + metrics=["accuracy", "balanced_accuracy", "precision", "recall", "f1", "roc_auc", "pr_auc"], + device=str(device), + enable_logging=False, + ) + + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=MAX_EPOCHS, + optimizer_class=torch.optim.Adam, + optimizer_params={"lr": LR, "betas": (MOMENTUM, 0.999)}, + weight_decay=WEIGHT_DECAY, + monitor="balanced_accuracy", + monitor_criterion="max", +# patience=PATIENCE, + load_best_model_at_last=True, + ) + + print("\nValidation metrics:", flush=True) + print(trainer.evaluate(val_loader), flush=True) + + print("\nTest metrics:", flush=True) + print(trainer.evaluate(test_loader), flush=True) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyhealth-env/bin/Activate.ps1 b/pyhealth-env/bin/Activate.ps1 new file mode 100644 index 000000000..b49d77ba4 --- /dev/null +++ b/pyhealth-env/bin/Activate.ps1 @@ -0,0 +1,247 @@ +<# +.Synopsis +Activate a Python virtual environment for the current PowerShell session. + +.Description +Pushes the python executable for a virtual environment to the front of the +$Env:PATH environment variable and sets the prompt to signify that you are +in a Python virtual environment. Makes use of the command line switches as +well as the `pyvenv.cfg` file values present in the virtual environment. + +.Parameter VenvDir +Path to the directory that contains the virtual environment to activate. The +default value for this is the parent of the directory that the Activate.ps1 +script is located within. + +.Parameter Prompt +The prompt prefix to display when this virtual environment is activated. By +default, this prompt is the name of the virtual environment folder (VenvDir) +surrounded by parentheses and followed by a single space (ie. '(.venv) '). + +.Example +Activate.ps1 +Activates the Python virtual environment that contains the Activate.ps1 script. + +.Example +Activate.ps1 -Verbose +Activates the Python virtual environment that contains the Activate.ps1 script, +and shows extra information about the activation as it executes. + +.Example +Activate.ps1 -VenvDir C:\Users\MyUser\Common\.venv +Activates the Python virtual environment located in the specified location. + +.Example +Activate.ps1 -Prompt "MyPython" +Activates the Python virtual environment that contains the Activate.ps1 script, +and prefixes the current prompt with the specified string (surrounded in +parentheses) while the virtual environment is active. + +.Notes +On Windows, it may be required to enable this Activate.ps1 script by setting the +execution policy for the user. You can do this by issuing the following PowerShell +command: + +PS C:\> Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser + +For more information on Execution Policies: +https://go.microsoft.com/fwlink/?LinkID=135170 + +#> +Param( + [Parameter(Mandatory = $false)] + [String] + $VenvDir, + [Parameter(Mandatory = $false)] + [String] + $Prompt +) + +<# Function declarations --------------------------------------------------- #> + +<# +.Synopsis +Remove all shell session elements added by the Activate script, including the +addition of the virtual environment's Python executable from the beginning of +the PATH variable. + +.Parameter NonDestructive +If present, do not remove this function from the global namespace for the +session. + +#> +function global:deactivate ([switch]$NonDestructive) { + # Revert to original values + + # The prior prompt: + if (Test-Path -Path Function:_OLD_VIRTUAL_PROMPT) { + Copy-Item -Path Function:_OLD_VIRTUAL_PROMPT -Destination Function:prompt + Remove-Item -Path Function:_OLD_VIRTUAL_PROMPT + } + + # The prior PYTHONHOME: + if (Test-Path -Path Env:_OLD_VIRTUAL_PYTHONHOME) { + Copy-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME -Destination Env:PYTHONHOME + Remove-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME + } + + # The prior PATH: + if (Test-Path -Path Env:_OLD_VIRTUAL_PATH) { + Copy-Item -Path Env:_OLD_VIRTUAL_PATH -Destination Env:PATH + Remove-Item -Path Env:_OLD_VIRTUAL_PATH + } + + # Just remove the VIRTUAL_ENV altogether: + if (Test-Path -Path Env:VIRTUAL_ENV) { + Remove-Item -Path env:VIRTUAL_ENV + } + + # Just remove VIRTUAL_ENV_PROMPT altogether. + if (Test-Path -Path Env:VIRTUAL_ENV_PROMPT) { + Remove-Item -Path env:VIRTUAL_ENV_PROMPT + } + + # Just remove the _PYTHON_VENV_PROMPT_PREFIX altogether: + if (Get-Variable -Name "_PYTHON_VENV_PROMPT_PREFIX" -ErrorAction SilentlyContinue) { + Remove-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Scope Global -Force + } + + # Leave deactivate function in the global namespace if requested: + if (-not $NonDestructive) { + Remove-Item -Path function:deactivate + } +} + +<# +.Description +Get-PyVenvConfig parses the values from the pyvenv.cfg file located in the +given folder, and returns them in a map. + +For each line in the pyvenv.cfg file, if that line can be parsed into exactly +two strings separated by `=` (with any amount of whitespace surrounding the =) +then it is considered a `key = value` line. The left hand string is the key, +the right hand is the value. + +If the value starts with a `'` or a `"` then the first and last character is +stripped from the value before being captured. + +.Parameter ConfigDir +Path to the directory that contains the `pyvenv.cfg` file. +#> +function Get-PyVenvConfig( + [String] + $ConfigDir +) { + Write-Verbose "Given ConfigDir=$ConfigDir, obtain values in pyvenv.cfg" + + # Ensure the file exists, and issue a warning if it doesn't (but still allow the function to continue). + $pyvenvConfigPath = Join-Path -Resolve -Path $ConfigDir -ChildPath 'pyvenv.cfg' -ErrorAction Continue + + # An empty map will be returned if no config file is found. + $pyvenvConfig = @{ } + + if ($pyvenvConfigPath) { + + Write-Verbose "File exists, parse `key = value` lines" + $pyvenvConfigContent = Get-Content -Path $pyvenvConfigPath + + $pyvenvConfigContent | ForEach-Object { + $keyval = $PSItem -split "\s*=\s*", 2 + if ($keyval[0] -and $keyval[1]) { + $val = $keyval[1] + + # Remove extraneous quotations around a string value. + if ("'""".Contains($val.Substring(0, 1))) { + $val = $val.Substring(1, $val.Length - 2) + } + + $pyvenvConfig[$keyval[0]] = $val + Write-Verbose "Adding Key: '$($keyval[0])'='$val'" + } + } + } + return $pyvenvConfig +} + + +<# Begin Activate script --------------------------------------------------- #> + +# Determine the containing directory of this script +$VenvExecPath = Split-Path -Parent $MyInvocation.MyCommand.Definition +$VenvExecDir = Get-Item -Path $VenvExecPath + +Write-Verbose "Activation script is located in path: '$VenvExecPath'" +Write-Verbose "VenvExecDir Fullname: '$($VenvExecDir.FullName)" +Write-Verbose "VenvExecDir Name: '$($VenvExecDir.Name)" + +# Set values required in priority: CmdLine, ConfigFile, Default +# First, get the location of the virtual environment, it might not be +# VenvExecDir if specified on the command line. +if ($VenvDir) { + Write-Verbose "VenvDir given as parameter, using '$VenvDir' to determine values" +} +else { + Write-Verbose "VenvDir not given as a parameter, using parent directory name as VenvDir." + $VenvDir = $VenvExecDir.Parent.FullName.TrimEnd("\\/") + Write-Verbose "VenvDir=$VenvDir" +} + +# Next, read the `pyvenv.cfg` file to determine any required value such +# as `prompt`. +$pyvenvCfg = Get-PyVenvConfig -ConfigDir $VenvDir + +# Next, set the prompt from the command line, or the config file, or +# just use the name of the virtual environment folder. +if ($Prompt) { + Write-Verbose "Prompt specified as argument, using '$Prompt'" +} +else { + Write-Verbose "Prompt not specified as argument to script, checking pyvenv.cfg value" + if ($pyvenvCfg -and $pyvenvCfg['prompt']) { + Write-Verbose " Setting based on value in pyvenv.cfg='$($pyvenvCfg['prompt'])'" + $Prompt = $pyvenvCfg['prompt']; + } + else { + Write-Verbose " Setting prompt based on parent's directory's name. (Is the directory name passed to venv module when creating the virtual environment)" + Write-Verbose " Got leaf-name of $VenvDir='$(Split-Path -Path $venvDir -Leaf)'" + $Prompt = Split-Path -Path $venvDir -Leaf + } +} + +Write-Verbose "Prompt = '$Prompt'" +Write-Verbose "VenvDir='$VenvDir'" + +# Deactivate any currently active virtual environment, but leave the +# deactivate function in place. +deactivate -nondestructive + +# Now set the environment variable VIRTUAL_ENV, used by many tools to determine +# that there is an activated venv. +$env:VIRTUAL_ENV = $VenvDir + +if (-not $Env:VIRTUAL_ENV_DISABLE_PROMPT) { + + Write-Verbose "Setting prompt to '$Prompt'" + + # Set the prompt to include the env name + # Make sure _OLD_VIRTUAL_PROMPT is global + function global:_OLD_VIRTUAL_PROMPT { "" } + Copy-Item -Path function:prompt -Destination function:_OLD_VIRTUAL_PROMPT + New-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Description "Python virtual environment prompt prefix" -Scope Global -Option ReadOnly -Visibility Public -Value $Prompt + + function global:prompt { + Write-Host -NoNewline -ForegroundColor Green "($_PYTHON_VENV_PROMPT_PREFIX) " + _OLD_VIRTUAL_PROMPT + } + $env:VIRTUAL_ENV_PROMPT = $Prompt +} + +# Clear PYTHONHOME +if (Test-Path -Path Env:PYTHONHOME) { + Copy-Item -Path Env:PYTHONHOME -Destination Env:_OLD_VIRTUAL_PYTHONHOME + Remove-Item -Path Env:PYTHONHOME +} + +# Add the venv to the PATH +Copy-Item -Path Env:PATH -Destination Env:_OLD_VIRTUAL_PATH +$Env:PATH = "$VenvExecDir$([System.IO.Path]::PathSeparator)$Env:PATH" diff --git a/pyhealth-env/bin/activate b/pyhealth-env/bin/activate new file mode 100644 index 000000000..394929f83 --- /dev/null +++ b/pyhealth-env/bin/activate @@ -0,0 +1,69 @@ +# This file must be used with "source bin/activate" *from bash* +# you cannot run it directly + +deactivate () { + # reset old environment variables + if [ -n "${_OLD_VIRTUAL_PATH:-}" ] ; then + PATH="${_OLD_VIRTUAL_PATH:-}" + export PATH + unset _OLD_VIRTUAL_PATH + fi + if [ -n "${_OLD_VIRTUAL_PYTHONHOME:-}" ] ; then + PYTHONHOME="${_OLD_VIRTUAL_PYTHONHOME:-}" + export PYTHONHOME + unset _OLD_VIRTUAL_PYTHONHOME + fi + + # This should detect bash and zsh, which have a hash command that must + # be called to get it to forget past commands. Without forgetting + # past commands the $PATH changes we made may not be respected + if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then + hash -r 2> /dev/null + fi + + if [ -n "${_OLD_VIRTUAL_PS1:-}" ] ; then + PS1="${_OLD_VIRTUAL_PS1:-}" + export PS1 + unset _OLD_VIRTUAL_PS1 + fi + + unset VIRTUAL_ENV + unset VIRTUAL_ENV_PROMPT + if [ ! "${1:-}" = "nondestructive" ] ; then + # Self destruct! + unset -f deactivate + fi +} + +# unset irrelevant variables +deactivate nondestructive + +VIRTUAL_ENV=/home/ubuntu/PyHealth/pyhealth-env +export VIRTUAL_ENV + +_OLD_VIRTUAL_PATH="$PATH" +PATH="$VIRTUAL_ENV/"bin":$PATH" +export PATH + +# unset PYTHONHOME if set +# this will fail if PYTHONHOME is set to the empty string (which is bad anyway) +# could use `if (set -u; : $PYTHONHOME) ;` in bash +if [ -n "${PYTHONHOME:-}" ] ; then + _OLD_VIRTUAL_PYTHONHOME="${PYTHONHOME:-}" + unset PYTHONHOME +fi + +if [ -z "${VIRTUAL_ENV_DISABLE_PROMPT:-}" ] ; then + _OLD_VIRTUAL_PS1="${PS1:-}" + PS1='(pyhealth-env) '"${PS1:-}" + export PS1 + VIRTUAL_ENV_PROMPT='(pyhealth-env) ' + export VIRTUAL_ENV_PROMPT +fi + +# This should detect bash and zsh, which have a hash command that must +# be called to get it to forget past commands. Without forgetting +# past commands the $PATH changes we made may not be respected +if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then + hash -r 2> /dev/null +fi diff --git a/pyhealth-env/bin/activate.csh b/pyhealth-env/bin/activate.csh new file mode 100644 index 000000000..f16bd6b4b --- /dev/null +++ b/pyhealth-env/bin/activate.csh @@ -0,0 +1,26 @@ +# This file must be used with "source bin/activate.csh" *from csh*. +# You cannot run it directly. +# Created by Davide Di Blasi . +# Ported to Python 3.3 venv by Andrew Svetlov + +alias deactivate 'test $?_OLD_VIRTUAL_PATH != 0 && setenv PATH "$_OLD_VIRTUAL_PATH" && unset _OLD_VIRTUAL_PATH; rehash; test $?_OLD_VIRTUAL_PROMPT != 0 && set prompt="$_OLD_VIRTUAL_PROMPT" && unset _OLD_VIRTUAL_PROMPT; unsetenv VIRTUAL_ENV; unsetenv VIRTUAL_ENV_PROMPT; test "\!:*" != "nondestructive" && unalias deactivate' + +# Unset irrelevant variables. +deactivate nondestructive + +setenv VIRTUAL_ENV /home/ubuntu/PyHealth/pyhealth-env + +set _OLD_VIRTUAL_PATH="$PATH" +setenv PATH "$VIRTUAL_ENV/"bin":$PATH" + + +set _OLD_VIRTUAL_PROMPT="$prompt" + +if (! "$?VIRTUAL_ENV_DISABLE_PROMPT") then + set prompt = '(pyhealth-env) '"$prompt" + setenv VIRTUAL_ENV_PROMPT '(pyhealth-env) ' +endif + +alias pydoc python -m pydoc + +rehash diff --git a/pyhealth-env/bin/activate.fish b/pyhealth-env/bin/activate.fish new file mode 100644 index 000000000..176f84d63 --- /dev/null +++ b/pyhealth-env/bin/activate.fish @@ -0,0 +1,69 @@ +# This file must be used with "source /bin/activate.fish" *from fish* +# (https://fishshell.com/); you cannot run it directly. + +function deactivate -d "Exit virtual environment and return to normal shell environment" + # reset old environment variables + if test -n "$_OLD_VIRTUAL_PATH" + set -gx PATH $_OLD_VIRTUAL_PATH + set -e _OLD_VIRTUAL_PATH + end + if test -n "$_OLD_VIRTUAL_PYTHONHOME" + set -gx PYTHONHOME $_OLD_VIRTUAL_PYTHONHOME + set -e _OLD_VIRTUAL_PYTHONHOME + end + + if test -n "$_OLD_FISH_PROMPT_OVERRIDE" + set -e _OLD_FISH_PROMPT_OVERRIDE + # prevents error when using nested fish instances (Issue #93858) + if functions -q _old_fish_prompt + functions -e fish_prompt + functions -c _old_fish_prompt fish_prompt + functions -e _old_fish_prompt + end + end + + set -e VIRTUAL_ENV + set -e VIRTUAL_ENV_PROMPT + if test "$argv[1]" != "nondestructive" + # Self-destruct! + functions -e deactivate + end +end + +# Unset irrelevant variables. +deactivate nondestructive + +set -gx VIRTUAL_ENV /home/ubuntu/PyHealth/pyhealth-env + +set -gx _OLD_VIRTUAL_PATH $PATH +set -gx PATH "$VIRTUAL_ENV/"bin $PATH + +# Unset PYTHONHOME if set. +if set -q PYTHONHOME + set -gx _OLD_VIRTUAL_PYTHONHOME $PYTHONHOME + set -e PYTHONHOME +end + +if test -z "$VIRTUAL_ENV_DISABLE_PROMPT" + # fish uses a function instead of an env var to generate the prompt. + + # Save the current fish_prompt function as the function _old_fish_prompt. + functions -c fish_prompt _old_fish_prompt + + # With the original prompt function renamed, we can override with our own. + function fish_prompt + # Save the return status of the last command. + set -l old_status $status + + # Output the venv prompt; color taken from the blue of the Python logo. + printf "%s%s%s" (set_color 4B8BBE) '(pyhealth-env) ' (set_color normal) + + # Restore the return status of the previous command. + echo "exit $old_status" | . + # Output the original/"old" prompt. + _old_fish_prompt + end + + set -gx _OLD_FISH_PROMPT_OVERRIDE "$VIRTUAL_ENV" + set -gx VIRTUAL_ENV_PROMPT '(pyhealth-env) ' +end diff --git a/pyhealth-env/bin/f2py b/pyhealth-env/bin/f2py new file mode 100755 index 000000000..ff3893aad --- /dev/null +++ b/pyhealth-env/bin/f2py @@ -0,0 +1,8 @@ +#!/home/ubuntu/PyHealth/pyhealth-env/bin/python3 +# -*- coding: utf-8 -*- +import re +import sys +from numpy.f2py.f2py2e import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/pyhealth-env/bin/numpy-config b/pyhealth-env/bin/numpy-config new file mode 100755 index 000000000..9ce052fb2 --- /dev/null +++ b/pyhealth-env/bin/numpy-config @@ -0,0 +1,8 @@ +#!/home/ubuntu/PyHealth/pyhealth-env/bin/python3 +# -*- coding: utf-8 -*- +import re +import sys +from numpy._configtool import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/pyhealth-env/bin/pip b/pyhealth-env/bin/pip new file mode 100755 index 000000000..970fa9b14 --- /dev/null +++ b/pyhealth-env/bin/pip @@ -0,0 +1,8 @@ +#!/home/ubuntu/PyHealth/pyhealth-env/bin/python3 +# -*- coding: utf-8 -*- +import re +import sys +from pip._internal.cli.main import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/pyhealth-env/bin/pip3 b/pyhealth-env/bin/pip3 new file mode 100755 index 000000000..970fa9b14 --- /dev/null +++ b/pyhealth-env/bin/pip3 @@ -0,0 +1,8 @@ +#!/home/ubuntu/PyHealth/pyhealth-env/bin/python3 +# -*- coding: utf-8 -*- +import re +import sys +from pip._internal.cli.main import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/pyhealth-env/bin/pip3.10 b/pyhealth-env/bin/pip3.10 new file mode 100755 index 000000000..970fa9b14 --- /dev/null +++ b/pyhealth-env/bin/pip3.10 @@ -0,0 +1,8 @@ +#!/home/ubuntu/PyHealth/pyhealth-env/bin/python3 +# -*- coding: utf-8 -*- +import re +import sys +from pip._internal.cli.main import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/pyhealth-env/bin/python b/pyhealth-env/bin/python new file mode 120000 index 000000000..b8a0adbbb --- /dev/null +++ b/pyhealth-env/bin/python @@ -0,0 +1 @@ +python3 \ No newline at end of file diff --git a/pyhealth-env/bin/python3 b/pyhealth-env/bin/python3 new file mode 120000 index 000000000..ae65fdaa1 --- /dev/null +++ b/pyhealth-env/bin/python3 @@ -0,0 +1 @@ +/usr/bin/python3 \ No newline at end of file diff --git a/pyhealth-env/bin/python3.10 b/pyhealth-env/bin/python3.10 new file mode 120000 index 000000000..b8a0adbbb --- /dev/null +++ b/pyhealth-env/bin/python3.10 @@ -0,0 +1 @@ +python3 \ No newline at end of file diff --git a/pyhealth-env/lib64 b/pyhealth-env/lib64 new file mode 120000 index 000000000..7951405f8 --- /dev/null +++ b/pyhealth-env/lib64 @@ -0,0 +1 @@ +lib \ No newline at end of file diff --git a/pyhealth-env/pyvenv.cfg b/pyhealth-env/pyvenv.cfg new file mode 100644 index 000000000..0537ffc00 --- /dev/null +++ b/pyhealth-env/pyvenv.cfg @@ -0,0 +1,3 @@ +home = /usr/bin +include-system-site-packages = false +version = 3.10.12 diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 50b1b3887..282716b20 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -68,6 +68,7 @@ def __init__(self, *args, **kwargs): from .bmd_hs import BMDHSDataset from .support2 import Support2Dataset from .tcga_prad import TCGAPRADDataset +from .tcga_crck import TCGACRCkDataset from .splitter import ( sample_balanced, split_by_patient, diff --git a/pyhealth/datasets/configs/tcga_crck.yaml b/pyhealth/datasets/configs/tcga_crck.yaml new file mode 100644 index 000000000..a55324296 --- /dev/null +++ b/pyhealth/datasets/configs/tcga_crck.yaml @@ -0,0 +1,12 @@ +version: "1.0" +tables: + tcga_crck: + file_path: "tcga_crck_metadata-pyhealth.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "slide_id" + - "tile_path" + - "tile_index" + - "data_split" + - "label" diff --git a/pyhealth/datasets/tcga_crck.py b/pyhealth/datasets/tcga_crck.py new file mode 100644 index 000000000..30459979f --- /dev/null +++ b/pyhealth/datasets/tcga_crck.py @@ -0,0 +1,198 @@ +"""TCGA-CRCk dataset loader for PyHealth.""" + +import csv +import logging +import os +import re +from pathlib import Path +from typing import List, Optional + +import pandas as pd + +from .base_dataset import BaseDataset + +logger = logging.getLogger(__name__) + + +class TCGACRCkDataset(BaseDataset): + """Dataset class for the TCGA-CRCk dataset. + + Attributes: + root (str): Root directory of the raw data. + dataset_name (str): Name of the dataset. + config_path (str): Path to the configuration file. + cache_dir (str): Path to the cache directory. + num_workers (int): Number of worker processes. + dev (bool): Whether the dataset is in development mode. + """ + + classes: List[str] = ["MSIMUT", "MSS"] + + def __init__( + self, + root: str, + dataset_name: Optional[str] = "tcga_crck", + config_path: Optional[str] = None, + cache_dir: Optional[str] = None, + num_workers: Optional[int] = None, + dev: bool = False, + ) -> None: + """Initializes the TCGA-CRCk dataset. + + Args: + root: Root directory containing the raw TCGA-CRCk image folders. + dataset_name: Dataset name used by PyHealth cache management. + config_path: Optional YAML config path. If omitted, the bundled + tcga_crck.yaml config is used. + cache_dir: Optional cache directory for processed artifacts. + num_workers: Number of workers used by the PyHealth processing + pipeline. + dev: Whether to enable development-mode shortcuts. + """ + self.root = root + + if num_workers is None: + num_workers = 1 + + self._verify_root() + + if config_path is None: + logger.info("No config path provided, using default config") + config_path = Path(__file__).parent / "configs" / "tcga_crck.yaml" + + self.metadata_path = os.path.join( + self.root, + "tcga_crck_metadata-pyhealth.csv", + ) + + super().__init__( + root=root, + tables=["tcga_crck"], + dataset_name=dataset_name, + config_path=config_path, + cache_dir=cache_dir, + num_workers=num_workers, + dev=dev, + ) + + if not os.path.exists(self.metadata_path): + logger.info("Preparing TCGA_CRCk metadata...") + self.prepare_metadata(self.root) + + self._verify_metadata() + + @staticmethod + def prepare_metadata(root: str) -> None: + """Creates the PyHealth metadata CSV from the raw image folders. + + Args: + root: Root directory containing CRC_DX_TRAIN and + CRC_DX_TEST folders with class-specific PNG tiles. + """ + wsi_regex = re.compile(r"^blk-.+-(TCGA-..-....-...-..-...)\.png$") + csv_path = Path(os.path.join(root, "tcga_crck_metadata-pyhealth.csv")) + csv_data = [] + for split in ["TRAIN", "TEST"]: + for label_name in TCGACRCkDataset.classes: + raw_dir = Path(os.path.join(root, f"CRC_DX_{split}", label_name)) + if not raw_dir.is_dir(): + logger.warning( + "Unexpected format for raw TCGA-CRCk dataset." + f"Expected directory at {raw_dir}" + ) + continue + wsi_dict = {} + for tile_path in raw_dir.glob("*.png"): + match = wsi_regex.search(tile_path.name) + if match: + slide_id = match.group(1) + + tile_index = wsi_dict.get(slide_id, 0) + wsi_dict[slide_id] = tile_index + 1 + + csv_data.append( + { + "patient_id": slide_id[:12], + "slide_id": slide_id, + "tile_path": str(tile_path.resolve()), + "tile_index": tile_index, + "data_split": split.lower(), + "label": ( + 1 + if label_name == "MSIMUT" + else 0 + ), + } + ) + with open(csv_path, "w", newline="", encoding="utf-8") as f: + fields = [ + "patient_id", + "slide_id", + "tile_path", + "tile_index", + "data_split", + "label", + ] + writer = csv.DictWriter(f, fieldnames=fields) + writer.writeheader() + writer.writerows(csv_data) + + def _verify_root(self) -> None: + """Verifies that the dataset root exists.""" + if not os.path.exists(self.root): + raise FileNotFoundError(f"Dataset root does not exist: {self.root}") + + def _verify_metadata(self) -> None: + """Verifies that the normalized metadata file is present and valid.""" + if not os.path.isfile(self.metadata_path): + raise FileNotFoundError( + f"Dataset metadata file does not exist: {self.metadata_path}" + ) + + df = pd.read_csv(self.metadata_path) + required_cols = { + "patient_id", + "slide_id", + "tile_path", + "tile_index", + "data_split", + "label", + } + missing = required_cols.difference(df.columns) + if missing: + raise ValueError( + "Metadata file is missing required columns: " + f"{sorted(missing)}" + ) + if df.empty: + logger.warning("Metadata file is empty.") + + nonexistent_paths = [ + p for p in df["tile_path"].tolist() if not os.path.isfile(str(p)) + ] + if nonexistent_paths: + raise FileNotFoundError( + "Some metadata paths do not exist. Example: " + f"{nonexistent_paths[0]}" + ) + + invalid_splits = set(df["data_split"].unique()).difference( + {"train", "test"} + ) + if invalid_splits: + raise ValueError( + f"data_split must be train/test. Found: {invalid_splits}" + ) + + invalid_labels = set(df["label"].unique()).difference({0, 1}) + if invalid_labels: + raise ValueError( + f"label must be binary 0/1. Found: {invalid_labels}" + ) + + @property + def default_task(self): + """Returns the default task for this dataset.""" + from pyhealth.tasks import TCGACRCkMSIClassification + + return TCGACRCkMSIClassification() diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 4c168d3e3..e8c4394c3 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -26,6 +26,7 @@ from .sparcnet import DenseBlock, DenseLayer, SparcNet, TransitionLayer from .stagenet import StageNet, StageNetLayer from .stagenet_mha import StageAttentionNet, StageNetAttentionLayer +from .tcga_crck_simclr_mil import TissueAwareSimCLR from .tcn import TCN, TCNLayer from .tfm_tokenizer import ( TFMTokenizer, diff --git a/pyhealth/models/tcga_crck_simclr_mil.py b/pyhealth/models/tcga_crck_simclr_mil.py new file mode 100644 index 000000000..168a2d92b --- /dev/null +++ b/pyhealth/models/tcga_crck_simclr_mil.py @@ -0,0 +1,440 @@ +from __future__ import annotations + +import os +import sys +import types +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models + +from pyhealth.datasets import SampleDataset +from .base_model import BaseModel + +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + + +class TissueAwareSimCLR(BaseModel): + """Multiple instance learning classifier with a SimCLR-initialized ResNet-18 encoder. + + This model is designed for TCGA-CRCk slide-level prediction from bags of image + tiles. Each bag is encoded tile by tile using a ResNet-18 backbone, projected + into a hidden space, pooled into a single bag representation, and passed to a + classification head. + + The model supports either attention-based pooling or mean pooling over tile + embeddings and can optionally freeze the encoder during downstream training. + """ + + def __init__( + self, + dataset: SampleDataset, + checkpoint_path: Optional[str] = None, + hidden_dim: int = 128, + dropout: float = 0.25, + freeze_encoder: bool = False, + pooling: str = "attention", + tile_chunk_size: int = 1024, + use_bf16: bool = False, + ) -> None: + """Initializes the tissue-aware SimCLR classifier. + + Args: + dataset: PyHealth sample dataset used to infer feature and label metadata. + checkpoint_path: Optional path to a pretrained encoder checkpoint. + hidden_dim: Output dimension of the projection layer before pooling. + dropout: Dropout probability applied before the final classifier. + freeze_encoder: Whether to freeze encoder weights during training. + pooling: Bag pooling strategy. Must be either "attention" or "mean". + tile_chunk_size: Number of tiles to encode at once to control memory usage. + use_bf16: Whether to use bfloat16 autocast on CUDA during encoding and + classification. + + Raises: + ValueError: If the dataset does not expose exactly one feature key. + ValueError: If the dataset does not expose exactly one label key. + ValueError: If `pooling` is not one of the supported strategies. + ValueError: If `tile_chunk_size` is less than 1. + """ + super().__init__(dataset=dataset) + + if len(self.feature_keys) != 1: + raise ValueError( + f"{self.__class__.__name__} expects exactly one feature key, " + f"but got {self.feature_keys}." + ) + if len(self.label_keys) != 1: + raise ValueError( + f"{self.__class__.__name__} expects exactly one label key, " + f"but got {self.label_keys}." + ) + if pooling not in {"attention", "mean"}: + raise ValueError("pooling must be either 'attention' or 'mean'.") + if tile_chunk_size < 1: + raise ValueError("tile_chunk_size must be at least 1.") + + self.feature_key = self.feature_keys[0] + self.label_key = self.label_keys[0] + self.hidden_dim = hidden_dim + self.pooling = pooling + self.freeze_encoder = freeze_encoder + self.tile_chunk_size = tile_chunk_size + self.use_bf16 = use_bf16 + + backbone = models.resnet18(weights=None) + self.encoder_dim = backbone.fc.in_features + backbone.fc = nn.Identity() + self.encoder = backbone + + if checkpoint_path is not None: + self._load_encoder_checkpoint(checkpoint_path) + + if freeze_encoder: + for param in self.encoder.parameters(): + param.requires_grad = False + + self.proj = nn.Linear(self.encoder_dim, hidden_dim) + self.dropout = nn.Dropout(dropout) + + self.attn_v = nn.Linear(hidden_dim, hidden_dim) + self.attn_u = nn.Linear(hidden_dim, hidden_dim) + self.attn_w = nn.Linear(hidden_dim, 1) + + self.classifier = nn.Linear(hidden_dim, self.get_output_size()) + + train_labels = [ + int(dataset[i]["label"]) + for i in range(len(dataset)) + if str(dataset[i].get("data_split", "")).strip().lower() in {"train", "training", "tr"} + ] + + num_pos = sum(train_labels) + num_neg = len(train_labels) - num_pos + pos_weight = num_neg / max(num_pos, 1) + + self.register_buffer( + "pos_weight_tensor", + torch.tensor([pos_weight], dtype=torch.float32), + ) + + def _load_encoder_checkpoint(self, checkpoint_path: str) -> None: + """Loads encoder weights from a plain PyTorch checkpoint. + + The loader accepts multiple checkpoint layouts and removes common wrapper + prefixes before applying the weights to the ResNet encoder. + + Args: + checkpoint_path: Path to the checkpoint file on disk. + + Raises: + FileNotFoundError: If the checkpoint path does not exist. + ValueError: If the checkpoint does not contain a usable state dict. + """ + if not os.path.isfile(checkpoint_path): + raise FileNotFoundError(f"Checkpoint does not exist: {checkpoint_path}") + + checkpoint = torch.load( + checkpoint_path, + map_location="cpu", + weights_only=False, + ) + + state_dict = checkpoint + if isinstance(checkpoint, dict): + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + elif "model" in checkpoint: + state_dict = checkpoint["model"] + elif "encoder" in checkpoint: + state_dict = checkpoint["encoder"] + + if not isinstance(state_dict, dict): + raise ValueError("Checkpoint did not contain a usable state dict.") + + cleaned_state_dict = {} + removable_prefixes = [ + "model.resnet.", + "resnet.", + "model.", + "module.", + "encoder.", + "backbone.", + "online_network.encoder.", + "encoder_q.", + ] + + for key, value in state_dict.items(): + if not torch.is_tensor(value): + continue + + new_key = key + changed = True + while changed: + changed = False + for prefix in removable_prefixes: + if new_key.startswith(prefix): + new_key = new_key[len(prefix):] + changed = True + cleaned_state_dict[new_key] = value + + missing, unexpected = self.encoder.load_state_dict( + cleaned_state_dict, + strict=False, + ) + + if missing: + print(f"[Warning] Missing keys: {missing}") + if unexpected: + print(f"[Warning] Unexpected keys: {unexpected}") + + def _extract_images_and_mask( + self, + feature: Union[torch.Tensor, Tuple[torch.Tensor, ...], list], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Converts a batch of bags into a padded image tensor and validity mask. + + This method supports either an already-batched image tensor or the + list/tuple-based bag structure commonly returned by PyHealth processors. + + Args: + feature: Input bag payload. Expected either as a tensor with shape + [B, N, C, H, W] or [N, C, H, W], or as a list/tuple of per-sample + bag tensors. + + Returns: + A tuple containing: + - images: Padded image tensor of shape [B, N, C, H, W]. + - mask: Boolean tensor of shape [B, N] indicating valid tiles. + + Raises: + ValueError: If the feature payload type is unsupported. + ValueError: If image tensors do not have the expected shape. + ValueError: If a bag does not contain a tensor. + ValueError: If the batch of bags is empty. + """ + # Case 1: already a proper batched tensor + if torch.is_tensor(feature): + images = feature + if images.dim() == 4: + images = images.unsqueeze(0) + if images.dim() != 5: + raise ValueError( + f"Expected image tensor with shape [B, N, C, H, W], got {tuple(images.shape)}" + ) + mask = images.abs().sum(dim=(2, 3, 4)) > 0 + return images, mask + + # Case 2: PyHealth often gives list/tuple per sample + if not isinstance(feature, (list, tuple)): + raise ValueError("Unsupported tile_bag payload type.") + + bag_tensors = [] + for item in feature: + bag = item + while isinstance(bag, (tuple, list)) and len(bag) > 0: + bag = bag[0] + + if not torch.is_tensor(bag): + raise ValueError("Expected each bag to contain a tensor.") + + # Sometimes processor may add a leading singleton batch dim + if bag.dim() == 5 and bag.size(0) == 1: + bag = bag.squeeze(0) + + if bag.dim() != 4: + raise ValueError( + f"Expected each bag tensor to have shape [N, C, H, W], got {tuple(bag.shape)}" + ) + + bag_tensors.append(bag) + + if not bag_tensors: + raise ValueError("Received empty batch of bags.") + + max_tiles = max(bag.shape[0] for bag in bag_tensors) + c, h, w = bag_tensors[0].shape[1:] + + images = torch.zeros( + (len(bag_tensors), max_tiles, c, h, w), + dtype=bag_tensors[0].dtype, + ) + mask = torch.zeros((len(bag_tensors), max_tiles), dtype=torch.bool) + + for i, bag in enumerate(bag_tensors): + n = bag.shape[0] + images[i, :n] = bag + mask[i, :n] = True + + return images, mask + + def _encode_flat_images(self, flat_images: torch.Tensor) -> torch.Tensor: + """Encodes flattened tile images into normalized projected features. + + Images are processed in chunks to reduce memory usage. If enabled, bfloat16 + autocast is used on CUDA. When the encoder is frozen, forward passes through + the encoder are wrapped in `torch.no_grad()`. + + Args: + flat_images: Tensor of tile images with shape [num_tiles, C, H, W]. + + Returns: + Tensor of L2-normalized tile features with shape + [num_tiles, hidden_dim]. + """ + outputs = [] + use_amp = self.use_bf16 and self.device.type == "cuda" + + if self.freeze_encoder: + self.encoder.eval() + + for start in range(0, flat_images.size(0), self.tile_chunk_size): + chunk = flat_images[start : start + self.tile_chunk_size] + chunk = chunk.to(self.device, non_blocking=True).float() + chunk = chunk.contiguous(memory_format=torch.channels_last) + + if use_amp: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + if self.freeze_encoder: + with torch.no_grad(): + enc = self.encoder(chunk) + else: + enc = self.encoder(chunk) + enc = torch.flatten(enc, start_dim=1) + feat = self.proj(enc) + feat = F.normalize(feat, dim=-1) + else: + if self.freeze_encoder: + with torch.no_grad(): + enc = self.encoder(chunk) + enc = torch.flatten(enc, start_dim=1) + feat = self.proj(enc) + feat = F.normalize(feat, dim=-1) + else: + enc = self.encoder(chunk) + enc = torch.flatten(enc, start_dim=1) + feat = self.proj(enc) + feat = F.normalize(feat, dim=-1) + + outputs.append(feat.float()) + + return torch.cat(outputs, dim=0) + + def _pool_bag( + self, + tile_features: torch.Tensor, + mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Pools tile-level features into a single bag representation. + + Depending on the configured pooling strategy, this method applies either + simple masked mean pooling or gated attention pooling. + + Args: + tile_features: Tile embeddings of shape [B, N, hidden_dim]. + mask: Boolean validity mask of shape [B, N]. + + Returns: + A tuple containing: + - bag_embeddings: Tensor of pooled bag representations with shape + [B, hidden_dim]. + - weights: Tensor of pooling weights with shape [B, N]. + """ + if self.pooling == "mean": + weights = mask.float() + denom = weights.sum(dim=1, keepdim=True).clamp_min(1e-6) + weights = weights / denom + bag_embeddings = torch.sum(weights.unsqueeze(-1) * tile_features, dim=1) + return bag_embeddings, weights + + attn_v = torch.tanh(self.attn_v(tile_features)) + attn_u = torch.sigmoid(self.attn_u(tile_features)) + attn_logits = self.attn_w(attn_v * attn_u).squeeze(-1) + + attn_logits = attn_logits.masked_fill(~mask, float("-inf")) + weights = torch.softmax(attn_logits, dim=1) + weights = torch.nan_to_num(weights, nan=0.0) + + bag_embeddings = torch.sum(weights.unsqueeze(-1) * tile_features, dim=1) + return bag_embeddings, weights + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Runs the forward pass for slide-level prediction. + + Expected inputs are passed through `kwargs`, with image bags stored under + `self.feature_key` and labels optionally stored under `self.label_key`. + + Args: + **kwargs: Model inputs containing: + - `self.feature_key`: Bagged image tiles. + - `self.label_key`: Optional labels for supervised loss computation. + + Returns: + A dictionary containing: + - "loss": Training loss if labels are provided, otherwise `None`. + - "y_prob": Predicted probabilities. + - "y_true": Ground-truth labels if provided, otherwise `None`. + - "logit": Raw classifier logits in float32. + - "attention_weights": Tile-level pooling weights. + + Raises: + ValueError: If the extracted image tensor does not have 3 channels. + """ + images, mask = self._extract_images_and_mask(kwargs[self.feature_key]) + + batch_size, num_tiles, channels, height, width = images.shape + if channels != 3: + raise ValueError("ResNet expects 3-channel images.") + + flat_images = images.reshape(batch_size * num_tiles, channels, height, width) + tile_features = self._encode_flat_images(flat_images) + tile_features = tile_features.reshape(batch_size, num_tiles, -1) + + mask = mask.to(self.device, non_blocking=True) + bag_embeddings, weights = self._pool_bag(tile_features, mask) + + use_amp = self.use_bf16 and self.device.type == "cuda" + if use_amp: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + bag_embeddings = self.dropout(bag_embeddings) + logit = self.classifier(bag_embeddings) + else: + bag_embeddings = self.dropout(bag_embeddings) + logit = self.classifier(bag_embeddings) + + logit_fp32 = logit.float() + y_prob = self.prepare_y_prob(logit_fp32) + + y_true = None + loss = None + + if self.label_key in kwargs: + y_true = kwargs[self.label_key].to(self.device) + mode = self.mode + + if mode == "multiclass": + y_true = y_true.squeeze(-1).long() + loss = nn.CrossEntropyLoss()(logit_fp32, y_true) + + elif mode in {"binary", "multilabel"}: + y_true = y_true.float() + if y_true.dim() == 1: + y_true = y_true.unsqueeze(-1) + loss = F.binary_cross_entropy_with_logits( + logit_fp32, + y_true, + pos_weight=self.pos_weight_tensor, + ) + + elif mode == "regression": + y_true = y_true.float() + loss = nn.MSELoss()(logit_fp32, y_true) + + return { + "loss": loss, + "y_prob": y_prob, + "y_true": y_true, + "logit": logit_fp32, + "attention_weights": weights, + } \ No newline at end of file diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index a32618f9c..4bdd20abc 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -67,3 +67,4 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .tcga_crck_msi_classification import TCGACRCkMSIClassification diff --git a/pyhealth/tasks/tcga_crck_msi_classification.py b/pyhealth/tasks/tcga_crck_msi_classification.py new file mode 100644 index 000000000..096871e62 --- /dev/null +++ b/pyhealth/tasks/tcga_crck_msi_classification.py @@ -0,0 +1,116 @@ +"""Patient-level MSI classification task for TCGA-CRCk.""" + +from __future__ import annotations + +from typing import Dict, List, Optional + +from pyhealth.data import Event, Patient +from pyhealth.tasks import BaseTask + + +class TCGACRCkMSIClassification(BaseTask): + """Creates one bag-of-tiles sample per patient for MSI classification. + + This task groups all tile events for a single patient into one bag and produces + a single binary MSI classification sample. It also infers a patient-level + train/test split from the underlying tile-level event metadata. + """ + + # change the task_name so cache is rebuilt + task_name: str = "TCGACRCkMSIClassificationPatientLevel" + + input_schema: Dict[str, object] = { + "tile_bag": ( + "time_image", + {"image_size": 224, "mode": "RGB", "max_images": 1000}, + ) + } + output_schema: Dict[str, str] = {"label": "binary"} + + def __init__(self, max_tiles: Optional[int] = 1000) -> None: + """Initializes the TCGA-CRCk MSI classification task. + + Args: + max_tiles: Maximum number of tile images to include in each patient bag. + If `None`, all available tiles are used. + """ + self.max_tiles = max_tiles + processor_kwargs = {"image_size": 224, "mode": "RGB"} + if max_tiles is not None: + processor_kwargs["max_images"] = max_tiles + self.input_schema = {"tile_bag": ("time_image", processor_kwargs)} + + @staticmethod + def _normalize_split(value: object) -> str: + """Normalizes split labels into a canonical string form. + + Args: + value: Raw split value from event metadata. + + Returns: + Normalized split name such as "train" or "test". Unrecognized values + are returned as lowercase stripped strings. + """ + text = str(value).strip().lower() + if text in {"train", "training", "tr"}: + return "train" + if text in {"test", "testing", "te"}: + return "test" + return text + + def __call__(self, patient: Patient) -> List[Dict]: + """Builds one patient-level MSI classification sample. + + The method collects all TCGA-CRCk events for a patient, verifies that the + labels are consistent across tiles, infers a patient-level data split, and + returns a single bag-of-tiles sample. + + Args: + patient: Patient object containing TCGA-CRCk tile events. + + Returns: + A list containing one sample dictionary for the patient, or an empty + list if the patient has no matching events. + + Raises: + ValueError: If tile labels for the patient are inconsistent. + ValueError: If a valid train/test split cannot be inferred. + """ + events: List[Event] = patient.get_events(event_type="tcga_crck") + if not events: + return [] + + sorted_events = sorted( + events, + key=lambda e: (str(e["slide_id"]), int(e["tile_index"])), + ) + + labels = {int(event["label"]) for event in sorted_events} + if len(labels) != 1: + raise ValueError( + f"Inconsistent labels for patient {patient.patient_id}: {sorted(labels)}" + ) + label = next(iter(labels)) + + splits = {self._normalize_split(event["data_split"]) for event in sorted_events} + if "test" in splits: + data_split = "test" + elif "train" in splits: + data_split = "train" + else: + raise ValueError( + f"Could not infer split for patient {patient.patient_id}: {sorted(splits)}" + ) + + tile_paths = [str(event["tile_path"]) for event in sorted_events] + tile_times = [float(i) for i in range(len(sorted_events))] + + return [ + { + "patient_id": str(patient.patient_id), + "visit_id": str(patient.patient_id), + "tile_bag": (tile_paths, tile_times), + "label": label, + "data_split": data_split, + } + ] \ No newline at end of file diff --git a/test-resources/tcga_crck/tcga_crck_metadata-pyhealth.csv b/test-resources/tcga_crck/tcga_crck_metadata-pyhealth.csv new file mode 100644 index 000000000..6127db881 --- /dev/null +++ b/test-resources/tcga_crck/tcga_crck_metadata-pyhealth.csv @@ -0,0 +1,6 @@ +patient_id,slide_id,tile_path,tile_index,data_split,label +TCGA-CC-0003,TCGA-CC-0003-01Z-00-DX2,/home/ubuntu/PyHealth/test-resources/tcga_crck/CRC_DX_TRAIN/MSIMUT/blk-WWMIEPLQMACE-TCGA-CC-0003-01Z-00-DX2.png,0,train,1 +TCGA-CC-0003,TCGA-CC-0003-01Z-00-DX1,/home/ubuntu/PyHealth/test-resources/tcga_crck/CRC_DX_TRAIN/MSIMUT/blk-WWKFIIMTFPSG-TCGA-CC-0003-01Z-00-DX1.png,0,train,1 +TCGA-BB-0002,TCGA-BB-0002-01Z-00-DX2,/home/ubuntu/PyHealth/test-resources/tcga_crck/CRC_DX_TRAIN/MSS/blk-YYKVVRTRMCKG-TCGA-BB-0002-01Z-00-DX2.png,0,train,0 +TCGA-BB-0002,TCGA-BB-0002-01Z-00-DX1,/home/ubuntu/PyHealth/test-resources/tcga_crck/CRC_DX_TRAIN/MSS/blk-YYIVRGPNDWWN-TCGA-BB-0002-01Z-00-DX1.png,0,train,0 +TCGA-AA-0001,TCGA-AA-0001-01Z-00-DX1,/home/ubuntu/PyHealth/test-resources/tcga_crck/CRC_DX_TEST/MSS/blk-YYNNMCRNWWEP-TCGA-AA-0001-01Z-00-DX1.png,0,test,0 diff --git a/tests/core/test_tcga_crck.py b/tests/core/test_tcga_crck.py new file mode 100644 index 000000000..ac1966494 --- /dev/null +++ b/tests/core/test_tcga_crck.py @@ -0,0 +1,222 @@ +"""Fast synthetic unit tests for TCGACRCkDataset and TCGACRCkMSIClassification. + +This test file uses only synthetic data generated inside +test-resources/tcga_crck during the test run. +""" + +from __future__ import annotations + +import shutil +import unittest +from pathlib import Path + +import numpy as np +from PIL import Image + +from pyhealth.datasets import TCGACRCkDataset +from pyhealth.tasks import TCGACRCkMSIClassification + + +class _SyntheticTCGACRCkData: + """Shared helpers for a tiny synthetic TCGA-CRCk fixture.""" + + @classmethod + def setUpClass(cls): + cls.test_dir = ( + Path(__file__).parent.parent.parent / "test-resources" / "tcga_crck" + ) + cls.test_dir.mkdir(parents=True, exist_ok=True) + + cls._remove_raw_fixture() + cls._build_raw_fixture() + + # Overwrite the CSV from the fresh synthetic raw files each run. + TCGACRCkDataset.prepare_metadata(str(cls.test_dir)) + + @classmethod + def tearDownClass(cls): + """Remove only synthetic image folders, keep root folder and CSV.""" + cls._remove_raw_fixture() + + @classmethod + def _remove_raw_fixture(cls) -> None: + """Deletes only raw image directories and keeps the metadata CSV.""" + for dirname in ["CRC_DX_TRAIN", "CRC_DX_TEST"]: + dir_path = cls.test_dir / dirname + if dir_path.exists(): + shutil.rmtree(dir_path, ignore_errors=True) + + @classmethod + def _make_image(cls, path: Path) -> None: + """Creates a tiny synthetic RGB PNG image.""" + path.parent.mkdir(parents=True, exist_ok=True) + image = Image.fromarray( + np.random.randint(0, 255, (8, 8, 3), dtype=np.uint8), + mode="RGB", + ) + image.save(path) + + @classmethod + def _tile_filename(cls, random_prefix: str, slide_id: str) -> str: + """Returns a tile filename matching the dataset regex convention.""" + return f"blk-{random_prefix}-{slide_id}.png" + + @classmethod + def _build_raw_fixture(cls) -> None: + """Builds a tiny synthetic raw TCGA-CRCk-style folder tree. + + Current dataset convention: + 1 -> MSIMUT + 0 -> MSS + """ + # 3 patients total, 5 images total + slide_specs = [ + ( + "TCGA-AA-0001", + "TCGA-AA-0001-01Z-00-DX1", + "test", + 0, + ["YYNNMCRNWWEP"], + ), + ( + "TCGA-BB-0002", + "TCGA-BB-0002-01Z-00-DX1", + "train", + 0, + ["YYIVRGPNDWWN"], + ), + ( + "TCGA-BB-0002", + "TCGA-BB-0002-01Z-00-DX2", + "train", + 0, + ["YYKVVRTRMCKG"], + ), + ( + "TCGA-CC-0003", + "TCGA-CC-0003-01Z-00-DX1", + "train", + 1, + ["WWKFIIMTFPSG"], + ), + ( + "TCGA-CC-0003", + "TCGA-CC-0003-01Z-00-DX2", + "train", + 1, + ["WWMIEPLQMACE"], + ), + ] + + for _, slide_id, data_split, label, random_prefixes in slide_specs: + split_dir = "CRC_DX_TRAIN" if data_split == "train" else "CRC_DX_TEST" + label_dir = "MSIMUT" if label == 1 else "MSS" + + for random_prefix in random_prefixes: + filename = cls._tile_filename(random_prefix, slide_id) + image_path = cls.test_dir / split_dir / label_dir / filename + cls._make_image(image_path) + + +class TestTCGACRCkDataset(_SyntheticTCGACRCkData, unittest.TestCase): + """Fast test cases for TCGACRCkDataset.""" + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.dataset = TCGACRCkDataset(root=str(cls.test_dir)) + + def test_dataset_initialization(self): + """Test that the dataset initializes correctly.""" + self.assertIsNotNone(self.dataset) + self.assertEqual(self.dataset.dataset_name, "tcga_crck") + + def test_num_patients(self): + """Test the number of unique synthetic patient IDs.""" + self.assertEqual(len(self.dataset.unique_patient_ids), 3) + + def test_get_patient_single_slide(self): + """Test retrieving a patient record with a single slide.""" + patient = self.dataset.get_patient("TCGA-AA-0001") + self.assertIsNotNone(patient) + self.assertEqual(patient.patient_id, "TCGA-AA-0001") + + events = patient.get_events(event_type="tcga_crck") + self.assertEqual(len(events), 1) + + event = events[0] + self.assertEqual(event["slide_id"], "TCGA-AA-0001-01Z-00-DX1") + self.assertEqual(event["data_split"], "test") + self.assertEqual(event["label"], "0") + self.assertTrue(str(event["tile_path"]).endswith(".png")) + + def test_get_patient_multi_slide(self): + """Test retrieving a patient record with multiple slides.""" + patient = self.dataset.get_patient("TCGA-BB-0002") + self.assertIsNotNone(patient) + self.assertEqual(patient.patient_id, "TCGA-BB-0002") + + events = patient.get_events(event_type="tcga_crck") + self.assertEqual(len(events), 2) + + slide_ids = {event["slide_id"] for event in events} + self.assertEqual( + slide_ids, + { + "TCGA-BB-0002-01Z-00-DX1", + "TCGA-BB-0002-01Z-00-DX2", + }, + ) + + def test_event_fields_exist(self): + """Test that event records contain the expected fields.""" + patient = self.dataset.get_patient("TCGA-CC-0003") + events = patient.get_events(event_type="tcga_crck") + + self.assertGreater(len(events), 0) + event = events[0] + self.assertIn("slide_id", event) + self.assertIn("tile_path", event) + self.assertIn("tile_index", event) + self.assertIn("data_split", event) + self.assertIn("label", event) + + +class TestTCGACRCkMSIClassification(_SyntheticTCGACRCkData, unittest.TestCase): + """Fast test cases for TCGACRCkMSIClassification.""" + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.dataset = TCGACRCkDataset(root=str(cls.test_dir)) + cls.task = TCGACRCkMSIClassification(max_tiles=1) + + def test_default_task(self): + """Test that the dataset exposes the default task.""" + self.assertIsInstance(self.dataset.default_task, TCGACRCkMSIClassification) + + def test_task_raw_output_single_slide_patient(self): + """Test raw task output on a patient with one slide.""" + patient = self.dataset.get_patient("TCGA-AA-0001") + samples = self.task(patient) + + self.assertEqual(len(samples), 1) + + sample = samples[0] + self.assertEqual(sample["patient_id"], "TCGA-AA-0001") + self.assertEqual(sample["visit_id"], "TCGA-AA-0001") + self.assertEqual(sample["label"], 0) + + tile_paths, tile_times = sample["tile_bag"] + self.assertEqual(len(tile_paths), 1) + self.assertEqual(len(tile_times), 1) + self.assertEqual(tile_times, [0.0]) + + def test_set_task_runs(self): + """Single end-to-end smoke test for set_task().""" + sample_ds = self.dataset.set_task(self.task) + self.assertGreater(len(sample_ds), 0) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/core/test_tcga_crck_simclr_mil.py b/tests/core/test_tcga_crck_simclr_mil.py new file mode 100644 index 000000000..78cb51965 --- /dev/null +++ b/tests/core/test_tcga_crck_simclr_mil.py @@ -0,0 +1,124 @@ +"""Unit tests for TissueAwareSimCLR.""" + +from __future__ import annotations + +import tempfile +import unittest +from pathlib import Path + +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import TissueAwareSimCLR + +class TestTissueAwareSimCLR(unittest.TestCase): + """Synthetic tests for TissueAwareSimCLR.""" + + def _build_sample_dataset(self): + samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "tile_bag": torch.randn(2, 3, 4, 4).tolist(), + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "tile_bag": torch.randn(2, 3, 4, 4).tolist(), + "label": 0, + }, + ] + return create_sample_dataset( + samples=samples, + input_schema={"tile_bag": "tensor"}, + output_schema={"label": "binary"}, + dataset_name="tcga_crck_test", + ) + + def test_model_instantiation(self): + """Test that the model initializes correctly.""" + dataset = self._build_sample_dataset() + model = TissueAwareSimCLR( + dataset=dataset, + hidden_dim=16, + dropout=0.1, + freeze_encoder=True, + pooling="attention", + ) + self.assertIsNotNone(model) + self.assertEqual(model.feature_key, "tile_bag") + self.assertEqual(model.label_key, "label") + + def test_forward_pass_shapes(self): + """Test forward pass output keys and tensor shapes.""" + dataset = self._build_sample_dataset() + model = TissueAwareSimCLR(dataset=dataset, hidden_dim=16) + + loader = get_dataloader(dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + ret = model(**batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + self.assertIn("logit", ret) + self.assertIn("attention_weights", ret) + + self.assertEqual(ret["logit"].shape, (2, 1)) + self.assertEqual(ret["y_prob"].shape, (2, 1)) + self.assertEqual(ret["y_true"].shape, (2, 1)) + self.assertEqual(ret["attention_weights"].shape[0], 2) + + def test_backward_pass(self): + """Test that gradients can be computed.""" + dataset = self._build_sample_dataset() + model = TissueAwareSimCLR(dataset=dataset, hidden_dim=16) + + loader = get_dataloader(dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + ret = model(**batch) + + self.assertIsNotNone(ret["loss"]) + ret["loss"].backward() + self.assertIsNotNone(model.classifier.weight.grad) + + def test_mean_pooling_variant(self): + """Test the model with mean MIL pooling.""" + dataset = self._build_sample_dataset() + model = TissueAwareSimCLR( + dataset=dataset, + hidden_dim=8, + pooling="mean", + ) + + loader = get_dataloader(dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + ret = model(**batch) + + self.assertEqual(ret["logit"].shape, (2, 1)) + self.assertEqual(ret["attention_weights"].shape, (2, 2)) + + def test_checkpoint_loading(self): + """Test loading a prefixed encoder checkpoint.""" + dataset = self._build_sample_dataset() + model = TissueAwareSimCLR(dataset=dataset, hidden_dim=8) + + prefixed_state = { + f"encoder.{k}": v.cpu() for k, v in model.encoder.state_dict().items() + } + + with tempfile.TemporaryDirectory() as tmpdir: + ckpt_path = Path(tmpdir) / "simclr_encoder.pt" + torch.save({"state_dict": prefixed_state}, ckpt_path) + + loaded_model = TissueAwareSimCLR( + dataset=dataset, + checkpoint_path=str(ckpt_path), + hidden_dim=8, + ) + self.assertIsNotNone(loaded_model) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file