diff --git a/.gitignore b/.gitignore index aaf66fc36..23f645787 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,9 @@ resource* debug_entry* playground.ipynb +# External TFM-Tokenizer (do not commit) +TFM-Tokenizer/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/contrawr_coverage_1x2.png b/contrawr_coverage_1x2.png new file mode 100644 index 000000000..b33c1e373 Binary files /dev/null and b/contrawr_coverage_1x2.png differ diff --git a/contrawr_coverage_solid.png b/contrawr_coverage_solid.png new file mode 100644 index 000000000..b09ab9729 Binary files /dev/null and b/contrawr_coverage_solid.png differ diff --git a/contrawr_results_1x2.png b/contrawr_results_1x2.png new file mode 100644 index 000000000..eb672e8f5 Binary files /dev/null and b/contrawr_results_1x2.png differ diff --git a/contrawr_setsize_1x2.png b/contrawr_setsize_1x2.png new file mode 100644 index 000000000..137209610 Binary files /dev/null and b/contrawr_setsize_1x2.png differ diff --git a/contrawr_setsize_solid.png b/contrawr_setsize_solid.png new file mode 100644 index 000000000..a39f56f6c Binary files /dev/null and b/contrawr_setsize_solid.png differ diff --git a/coverage_plot.png b/coverage_plot.png new file mode 100644 index 000000000..cb233e330 Binary files /dev/null and b/coverage_plot.png differ diff --git a/examples/conformal_eeg/RUN_TFM_32_COMMANDS.txt b/examples/conformal_eeg/RUN_TFM_32_COMMANDS.txt new file mode 100644 index 000000000..47abeb808 --- /dev/null +++ b/examples/conformal_eeg/RUN_TFM_32_COMMANDS.txt @@ -0,0 +1,41 @@ +# TFM: 32 jobs in parallel (same style as ContraWR). Run from repo root. +# Logs go to logs/ (bare filenames are written under repo root logs/). +# Default tokenizer + TUEV/TUAB classifier paths are set in the scripts; no need to pass them. + +# --- Events (TUEV), 16 jobs --- +CUDA_VISIBLE_DEVICES=0 python examples/conformal_eeg/tuev_naive_cp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.2 --log-file naive_cp_alpha02_events_seeds5.log & +CUDA_VISIBLE_DEVICES=1 python examples/conformal_eeg/tuev_naive_cp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.1 --log-file naive_cp_alpha01_events_seeds5.log & +CUDA_VISIBLE_DEVICES=2 python examples/conformal_eeg/tuev_naive_cp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.05 --log-file naive_cp_alpha005_events_seeds5.log & +CUDA_VISIBLE_DEVICES=3 python examples/conformal_eeg/tuev_naive_cp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.01 --log-file naive_cp_alpha001_events_seeds5.log & +CUDA_VISIBLE_DEVICES=4 python examples/conformal_eeg/tuev_kde_cp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.2 --log-file kde_cp_alpha02_events_seeds5.log & +CUDA_VISIBLE_DEVICES=5 python examples/conformal_eeg/tuev_kde_cp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.1 --log-file kde_cp_alpha01_events_seeds5.log & +CUDA_VISIBLE_DEVICES=6 python examples/conformal_eeg/tuev_kde_cp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.05 --log-file kde_cp_alpha005_events_seeds5.log & +CUDA_VISIBLE_DEVICES=7 python examples/conformal_eeg/tuev_kde_cp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.01 --log-file kde_cp_alpha001_events_seeds5.log & +CUDA_VISIBLE_DEVICES=0 python examples/conformal_eeg/tuev_kmeans_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.2 --log-file kmeans_cp_alpha02_events_seeds5.log & +CUDA_VISIBLE_DEVICES=1 python examples/conformal_eeg/tuev_kmeans_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.1 --log-file kmeans_cp_alpha01_events_seeds5.log & +CUDA_VISIBLE_DEVICES=2 python examples/conformal_eeg/tuev_kmeans_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.05 --log-file kmeans_cp_alpha005_events_seeds5.log & +CUDA_VISIBLE_DEVICES=3 python examples/conformal_eeg/tuev_kmeans_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.01 --log-file kmeans_cp_alpha001_events_seeds5.log & +CUDA_VISIBLE_DEVICES=4 python examples/conformal_eeg/tuev_ncp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.2 --log-file ncp_alpha02_events_seeds5.log & +CUDA_VISIBLE_DEVICES=5 python examples/conformal_eeg/tuev_ncp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.1 --log-file ncp_alpha01_events_seeds5.log & +CUDA_VISIBLE_DEVICES=6 python examples/conformal_eeg/tuev_ncp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.05 --log-file ncp_alpha005_events_seeds5.log & +CUDA_VISIBLE_DEVICES=7 python examples/conformal_eeg/tuev_ncp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.01 --log-file ncp_alpha001_events_seeds5.log & + +# --- Abnormal (TUAB), 16 jobs. Use --dataset tuab so task and default classifier path are correct. --- +CUDA_VISIBLE_DEVICES=0 python examples/conformal_eeg/tuev_naive_cp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.2 --log-file naive_cp_alpha02_abnormal_seeds5.log & +CUDA_VISIBLE_DEVICES=1 python examples/conformal_eeg/tuev_naive_cp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.1 --log-file naive_cp_alpha01_abnormal_seeds5.log & +CUDA_VISIBLE_DEVICES=2 python examples/conformal_eeg/tuev_naive_cp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.05 --log-file naive_cp_alpha005_abnormal_seeds5.log & +CUDA_VISIBLE_DEVICES=3 python examples/conformal_eeg/tuev_naive_cp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.01 --log-file naive_cp_alpha001_abnormal_seeds5.log & +CUDA_VISIBLE_DEVICES=4 python examples/conformal_eeg/tuev_kde_cp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.2 --log-file kde_cp_alpha02_abnormal_seeds5.log & +CUDA_VISIBLE_DEVICES=5 python examples/conformal_eeg/tuev_kde_cp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.1 --log-file kde_cp_alpha01_abnormal_seeds5.log & +CUDA_VISIBLE_DEVICES=6 python examples/conformal_eeg/tuev_kde_cp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.05 --log-file kde_cp_alpha005_abnormal_seeds5.log & +CUDA_VISIBLE_DEVICES=7 python examples/conformal_eeg/tuev_kde_cp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.01 --log-file kde_cp_alpha001_abnormal_seeds5.log & +CUDA_VISIBLE_DEVICES=0 python examples/conformal_eeg/tuev_kmeans_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.2 --log-file kmeans_cp_alpha02_abnormal_seeds5.log & +CUDA_VISIBLE_DEVICES=1 python examples/conformal_eeg/tuev_kmeans_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.1 --log-file kmeans_cp_alpha01_abnormal_seeds5.log & +CUDA_VISIBLE_DEVICES=2 python examples/conformal_eeg/tuev_kmeans_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.05 --log-file kmeans_cp_alpha005_abnormal_seeds5.log & +CUDA_VISIBLE_DEVICES=3 python examples/conformal_eeg/tuev_kmeans_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.01 --log-file kmeans_cp_alpha001_abnormal_seeds5.log & +CUDA_VISIBLE_DEVICES=4 python examples/conformal_eeg/tuev_ncp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.2 --log-file ncp_alpha02_abnormal_seeds5.log & +CUDA_VISIBLE_DEVICES=5 python examples/conformal_eeg/tuev_ncp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.1 --log-file ncp_alpha01_abnormal_seeds5.log & +CUDA_VISIBLE_DEVICES=6 python examples/conformal_eeg/tuev_ncp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.05 --log-file ncp_alpha005_abnormal_seeds5.log & +CUDA_VISIBLE_DEVICES=7 python examples/conformal_eeg/tuev_ncp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.01 --log-file ncp_alpha001_abnormal_seeds5.log & + +wait diff --git a/examples/conformal_eeg/model_utils.py b/examples/conformal_eeg/model_utils.py new file mode 100644 index 000000000..fe259c8ce --- /dev/null +++ b/examples/conformal_eeg/model_utils.py @@ -0,0 +1,158 @@ +"""Shared helpers for TUEV conformal scripts: model choice (ContraWR vs TFM-Tokenizer) and STFT dataset wrapper. + +TFM runs use the same experimental protocol as ContraWR (split_seed, run_seeds, alpha, +ratios, fixed test set) so results are directly comparable. + +TFM loading options: +- Single checkpoint: --tfm-checkpoint (full model or tokenizer-only; use {seed} for per-seed). +- Two checkpoints (pretrained tokenizer + finetuned classifier): --tfm-tokenizer-checkpoint + and --tfm-classifier-checkpoint (use {seed} in classifier path for per-seed). Use + --tfm-skip-train to run calibration + inference only. + +Scripts support --dataset tuev|tuab (same protocol; TUAB uses binary task, TUEV uses multiclass). +""" + +from __future__ import annotations + +import numpy as np +import torch + +from pyhealth.models import ContraWR, TFMTokenizer + + +RESAMPLING_RATE = 200 # TFM-Tokenizer standard; n_fft=200, hop_length=100, 100 freq bins + + +def get_stft_torch(X: torch.Tensor, resampling_rate: int = RESAMPLING_RATE) -> torch.Tensor: + """Per-channel magnitude STFT matching TFM-Tokenizer repo. Input (B, C, T) -> output (B, C, 100, T').""" + B, C, T = X.shape + x_temp = X.reshape(B * C, T) + window = torch.hann_window(resampling_rate, device=X.device, dtype=X.dtype) + stft_complex = torch.stft( + x_temp, + n_fft=resampling_rate, + hop_length=resampling_rate // 2, + window=window, + onesided=True, + return_complex=True, + center=False, + ) + # (B*C, n_fft//2+1, T') -> take first 100 freq bins + x_stft_temp = torch.abs(stft_complex)[:, : resampling_rate // 2, :] + x_stft_temp = x_stft_temp.reshape(B, C, resampling_rate // 2, -1) + return x_stft_temp + + +class AddSTFTDataset: + """Wraps a TUEV/TUAB task dataset to add per-channel 'stft' for TFM-Tokenizer. + Keeps 'signal' as (C, T); adds 'stft' as (C, 100, T') with n_fft=200, hop_length=100. + Matches the original TFM-Tokenizer training pipeline (16 token sequences per sample). + """ + + def __init__(self, base, n_fft: int = 200, hop_length: int = 100): + self._base = base + self.n_fft = n_fft + self.hop_length = hop_length + self.input_schema = {**getattr(base, "input_schema", {}), "stft": "tensor"} + self.output_schema = getattr(base, "output_schema", {}) + + @property + def output_processors(self): + """Forward so BaseModel.get_output_size() works (TFMTokenizer/ContraWR).""" + return getattr(self._base, "output_processors", {}) + + @property + def input_processors(self): + """Forward in case the model reads input_processors.""" + return getattr(self._base, "input_processors", {}) + + def __len__(self) -> int: + return len(self._base) + + def __getitem__(self, i: int): + sample = dict(self._base[i]) + signal = sample["signal"] + signal = np.asarray(signal, dtype=np.float32) + if signal.ndim == 1: + signal = signal.reshape(1, -1) + # Normalize by 95th percentile of |signal| per channel (axis=-1), matching TFM training + scale = np.quantile( + np.abs(signal), q=0.95, axis=-1, method="linear", keepdims=True + ) + 1e-8 + signal = np.asarray(signal / scale, dtype=np.float32) + # signal (C, T) -> tensor (float32 for tokenizer) + signal_t = torch.from_numpy(signal) + sample["signal"] = signal_t + # Per-channel STFT: (1, C, T) -> (1, C, 100, T') + stft = get_stft_torch(signal_t.unsqueeze(0), resampling_rate=self.n_fft) + sample["stft"] = stft.squeeze(0) + return sample + + def subset(self, indices): + return AddSTFTDataset(self._base.subset(indices), self.n_fft, self.hop_length) + + def set_shuffle(self, shuffle: bool) -> None: + """Forward to base dataset so get_dataloader() works (pyhealth.datasets.utils).""" + if hasattr(self._base, "set_shuffle"): + self._base.set_shuffle(shuffle) + + +def _load_tfm_checkpoint(model, checkpoint_path: str, device: str): + """Load TFM checkpoint: full model state_dict or tokenizer-only (legacy).""" + ckpt = torch.load(checkpoint_path, map_location=device) + state = ckpt.get("model_state_dict", ckpt.get("state_dict", ckpt)) + if not isinstance(state, dict): + model.load_pretrained_weights(checkpoint_path, map_location=device) + return + keys = list(state.keys()) + if any(str(k).startswith("tokenizer.") or str(k).startswith("classifier.") for k in keys): + model.load_state_dict(state, strict=False) + print(f" Loaded full model from {checkpoint_path}") + else: + model.load_pretrained_weights(checkpoint_path, map_location=device) + + +def _load_tfm_classifier_checkpoint(model, checkpoint_path: str, device: str): + """Load classifier-only checkpoint into model.classifier. Handles keys with or without 'classifier.' prefix.""" + ckpt = torch.load(checkpoint_path, map_location=device) + state = ckpt.get("model_state_dict", ckpt.get("state_dict", ckpt)) + if not isinstance(state, dict): + return + keys = list(state.keys()) + if any(str(k).startswith("classifier.") for k in keys): + model.load_state_dict(state, strict=False) + print(f" Loaded classifier from {checkpoint_path}") + else: + model.classifier.load_state_dict(state, strict=False) + print(f" Loaded classifier from {checkpoint_path}") + + +def get_model(args, sample_dataset, device: str): + """Build ContraWR or TFMTokenizer from args.model. Use sample_dataset (possibly AddSTFTDataset for TFM). + Loading options (TFM): + - args.tfm_checkpoint: single path (full model or tokenizer-only). + - args.tfm_tokenizer_checkpoint + args.tfm_classifier_checkpoint: pretrained tokenizer + per-seed finetuned classifier. + """ + if getattr(args, "model", "contrawr").lower() == "tfm": + model = TFMTokenizer( + dataset=sample_dataset, + n_freq=100, + emb_size=getattr(args, "tfm_emb_size", 64), + code_book_size=getattr(args, "tfm_code_book_size", 8192), + ) + model = model.to(device) + tokenizer_ckpt = getattr(args, "tfm_tokenizer_checkpoint", None) + classifier_ckpt = getattr(args, "tfm_classifier_checkpoint", None) + single_ckpt = getattr(args, "tfm_checkpoint", None) + if tokenizer_ckpt and classifier_ckpt: + model.load_pretrained_weights(tokenizer_ckpt, map_location=device) + _load_tfm_classifier_checkpoint(model, classifier_ckpt, device) + elif single_ckpt: + _load_tfm_checkpoint(model, single_ckpt, device) + if getattr(args, "tfm_freeze_tokenizer", False): + for p in model.tokenizer.parameters(): + p.requires_grad = False + return model + else: + model = ContraWR(dataset=sample_dataset, n_fft=getattr(args, "n_fft", 128)) + return model.to(device) diff --git a/examples/conformal_eeg/run_tfm_grid_8gpu.sh b/examples/conformal_eeg/run_tfm_grid_8gpu.sh new file mode 100755 index 000000000..a7639c37c --- /dev/null +++ b/examples/conformal_eeg/run_tfm_grid_8gpu.sh @@ -0,0 +1,98 @@ +#!/usr/bin/env bash +# Run full TFM conformal grid on 8 GPUs (4 alphas × 2 datasets × 4 methods = 32 jobs). +# Dynamic assignment: first free GPU takes the next job (faster when runtimes vary). +# Usage: set TOK, TUEV_CLF, TUAB_CLF below, then: bash run_tfm_grid_8gpu.sh +# From repo root: bash examples/conformal_eeg/run_tfm_grid_8gpu.sh +# Run 8 at a time (4 waves): WAVES_OF_8=1 bash examples/conformal_eeg/run_tfm_grid_8gpu.sh + +set -e +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR/../.." + +# --- PATHS (structure: .../TFM_Tokenizer_multiple_finetuned_on_TUEV/TFM_Tokenizer_multiple_finetuned_on_TUEV_{seed}/best_model.pth) --- +TOK="${TOK:-/srv/local/data/arjunc4/tfm_tokenizer_last.pth}" +# Escape } in {seed} so bash does not close ${VAR:-default} at the first } +TUEV_CLF="${TUEV_CLF:-/srv/local/data/arjunc4/TFM_Tokenizer_multiple_finetuned_on_TUEV/TFM_Tokenizer_multiple_finetuned_on_TUEV_{seed\}/best_model.pth}" +TUAB_CLF="${TUAB_CLF:-/srv/local/data/arjunc4/TFM_Tokenizer_multiple_finetuned_on_TUAB/TFM_Tokenizer_multiple_finetuned_on_TUAB_{seed\}/best_model.pth}" +# Logs go to repo root logs/ (create if missing) +REPO_ROOT="$SCRIPT_DIR/../.." +LOG_DIR="${LOG_DIR:-$REPO_ROOT/logs}" +mkdir -p "$LOG_DIR" + +run_one() { + local gpu="$1" + local method="$2" + local dataset="$3" + local alpha="$4" + local script + case "$method" in + naive) script="tuev_naive_cp_conformal.py" ;; + kde) script="tuev_kde_cp_conformal.py" ;; + kmeans) script="tuev_kmeans_conformal.py" ;; + ncp) script="tuev_ncp_conformal.py" ;; + *) echo "Unknown method $method"; exit 1 ;; + esac + local clf_path + [ "$dataset" = tuev ] && clf_path="$TUEV_CLF" || clf_path="$TUAB_CLF" + CUDA_VISIBLE_DEVICES=$gpu python "examples/conformal_eeg/$script" \ + --dataset "$dataset" --model tfm --alpha "$alpha" \ + --tfm-tokenizer-checkpoint "$TOK" \ + --tfm-classifier-checkpoint "$clf_path" \ + --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 \ + --log-file "$LOG_DIR/${method}_${dataset}_alpha${alpha}.log" +} + +# Build list of jobs: method dataset alpha (32 jobs) +jobs=() +for method in naive kde kmeans ncp; do + for dataset in tuev tuab; do + for alpha in 0.2 0.1 0.05 0.01; do + jobs+=("$method $dataset $alpha") + done + done +done +total=${#jobs[@]} + +if [[ -n "${WAVES_OF_8:-}" ]]; then + # Run 4 waves of 8 jobs: wait for each wave to finish before starting the next + for wave in 0 1 2 3; do + start=$((wave * 8)) + echo "=== Wave $((wave + 1))/4 (jobs $((start + 1))–$((start + 8))) ===" + for gpu in 0 1 2 3 4 5 6 7; do + idx=$((start + gpu)) + read -r method dataset alpha <<< "${jobs[idx]}" + run_one "$gpu" "$method" "$dataset" "$alpha" & + done + wait + echo "Wave $((wave + 1)) done." + done + echo "All 32 jobs done. Logs in $LOG_DIR" +else + # Dynamic assignment: shared job counter, first free GPU claims next job + NEXT_JOB_FILE=$(mktemp) + LOCK_FILE=$(mktemp) + echo 0 > "$NEXT_JOB_FILE" + cleanup() { rm -f "$NEXT_JOB_FILE" "$LOCK_FILE"; } + trap cleanup EXIT + + claim_next_index() { + flock 200 bash -c "read -r idx < \"$NEXT_JOB_FILE\"; echo \$((idx + 1)) > \"$NEXT_JOB_FILE\"; printf '%s\n' \"\$idx\"" + } 200>>"$LOCK_FILE" + + worker() { + local gpu=$1 + while true; do + local job_idx + job_idx=$(claim_next_index) + [[ "$job_idx" -ge "$total" ]] && break + read -r method dataset alpha <<< "${jobs[job_idx]}" + run_one "$gpu" "$method" "$dataset" "$alpha" + done + } + + for gpu in 0 1 2 3 4 5 6 7; do + worker "$gpu" & + done + wait + echo "All 32 jobs done. Logs in $LOG_DIR" +fi diff --git a/examples/conformal_eeg/tuev_kde_cp_conformal.py b/examples/conformal_eeg/tuev_kde_cp_conformal.py new file mode 100644 index 000000000..7fcffc342 --- /dev/null +++ b/examples/conformal_eeg/tuev_kde_cp_conformal.py @@ -0,0 +1,399 @@ +"""CP with covariate shift correction (KDE) on TUEV EEG Events using ContraWR. + +Baseline: CovariateLabel with KDE-based likelihood ratio weighting (CoDrug-style). +Requires cal and test embeddings to fit KDEs and compute weights. + +With --n-seeds > 1: fixes the test set (--split-seed), runs multiple training runs +with different seeds, reports coverage / set size / accuracy as mean ± std. + +Example (from repo root): + python examples/conformal_eeg/tuev_kde_cp_conformal.py --alpha 0.1 --n-seeds 5 --split-seed 0 --log-file kde_cp_alpha01_seeds5.log + +Run for PI baselines (alpha=0.2, 0.1, 0.05 with error bars): + for a in 0.2 0.1 0.05; do python examples/conformal_eeg/tuev_kde_cp_conformal.py --alpha $a --n-seeds 5 --split-seed 0 --log-file kde_cp_alpha${a}_seeds5.log; done + Or in parallel on 3 GPUs: same with CUDA_VISIBLE_DEVICES=0/1/2 and & wait. +""" + +from __future__ import annotations + +import argparse +import random +import sys +from pathlib import Path + +_script_dir = Path(__file__).resolve().parent +if str(_script_dir) not in sys.path: + sys.path.insert(0, str(_script_dir)) + +import numpy as np +import torch + + +class _Tee: + def __init__(self, stream, file): + self._stream = stream + self._file = file + + def write(self, data): + self._stream.write(data) + self._file.write(data) + self._file.flush() + + def flush(self): + self._stream.flush() + self._file.flush() + + +from pyhealth.calib.predictionset.covariate import CovariateLabel +from pyhealth.calib.utils import extract_embeddings +from pyhealth.datasets import TUEVDataset, TUABDataset, get_dataloader, split_by_sample_conformal +from pyhealth.tasks import EEGEventsTUEV, EEGAbnormalTUAB +from pyhealth.trainer import Trainer, get_metrics_fn + +from model_utils import AddSTFTDataset, get_model + +DEFAULT_ROOT = {"tuev": "/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf", "tuab": "/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf"} + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="CP with covariate shift correction (KDE / CovariateLabel) on TUEV/TUAB EEG using ContraWR or TFM." + ) + parser.add_argument("--dataset", type=str, default="tuev", choices=["tuev", "tuab"], help="EEG dataset: tuev or tuab.") + parser.add_argument("--root", type=str, default=None, help="Path to dataset edf/ folder. Default per --dataset.") + parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"]) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n-seeds", type=int, default=1, help="Number of runs for mean±std. Test set fixed when > 1.") + parser.add_argument("--split-seed", type=int, default=0, help="Fixed seed for initial split (fixes test set when n-seeds > 1).") + parser.add_argument("--seeds", type=str, default=None, help="Comma-separated run seeds. Overrides --seed and --n-seeds.") + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--epochs", type=int, default=20) + parser.add_argument("--alpha", type=float, default=0.1, help="Miscoverage rate (e.g. 0.1 => 90%% target coverage).") + parser.add_argument("--ratios", type=float, nargs=4, default=(0.6, 0.1, 0.15, 0.15), metavar=("TRAIN", "VAL", "CAL", "TEST")) + parser.add_argument("--n-fft", type=int, default=128) + parser.add_argument("--model", type=str, default="contrawr", choices=["contrawr", "tfm"], help="Backbone: contrawr or tfm (TFM-Tokenizer).") + parser.add_argument("--tfm-checkpoint", type=str, default=None, help="Path to TFM checkpoint (full model or tokenizer). Use {seed} for per-seed paths.") + parser.add_argument("--tfm-tokenizer-checkpoint", type=str, default=None, help="Path to pretrained TFM tokenizer (shared). Use with --tfm-classifier-checkpoint for inference.") + parser.add_argument("--tfm-classifier-checkpoint", type=str, default=None, help="Path to finetuned classifier. Use {seed} for per-seed paths.") + parser.add_argument("--tfm-skip-train", action="store_true", help="Skip training; load checkpoint(s) and run calibration + inference only.") + parser.add_argument("--tfm-freeze-tokenizer", action="store_true", help="Freeze tokenizer when fine-tuning; only train classifier.") + parser.add_argument("--tfm-epochs", type=int, default=5, help="Epochs when fine-tuning TFM. Ignored if --tfm-skip-train.") + parser.add_argument("--tfm-lr", type=float, default=1e-4, help="Learning rate when fine-tuning TFM.") + parser.add_argument("--device", type=str, default=None) + parser.add_argument("--log-file", type=str, default=None) + parser.add_argument("--cache-dir", type=str, default=None, help="Per-job cache dir to avoid races when running 8 in parallel.") + parser.add_argument("--quick-test", action="store_true", help="dev=True, max 2000 samples, 2 epochs.") + return parser.parse_args() + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def _split_remainder_into_train_val_cal(sample_dataset, remainder_indices, ratios, run_seed): + r0, r1, r2, r3 = ratios + remainder_frac = 1.0 - r3 + if remainder_frac <= 0: + raise ValueError("Test ratio must be < 1.") + r_train = r0 / remainder_frac + r_val = r1 / remainder_frac + remainder = np.asarray(remainder_indices, dtype=np.int64) + np.random.seed(run_seed) + shuffled = np.random.permutation(remainder) + M = len(shuffled) + train_end = int(M * r_train) + val_end = int(M * (r_train + r_val)) + train_index = shuffled[:train_end] + val_index = shuffled[train_end:val_end] + cal_index = shuffled[val_end:] + train_ds = sample_dataset.subset(train_index.tolist()) + val_ds = sample_dataset.subset(val_index.tolist()) + cal_ds = sample_dataset.subset(cal_index.tolist()) + return train_ds, val_ds, cal_ds + + +def _run_one_kde_cp( + sample_dataset, + train_ds, + val_ds, + cal_ds, + test_ds, + test_loader, + args, + device, + epochs, + task_mode="multiclass", + return_metrics=False, +): + """Train ContraWR, extract cal + test embeddings, calibrate CovariateLabel (KDE), evaluate.""" + train_loader = get_dataloader(train_ds, batch_size=args.batch_size, shuffle=True) + val_loader = get_dataloader(val_ds, batch_size=args.batch_size, shuffle=False) if len(val_ds) else None + + model_name = "TFM-Tokenizer" if args.model.lower() == "tfm" else "ContraWR" + print("\n" + "=" * 80) + print(f"STEP 3: Train {model_name}" if not getattr(args, "tfm_skip_train", False) else f"STEP 3: Load {model_name} (skip train)") + print("=" * 80) + model = get_model(args, sample_dataset, device) + trainer = Trainer(model=model, device=device, enable_logging=False) + if not getattr(args, "tfm_skip_train", False): + optimizer_params = None + if args.model.lower() == "tfm" and ( + getattr(args, "tfm_checkpoint", None) + or (getattr(args, "tfm_tokenizer_checkpoint", None) and getattr(args, "tfm_classifier_checkpoint", None)) + ): + optimizer_params = {"lr": getattr(args, "tfm_lr", 1e-4)} + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=epochs, + monitor="accuracy" if val_loader is not None else None, + optimizer_params=optimizer_params, + ) + + if not return_metrics: + print("\nBase model performance on test set:") + y_true_base, y_prob_base, _ = trainer.inference(test_loader) + base_metrics = get_metrics_fn(task_mode)(y_true_base, y_prob_base, metrics=["accuracy", "f1_weighted"]) + for k, v in base_metrics.items(): + print(f" {k}: {v:.4f}") + + print("\n" + "=" * 80) + print("STEP 4: CP with covariate shift correction (CovariateLabel / KDE)") + print("=" * 80) + print(f"Target miscoverage alpha: {args.alpha} (target coverage {1 - args.alpha:.0%})") + + print("Extracting calibration embeddings...") + cal_embeddings = extract_embeddings(model, cal_ds, batch_size=args.batch_size, device=device) + print("Extracting test embeddings...") + test_embeddings = extract_embeddings(model, test_ds, batch_size=args.batch_size, device=device) + if not return_metrics: + print(f" cal_embeddings shape: {cal_embeddings.shape}, test_embeddings shape: {test_embeddings.shape}") + + predictor = CovariateLabel(model=model, alpha=float(args.alpha)) + predictor.calibrate( + cal_dataset=cal_ds, + cal_embeddings=cal_embeddings, + test_embeddings=test_embeddings, + ) + + y_true, y_prob, _loss, extra = Trainer(model=predictor, enable_logging=False).inference(test_loader, additional_outputs=["y_predset"]) + metrics = get_metrics_fn(task_mode)(y_true, y_prob, metrics=["accuracy", "miscoverage_ps"], y_predset=extra["y_predset"]) + predset = extra["y_predset"] + predset_t = torch.tensor(predset) if isinstance(predset, np.ndarray) else predset + avg_set_size = predset_t.float().sum(dim=1).mean().item() + miscoverage = metrics["miscoverage_ps"] + if isinstance(miscoverage, np.ndarray): + miscoverage = float(miscoverage.item() if miscoverage.size == 1 else miscoverage.mean()) + else: + miscoverage = float(miscoverage) + coverage = 1.0 - miscoverage + + if return_metrics: + return {"accuracy": float(metrics["accuracy"]), "coverage": coverage, "miscoverage": miscoverage, "avg_set_size": avg_set_size} + + print("\nCovariateLabel (KDE) Results:") + print(f" Accuracy: {metrics['accuracy']:.4f}") + print(f" Empirical miscoverage: {miscoverage:.4f}") + print(f" Empirical coverage: {coverage:.4f}") + print(f" Average set size: {avg_set_size:.2f}") + print("\n--- Single-run summary (for reporting) ---") + print(f" alpha={args.alpha}, target_coverage={1 - args.alpha:.2f}, empirical_coverage={coverage:.4f}, miscoverage={miscoverage:.4f}, accuracy={metrics['accuracy']:.4f}, avg_set_size={avg_set_size:.2f}") + + +def main() -> None: + args = parse_args() + if args.n_seeds <= 1 and args.seeds is None: + set_seed(args.seed) + + orig_stdout, orig_stderr = sys.stdout, sys.stderr + log_file = None + if args.log_file: + p = Path(args.log_file) + if "/" not in args.log_file and not p.is_absolute(): + Path("logs").mkdir(parents=True, exist_ok=True) + args.log_file = str(Path("logs") / args.log_file) + log_file = open(args.log_file, "w", encoding="utf-8") + sys.stdout = _Tee(orig_stdout, log_file) + sys.stderr = _Tee(orig_stderr, log_file) + try: + _run(args) + finally: + if log_file is not None: + sys.stdout, sys.stderr = orig_stdout, orig_stderr + log_file.close() + + +def _run(args: argparse.Namespace) -> None: + device = args.device or ("cuda:0" if torch.cuda.is_available() else "cpu") + dataset_name = getattr(args, "dataset", "tuev") + root = Path(args.root or DEFAULT_ROOT[dataset_name]) + # If --root was passed but not --dataset, infer dataset from path so TUAB root => TUAB task + if args.root is not None and "--dataset" not in sys.argv: + root_str = str(root).lower() + if "abnormal" in root_str or "tuab" in root_str: + dataset_name = "tuab" + elif "events" in root_str or "tuev" in root_str: + dataset_name = "tuev" + if not root.exists(): + raise FileNotFoundError(f"Dataset root not found: {root}. Set --root for {dataset_name}.") + + if args.model.lower() == "tfm": + if not getattr(args, "tfm_tokenizer_checkpoint", None): + args.tfm_tokenizer_checkpoint = "/srv/local/data/arjunc4/tfm_tokenizer_last.pth" + if not getattr(args, "tfm_classifier_checkpoint", None): + args.tfm_classifier_checkpoint = ( + "/srv/local/data/arjunc4/TFM_Tokenizer_multiple_finetuned_on_TUAB/TFM_Tokenizer_multiple_finetuned_on_TUAB_{seed}/best_model.pth" + if dataset_name == "tuab" + else "/srv/local/data/arjunc4/TFM_Tokenizer_multiple_finetuned_on_TUEV/TFM_Tokenizer_multiple_finetuned_on_TUEV_{seed}/best_model.pth" + ) + + if args.quick_test: + epochs = 2 + elif args.model.lower() == "tfm" and ( + getattr(args, "tfm_checkpoint", None) + or (getattr(args, "tfm_tokenizer_checkpoint", None) and getattr(args, "tfm_classifier_checkpoint", None)) + ): + epochs = getattr(args, "tfm_epochs", 5) + else: + epochs = args.epochs + quick_test_max = 2000 + if args.quick_test: + print("*** QUICK TEST MODE ***") + + task_mode = "binary" if dataset_name == "tuab" else "multiclass" + print("=" * 80) + print(f"STEP 1: Load {dataset_name.upper()} + build task dataset") + print("=" * 80) + cache_base = getattr(args, "cache_dir", None) + if dataset_name == "tuab": + cache_dir = (cache_base.rstrip("/") + "_tuab") if cache_base else "examples/conformal_eeg/cache_tuab" + dataset = TUABDataset(root=str(root), subset=args.subset, dev=args.quick_test) + sample_dataset = dataset.set_task(EEGAbnormalTUAB(), cache_dir=cache_dir) + else: + cache_dir = cache_base or "examples/conformal_eeg/cache" + dataset = TUEVDataset(root=str(root), subset=args.subset, dev=args.quick_test) + sample_dataset = dataset.set_task( + EEGEventsTUEV(resample_rate=200), cache_dir=cache_dir + ) + if args.quick_test and len(sample_dataset) > quick_test_max: + sample_dataset = sample_dataset.subset(range(quick_test_max)) + print(f"Capped to {quick_test_max} samples.") + if args.model.lower() == "tfm": + sample_dataset = AddSTFTDataset(sample_dataset, n_fft=200, hop_length=100) + print("Wrapped dataset with STFT for TFM-Tokenizer.") + print(f"Task samples: {len(sample_dataset)} (task_mode={task_mode})") + + print("\n--- Experiment configuration ---") + print(f" dataset: {dataset_name}, dataset_root: {root}, subset: {args.subset}") + print(f" ratios: train/val/cal/test = {args.ratios[0]:.2f}/{args.ratios[1]:.2f}/{args.ratios[2]:.2f}/{args.ratios[3]:.2f}") + print(f" alpha: {args.alpha} (target coverage {1 - args.alpha:.0%})") + print(f" epochs: {epochs}, batch_size: {args.batch_size}, device: {device}, seed: {args.seed}") + + if len(sample_dataset) == 0: + raise RuntimeError("No samples.") + + ratios = list(args.ratios) + use_multi_seed = args.n_seeds > 1 or args.seeds is not None + if use_multi_seed: + if args.seeds: + run_seeds = [int(s.strip()) for s in args.seeds.split(",")] + elif getattr(args, "tfm_skip_train", False) and args.model.lower() == "tfm": + run_seeds = list(range(1, 1 + args.n_seeds)) + else: + run_seeds = [args.seed + i for i in range(args.n_seeds)] + n_runs = len(run_seeds) + print(f" multi_seed: n_runs={n_runs}, run_seeds={run_seeds}, split_seed={args.split_seed} (fixed test set)") + print(f"Multi-seed mode: {n_runs} runs (fixed test set), run seeds: {run_seeds}") + + if not use_multi_seed: + if args.model.lower() == "tfm": + ckpt = getattr(args, "tfm_checkpoint", None) + if ckpt and "{seed}" in ckpt: + args.tfm_checkpoint = ckpt.replace("{seed}", str(args.seed)) + clf = getattr(args, "tfm_classifier_checkpoint", None) + if clf and "{seed}" in clf: + args.tfm_classifier_checkpoint = clf.replace("{seed}", str(args.seed)) + print("\n" + "=" * 80) + print("STEP 2: Split train/val/cal/test") + print("=" * 80) + train_ds, val_ds, cal_ds, test_ds = split_by_sample_conformal(dataset=sample_dataset, ratios=ratios, seed=args.seed) + print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Cal: {len(cal_ds)}, Test: {len(test_ds)}") + test_loader = get_dataloader(test_ds, batch_size=args.batch_size, shuffle=False) + _run_one_kde_cp(sample_dataset, train_ds, val_ds, cal_ds, test_ds, test_loader, args, device, epochs, task_mode=task_mode) + print("\n--- Split sizes and seed (for reporting) ---") + print(f" train={len(train_ds)}, val={len(val_ds)}, cal={len(cal_ds)}, test={len(test_ds)}, seed={args.seed}") + return + + print("\n" + "=" * 80) + print("STEP 2: Fix test set (split-seed), then run multiple train/cal splits") + print("=" * 80) + train_idx, val_idx, cal_idx, test_idx = split_by_sample_conformal(dataset=sample_dataset, ratios=ratios, seed=args.split_seed, get_index=True) + train_index = train_idx.numpy() if hasattr(train_idx, "numpy") else np.array(train_idx) + val_index = val_idx.numpy() if hasattr(val_idx, "numpy") else np.array(val_idx) + cal_index = cal_idx.numpy() if hasattr(cal_idx, "numpy") else np.array(cal_idx) + test_index = test_idx.numpy() if hasattr(test_idx, "numpy") else np.array(test_idx) + remainder_indices = np.concatenate([train_index, val_index, cal_index]) + test_ds = sample_dataset.subset(test_index.tolist()) + test_loader = get_dataloader(test_ds, batch_size=args.batch_size, shuffle=False) + n_test = len(test_ds) + print(f"Fixed test set size: {n_test}") + + accs, coverages, miscoverages, set_sizes = [], [], [], [] + tfm_ckpt_original = getattr(args, "tfm_checkpoint", None) + tfm_classifier_ckpt_original = getattr(args, "tfm_classifier_checkpoint", None) + for run_i, run_seed in enumerate(run_seeds): + print("\n" + "=" * 80) + print(f"Run {run_i + 1} / {n_runs} (seed={run_seed})") + print("=" * 80) + if tfm_ckpt_original and "{seed}" in tfm_ckpt_original: + args.tfm_checkpoint = tfm_ckpt_original.replace("{seed}", str(run_seed)) + if tfm_classifier_ckpt_original and "{seed}" in tfm_classifier_ckpt_original: + args.tfm_classifier_checkpoint = tfm_classifier_ckpt_original.replace("{seed}", str(run_seed)) + set_seed(run_seed) + train_ds, val_ds, cal_ds = _split_remainder_into_train_val_cal(sample_dataset, remainder_indices, ratios, run_seed) + print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Cal: {len(cal_ds)}") + m = _run_one_kde_cp(sample_dataset, train_ds, val_ds, cal_ds, test_ds, test_loader, args, device, epochs, task_mode=task_mode, return_metrics=True) + if tfm_ckpt_original and "{seed}" in tfm_ckpt_original: + args.tfm_checkpoint = tfm_ckpt_original + if tfm_classifier_ckpt_original and "{seed}" in tfm_classifier_ckpt_original: + args.tfm_classifier_checkpoint = tfm_classifier_ckpt_original + accs.append(m["accuracy"]) + coverages.append(m["coverage"]) + miscoverages.append(m["miscoverage"]) + set_sizes.append(m["avg_set_size"]) + + accs = np.array(accs) + coverages = np.array(coverages) + miscoverages_arr = np.array(miscoverages) + set_sizes = np.array(set_sizes) + + print("\n" + "=" * 80) + print("Per-run KDE CP results (fixed test set)") + print("=" * 80) + print(f" {'Run':<4} {'Seed':<6} {'Accuracy':<10} {'Coverage':<10} {'Miscoverage':<12} {'Avg set size':<12}") + print(" " + "-" * 54) + for i in range(n_runs): + print(f" {i+1:<4} {run_seeds[i]:<6} {accs[i]:<10.4f} {coverages[i]:<10.4f} {miscoverages_arr[i]:<12.4f} {set_sizes[i]:<12.2f}") + + print("\n" + "=" * 80) + print("KDE CP summary (mean ± std over {} runs, fixed test set)".format(n_runs)) + print("=" * 80) + print(f" Accuracy: {accs.mean():.4f} ± {accs.std():.4f}") + print(f" Empirical coverage: {coverages.mean():.4f} ± {coverages.std():.4f}") + print(f" Empirical miscoverage: {miscoverages_arr.mean():.4f} ± {miscoverages_arr.std():.4f}") + print(f" Average set size: {set_sizes.mean():.2f} ± {set_sizes.std():.2f}") + print(f" Target coverage: {1 - args.alpha:.0%} (alpha={args.alpha})") + print(f" Test set size: {n_test} (fixed across runs)") + print(f" Run seeds: {run_seeds}") + print("\n--- Min / Max (across runs) ---") + print(f" Coverage: [{coverages.min():.4f}, {coverages.max():.4f}]") + print(f" Set size: [{set_sizes.min():.2f}, {set_sizes.max():.2f}]") + print(f" Accuracy: [{accs.min():.4f}, {accs.max():.4f}]") + + +if __name__ == "__main__": + main() diff --git a/examples/conformal_eeg/tuev_kmeans_conformal.py b/examples/conformal_eeg/tuev_kmeans_conformal.py index b5a043dad..66386cadc 100644 --- a/examples/conformal_eeg/tuev_kmeans_conformal.py +++ b/examples/conformal_eeg/tuev_kmeans_conformal.py @@ -8,8 +8,12 @@ 5) Calibrates a ClusterLabel prediction-set predictor (K-means clustering). 6) Evaluates prediction-set coverage/miscoverage and efficiency on the test split. +With --n-seeds > 1: fixes the test set (--split-seed), runs multiple training runs +with different seeds, reports coverage / set size / accuracy as mean ± std. + Example (from repo root): python examples/conformal_eeg/tuev_kmeans_conformal.py --root /srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf --n-clusters 5 + python examples/conformal_eeg/tuev_kmeans_conformal.py --alpha 0.2 --n-seeds 5 --split-seed 0 --log-file kmeans_cp_alpha02_seeds5.log python examples/conformal_eeg/tuev_kmeans_conformal.py --quick-test --log-file quicktest_kmeans.log Notes: @@ -24,6 +28,10 @@ import sys from pathlib import Path +_script_dir = Path(__file__).resolve().parent +if str(_script_dir) not in sys.path: + sys.path.insert(0, str(_script_dir)) + import numpy as np import torch @@ -47,24 +55,26 @@ def flush(self): from pyhealth.calib.predictionset.cluster import ClusterLabel from pyhealth.calib.utils import extract_embeddings -from pyhealth.datasets import TUEVDataset, get_dataloader, split_by_sample_conformal -from pyhealth.models import ContraWR -from pyhealth.tasks import EEGEventsTUEV +from pyhealth.datasets import TUEVDataset, TUABDataset, get_dataloader, split_by_sample_conformal +from pyhealth.tasks import EEGEventsTUEV, EEGAbnormalTUAB from pyhealth.trainer import Trainer, get_metrics_fn +from model_utils import AddSTFTDataset, get_model + +DEFAULT_ROOT = {"tuev": "/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf", "tuab": "/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf"} + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( - description="K-means cluster-based conformal prediction (ClusterLabel) on TUEV EEG events using ContraWR." - ) - parser.add_argument( - "--root", - type=str, - default="/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf", - help="Path to TUEV edf/ folder.", + description="K-means cluster-based conformal prediction (ClusterLabel) on TUEV/TUAB EEG using ContraWR or TFM." ) + parser.add_argument("--dataset", type=str, default="tuev", choices=["tuev", "tuab"], help="EEG dataset: tuev or tuab.") + parser.add_argument("--root", type=str, default=None, help="Path to dataset edf/ folder. Default per --dataset.") parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"]) parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n-seeds", type=int, default=1, help="Number of runs for mean±std. Test set fixed when > 1.") + parser.add_argument("--split-seed", type=int, default=0, help="Fixed seed for initial split (fixes test set when n-seeds > 1).") + parser.add_argument("--seeds", type=str, default=None, help="Comma-separated run seeds. Overrides --seed and --n-seeds.") parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--alpha", type=float, default=0.1, help="Miscoverage rate (e.g., 0.1 => 90% target coverage).") @@ -82,7 +92,15 @@ def parse_args() -> argparse.Namespace: default=5, help="Number of K-means clusters for cluster-specific thresholds.", ) - parser.add_argument("--n-fft", type=int, default=128, help="STFT FFT size used by ContraWR.") + parser.add_argument("--n-fft", type=int, default=128, help="STFT FFT size (ContraWR and TFM-Tokenizer).") + parser.add_argument("--model", type=str, default="contrawr", choices=["contrawr", "tfm"], help="Backbone: contrawr or tfm (TFM-Tokenizer).") + parser.add_argument("--tfm-checkpoint", type=str, default=None, help="Path to TFM checkpoint (full model or tokenizer). Use {seed} for per-seed paths.") + parser.add_argument("--tfm-tokenizer-checkpoint", type=str, default=None, help="Path to pretrained TFM tokenizer (shared). Use with --tfm-classifier-checkpoint for inference.") + parser.add_argument("--tfm-classifier-checkpoint", type=str, default=None, help="Path to finetuned classifier. Use {seed} for per-seed paths.") + parser.add_argument("--tfm-skip-train", action="store_true", help="Skip training; load checkpoint(s) and run calibration + inference only.") + parser.add_argument("--tfm-freeze-tokenizer", action="store_true", help="Freeze tokenizer when fine-tuning; only train classifier.") + parser.add_argument("--tfm-epochs", type=int, default=5, help="Epochs when fine-tuning TFM. Ignored if --tfm-skip-train.") + parser.add_argument("--tfm-lr", type=float, default=1e-4, help="Learning rate when fine-tuning TFM.") parser.add_argument( "--device", type=str, @@ -95,6 +113,7 @@ def parse_args() -> argparse.Namespace: default=None, help="Path to log file. Stdout and stderr are teed to this file.", ) + parser.add_argument("--cache-dir", type=str, default=None, help="Per-job cache dir to avoid races when running 8 in parallel.") parser.add_argument( "--quick-test", action="store_true", @@ -111,89 +130,71 @@ def set_seed(seed: int) -> None: torch.cuda.manual_seed_all(seed) -def main() -> None: - args = parse_args() - set_seed(args.seed) - - orig_stdout, orig_stderr = sys.stdout, sys.stderr - log_file = None - if args.log_file: - log_file = open(args.log_file, "w", encoding="utf-8") - sys.stdout = _Tee(orig_stdout, log_file) - sys.stderr = _Tee(orig_stderr, log_file) - - try: - _run(args) - finally: - if log_file is not None: - sys.stdout = orig_stdout - sys.stderr = orig_stderr - log_file.close() - - -def _run(args: argparse.Namespace) -> None: - device = args.device or ("cuda:0" if torch.cuda.is_available() else "cpu") - root = Path(args.root) - if not root.exists(): - raise FileNotFoundError( - f"TUEV root not found: {root}. " - "Pass --root to point to your downloaded TUEV edf/ directory." - ) - - epochs = 2 if args.quick_test else args.epochs - quick_test_max_samples = 2000 # cap samples so quick-test finishes in ~5-10 min - if args.quick_test: - print("*** QUICK TEST MODE (dev=True, 2 epochs, max 2000 samples) ***") - - print("=" * 80) - print("STEP 1: Load TUEV + build task dataset") - print("=" * 80) - dataset = TUEVDataset(root=str(root), subset=args.subset, dev=args.quick_test) - sample_dataset = dataset.set_task(EEGEventsTUEV(), cache_dir="examples/conformal_eeg/cache") - if args.quick_test and len(sample_dataset) > quick_test_max_samples: - sample_dataset = sample_dataset.subset(range(quick_test_max_samples)) - print(f"Capped to {quick_test_max_samples} samples for quick-test.") - - print(f"Task samples: {len(sample_dataset)}") - print(f"Input schema: {sample_dataset.input_schema}") - print(f"Output schema: {sample_dataset.output_schema}") - - if len(sample_dataset) == 0: - raise RuntimeError("No samples produced. Verify TUEV root/subset/task.") - - print("\n" + "=" * 80) - print("STEP 2: Split train/val/cal/test") - print("=" * 80) - train_ds, val_ds, cal_ds, test_ds = split_by_sample_conformal( - dataset=sample_dataset, ratios=list(args.ratios), seed=args.seed - ) - print(f"Train: {len(train_ds)}") - print(f"Val: {len(val_ds)}") - print(f"Cal: {len(cal_ds)}") - print(f"Test: {len(test_ds)}") - +def _split_remainder_into_train_val_cal(sample_dataset, remainder_indices, ratios, run_seed): + r0, r1, r2, r3 = ratios + remainder_frac = 1.0 - r3 + if remainder_frac <= 0: + raise ValueError("Test ratio must be < 1.") + r_train = r0 / remainder_frac + r_val = r1 / remainder_frac + remainder = np.asarray(remainder_indices, dtype=np.int64) + np.random.seed(run_seed) + shuffled = np.random.permutation(remainder) + M = len(shuffled) + train_end = int(M * r_train) + val_end = int(M * (r_train + r_val)) + train_index = shuffled[:train_end] + val_index = shuffled[train_end:val_end] + cal_index = shuffled[val_end:] + train_ds = sample_dataset.subset(train_index.tolist()) + val_ds = sample_dataset.subset(val_index.tolist()) + cal_ds = sample_dataset.subset(cal_index.tolist()) + return train_ds, val_ds, cal_ds + + +def _run_one_kmeans_cp( + sample_dataset, + train_ds, + val_ds, + cal_ds, + test_loader, + args, + device, + epochs, + run_seed, + task_mode="multiclass", + return_metrics=False, +): train_loader = get_dataloader(train_ds, batch_size=args.batch_size, shuffle=True) val_loader = get_dataloader(val_ds, batch_size=args.batch_size, shuffle=False) if len(val_ds) else None - test_loader = get_dataloader(test_ds, batch_size=args.batch_size, shuffle=False) + model_name = "TFM-Tokenizer" if args.model.lower() == "tfm" else "ContraWR" print("\n" + "=" * 80) - print("STEP 3: Train ContraWR") + print(f"STEP 3: Train {model_name}" if not getattr(args, "tfm_skip_train", False) else f"STEP 3: Load {model_name} (skip train)") print("=" * 80) - model = ContraWR(dataset=sample_dataset, n_fft=args.n_fft).to(device) + model = get_model(args, sample_dataset, device) trainer = Trainer(model=model, device=device, enable_logging=False) + if not getattr(args, "tfm_skip_train", False): + optimizer_params = None + if args.model.lower() == "tfm" and ( + getattr(args, "tfm_checkpoint", None) + or (getattr(args, "tfm_tokenizer_checkpoint", None) and getattr(args, "tfm_classifier_checkpoint", None)) + ): + optimizer_params = {"lr": getattr(args, "tfm_lr", 1e-4)} + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=epochs, + monitor="accuracy" if val_loader is not None else None, + optimizer_params=optimizer_params, + ) - trainer.train( - train_dataloader=train_loader, - val_dataloader=val_loader, - epochs=epochs, - monitor="accuracy" if val_loader is not None else None, - ) - - print("\nBase model performance on test set:") - y_true_base, y_prob_base, _loss_base = trainer.inference(test_loader) - base_metrics = get_metrics_fn("multiclass")(y_true_base, y_prob_base, metrics=["accuracy", "f1_weighted"]) - for metric, value in base_metrics.items(): - print(f" {metric}: {value:.4f}") + if not return_metrics: + print("\nBase model performance on test set:") + y_true_base, y_prob_base, _loss_base = trainer.inference(test_loader) + base_metrics = get_metrics_fn(task_mode)(y_true_base, y_prob_base, metrics=["accuracy", "f1_weighted"]) + for metric, value in base_metrics.items(): + print(f" {metric}: {value:.4f}") print("\n" + "=" * 80) print("STEP 4: K-means Cluster-Based Conformal Prediction (ClusterLabel)") @@ -213,7 +214,7 @@ def _run(args: argparse.Namespace) -> None: model=model, alpha=float(args.alpha), n_clusters=args.n_clusters, - random_state=args.seed, + random_state=run_seed, ) print("Calibrating ClusterLabel predictor (fits K-means and computes cluster-specific thresholds)...") cluster_predictor.calibrate( @@ -223,11 +224,11 @@ def _run(args: argparse.Namespace) -> None: ) print("Evaluating ClusterLabel predictor on test set...") - y_true, y_prob, _loss, extra = Trainer(model=cluster_predictor).inference( + y_true, y_prob, _loss, extra = Trainer(model=cluster_predictor, enable_logging=False).inference( test_loader, additional_outputs=["y_predset"] ) - cluster_metrics = get_metrics_fn("multiclass")( + cluster_metrics = get_metrics_fn(task_mode)( y_true, y_prob, metrics=["accuracy", "miscoverage_ps"], @@ -246,13 +247,247 @@ def _run(args: argparse.Namespace) -> None: miscoverage = float(miscoverage.item() if miscoverage.size == 1 else miscoverage.mean()) else: miscoverage = float(miscoverage) + coverage = 1.0 - miscoverage + + if return_metrics: + return { + "accuracy": float(cluster_metrics["accuracy"]), + "coverage": coverage, + "miscoverage": miscoverage, + "avg_set_size": avg_set_size, + } print("\nClusterLabel Results:") print(f" Accuracy: {cluster_metrics['accuracy']:.4f}") print(f" Empirical miscoverage: {miscoverage:.4f}") - print(f" Empirical coverage: {1 - miscoverage:.4f}") + print(f" Empirical coverage: {coverage:.4f}") print(f" Average set size: {avg_set_size:.2f}") print(f" Number of clusters: {args.n_clusters}") + print("\n--- Single-run summary (for reporting) ---") + print(f" alpha={args.alpha}, target_coverage={1 - args.alpha:.2f}, empirical_coverage={coverage:.4f}, miscoverage={miscoverage:.4f}, accuracy={cluster_metrics['accuracy']:.4f}, avg_set_size={avg_set_size:.2f}") + + +def main() -> None: + args = parse_args() + if args.n_seeds <= 1 and args.seeds is None: + set_seed(args.seed) + + orig_stdout, orig_stderr = sys.stdout, sys.stderr + log_file = None + if args.log_file: + p = Path(args.log_file) + if "/" not in args.log_file and not p.is_absolute(): + Path("logs").mkdir(parents=True, exist_ok=True) + args.log_file = str(Path("logs") / args.log_file) + log_file = open(args.log_file, "w", encoding="utf-8") + sys.stdout = _Tee(orig_stdout, log_file) + sys.stderr = _Tee(orig_stderr, log_file) + + try: + _run(args) + finally: + if log_file is not None: + sys.stdout = orig_stdout + sys.stderr = orig_stderr + log_file.close() + + +def _run(args: argparse.Namespace) -> None: + device = args.device or ("cuda:0" if torch.cuda.is_available() else "cpu") + dataset_name = getattr(args, "dataset", "tuev") + root = Path(args.root or DEFAULT_ROOT[dataset_name]) + # If --root was passed but not --dataset, infer dataset from path so TUAB root => TUAB task + if args.root is not None and "--dataset" not in sys.argv: + root_str = str(root).lower() + if "abnormal" in root_str or "tuab" in root_str: + dataset_name = "tuab" + elif "events" in root_str or "tuev" in root_str: + dataset_name = "tuev" + if not root.exists(): + raise FileNotFoundError(f"Dataset root not found: {root}. Set --root for {dataset_name}.") + + if args.model.lower() == "tfm": + if not getattr(args, "tfm_tokenizer_checkpoint", None): + args.tfm_tokenizer_checkpoint = "/srv/local/data/arjunc4/tfm_tokenizer_last.pth" + if not getattr(args, "tfm_classifier_checkpoint", None): + args.tfm_classifier_checkpoint = ( + "/srv/local/data/arjunc4/TFM_Tokenizer_multiple_finetuned_on_TUAB/TFM_Tokenizer_multiple_finetuned_on_TUAB_{seed}/best_model.pth" + if dataset_name == "tuab" + else "/srv/local/data/arjunc4/TFM_Tokenizer_multiple_finetuned_on_TUEV/TFM_Tokenizer_multiple_finetuned_on_TUEV_{seed}/best_model.pth" + ) + + if args.quick_test: + epochs = 2 + elif args.model.lower() == "tfm" and ( + getattr(args, "tfm_checkpoint", None) + or (getattr(args, "tfm_tokenizer_checkpoint", None) and getattr(args, "tfm_classifier_checkpoint", None)) + ): + epochs = getattr(args, "tfm_epochs", 5) + else: + epochs = args.epochs + quick_test_max_samples = 2000 + if args.quick_test: + print("*** QUICK TEST MODE (dev=True, 2 epochs, max 2000 samples) ***") + + task_mode = "binary" if dataset_name == "tuab" else "multiclass" + print("=" * 80) + print(f"STEP 1: Load {dataset_name.upper()} + build task dataset") + print("=" * 80) + cache_base = getattr(args, "cache_dir", None) + if dataset_name == "tuab": + cache_dir = (cache_base.rstrip("/") + "_tuab") if cache_base else "examples/conformal_eeg/cache_tuab" + dataset = TUABDataset(root=str(root), subset=args.subset, dev=args.quick_test) + sample_dataset = dataset.set_task(EEGAbnormalTUAB(), cache_dir=cache_dir) + else: + cache_dir = cache_base or "examples/conformal_eeg/cache" + dataset = TUEVDataset(root=str(root), subset=args.subset, dev=args.quick_test) + sample_dataset = dataset.set_task( + EEGEventsTUEV(resample_rate=200), cache_dir=cache_dir + ) + if args.quick_test and len(sample_dataset) > quick_test_max_samples: + sample_dataset = sample_dataset.subset(range(quick_test_max_samples)) + print(f"Capped to {quick_test_max_samples} samples for quick-test.") + if args.model.lower() == "tfm": + sample_dataset = AddSTFTDataset(sample_dataset, n_fft=200, hop_length=100) + print("Wrapped dataset with STFT for TFM-Tokenizer.") + + print(f"Task samples: {len(sample_dataset)} (task_mode={task_mode})") + print(f"Input schema: {sample_dataset.input_schema}") + print(f"Output schema: {sample_dataset.output_schema}") + + if len(sample_dataset) == 0: + raise RuntimeError("No samples produced. Verify TUEV root/subset/task.") + + ratios = list(args.ratios) + use_multi_seed = args.n_seeds > 1 or args.seeds is not None + if use_multi_seed: + if args.seeds: + run_seeds = [int(s.strip()) for s in args.seeds.split(",")] + elif getattr(args, "tfm_skip_train", False) and args.model.lower() == "tfm": + run_seeds = list(range(1, 1 + args.n_seeds)) + else: + run_seeds = [args.seed + i for i in range(args.n_seeds)] + n_runs = len(run_seeds) + print(f"\n--- Experiment configuration ---") + print(f" dataset_root: {root}, subset: {args.subset}") + print(f" ratios: train/val/cal/test = {ratios[0]:.2f}/{ratios[1]:.2f}/{ratios[2]:.2f}/{ratios[3]:.2f}") + print(f" alpha: {args.alpha} (target coverage {1 - args.alpha:.0%})") + print(f" multi_seed: n_runs={n_runs}, run_seeds={run_seeds}, split_seed={args.split_seed} (fixed test set)") + + if not use_multi_seed: + if args.model.lower() == "tfm": + ckpt = getattr(args, "tfm_checkpoint", None) + if ckpt and "{seed}" in ckpt: + args.tfm_checkpoint = ckpt.replace("{seed}", str(args.seed)) + clf = getattr(args, "tfm_classifier_checkpoint", None) + if clf and "{seed}" in clf: + args.tfm_classifier_checkpoint = clf.replace("{seed}", str(args.seed)) + print("\n" + "=" * 80) + print("STEP 2: Split train/val/cal/test") + print("=" * 80) + train_ds, val_ds, cal_ds, test_ds = split_by_sample_conformal( + dataset=sample_dataset, ratios=ratios, seed=args.seed + ) + print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Cal: {len(cal_ds)}, Test: {len(test_ds)}") + test_loader = get_dataloader(test_ds, batch_size=args.batch_size, shuffle=False) + _run_one_kmeans_cp( + sample_dataset, + train_ds, + val_ds, + cal_ds, + test_loader, + args, + device, + epochs, + args.seed, + task_mode=task_mode, + return_metrics=False, + ) + print("\n--- Split sizes and seed (for reporting) ---") + print(f" train={len(train_ds)}, val={len(val_ds)}, cal={len(cal_ds)}, test={len(test_ds)}, seed={args.seed}") + return + + print("\n" + "=" * 80) + print("STEP 2: Fix test set (split-seed), then run multiple train/cal splits") + print("=" * 80) + train_idx, val_idx, cal_idx, test_idx = split_by_sample_conformal( + dataset=sample_dataset, ratios=ratios, seed=args.split_seed, get_index=True + ) + train_index = train_idx.numpy() if hasattr(train_idx, "numpy") else np.array(train_idx) + val_index = val_idx.numpy() if hasattr(val_idx, "numpy") else np.array(val_idx) + cal_index = cal_idx.numpy() if hasattr(cal_idx, "numpy") else np.array(cal_idx) + test_index = test_idx.numpy() if hasattr(test_idx, "numpy") else np.array(test_idx) + remainder_indices = np.concatenate([train_index, val_index, cal_index]) + test_ds = sample_dataset.subset(test_index.tolist()) + test_loader = get_dataloader(test_ds, batch_size=args.batch_size, shuffle=False) + n_test = len(test_ds) + print(f"Fixed test set size: {n_test}") + + accs, coverages, miscoverages, set_sizes = [], [], [], [] + tfm_ckpt_original = getattr(args, "tfm_checkpoint", None) + tfm_classifier_ckpt_original = getattr(args, "tfm_classifier_checkpoint", None) + for run_i, run_seed in enumerate(run_seeds): + print("\n" + "=" * 80) + print(f"Run {run_i + 1} / {n_runs} (seed={run_seed})") + print("=" * 80) + if tfm_ckpt_original and "{seed}" in tfm_ckpt_original: + args.tfm_checkpoint = tfm_ckpt_original.replace("{seed}", str(run_seed)) + if tfm_classifier_ckpt_original and "{seed}" in tfm_classifier_ckpt_original: + args.tfm_classifier_checkpoint = tfm_classifier_ckpt_original.replace("{seed}", str(run_seed)) + set_seed(run_seed) + train_ds, val_ds, cal_ds = _split_remainder_into_train_val_cal( + sample_dataset, remainder_indices, ratios, run_seed + ) + print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Cal: {len(cal_ds)}") + m = _run_one_kmeans_cp( + sample_dataset, + train_ds, + val_ds, + cal_ds, + test_loader, + args, + device, + epochs, + run_seed, + task_mode=task_mode, + return_metrics=True, + ) + if tfm_ckpt_original and "{seed}" in tfm_ckpt_original: + args.tfm_checkpoint = tfm_ckpt_original + if tfm_classifier_ckpt_original and "{seed}" in tfm_classifier_ckpt_original: + args.tfm_classifier_checkpoint = tfm_classifier_ckpt_original + accs.append(m["accuracy"]) + coverages.append(m["coverage"]) + miscoverages.append(m["miscoverage"]) + set_sizes.append(m["avg_set_size"]) + + accs = np.array(accs) + coverages = np.array(coverages) + miscoverages_arr = np.array(miscoverages) + set_sizes = np.array(set_sizes) + + print("\n" + "=" * 80) + print("Per-run ClusterLabel results (fixed test set)") + print("=" * 80) + print(f" {'Run':<4} {'Seed':<6} {'Accuracy':<10} {'Coverage':<10} {'Miscoverage':<12} {'Avg set size':<12}") + print(" " + "-" * 54) + for i in range(n_runs): + print(f" {i+1:<4} {run_seeds[i]:<6} {accs[i]:<10.4f} {coverages[i]:<10.4f} {miscoverages_arr[i]:<12.4f} {set_sizes[i]:<12.2f}") + + print("\n" + "=" * 80) + print("ClusterLabel summary (mean ± std over {} runs, fixed test set)".format(n_runs)) + print("=" * 80) + print(f" Accuracy: {accs.mean():.4f} ± {accs.std():.4f}") + print(f" Empirical coverage: {coverages.mean():.4f} ± {coverages.std():.4f}") + print(f" Empirical miscoverage: {miscoverages_arr.mean():.4f} ± {miscoverages_arr.std():.4f}") + print(f" Average set size: {set_sizes.mean():.2f} ± {set_sizes.std():.2f}") + print(f" Target coverage: {1 - args.alpha:.0%} (alpha={args.alpha})") + print(f" Test set size: {n_test} (fixed across runs)") + print(f" Run seeds: {run_seeds}") + print("\n--- Min / Max (across runs) ---") + print(f" Coverage: [{coverages.min():.4f}, {coverages.max():.4f}]") + print(f" Set size: [{set_sizes.min():.2f}, {set_sizes.max():.2f}]") + print(f" Accuracy: [{accs.min():.4f}, {accs.max():.4f}]") if __name__ == "__main__": diff --git a/examples/conformal_eeg/tuev_naive_cp_conformal.py b/examples/conformal_eeg/tuev_naive_cp_conformal.py new file mode 100644 index 000000000..6e21e6b45 --- /dev/null +++ b/examples/conformal_eeg/tuev_naive_cp_conformal.py @@ -0,0 +1,390 @@ +"""Naive (split) Conformal Prediction (BaseConformal) on TUEV EEG Events using ContraWR. + +Baseline: standard split conformal prediction with a single threshold on the +calibration set. No covariate or neighborhood correction. + +With --n-seeds > 1: fixes the test set (--split-seed), runs multiple training runs +with different seeds, reports coverage / set size / accuracy as mean ± std. + +Example (from repo root): + python examples/conformal_eeg/tuev_naive_cp_conformal.py --alpha 0.1 --n-seeds 5 --split-seed 0 --log-file naive_cp_alpha01_seeds5.log + +Run for PI baselines (alpha=0.2, 0.1, 0.05 with error bars): + for a in 0.2 0.1 0.05; do python examples/conformal_eeg/tuev_naive_cp_conformal.py --alpha $a --n-seeds 5 --split-seed 0 --log-file naive_cp_alpha${a}_seeds5.log; done + Or in parallel on 3 GPUs: same with CUDA_VISIBLE_DEVICES=0/1/2 and & wait. + +TFM inference (pretrained tokenizer + 5 finetuned classifier checkpoints, same protocol as ContraWR): + TUEV: --dataset tuev --model tfm --tfm-tokenizer-checkpoint .../Pretrained_tfm_tokenizer_2x2x8/tfm_tokenizer_last.pth --tfm-classifier-checkpoint .../TFM_Tokenizer_multiple_finetuned_on_TUEV_{seed}/best_model.pth --tfm-skip-train --seeds 1,2,3,4,5 --alpha 0.1 + TUAB: --dataset tuab --model tfm --tfm-tokenizer-checkpoint .../Pretrained_tfm_tokenizer_2x2x8/tfm_tokenizer_last.pth --tfm-classifier-checkpoint .../TFM_Tokenizer_multiple_finetuned_on_TUAB_{seed}/best_model.pth --tfm-skip-train --seeds 1,2,3,4,5 --alpha 0.1 +""" + +from __future__ import annotations + +import argparse +import random +import sys +from pathlib import Path + +_script_dir = Path(__file__).resolve().parent +if str(_script_dir) not in sys.path: + sys.path.insert(0, str(_script_dir)) + +import numpy as np +import torch + + +class _Tee: + def __init__(self, stream, file): + self._stream = stream + self._file = file + + def write(self, data): + self._stream.write(data) + self._file.write(data) + self._file.flush() + + def flush(self): + self._stream.flush() + self._file.flush() + + +from pyhealth.calib.predictionset.base_conformal import BaseConformal +from pyhealth.datasets import TUEVDataset, TUABDataset, get_dataloader, split_by_sample_conformal +from pyhealth.tasks import EEGEventsTUEV, EEGAbnormalTUAB +from pyhealth.trainer import Trainer, get_metrics_fn + +from model_utils import AddSTFTDataset, get_model + +# Default roots when --root is not set (TUEV vs TUAB) +DEFAULT_ROOT = { + "tuev": "/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf", + "tuab": "/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf", +} + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Naive (split) conformal prediction (BaseConformal) on TUEV/TUAB EEG using ContraWR or TFM." + ) + parser.add_argument("--dataset", type=str, default="tuev", choices=["tuev", "tuab"], help="EEG dataset: tuev (events, 6-class) or tuab (abnormal, binary).") + parser.add_argument("--root", type=str, default=None, help="Path to dataset edf/ folder. Default: TUEV or TUAB path per --dataset.") + parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"]) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n-seeds", type=int, default=1, help="Number of runs for mean±std. Test set fixed when > 1.") + parser.add_argument("--split-seed", type=int, default=0, help="Fixed seed for initial split (fixes test set when n-seeds > 1).") + parser.add_argument("--seeds", type=str, default=None, help="Comma-separated run seeds. Overrides --seed and --n-seeds.") + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--epochs", type=int, default=20) + parser.add_argument("--alpha", type=float, default=0.1, help="Miscoverage rate (e.g. 0.1 => 90%% target coverage).") + parser.add_argument("--ratios", type=float, nargs=4, default=(0.6, 0.1, 0.15, 0.15), metavar=("TRAIN", "VAL", "CAL", "TEST")) + parser.add_argument("--n-fft", type=int, default=128) + parser.add_argument("--model", type=str, default="contrawr", choices=["contrawr", "tfm"], help="Backbone: contrawr or tfm (TFM-Tokenizer).") + parser.add_argument("--tfm-checkpoint", type=str, default=None, help="Path to TFM checkpoint (full model or tokenizer). Use {seed} for per-seed paths.") + parser.add_argument("--tfm-tokenizer-checkpoint", type=str, default=None, help="Path to pretrained TFM tokenizer (shared). Use with --tfm-classifier-checkpoint for inference.") + parser.add_argument("--tfm-classifier-checkpoint", type=str, default=None, help="Path to finetuned classifier. Use {seed} for per-seed paths, e.g. .../finetuned_TUEV_{seed}/best_model.pth.") + parser.add_argument("--tfm-skip-train", action="store_true", help="Skip training; load checkpoint(s) and run calibration + inference only.") + parser.add_argument("--tfm-freeze-tokenizer", action="store_true", help="Freeze tokenizer when fine-tuning; only train classifier.") + parser.add_argument("--tfm-epochs", type=int, default=5, help="Epochs when fine-tuning TFM. Ignored if --tfm-skip-train.") + parser.add_argument("--tfm-lr", type=float, default=1e-4, help="Learning rate when fine-tuning TFM.") + parser.add_argument("--device", type=str, default=None) + parser.add_argument("--log-file", type=str, default=None) + parser.add_argument("--cache-dir", type=str, default=None, help="Per-job cache dir to avoid races when running 8 in parallel. Default: examples/conformal_eeg/cache or cache_tuab.") + parser.add_argument("--quick-test", action="store_true", help="dev=True, max 2000 samples, 2 epochs.") + return parser.parse_args() + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def _split_remainder_into_train_val_cal(sample_dataset, remainder_indices, ratios, run_seed): + r0, r1, r2, r3 = ratios + remainder_frac = 1.0 - r3 + if remainder_frac <= 0: + raise ValueError("Test ratio must be < 1.") + r_train = r0 / remainder_frac + r_val = r1 / remainder_frac + remainder = np.asarray(remainder_indices, dtype=np.int64) + np.random.seed(run_seed) + shuffled = np.random.permutation(remainder) + M = len(shuffled) + train_end = int(M * r_train) + val_end = int(M * (r_train + r_val)) + train_index = shuffled[:train_end] + val_index = shuffled[train_end:val_end] + cal_index = shuffled[val_end:] + train_ds = sample_dataset.subset(train_index.tolist()) + val_ds = sample_dataset.subset(val_index.tolist()) + cal_ds = sample_dataset.subset(cal_index.tolist()) + return train_ds, val_ds, cal_ds + + +def _run_one_naive_cp(sample_dataset, train_ds, val_ds, cal_ds, test_loader, args, device, epochs, task_mode="multiclass", return_metrics=False): + train_loader = get_dataloader(train_ds, batch_size=args.batch_size, shuffle=True) + val_loader = get_dataloader(val_ds, batch_size=args.batch_size, shuffle=False) if len(val_ds) else None + + model_name = "TFM-Tokenizer" if args.model.lower() == "tfm" else "ContraWR" + print("\n" + "=" * 80) + print(f"STEP 3: Train {model_name}" if not getattr(args, "tfm_skip_train", False) else f"STEP 3: Load {model_name} (skip train)") + print("=" * 80) + model = get_model(args, sample_dataset, device) + trainer = Trainer(model=model, device=device, enable_logging=False) + if not getattr(args, "tfm_skip_train", False): + optimizer_params = None + if args.model.lower() == "tfm" and ( + getattr(args, "tfm_checkpoint", None) + or (getattr(args, "tfm_tokenizer_checkpoint", None) and getattr(args, "tfm_classifier_checkpoint", None)) + ): + optimizer_params = {"lr": getattr(args, "tfm_lr", 1e-4)} + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=epochs, + monitor="accuracy" if val_loader is not None else None, + optimizer_params=optimizer_params, + ) + + if not return_metrics: + print("\nBase model performance on test set:") + y_true_base, y_prob_base, _ = trainer.inference(test_loader) + base_metrics = get_metrics_fn(task_mode)(y_true_base, y_prob_base, metrics=["accuracy", "f1_weighted"]) + for k, v in base_metrics.items(): + print(f" {k}: {v:.4f}") + + print("\n" + "=" * 80) + print("STEP 4: Naive Conformal Prediction (BaseConformal)") + print("=" * 80) + print(f"Target miscoverage alpha: {args.alpha} (target coverage {1 - args.alpha:.0%})") + + predictor = BaseConformal(model=model, alpha=float(args.alpha)) + predictor.calibrate(cal_dataset=cal_ds) + + y_true, y_prob, _loss, extra = Trainer(model=predictor, enable_logging=False).inference(test_loader, additional_outputs=["y_predset"]) + metrics = get_metrics_fn(task_mode)(y_true, y_prob, metrics=["accuracy", "miscoverage_ps"], y_predset=extra["y_predset"]) + predset = extra["y_predset"] + predset_t = torch.tensor(predset) if isinstance(predset, np.ndarray) else predset + avg_set_size = predset_t.float().sum(dim=1).mean().item() + miscoverage = metrics["miscoverage_ps"] + if isinstance(miscoverage, np.ndarray): + miscoverage = float(miscoverage.item() if miscoverage.size == 1 else miscoverage.mean()) + else: + miscoverage = float(miscoverage) + coverage = 1.0 - miscoverage + + if return_metrics: + return {"accuracy": float(metrics["accuracy"]), "coverage": coverage, "miscoverage": miscoverage, "avg_set_size": avg_set_size} + + print("\nNaive CP (BaseConformal) Results:") + print(f" Accuracy: {metrics['accuracy']:.4f}") + print(f" Empirical miscoverage: {miscoverage:.4f}") + print(f" Empirical coverage: {coverage:.4f}") + print(f" Average set size: {avg_set_size:.2f}") + print("\n--- Single-run summary (for reporting) ---") + print(f" alpha={args.alpha}, target_coverage={1 - args.alpha:.2f}, empirical_coverage={coverage:.4f}, miscoverage={miscoverage:.4f}, accuracy={metrics['accuracy']:.4f}, avg_set_size={avg_set_size:.2f}") + + +def main() -> None: + args = parse_args() + if args.n_seeds <= 1 and args.seeds is None: + set_seed(args.seed) + + orig_stdout, orig_stderr = sys.stdout, sys.stderr + log_file = None + if args.log_file: + # If bare filename (no path), write under repo root logs/ + p = Path(args.log_file) + if "/" not in args.log_file and not p.is_absolute(): + Path("logs").mkdir(parents=True, exist_ok=True) + args.log_file = str(Path("logs") / args.log_file) + log_file = open(args.log_file, "w", encoding="utf-8") + sys.stdout = _Tee(orig_stdout, log_file) + sys.stderr = _Tee(orig_stderr, log_file) + try: + _run(args) + finally: + if log_file is not None: + sys.stdout, sys.stderr = orig_stdout, orig_stderr + log_file.close() + + +# Default TFM paths (used when --tfm-tokenizer-checkpoint / --tfm-classifier-checkpoint not set) +DEFAULT_TFM_TOKENIZER = "/srv/local/data/arjunc4/tfm_tokenizer_last.pth" +DEFAULT_TFM_CLF_TUEV = "/srv/local/data/arjunc4/TFM_Tokenizer_multiple_finetuned_on_TUEV/TFM_Tokenizer_multiple_finetuned_on_TUEV_{seed}/best_model.pth" +DEFAULT_TFM_CLF_TUAB = "/srv/local/data/arjunc4/TFM_Tokenizer_multiple_finetuned_on_TUAB/TFM_Tokenizer_multiple_finetuned_on_TUAB_{seed}/best_model.pth" + + +def _run(args: argparse.Namespace) -> None: + device = args.device or ("cuda:0" if torch.cuda.is_available() else "cpu") + dataset_name = getattr(args, "dataset", "tuev") + root = Path(args.root or DEFAULT_ROOT[dataset_name]) + # If --root was passed but not --dataset, infer dataset from path so TUAB root => TUAB task + if args.root is not None and "--dataset" not in sys.argv: + root_str = str(root).lower() + if "abnormal" in root_str or "tuab" in root_str: + dataset_name = "tuab" + elif "events" in root_str or "tuev" in root_str: + dataset_name = "tuev" + if not root.exists(): + raise FileNotFoundError(f"Dataset root not found: {root}. Set --root for {dataset_name}.") + + if getattr(args, "model", "contrawr").lower() == "tfm": + if not getattr(args, "tfm_tokenizer_checkpoint", None): + args.tfm_tokenizer_checkpoint = DEFAULT_TFM_TOKENIZER + if not getattr(args, "tfm_classifier_checkpoint", None): + args.tfm_classifier_checkpoint = DEFAULT_TFM_CLF_TUAB if dataset_name == "tuab" else DEFAULT_TFM_CLF_TUEV + + # Same protocol (seeds, alpha, ratios, split_seed) for ContraWR and TFM for comparable results. + if args.quick_test: + epochs = 2 + elif args.model.lower() == "tfm" and ( + getattr(args, "tfm_checkpoint", None) + or (getattr(args, "tfm_tokenizer_checkpoint", None) and getattr(args, "tfm_classifier_checkpoint", None)) + ): + epochs = getattr(args, "tfm_epochs", 5) + else: + epochs = args.epochs + quick_test_max = 2000 + if args.quick_test: + print("*** QUICK TEST MODE ***") + + task_mode = "binary" if dataset_name == "tuab" else "multiclass" + print("=" * 80) + print(f"STEP 1: Load {dataset_name.upper()} + build task dataset") + print("=" * 80) + cache_base = getattr(args, "cache_dir", None) + if dataset_name == "tuab": + cache_dir = (cache_base.rstrip("/") + "_tuab") if cache_base else "examples/conformal_eeg/cache_tuab" + dataset = TUABDataset(root=str(root), subset=args.subset, dev=args.quick_test) + sample_dataset = dataset.set_task(EEGAbnormalTUAB(), cache_dir=cache_dir) + else: + cache_dir = cache_base or "examples/conformal_eeg/cache" + dataset = TUEVDataset(root=str(root), subset=args.subset, dev=args.quick_test) + sample_dataset = dataset.set_task( + EEGEventsTUEV(resample_rate=200), cache_dir=cache_dir + ) + if args.quick_test and len(sample_dataset) > quick_test_max: + sample_dataset = sample_dataset.subset(range(quick_test_max)) + print(f"Capped to {quick_test_max} samples.") + if args.model.lower() == "tfm": + # TFM tokenizer needs n_fft=200, hop_length=100 so STFT time = temporal patches + sample_dataset = AddSTFTDataset( + sample_dataset, n_fft=200, hop_length=100 + ) + print("Wrapped dataset with STFT for TFM-Tokenizer.") + print(f"Task samples: {len(sample_dataset)} (task_mode={task_mode})") + + print("\n--- Experiment configuration ---") + print(f" dataset: {dataset_name}, dataset_root: {root}, subset: {args.subset}") + print(f" ratios: train/val/cal/test = {args.ratios[0]:.2f}/{args.ratios[1]:.2f}/{args.ratios[2]:.2f}/{args.ratios[3]:.2f}") + print(f" alpha: {args.alpha} (target coverage {1 - args.alpha:.0%})") + print(f" epochs: {epochs}, batch_size: {args.batch_size}, device: {device}, seed: {args.seed}") + + if len(sample_dataset) == 0: + raise RuntimeError("No samples.") + + ratios = list(args.ratios) + use_multi_seed = args.n_seeds > 1 or args.seeds is not None + if use_multi_seed: + if args.seeds: + run_seeds = [int(s.strip()) for s in args.seeds.split(",")] + elif getattr(args, "tfm_skip_train", False) and args.model.lower() == "tfm": + run_seeds = list(range(1, 1 + args.n_seeds)) + else: + run_seeds = [args.seed + i for i in range(args.n_seeds)] + n_runs = len(run_seeds) + print(f" multi_seed: n_runs={n_runs}, run_seeds={run_seeds}, split_seed={args.split_seed} (fixed test set)") + print(f"Multi-seed mode: {n_runs} runs (fixed test set), run seeds: {run_seeds}") + + if not use_multi_seed: + # Single run: substitute {seed} in TFM checkpoint paths so loading works + if args.model.lower() == "tfm": + ckpt = getattr(args, "tfm_checkpoint", None) + if ckpt and "{seed}" in ckpt: + args.tfm_checkpoint = ckpt.replace("{seed}", str(args.seed)) + clf = getattr(args, "tfm_classifier_checkpoint", None) + if clf and "{seed}" in clf: + args.tfm_classifier_checkpoint = clf.replace("{seed}", str(args.seed)) + print("\n" + "=" * 80) + print("STEP 2: Split train/val/cal/test") + print("=" * 80) + train_ds, val_ds, cal_ds, test_ds = split_by_sample_conformal(dataset=sample_dataset, ratios=ratios, seed=args.seed) + print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Cal: {len(cal_ds)}, Test: {len(test_ds)}") + test_loader = get_dataloader(test_ds, batch_size=args.batch_size, shuffle=False) + _run_one_naive_cp(sample_dataset, train_ds, val_ds, cal_ds, test_loader, args, device, epochs, task_mode=task_mode) + print("\n--- Split sizes and seed (for reporting) ---") + print(f" train={len(train_ds)}, val={len(val_ds)}, cal={len(cal_ds)}, test={len(test_ds)}, seed={args.seed}") + return + + print("\n" + "=" * 80) + print("STEP 2: Fix test set (split-seed), then run multiple train/cal splits") + print("=" * 80) + train_idx, val_idx, cal_idx, test_idx = split_by_sample_conformal(dataset=sample_dataset, ratios=ratios, seed=args.split_seed, get_index=True) + train_index = train_idx.numpy() if hasattr(train_idx, "numpy") else np.array(train_idx) + val_index = val_idx.numpy() if hasattr(val_idx, "numpy") else np.array(val_idx) + cal_index = cal_idx.numpy() if hasattr(cal_idx, "numpy") else np.array(cal_idx) + test_index = test_idx.numpy() if hasattr(test_idx, "numpy") else np.array(test_idx) + remainder_indices = np.concatenate([train_index, val_index, cal_index]) + test_ds = sample_dataset.subset(test_index.tolist()) + test_loader = get_dataloader(test_ds, batch_size=args.batch_size, shuffle=False) + n_test = len(test_ds) + print(f"Fixed test set size: {n_test}") + + accs, coverages, miscoverages, set_sizes = [], [], [], [] + tfm_ckpt_original = getattr(args, "tfm_checkpoint", None) + tfm_classifier_ckpt_original = getattr(args, "tfm_classifier_checkpoint", None) + for run_i, run_seed in enumerate(run_seeds): + print("\n" + "=" * 80) + print(f"Run {run_i + 1} / {n_runs} (seed={run_seed})") + print("=" * 80) + if tfm_ckpt_original and "{seed}" in tfm_ckpt_original: + args.tfm_checkpoint = tfm_ckpt_original.replace("{seed}", str(run_seed)) + if tfm_classifier_ckpt_original and "{seed}" in tfm_classifier_ckpt_original: + args.tfm_classifier_checkpoint = tfm_classifier_ckpt_original.replace("{seed}", str(run_seed)) + set_seed(run_seed) + train_ds, val_ds, cal_ds = _split_remainder_into_train_val_cal(sample_dataset, remainder_indices, ratios, run_seed) + print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Cal: {len(cal_ds)}") + m = _run_one_naive_cp(sample_dataset, train_ds, val_ds, cal_ds, test_loader, args, device, epochs, task_mode=task_mode, return_metrics=True) + if tfm_ckpt_original and "{seed}" in tfm_ckpt_original: + args.tfm_checkpoint = tfm_ckpt_original + if tfm_classifier_ckpt_original and "{seed}" in tfm_classifier_ckpt_original: + args.tfm_classifier_checkpoint = tfm_classifier_ckpt_original + accs.append(m["accuracy"]) + coverages.append(m["coverage"]) + miscoverages.append(m["miscoverage"]) + set_sizes.append(m["avg_set_size"]) + + accs = np.array(accs) + coverages = np.array(coverages) + miscoverages_arr = np.array(miscoverages) + set_sizes = np.array(set_sizes) + + print("\n" + "=" * 80) + print("Per-run Naive CP results (fixed test set)") + print("=" * 80) + print(f" {'Run':<4} {'Seed':<6} {'Accuracy':<10} {'Coverage':<10} {'Miscoverage':<12} {'Avg set size':<12}") + print(" " + "-" * 54) + for i in range(n_runs): + print(f" {i+1:<4} {run_seeds[i]:<6} {accs[i]:<10.4f} {coverages[i]:<10.4f} {miscoverages_arr[i]:<12.4f} {set_sizes[i]:<12.2f}") + + print("\n" + "=" * 80) + print("Naive CP summary (mean ± std over {} runs, fixed test set)".format(n_runs)) + print("=" * 80) + print(f" Accuracy: {accs.mean():.4f} ± {accs.std():.4f}") + print(f" Empirical coverage: {coverages.mean():.4f} ± {coverages.std():.4f}") + print(f" Empirical miscoverage: {miscoverages_arr.mean():.4f} ± {miscoverages_arr.std():.4f}") + print(f" Average set size: {set_sizes.mean():.2f} ± {set_sizes.std():.2f}") + print(f" Target coverage: {1 - args.alpha:.0%} (alpha={args.alpha})") + print(f" Test set size: {n_test} (fixed across runs)") + print(f" Run seeds: {run_seeds}") + print("\n--- Min / Max (across runs) ---") + print(f" Coverage: [{coverages.min():.4f}, {coverages.max():.4f}]") + print(f" Set size: [{set_sizes.min():.2f}, {set_sizes.max():.2f}]") + print(f" Accuracy: [{accs.min():.4f}, {accs.max():.4f}]") + + +if __name__ == "__main__": + main() diff --git a/examples/conformal_eeg/tuev_ncp_conformal.py b/examples/conformal_eeg/tuev_ncp_conformal.py index 9b51a7756..fdabacfff 100644 --- a/examples/conformal_eeg/tuev_ncp_conformal.py +++ b/examples/conformal_eeg/tuev_ncp_conformal.py @@ -24,6 +24,10 @@ import sys from pathlib import Path +_script_dir = Path(__file__).resolve().parent +if str(_script_dir) not in sys.path: + sys.path.insert(0, str(_script_dir)) + import numpy as np import torch @@ -47,22 +51,21 @@ def flush(self): from pyhealth.calib.predictionset.cluster import NeighborhoodLabel from pyhealth.calib.utils import extract_embeddings -from pyhealth.datasets import TUEVDataset, get_dataloader, split_by_sample_conformal -from pyhealth.models import ContraWR -from pyhealth.tasks import EEGEventsTUEV +from pyhealth.datasets import TUEVDataset, TUABDataset, get_dataloader, split_by_sample_conformal +from pyhealth.tasks import EEGEventsTUEV, EEGAbnormalTUAB from pyhealth.trainer import Trainer, get_metrics_fn +from model_utils import AddSTFTDataset, get_model + +DEFAULT_ROOT = {"tuev": "/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf", "tuab": "/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf"} + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( - description="Neighborhood conformal prediction (NCP) on TUEV EEG events using ContraWR." - ) - parser.add_argument( - "--root", - type=str, - default="/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf", - help="Path to TUEV edf/ folder.", + description="Neighborhood conformal prediction (NCP) on TUEV/TUAB EEG using ContraWR or TFM." ) + parser.add_argument("--dataset", type=str, default="tuev", choices=["tuev", "tuab"], help="EEG dataset: tuev or tuab.") + parser.add_argument("--root", type=str, default=None, help="Path to dataset edf/ folder. Default per --dataset.") parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"]) parser.add_argument("--seed", type=int, default=42, help="Run seed (or first of run seeds when n-seeds > 1).") parser.add_argument( @@ -106,7 +109,15 @@ def parse_args() -> argparse.Namespace: default=100.0, help="Temperature for NCP exponential weights; smaller => more localization.", ) - parser.add_argument("--n-fft", type=int, default=128, help="STFT FFT size used by ContraWR.") + parser.add_argument("--n-fft", type=int, default=128, help="STFT FFT size (ContraWR and TFM-Tokenizer).") + parser.add_argument("--model", type=str, default="contrawr", choices=["contrawr", "tfm"], help="Backbone: contrawr or tfm (TFM-Tokenizer).") + parser.add_argument("--tfm-checkpoint", type=str, default=None, help="Path to TFM checkpoint (full model or tokenizer). Use {seed} for per-seed paths.") + parser.add_argument("--tfm-tokenizer-checkpoint", type=str, default=None, help="Path to pretrained TFM tokenizer (shared). Use with --tfm-classifier-checkpoint for inference.") + parser.add_argument("--tfm-classifier-checkpoint", type=str, default=None, help="Path to finetuned classifier. Use {seed} for per-seed paths.") + parser.add_argument("--tfm-skip-train", action="store_true", help="Skip training; load checkpoint(s) and run calibration + inference only.") + parser.add_argument("--tfm-freeze-tokenizer", action="store_true", help="Freeze tokenizer when fine-tuning; only train classifier.") + parser.add_argument("--tfm-epochs", type=int, default=5, help="Epochs when fine-tuning TFM. Ignored if --tfm-skip-train.") + parser.add_argument("--tfm-lr", type=float, default=1e-4, help="Learning rate when fine-tuning TFM.") parser.add_argument( "--device", type=str, @@ -119,6 +130,7 @@ def parse_args() -> argparse.Namespace: default=None, help="Path to log file. Stdout and stderr are teed to this file.", ) + parser.add_argument("--cache-dir", type=str, default=None, help="Per-job cache dir to avoid races when running 8 in parallel.") parser.add_argument( "--quick-test", action="store_true", @@ -168,6 +180,7 @@ def _run_one_ncp( args, device, epochs, + task_mode="multiclass", return_metrics=False, ): """Train ContraWR, calibrate NCP, evaluate on test. Optionally return metrics dict for aggregation.""" @@ -175,21 +188,30 @@ def _run_one_ncp( val_loader = get_dataloader(val_ds, batch_size=args.batch_size, shuffle=False) if len(val_ds) else None print("\n" + "=" * 80) - print("STEP 3: Train ContraWR") + model_name = "TFM-Tokenizer" if args.model.lower() == "tfm" else "ContraWR" + print(f"STEP 3: Train {model_name}" if not getattr(args, "tfm_skip_train", False) else f"STEP 3: Load {model_name} (skip train)") print("=" * 80) - model = ContraWR(dataset=sample_dataset, n_fft=args.n_fft).to(device) + model = get_model(args, sample_dataset, device) trainer = Trainer(model=model, device=device, enable_logging=False) - trainer.train( - train_dataloader=train_loader, - val_dataloader=val_loader, - epochs=epochs, - monitor="accuracy" if val_loader is not None else None, - ) + if not getattr(args, "tfm_skip_train", False): + optimizer_params = None + if args.model.lower() == "tfm" and ( + getattr(args, "tfm_checkpoint", None) + or (getattr(args, "tfm_tokenizer_checkpoint", None) and getattr(args, "tfm_classifier_checkpoint", None)) + ): + optimizer_params = {"lr": getattr(args, "tfm_lr", 1e-4)} + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=epochs, + monitor="accuracy" if val_loader is not None else None, + optimizer_params=optimizer_params, + ) if not return_metrics: print("\nBase model performance on test set:") y_true_base, y_prob_base, _loss_base = trainer.inference(test_loader) - base_metrics = get_metrics_fn("multiclass")( + base_metrics = get_metrics_fn(task_mode)( y_true_base, y_prob_base, metrics=["accuracy", "f1_weighted"] ) for metric, value in base_metrics.items(): @@ -201,6 +223,7 @@ def _run_one_ncp( print(f"Target miscoverage alpha: {args.alpha} (target coverage {1 - args.alpha:.0%})") print(f"k_neighbors: {args.k_neighbors}, lambda_L: {args.lambda_L}") + print("Extracting calibration embeddings...") cal_embeddings = extract_embeddings(model, cal_ds, batch_size=args.batch_size, device=device) if not return_metrics: print(f" cal_embeddings shape: {cal_embeddings.shape}") @@ -213,10 +236,10 @@ def _run_one_ncp( ) ncp_predictor.calibrate(cal_dataset=cal_ds, cal_embeddings=cal_embeddings) - y_true, y_prob, _loss, extra = Trainer(model=ncp_predictor).inference( + y_true, y_prob, _loss, extra = Trainer(model=ncp_predictor, enable_logging=False).inference( test_loader, additional_outputs=["y_predset"] ) - ncp_metrics = get_metrics_fn("multiclass")( + ncp_metrics = get_metrics_fn(task_mode)( y_true, y_prob, metrics=["accuracy", "miscoverage_ps"], y_predset=extra["y_predset"] ) predset = extra["y_predset"] @@ -259,6 +282,10 @@ def main() -> None: orig_stdout, orig_stderr = sys.stdout, sys.stderr log_file = None if args.log_file: + p = Path(args.log_file) + if "/" not in args.log_file and not p.is_absolute(): + Path("logs").mkdir(parents=True, exist_ok=True) + args.log_file = str(Path("logs") / args.log_file) log_file = open(args.log_file, "w", encoding="utf-8") sys.stdout = _Tee(orig_stdout, log_file) sys.stderr = _Tee(orig_stderr, log_file) @@ -274,28 +301,67 @@ def main() -> None: def _run(args: argparse.Namespace) -> None: device = args.device or ("cuda:0" if torch.cuda.is_available() else "cpu") - root = Path(args.root) + dataset_name = getattr(args, "dataset", "tuev") + root = Path(args.root or DEFAULT_ROOT[dataset_name]) + # If --root was passed but not --dataset, infer dataset from path so TUAB root => TUAB task + if args.root is not None and "--dataset" not in sys.argv: + root_str = str(root).lower() + if "abnormal" in root_str or "tuab" in root_str: + dataset_name = "tuab" + elif "events" in root_str or "tuev" in root_str: + dataset_name = "tuev" if not root.exists(): - raise FileNotFoundError( - f"TUEV root not found: {root}. " - "Pass --root to point to your downloaded TUEV edf/ directory." - ) + raise FileNotFoundError(f"Dataset root not found: {root}. Set --root for {dataset_name}.") + + if args.model.lower() == "tfm": + if not getattr(args, "tfm_tokenizer_checkpoint", None): + args.tfm_tokenizer_checkpoint = "/srv/local/data/arjunc4/tfm_tokenizer_last.pth" + if not getattr(args, "tfm_classifier_checkpoint", None): + args.tfm_classifier_checkpoint = ( + "/srv/local/data/arjunc4/TFM_Tokenizer_multiple_finetuned_on_TUAB/TFM_Tokenizer_multiple_finetuned_on_TUAB_{seed}/best_model.pth" + if dataset_name == "tuab" + else "/srv/local/data/arjunc4/TFM_Tokenizer_multiple_finetuned_on_TUEV/TFM_Tokenizer_multiple_finetuned_on_TUEV_{seed}/best_model.pth" + ) - epochs = 2 if args.quick_test else args.epochs - quick_test_max_samples = 2000 # cap samples so quick-test finishes in ~5-10 min + if args.quick_test: + epochs = 2 + elif args.model.lower() == "tfm" and ( + getattr(args, "tfm_checkpoint", None) + or (getattr(args, "tfm_tokenizer_checkpoint", None) and getattr(args, "tfm_classifier_checkpoint", None)) + ): + epochs = getattr(args, "tfm_epochs", 5) + else: + epochs = args.epochs + quick_test_max_samples = 2000 if args.quick_test: print("*** QUICK TEST MODE (dev=True, 2 epochs, max 2000 samples) ***") + task_mode = "binary" if dataset_name == "tuab" else "multiclass" print("=" * 80) - print("STEP 1: Load TUEV + build task dataset") + print(f"STEP 1: Load {dataset_name.upper()} + build task dataset") print("=" * 80) - dataset = TUEVDataset(root=str(root), subset=args.subset, dev=args.quick_test) - sample_dataset = dataset.set_task(EEGEventsTUEV(), cache_dir="examples/conformal_eeg/cache") + cache_base = getattr(args, "cache_dir", None) + if dataset_name == "tuab": + cache_dir = (cache_base.rstrip("/") + "_tuab") if cache_base else "examples/conformal_eeg/cache_tuab" + dataset = TUABDataset(root=str(root), subset=args.subset, dev=args.quick_test) + sample_dataset = dataset.set_task(EEGAbnormalTUAB(), cache_dir=cache_dir) + else: + cache_dir = cache_base or "examples/conformal_eeg/cache" + dataset = TUEVDataset(root=str(root), subset=args.subset, dev=args.quick_test) + sample_dataset = dataset.set_task( + EEGEventsTUEV(resample_rate=200), cache_dir=cache_dir + ) if args.quick_test and len(sample_dataset) > quick_test_max_samples: sample_dataset = sample_dataset.subset(range(quick_test_max_samples)) print(f"Capped to {quick_test_max_samples} samples for quick-test.") + if args.model.lower() == "tfm": + # TFM tokenizer needs n_fft=200, hop_length=100 so STFT time = temporal patches + sample_dataset = AddSTFTDataset( + sample_dataset, n_fft=200, hop_length=100 + ) + print("Wrapped dataset with STFT for TFM-Tokenizer.") - print(f"Task samples: {len(sample_dataset)}") + print(f"Task samples: {len(sample_dataset)} (task_mode={task_mode})") print(f"Input schema: {sample_dataset.input_schema}") print(f"Output schema: {sample_dataset.output_schema}") @@ -313,17 +379,24 @@ def _run(args: argparse.Namespace) -> None: ratios = list(args.ratios) use_multi_seed = args.n_seeds > 1 or args.seeds is not None if use_multi_seed: - run_seeds = ( - [int(s.strip()) for s in args.seeds.split(",")] - if args.seeds - else [args.seed + i for i in range(args.n_seeds)] - ) + if args.seeds: + run_seeds = [int(s.strip()) for s in args.seeds.split(",")] + elif getattr(args, "tfm_skip_train", False) and args.model.lower() == "tfm": + run_seeds = list(range(1, 1 + args.n_seeds)) + else: + run_seeds = [args.seed + i for i in range(args.n_seeds)] n_runs = len(run_seeds) print(f" multi_seed: n_runs={n_runs}, run_seeds={run_seeds}, split_seed={args.split_seed} (fixed test set)") print(f"Multi-seed mode: {n_runs} runs (fixed test set), run seeds: {run_seeds}") if not use_multi_seed: - # Single run: original behavior + if args.model.lower() == "tfm": + ckpt = getattr(args, "tfm_checkpoint", None) + if ckpt and "{seed}" in ckpt: + args.tfm_checkpoint = ckpt.replace("{seed}", str(args.seed)) + clf = getattr(args, "tfm_classifier_checkpoint", None) + if clf and "{seed}" in clf: + args.tfm_classifier_checkpoint = clf.replace("{seed}", str(args.seed)) print("\n" + "=" * 80) print("STEP 2: Split train/val/cal/test") print("=" * 80) @@ -345,6 +418,7 @@ def _run(args: argparse.Namespace) -> None: args=args, device=device, epochs=epochs, + task_mode=task_mode, ) print("\n--- Split sizes and seed (for reporting) ---") print(f" train={len(train_ds)}, val={len(val_ds)}, cal={len(cal_ds)}, test={len(test_ds)}, seed={args.seed}") @@ -369,10 +443,16 @@ def _run(args: argparse.Namespace) -> None: print(f"Fixed test set size: {n_test}") accs, coverages, miscoverages, set_sizes = [], [], [], [] + tfm_ckpt_original = getattr(args, "tfm_checkpoint", None) + tfm_classifier_ckpt_original = getattr(args, "tfm_classifier_checkpoint", None) for run_i, run_seed in enumerate(run_seeds): print("\n" + "=" * 80) print(f"Run {run_i + 1} / {n_runs} (seed={run_seed})") print("=" * 80) + if tfm_ckpt_original and "{seed}" in tfm_ckpt_original: + args.tfm_checkpoint = tfm_ckpt_original.replace("{seed}", str(run_seed)) + if tfm_classifier_ckpt_original and "{seed}" in tfm_classifier_ckpt_original: + args.tfm_classifier_checkpoint = tfm_classifier_ckpt_original.replace("{seed}", str(run_seed)) set_seed(run_seed) train_ds, val_ds, cal_ds = _split_remainder_into_train_val_cal( sample_dataset, remainder_indices, ratios, run_seed @@ -388,8 +468,13 @@ def _run(args: argparse.Namespace) -> None: args=args, device=device, epochs=epochs, + task_mode=task_mode, return_metrics=True, ) + if tfm_ckpt_original and "{seed}" in tfm_ckpt_original: + args.tfm_checkpoint = tfm_ckpt_original + if tfm_classifier_ckpt_original and "{seed}" in tfm_classifier_ckpt_original: + args.tfm_classifier_checkpoint = tfm_classifier_ckpt_original accs.append(metrics["accuracy"]) coverages.append(metrics["coverage"]) miscoverages.append(metrics["miscoverage"]) diff --git a/examples/eeg/tuh_eeg/tuev_eeg_event_classification.ipynb b/examples/eeg/tuh_eeg/tuev_eeg_event_classification.ipynb index 95d7082a6..781b8a8c6 100644 --- a/examples/eeg/tuh_eeg/tuev_eeg_event_classification.ipynb +++ b/examples/eeg/tuh_eeg/tuev_eeg_event_classification.ipynb @@ -16,10 +16,19 @@ "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Running on device: cuda\n" + "ename": "AttributeError", + "evalue": "module 'torch' has no attribute 'uint16'", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mAttributeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 6\u001b[39m\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorch\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m6\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpyhealth\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mdatasets\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m TUEVDataset\n\u001b[32m 7\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpyhealth\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mtasks\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m EEGEventsTUEV\n\u001b[32m 8\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpyhealth\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mdatasets\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01msplitter\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m split_by_sample\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/PyHealth/pyhealth/datasets/__init__.py:49\u001b[39m\n\u001b[32m 41\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mwarnings\u001b[39;00m\n\u001b[32m 43\u001b[39m warnings.warn(\n\u001b[32m 44\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mThe SampleSignalDataset class is deprecated and will be removed in a future version.\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m 45\u001b[39m \u001b[38;5;167;01mDeprecationWarning\u001b[39;00m,\n\u001b[32m 46\u001b[39m )\n\u001b[32m---> \u001b[39m\u001b[32m49\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01m.\u001b[39;00m\u001b[34;01mbase_dataset\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m BaseDataset\n\u001b[32m 50\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01m.\u001b[39;00m\u001b[34;01mcardiology\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m CardiologyDataset\n\u001b[32m 51\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01m.\u001b[39;00m\u001b[34;01mchestxray14\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m ChestXray14Dataset\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/PyHealth/pyhealth/datasets/base_dataset.py:18\u001b[39m\n\u001b[32m 15\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmultiprocessing\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mqueues\u001b[39;00m\n\u001b[32m 16\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mshutil\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m18\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mlitdata\u001b[39;00m\n\u001b[32m 19\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mlitdata\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mstreaming\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mitem_loader\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m ParquetLoader\n\u001b[32m 20\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mlitdata\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mprocessing\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mdata_processor\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m in_notebook\n", + "\u001b[36mFile \u001b[39m\u001b[32m/usr/local/Caskroom/miniforge/base/lib/python3.12/site-packages/litdata/__init__.py:16\u001b[39m\n\u001b[32m 13\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mwarnings\u001b[39;00m\n\u001b[32m 15\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mlitdata\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01m__about__\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m * \u001b[38;5;66;03m# noqa: F403\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m16\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mlitdata\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mconstants\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m _LIGHTNING_SDK_AVAILABLE\n\u001b[32m 17\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mlitdata\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mprocessing\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mfunctions\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;28mmap\u001b[39m, merge_datasets, optimize, walk\n\u001b[32m 18\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mlitdata\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mraw\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mdataset\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m StreamingRawDataset\n", + "\u001b[36mFile \u001b[39m\u001b[32m/usr/local/Caskroom/miniforge/base/lib/python3.12/site-packages/litdata/constants.py:84\u001b[39m\n\u001b[32m 60\u001b[39m _LITDATA_DISABLE_VERSION_CHECK = \u001b[38;5;28mint\u001b[39m(os.getenv(\u001b[33m\"\u001b[39m\u001b[33mLITDATA_DISABLE_VERSION_CHECK\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33m0\u001b[39m\u001b[33m\"\u001b[39m))\n\u001b[32m 62\u001b[39m \u001b[38;5;66;03m# DON'T CHANGE ORDER\u001b[39;00m\n\u001b[32m 63\u001b[39m _TORCH_DTYPES_MAPPING = {\n\u001b[32m 64\u001b[39m \u001b[32m0\u001b[39m: torch.float32,\n\u001b[32m 65\u001b[39m \u001b[32m1\u001b[39m: torch.float,\n\u001b[32m 66\u001b[39m \u001b[32m2\u001b[39m: torch.float64,\n\u001b[32m 67\u001b[39m \u001b[32m3\u001b[39m: torch.double,\n\u001b[32m 68\u001b[39m \u001b[32m4\u001b[39m: torch.complex64,\n\u001b[32m 69\u001b[39m \u001b[32m5\u001b[39m: torch.cfloat,\n\u001b[32m 70\u001b[39m \u001b[32m6\u001b[39m: torch.complex128,\n\u001b[32m 71\u001b[39m \u001b[32m7\u001b[39m: torch.cdouble,\n\u001b[32m 72\u001b[39m \u001b[32m8\u001b[39m: torch.float16,\n\u001b[32m 73\u001b[39m \u001b[32m9\u001b[39m: torch.half,\n\u001b[32m 74\u001b[39m \u001b[32m10\u001b[39m: torch.bfloat16, \u001b[38;5;66;03m# Not supported https://github.com/pytorch/pytorch/issues/110285\u001b[39;00m\n\u001b[32m 75\u001b[39m \u001b[32m11\u001b[39m: torch.uint8,\n\u001b[32m 76\u001b[39m \u001b[32m12\u001b[39m: torch.int8,\n\u001b[32m 77\u001b[39m \u001b[32m13\u001b[39m: torch.int16,\n\u001b[32m 78\u001b[39m \u001b[32m14\u001b[39m: torch.short,\n\u001b[32m 79\u001b[39m \u001b[32m15\u001b[39m: torch.int32,\n\u001b[32m 80\u001b[39m \u001b[32m16\u001b[39m: torch.int,\n\u001b[32m 81\u001b[39m \u001b[32m17\u001b[39m: torch.int64,\n\u001b[32m 82\u001b[39m \u001b[32m18\u001b[39m: torch.long,\n\u001b[32m 83\u001b[39m \u001b[32m19\u001b[39m: torch.bool,\n\u001b[32m---> \u001b[39m\u001b[32m84\u001b[39m \u001b[32m20\u001b[39m: \u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43muint16\u001b[49m,\n\u001b[32m 85\u001b[39m }\n\u001b[32m 87\u001b[39m _NUMPY_SCTYPES = [ \u001b[38;5;66;03m# All NumPy scalar types from np.core.sctypes.values()\u001b[39;00m\n\u001b[32m 88\u001b[39m np.int8,\n\u001b[32m 89\u001b[39m np.int16,\n\u001b[32m (...)\u001b[39m\u001b[32m 105\u001b[39m np.void,\n\u001b[32m 106\u001b[39m ]\n\u001b[32m 107\u001b[39m _NUMPY_DTYPES_MAPPING: \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mint\u001b[39m, np.dtype] = {i: np.dtype(v) \u001b[38;5;28;01mfor\u001b[39;00m i, v \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(_NUMPY_SCTYPES)}\n", + "\u001b[36mFile \u001b[39m\u001b[32m/usr/local/Caskroom/miniforge/base/lib/python3.12/site-packages/torch/__init__.py:1938\u001b[39m, in \u001b[36m__getattr__\u001b[39m\u001b[34m(name)\u001b[39m\n\u001b[32m 1935\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mimportlib\u001b[39;00m\n\u001b[32m 1936\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m importlib.import_module(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m, \u001b[34m__name__\u001b[39m)\n\u001b[32m-> \u001b[39m\u001b[32m1938\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mmodule \u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[34m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m has no attribute \u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m\"\u001b[39m)\n", + "\u001b[31mAttributeError\u001b[39m: module 'torch' has no attribute 'uint16'" ] } ], @@ -56,7 +65,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "d1230c58", "metadata": {}, "outputs": [ @@ -115,7 +124,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "66f68916", "metadata": {}, "outputs": [ @@ -251,7 +260,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "c01a076f", "metadata": {}, "outputs": [ @@ -288,7 +297,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "1d490449", "metadata": {}, "outputs": [ @@ -332,7 +341,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "7236ddc0", "metadata": {}, "outputs": [ @@ -390,7 +399,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "11d7f9c5", "metadata": {}, "outputs": [ @@ -427,7 +436,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "5521de25", "metadata": {}, "outputs": [], @@ -447,7 +456,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "0c14a78d", "metadata": {}, "outputs": [ @@ -518,7 +527,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "id": "bbd0eb33", "metadata": {}, "outputs": [ @@ -533,8 +542,8 @@ "source": [ "model.eval()\n", "test_loss = 0.0\n", - "correct = 0\n", - "total = 0\n", + "all_y_true = []\n", + "all_y_prob = []\n", "with torch.no_grad():\n", " for batch in test_loader:\n", " signals = batch['signal'].to(device)\n", @@ -542,11 +551,22 @@ " outputs = model(signals)\n", " loss = criterion(outputs, labels)\n", " test_loss += loss.item()\n", - " predicted = torch.argmax(outputs, dim=1)\n", - " total += labels.size(0)\n", - " correct += (predicted == labels).sum().item()\n", + " probs = torch.softmax(outputs, dim=1)\n", + " all_y_true.append(labels.cpu().numpy())\n", + " all_y_prob.append(probs.cpu().numpy())\n", + "\n", + "y_true = np.concatenate(all_y_true, axis=0)\n", + "y_prob = np.concatenate(all_y_prob, axis=0)\n", + "print(f\"Test Loss: {test_loss/len(test_loader):.4f}\")\n", "\n", - "print(f\"Test Loss: {test_loss/len(test_loader):.4f}, Accuracy: {100 * correct / total:.2f}%\")" + "from pyhealth.metrics import multiclass_metrics_fn\n", + "metrics = multiclass_metrics_fn(\n", + " y_true, y_prob,\n", + " metrics=[\"accuracy\", \"balanced_accuracy\", \"f1_macro\", \"f1_micro\", \"cohen_kappa\"],\n", + ")\n", + "print(\"Test set metrics (PyHealth):\")\n", + "for name, value in metrics.items():\n", + " print(f\" {name}: {value:.4f}\")" ] }, { @@ -560,9 +580,9 @@ ], "metadata": { "kernelspec": { - "display_name": "pyhealth", + "display_name": "base", "language": "python", - "name": "pyhealth" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -574,7 +594,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.12" + "version": "3.12.11" } }, "nbformat": 4, diff --git a/legend_header.png b/legend_header.png new file mode 100644 index 000000000..4fcbebba0 Binary files /dev/null and b/legend_header.png differ diff --git a/plottfm.py b/plottfm.py new file mode 100644 index 000000000..e8d6cbaa5 --- /dev/null +++ b/plottfm.py @@ -0,0 +1,114 @@ +import matplotlib.pyplot as plt +import matplotlib.lines as mlines +import numpy as np + +# ========================================== +# 1. Configuration and Data Setup +# ========================================== + +# --- LATEX & MATPLOTLIB STYLE --- +plt.rcParams.update({ + "text.usetex": True, + "font.family": "serif", + "font.serif": ["Computer Modern Roman"], + "axes.labelsize": 14, + "font.size": 12, + "xtick.labelsize": 12, + "ytick.labelsize": 12, + "axes.titlesize": 16, + "legend.fontsize": 13, + "axes.grid": True, + "grid.alpha": 0.3, + "grid.linestyle": "--", + "lines.linewidth": 2.5, + "lines.markersize": 8, +}) + +# --- DATA POPULATION --- +alphas = [0.01, 0.05, 0.1, 0.2] +datasets_ordered = ["TUEV", "TUAB"] +methods_ordered = ["NCP", "KDE CP", "Naive CP", "KMeans CP"] + +# Extracted from provided logs +full_data = { + "TUEV": { + "KDE CP": {0.01: (0.9723, 0.0088, 1.90, 0.56), 0.05: (0.8718, 0.0161, 1.37, 0.22), 0.1: (0.7457, 0.0072, 1.11, 0.14), 0.2: (0.5415, 0.0297, 0.89, 0.10)}, + "KMeans CP": {0.01: (0.9721, 0.0095, 1.79, 0.44), 0.05: (0.8549, 0.0219, 1.33, 0.27), 0.1: (0.7488, 0.0152, 1.23, 0.24), 0.2: (0.6328, 0.0114, 0.93, 0.16)}, + "Naive CP": {0.01: (0.9744, 0.0069, 1.79, 0.22), 0.05: (0.8749, 0.0139, 1.31, 0.22), 0.1: (0.7461, 0.0334, 1.11, 0.21), 0.2: (0.5275, 0.0249, 0.90, 0.12)}, + "NCP": {0.01: (0.9617, 0.0057, 1.35, 0.16), 0.05: (0.9365, 0.0083, 1.28, 0.15), 0.1: (0.9152, 0.0117, 1.25, 0.13), 0.2: (0.8722, 0.0052, 1.22, 0.17)} + }, + "TUAB": { + "KDE CP": {0.01: (0.9729, 0.0065, 1.82, 0.36), 0.05: (0.8656, 0.0163, 1.30, 0.21), 0.1: (0.7447, 0.0217, 1.11, 0.19), 0.2: (0.5517, 0.0162, 0.87, 0.08)}, + "KMeans CP": {0.01: (0.9703, 0.0110, 1.89, 0.48), 0.05: (0.8652, 0.0072, 1.36, 0.25), 0.1: (0.7262, 0.0178, 1.17, 0.25), 0.2: (0.6161, 0.0307, 0.94, 0.15)}, + "Naive CP": {0.01: (0.9715, 0.0090, 1.97, 0.37), 0.05: (0.8691, 0.0189, 1.36, 0.27), 0.1: (0.7532, 0.0165, 1.11, 0.18), 0.2: (0.5380, 0.0248, 0.90, 0.10)}, + "NCP": {0.01: (0.9659, 0.0050, 1.38, 0.20), 0.05: (0.9348, 0.0113, 1.30, 0.18), 0.1: (0.9155, 0.0094, 1.24, 0.18), 0.2: (0.8713, 0.0098, 1.18, 0.13)} + } +} + +# --- STYLING (Solid lines for methods, dashed for target) --- +style_map = { + "NCP": {'color': '#4e6386', 'marker': 'o'}, # Blue-Gray + "KDE CP": {'color': '#7393B3', 'marker': 's'}, # Lighter Blue-Gray + "Naive CP": {'color': '#D55E00', 'marker': '^'}, # Vermillion Orange + "KMeans CP": {'color': '#E69F00', 'marker': 'D'} # Light Orange +} +target_style = {'color': 'black', 'ls': '--', 'lw': 2.0} + + +# ========================================== +# 2. Main Plotting Function +# ========================================== +def generate_1x2_plot(metric_idx, ylabel, filename, include_target=False): + fig, axes = plt.subplots(1, 2, figsize=(14, 7), sharey=True, sharex=True) + + for i, dset in enumerate(datasets_ordered): + ax = axes[i] + + # 1. Target Line (Dashed) + if include_target: + target_cov = [1 - a for a in alphas] + ax.plot(alphas, target_cov, **target_style, zorder=1) + + # 2. CP Methods (All Solid) + for method in methods_ordered: + s = style_map[method] + means = np.array([full_data[dset][method][a][metric_idx] for a in alphas]) + stds = np.array([full_data[dset][method][a][metric_idx + 1] for a in alphas]) + + ax.fill_between(alphas, means - stds, means + stds, + color=s['color'], alpha=0.15, zorder=2) + + # Using linestyle='-' for all methods + ax.plot(alphas, means, color=s['color'], linestyle='-', + marker=s['marker'], label=method, zorder=3) + + ax.set_title(rf"\textbf{{{dset} (ContraWR)}}") + if i == 0: ax.set_ylabel(ylabel) + ax.set_xlabel(r"Significance Level ($\alpha$)") + ax.set_xticks(alphas) + + # 3. Legend Header + handles = [] + if include_target: + handles.append(mlines.Line2D([], [], **target_style, label=r'Target ($1-\alpha$)')) + + for method in methods_ordered: + s = style_map[method] + handles.append(mlines.Line2D([], [], color=s['color'], linestyle='-', + marker=s['marker'], markersize=10, + linewidth=3, label=method)) + + fig.legend(handles=handles, loc='lower center', bbox_to_anchor=(0.5, 0.91), + ncol=len(handles), frameon=False, handlelength=3) + + plt.tight_layout(rect=[0, 0.0, 1, 0.91]) + plt.savefig(filename, dpi=300, bbox_inches='tight') + print(f"Generated: {filename}") + +if __name__ == "__main__": + # Coverage Plot + generate_1x2_plot(0, "Empirical Coverage", "contrawr_coverage_solid.png", True) + # Set Size Plot + generate_1x2_plot(2, "Avg. Prediction Set Size", "contrawr_setsize_solid.png", False) + + plt.show() \ No newline at end of file diff --git a/pyhealth/calib/predictionset/base_conformal/__init__.py b/pyhealth/calib/predictionset/base_conformal/__init__.py index 8e37c0d6d..5b12d4c3b 100644 --- a/pyhealth/calib/predictionset/base_conformal/__init__.py +++ b/pyhealth/calib/predictionset/base_conformal/__init__.py @@ -153,9 +153,9 @@ def __init__( ) -> None: super().__init__(model, **kwargs) - if model.mode != "multiclass": + if model.mode not in ("multiclass", "binary"): raise NotImplementedError( - "BaseConformal only supports multiclass classification" + "BaseConformal only supports multiclass or binary classification" ) self.mode = self.model.mode @@ -190,10 +190,16 @@ def _compute_conformity_scores( Conformity scores of shape (N,) """ N = len(y_true) + # Ensure integer indices (y_true can be float e.g. 0.0/1.0 from binary tasks) + y_true = np.asarray(y_true, dtype=np.int64) if self.score_type == "aps" or self.score_type == "threshold": # Use probability of true class as conformity score - # Higher score = more conforming (better prediction) - scores = y_prob[np.arange(N), y_true] + if y_prob.shape[1] == 1: + # Binary: y_prob is (N, 1) for positive class; P(y=0)=1-p, P(y=1)=p + p1 = np.asarray(y_prob[:, 0], dtype=np.float64).ravel() + scores = np.where(y_true == 1, p1, 1.0 - p1) + else: + scores = y_prob[np.arange(N), y_true] else: raise ValueError(f"Unknown score_type: {self.score_type}") @@ -222,8 +228,8 @@ def calibrate(self, cal_dataset: IterableDataset): # Compute quantile thresholds if isinstance(self.alpha, float): - # Marginal coverage: single threshold - t = _query_quantile(conformity_scores, self.alpha) + # Marginal coverage: single scalar threshold + t = float(_query_quantile(np.asarray(conformity_scores).ravel(), self.alpha)) else: # Class-conditional coverage: one threshold per class if len(self.alpha) != K: @@ -266,10 +272,18 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: ) pred = self.model(**kwargs) - - # Construct prediction set by thresholding probabilities - # Include classes with probability >= threshold - pred["y_predset"] = pred["y_prob"] >= self.t + y_prob = pred["y_prob"] + + # Binary models output (N, 1) for positive class; expand to (N, 2) for set construction + if y_prob.shape[-1] == 1: + p1 = y_prob.squeeze(-1).clamp(0.0, 1.0) + y_prob = torch.stack([1.0 - p1, p1], dim=-1) + + # Broadcast threshold to (1, K) so (batch, K) >= (1, K) works; scalar stays scalar + th = self.t.to(device=y_prob.device, dtype=y_prob.dtype) + if th.dim() >= 1 and th.numel() > 1: + th = th.view(1, -1) + pred["y_predset"] = y_prob >= th return pred diff --git a/pyhealth/calib/predictionset/cluster/cluster_label.py b/pyhealth/calib/predictionset/cluster/cluster_label.py index a29bdb854..60ccaaab1 100644 --- a/pyhealth/calib/predictionset/cluster/cluster_label.py +++ b/pyhealth/calib/predictionset/cluster/cluster_label.py @@ -102,9 +102,9 @@ def __init__( ) -> None: super().__init__(model, **kwargs) - if model.mode != "multiclass": + if model.mode not in ("multiclass", "binary"): raise NotImplementedError( - "ClusterLabel only supports multiclass classification" + "ClusterLabel only supports multiclass or binary classification" ) self.mode = self.model.mode @@ -176,6 +176,16 @@ def calibrate( y_true = cal_dataset_dict["y_true"] N, K = y_prob.shape + # Binary: model outputs (N, 1); treat as K=2 for conformity and thresholds + if K == 1: + y_true = np.asarray(y_true).ravel().astype(np.intp) + p1 = np.asarray(y_prob[:, 0], dtype=np.float64).ravel() + conformity_scores = np.where(y_true == 1, p1, 1.0 - p1) + K = 2 + else: + y_true = np.asarray(y_true).ravel().astype(np.intp) + conformity_scores = y_prob[np.arange(N), y_true] + # Extract embeddings if not provided if cal_embeddings is None: print("Extracting embeddings from calibration set...") @@ -193,19 +203,31 @@ def calibrate( else: train_embeddings = np.asarray(train_embeddings) + # Flatten to 2D (n_samples, n_features) so KMeans works with 3D embeddings (e.g. TFM) + def _flatten_emb(emb): + emb = np.asarray(emb) + if emb.ndim <= 2: + return emb.reshape(emb.shape[0], -1) if emb.ndim == 2 else emb.reshape(-1, 1) + return emb.reshape(emb.shape[0], -1) + + train_embeddings = _flatten_emb(train_embeddings) + cal_embeddings = _flatten_emb(cal_embeddings) + # Combine embeddings for clustering print(f"Combining embeddings: train={train_embeddings.shape}, cal={cal_embeddings.shape}") all_embeddings = np.concatenate([train_embeddings, cal_embeddings], axis=0) print(f"Total embeddings for clustering: {all_embeddings.shape}") - # Fit K-means on combined embeddings - print(f"Fitting K-means with {self.n_clusters} clusters...") + # Fit K-means on combined embeddings (verbose=1 so long runs show progress) + print(f"Fitting K-means with {self.n_clusters} clusters (n_init=10, may take a while for large N)...") self.kmeans_model = KMeans( n_clusters=self.n_clusters, random_state=self.random_state, n_init=10, + verbose=1, ) self.kmeans_model.fit(all_embeddings) + print(" K-means fit done.") # Assign calibration samples to clusters # Note: cal_embeddings start at index len(train_embeddings) in all_embeddings @@ -214,10 +236,10 @@ def calibrate( print(f"Cluster assignments: {np.bincount(cal_cluster_labels)}") - # Compute conformity scores (probabilities of true class) - conformity_scores = y_prob[np.arange(N), y_true] + # Conformity scores already set above (with binary handling) # Compute cluster-specific thresholds + print(f"Computing cluster-specific thresholds for {self.n_clusters} clusters...") self.cluster_thresholds = {} for cluster_id in range(self.n_clusters): cluster_mask = cal_cluster_labels == cluster_id @@ -262,6 +284,8 @@ def calibrate( t.append(t_k) self.cluster_thresholds[cluster_id] = np.array(t) + print(" Cluster thresholds computed.") + if self.debug: print(f"Cluster thresholds: {self.cluster_thresholds}") @@ -289,9 +313,12 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: "embed=True flag in its forward() method." ) - # Ensure embeddings are 2D (batch_size, embedding_dim) + # Flatten to 2D (batch_size, n_features) so KMeans.predict works with 3D embeddings sample_embedding = pred["embed"].detach().cpu().numpy() - sample_embedding = np.atleast_2d(sample_embedding) + if sample_embedding.ndim == 1: + sample_embedding = sample_embedding.reshape(1, -1) + elif sample_embedding.ndim > 2: + sample_embedding = sample_embedding.reshape(sample_embedding.shape[0], -1) # Predict cluster for each sample in the batch cluster_ids = self.kmeans_model.predict(sample_embedding) @@ -300,20 +327,27 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: cluster_thresholds = np.array( [self.cluster_thresholds[cid] for cid in cluster_ids] ) + y_prob = pred["y_prob"] + + # Binary: expand (batch, 1) to (batch, 2) only for set construction; keep pred["y_prob"] as-is + if y_prob.shape[-1] == 1: + p1 = y_prob.squeeze(-1).clamp(0.0, 1.0) + y_prob_2 = torch.stack([1.0 - p1, p1], dim=-1) + else: + y_prob_2 = y_prob + cluster_thresholds = torch.as_tensor( - cluster_thresholds, device=self.device, dtype=pred["y_prob"].dtype + cluster_thresholds, device=self.device, dtype=y_prob_2.dtype ) # Broadcast thresholds to match y_prob shape (batch_size, n_classes). - # Marginal: thresholds are (batch_size,) -> view to (batch_size, 1, ...). - # Class-conditional: thresholds are already (batch_size, K), no view. - if pred["y_prob"].ndim > 1 and cluster_thresholds.ndim == 1: + if y_prob_2.ndim > 1 and cluster_thresholds.ndim == 1: view_shape = (cluster_thresholds.shape[0],) + (1,) * ( - pred["y_prob"].ndim - 1 + y_prob_2.ndim - 1 ) cluster_thresholds = cluster_thresholds.view(view_shape) - pred["y_predset"] = pred["y_prob"] >= cluster_thresholds + pred["y_predset"] = y_prob_2 >= cluster_thresholds pred.pop("embed", None) # do not expose internal embedding to caller return pred diff --git a/pyhealth/calib/predictionset/cluster/neighborhood_label.py b/pyhealth/calib/predictionset/cluster/neighborhood_label.py index 2d9f2dc6d..c9ea2e2c6 100644 --- a/pyhealth/calib/predictionset/cluster/neighborhood_label.py +++ b/pyhealth/calib/predictionset/cluster/neighborhood_label.py @@ -74,9 +74,9 @@ def __init__( ) -> None: super().__init__(model, **kwargs) - if model.mode != "multiclass": + if model.mode not in ("multiclass", "binary"): raise NotImplementedError( - "NeighborhoodLabel only supports multiclass classification" + "NeighborhoodLabel only supports multiclass or binary classification" ) self.mode = self.model.mode @@ -134,6 +134,14 @@ def calibrate( y_prob = cal_dict["y_prob"] y_true = cal_dict["y_true"] N = y_prob.shape[0] + y_true = np.asarray(y_true).ravel().astype(np.intp) + + # Binary: model outputs (N, 1); conformity = prob of true class + if y_prob.shape[1] == 1: + p1 = np.asarray(y_prob[:, 0], dtype=np.float64).ravel() + conformity_scores = np.where(y_true == 1, p1, 1.0 - p1) + else: + conformity_scores = y_prob[np.arange(N), y_true] if cal_embeddings is None: cal_embeddings = extract_embeddings( @@ -148,16 +156,24 @@ def calibrate( f"cal_dataset size {N}" ) - conformity_scores = y_prob[np.arange(N), y_true] + # Flatten to 2D (n_samples, n_features) so NearestNeighbors works with 3D embeddings (e.g. TFM) + cal_embeddings = np.asarray(cal_embeddings) + if cal_embeddings.ndim > 2: + cal_embeddings = cal_embeddings.reshape(cal_embeddings.shape[0], -1) + elif cal_embeddings.ndim == 1: + cal_embeddings = cal_embeddings.reshape(-1, 1) k = min(self.k_neighbors, N) + print(" Fitting k-NN on calibration set (can be slow for large N)...") self._nn = NearestNeighbors(n_neighbors=k, metric="euclidean").fit( - np.atleast_2d(cal_embeddings) + cal_embeddings ) - self.cal_embeddings_ = np.atleast_2d(cal_embeddings) - self.cal_conformity_scores_ = np.asarray(conformity_scores, dtype=np.float64) + self.cal_embeddings_ = cal_embeddings + self.cal_conformity_scores_ = np.asarray( + conformity_scores, dtype=np.float64 + ).ravel() - # this is the ncp calibration step + print(" Computing k-NN neighbors...") distances_cal, indices_cal = self._nn.kneighbors( self.cal_embeddings_, n_neighbors=k ) @@ -174,13 +190,16 @@ def _empirical_coverage(alpha_tilde_cand: float) -> float: ) return float(np.mean(self.cal_conformity_scores_ >= t_all)) + print(" Calibrating alpha_tilde (binary search, 50 iters)...") low, high = 0.0, 1.0 - for _ in range(50): + for it in range(50): mid = (low + high) / 2 if _empirical_coverage(mid) >= 1.0 - self.alpha: low = mid else: high = mid + if (it + 1) % 10 == 0: + print(f" iteration {it + 1}/50, alpha_tilde in [{low:.4f}, {high:.4f}]") self.alpha_tilde_ = float(low) def forward(self, **kwargs) -> Dict[str, torch.Tensor]: @@ -203,7 +222,11 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: ) test_emb = pred["embed"].detach().cpu().numpy() - test_emb = np.atleast_2d(test_emb) + # Flatten to 2D (batch_size, n_features) to match calibration embeddings + if test_emb.ndim == 1: + test_emb = test_emb.reshape(1, -1) + elif test_emb.ndim > 2: + test_emb = test_emb.reshape(test_emb.shape[0], -1) batch_size = test_emb.shape[0] n_cal = self.cal_conformity_scores_.shape[0] k = min(self.k_neighbors, n_cal) @@ -218,16 +241,24 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: scores_i, self.alpha_tilde_, w ) + y_prob = pred["y_prob"] + # Binary: expand (batch, 1) to (batch, 2) only for set construction; keep pred["y_prob"] as-is + if y_prob.shape[-1] == 1: + p1 = y_prob.squeeze(-1).clamp(0.0, 1.0) + y_prob_2 = torch.stack([1.0 - p1, p1], dim=-1) + else: + y_prob_2 = y_prob + th = torch.as_tensor( - thresholds, device=self.device, dtype=pred["y_prob"].dtype + thresholds, device=self.device, dtype=y_prob_2.dtype ) - if pred["y_prob"].ndim > 1: - th = th.view(-1, *([1] * (pred["y_prob"].ndim - 1))) - y_predset = pred["y_prob"] >= th + if y_prob_2.ndim > 1: + th = th.view(-1, *([1] * (y_prob_2.ndim - 1))) + y_predset = y_prob_2 >= th # if threshold is high, include at least argmax empty = y_predset.sum(dim=1) == 0 if empty.any(): - argmax_idx = pred["y_prob"].argmax(dim=1) + argmax_idx = y_prob_2.argmax(dim=1) y_predset[empty, argmax_idx[empty]] = True pred["y_predset"] = y_predset pred.pop("embed", None) diff --git a/pyhealth/calib/predictionset/covariate/covariate_label.py b/pyhealth/calib/predictionset/covariate/covariate_label.py index f90430be8..ce3e0f5c1 100644 --- a/pyhealth/calib/predictionset/covariate/covariate_label.py +++ b/pyhealth/calib/predictionset/covariate/covariate_label.py @@ -96,55 +96,68 @@ def fit_kde( if kernel != "rbf": raise ValueError(f"Only 'rbf' kernel supported, got {kernel}") - # Calculate bandwidth if needed + # Calculate bandwidth if needed (embeddings may be 1D or 3D; flatten to (n_samples, n_features)) def get_bandwidth(embeddings, bw): if isinstance(bw, str): - n_samples, n_features = embeddings.shape + emb = np.asarray(embeddings) + if emb.ndim == 1: + emb = emb.reshape(-1, 1) + else: + emb = emb.reshape(emb.shape[0], -1) + n_samples, n_features = emb.shape if bw == "scott": return n_samples ** (-1.0 / (n_features + 4)) else: raise ValueError(f"Unknown bandwidth method: {bw}") return bw - # Convert to torch tensors - cal_emb_torch = torch.from_numpy(cal_embeddings).float() - test_emb_torch = torch.from_numpy(test_embeddings).float() + # Convert to torch tensors; flatten to (n_samples, n_features) for KDE + def _flatten_emb(emb): + emb = np.asarray(emb) + if emb.ndim == 1: + return emb.reshape(-1, 1) + return emb.reshape(emb.shape[0], -1) + + cal_emb_2d = _flatten_emb(cal_embeddings) + test_emb_2d = _flatten_emb(test_embeddings) + n_cal, n_test = cal_emb_2d.shape[0], test_emb_2d.shape[0] + print(f" Calibration embeddings: {n_cal} x {cal_emb_2d.shape[1]}, test: {n_test} x {test_emb_2d.shape[1]}") + cal_emb_torch = torch.from_numpy(cal_emb_2d).float() + test_emb_torch = torch.from_numpy(test_emb_2d).float() # Fit KDE on calibration embeddings + print(" Computing bandwidth and building calibration KDE...") cal_bw = get_bandwidth(cal_embeddings, bandwidth) kern_cal = RBFKernelMean(h=cal_bw) # Fit KDE on test embeddings + print(" Computing bandwidth and building test KDE...") test_bw = get_bandwidth(test_embeddings, bandwidth) kern_test = RBFKernelMean(h=test_bw) - # Create callable functions that compute density - def kde_cal(data): - """Compute density using calibration KDE.""" - if not isinstance(data, torch.Tensor): - data = torch.from_numpy(np.array(data)).float() + # Create callable functions that compute density (flatten to 2D so 3D embeddings work) + def _to_2d_tensor(data): + data = np.asarray(data) if data.ndim == 1: - data = data.unsqueeze(0) + data = data.reshape(-1, 1) + else: + data = data.reshape(data.shape[0], -1) + return torch.from_numpy(data).float() - # Compute kernel values and average (density estimate) + def kde_cal(data): + """Compute density using calibration KDE.""" + data = _to_2d_tensor(data) with torch.no_grad(): K = kern_cal(data, cal_emb_torch) # (n_query, n_cal) density = K.mean(dim=1) # Average over calibration points - return density.numpy() def kde_test(data): """Compute density using test KDE.""" - if not isinstance(data, torch.Tensor): - data = torch.from_numpy(np.array(data)).float() - if data.ndim == 1: - data = data.unsqueeze(0) - - # Compute kernel values and average (density estimate) + data = _to_2d_tensor(data) with torch.no_grad(): K = kern_test(data, test_emb_torch) # (n_query, n_test) density = K.mean(dim=1) # Average over test points - return density.numpy() return kde_cal, kde_test @@ -319,9 +332,9 @@ def __init__( ) -> None: super().__init__(model, **kwargs) - if model.mode != "multiclass": + if model.mode not in ("multiclass", "binary"): raise NotImplementedError( - "CovariateLabel only supports multiclass classification" + "CovariateLabel only supports multiclass or binary classification" ) self.mode = self.model.mode @@ -414,6 +427,16 @@ def calibrate( y_true = cal_dataset_dict["y_true"] N, K = y_prob.shape + # Binary: model outputs (N, 1) for positive class; treat as K=2 + if K == 1: + y_true = np.asarray(y_true).ravel().astype(np.intp) + p1 = np.asarray(y_prob[:, 0], dtype=np.float64).ravel() + conformity_scores = np.where(y_true == 1, p1, 1.0 - p1) + K = 2 + else: + y_true = np.asarray(y_true).ravel().astype(np.intp) + conformity_scores = y_prob[np.arange(N), y_true] + # Determine weights: either custom or KDE-based if cal_weights is not None: # Use custom weights provided by user @@ -452,18 +475,19 @@ def calibrate( "Please provide cal_embeddings and test_embeddings." ) - # Compute likelihood ratios using KDE - print("Computing likelihood ratios via KDE...") + # Compute likelihood ratios using KDE (can be slow for large N) + n_cal = X.shape[0] if hasattr(X, "shape") else len(X) + print(f"Computing likelihood ratios via KDE (evaluating on {n_cal} calibration points)...") likelihood_ratios = _compute_likelihood_ratio( self.kde_test, self.kde_cal, X ) + print(" Likelihood ratios computed.") # Normalize weights weights = likelihood_ratios / np.sum(likelihood_ratios) self._sum_cal_weights = np.sum(likelihood_ratios) - # Extract conformity scores (probabilities of true class) - conformity_scores = y_prob[np.arange(N), y_true] + # Conformity scores already set above (with binary handling) # Compute weighted quantile thresholds if isinstance(self.alpha, float): @@ -498,10 +522,20 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: are in the prediction set """ pred = self.model(**kwargs) + y_prob = pred["y_prob"] - # Construct prediction set by thresholding probabilities - pred["y_predset"] = pred["y_prob"] > self.t + # Binary: expand (batch, 1) to (batch, 2) only for set construction; keep pred["y_prob"] as-is for metrics + if y_prob.shape[-1] == 1: + p1 = y_prob.squeeze(-1).clamp(0.0, 1.0) + y_prob_2 = torch.stack([1.0 - p1, p1], dim=-1) + else: + y_prob_2 = y_prob + # Broadcast threshold for (batch, K) + th = self.t.to(device=y_prob_2.device, dtype=y_prob_2.dtype) + if th.dim() >= 1 and th.numel() > 1: + th = th.view(1, -1) + pred["y_predset"] = y_prob_2 > th return pred diff --git a/pyhealth/calib/utils.py b/pyhealth/calib/utils.py index fd7ffffc0..ebc0ca094 100644 --- a/pyhealth/calib/utils.py +++ b/pyhealth/calib/utils.py @@ -134,11 +134,29 @@ def extract_embeddings(model, dataset, batch_size=32, device="cpu"): loader = datautils.get_dataloader(dataset, batch_size=batch_size, shuffle=False) all_embeddings = [] + try: + n_batches = len(loader) + except TypeError: + n_batches = None + model.eval() model.to(device) + # Print immediately so logs show we're not stuck (first batch can be slow) + if n_batches is not None: + print(f" Extracting embeddings over {n_batches} batches...", flush=True) + else: + print(" Extracting embeddings (processing batches)...", flush=True) + + # Progress every N batches so logs show movement when stdout is redirected + log_interval = 25 + if n_batches is not None and n_batches <= 100: + log_interval = max(5, n_batches // 10) + with torch.no_grad(): - for batch in loader: + for batch_idx, batch in enumerate( + tqdm.tqdm(loader, desc="Extracting embeddings", leave=True) + ): # Move batch to device batch_device = { k: v.to(device) if isinstance(v, torch.Tensor) else v @@ -161,4 +179,14 @@ def extract_embeddings(model, dataset, batch_size=32, device="cpu"): embeddings = output["embed"].cpu().numpy() all_embeddings.append(embeddings) + # Periodic print so redirected logs show progress (tqdm often doesn't) + if batch_idx == 0 or (batch_idx + 1) % log_interval == 0: + if n_batches is not None: + print( + f" ... extracted {batch_idx + 1}/{n_batches} batches", + flush=True, + ) + else: + print(f" ... extracted {batch_idx + 1} batches", flush=True) + return np.concatenate(all_embeddings, axis=0) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 7faffef60..b9850bd6b 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -379,7 +379,11 @@ def clean_tmpdir(self) -> None: """Cleans up the temporary directory within the cache.""" tmp_dir = self.cache_dir / "tmp" if tmp_dir.exists(): - shutil.rmtree(tmp_dir) + try: + shutil.rmtree(tmp_dir) + except (FileNotFoundError, OSError): + # Ignore if already removed by another process (e.g. parallel grid jobs) + pass def _scan_csv_tsv_gz( self, source_path: str diff --git a/pyhealth/metrics/binary.py b/pyhealth/metrics/binary.py index ea7d125f7..878693f4f 100644 --- a/pyhealth/metrics/binary.py +++ b/pyhealth/metrics/binary.py @@ -1,9 +1,10 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import numpy as np import sklearn.metrics as sklearn_metrics import pyhealth.metrics.calibration as calib +import pyhealth.metrics.prediction_set as pset def binary_metrics_fn( @@ -11,7 +12,8 @@ def binary_metrics_fn( y_prob: np.ndarray, metrics: Optional[List[str]] = None, threshold: float = 0.5, -) -> Dict[str, float]: + y_predset: Optional[np.ndarray] = None, +) -> Dict[str, Union[float, np.ndarray]]: """Computes metrics for binary classification. User can specify which metrics to compute by passing a list of metric names. @@ -28,6 +30,11 @@ def binary_metrics_fn( - jaccard: Jaccard similarity coefficient score - ECE: Expected Calibration Error (with 20 equal-width bins). Check :func:`pyhealth.metrics.calibration.ece_confidence_binary`. - ECE_adapt: adaptive ECE (with 20 equal-size bins). Check :func:`pyhealth.metrics.calibration.ece_confidence_binary`. + + The following prediction-set metrics are accepted but ignored if y_predset is None: + - rejection_rate, set_size, miscoverage_ps, miscoverage_overall_ps, + error_ps, error_overall_ps (see :mod:`pyhealth.metrics.prediction_set`). + If no metrics are specified, pr_auc, roc_auc and f1 are computed by default. This function calls sklearn.metrics functions to compute the metrics. For @@ -39,6 +46,7 @@ def binary_metrics_fn( y_prob: Predicted probabilities of shape (n_samples,). metrics: List of metrics to compute. Default is ["pr_auc", "roc_auc", "f1"]. threshold: Threshold for binary classification. Default is 0.5. + y_predset: Optional (n_samples, 2) boolean prediction sets for conformal metrics. Returns: Dictionary of metrics whose keys are the metric names and values are @@ -54,6 +62,21 @@ def binary_metrics_fn( if metrics is None: metrics = ["pr_auc", "roc_auc", "f1"] + # Normalize to 1D so sklearn and pset get consistent shapes (e.g. from (N,1) tensors) + y_true = np.asarray(y_true).ravel() + y_prob = np.asarray(y_prob).ravel() + + prediction_set_metrics = [ + "rejection_rate", + "set_size", + "miscoverage_mean_ps", + "miscoverage_ps", + "miscoverage_overall_ps", + "error_mean_ps", + "error_ps", + "error_overall_ps", + ] + y_pred = y_prob.copy() y_pred[y_pred >= threshold] = 1 y_pred[y_pred < threshold] = 0 @@ -72,9 +95,10 @@ def binary_metrics_fn( elif metric == "balanced_accuracy": balanced_accuracy = sklearn_metrics.balanced_accuracy_score(y_true, y_pred) output["balanced_accuracy"] = balanced_accuracy - elif metric == "f1": + elif metric in ("f1", "f1_weighted"): + # f1_weighted alias for script compatibility with multiclass (binary has one class) f1 = sklearn_metrics.f1_score(y_true, y_pred) - output["f1"] = f1 + output["f1" if metric == "f1" else "f1_weighted"] = f1 elif metric == "precision": precision = sklearn_metrics.precision_score(y_true, y_pred) output["precision"] = precision @@ -91,6 +115,34 @@ def binary_metrics_fn( output[metric] = calib.ece_confidence_binary( y_prob, y_true, bins=20, adaptive=metric.endswith("_adapt") ) + elif metric in prediction_set_metrics: + if y_predset is None: + continue + y_predset_np = np.asarray(y_predset, dtype=bool) + if y_predset_np.ndim == 1: + y_predset_np = y_predset_np.reshape(-1, 1) + if y_predset_np.shape[1] == 1: + y_predset_np = np.concatenate( + [1 - y_predset_np, y_predset_np], axis=1 + ) + # pset uses y_true as integer class indices; ensure 1D int + y_true_flat = np.asarray(y_true).ravel().astype(np.intp) + if metric == "rejection_rate": + output[metric] = pset.rejection_rate(y_predset_np) + elif metric == "set_size": + output[metric] = pset.size(y_predset_np) + elif metric == "miscoverage_mean_ps": + output[metric] = pset.miscoverage_ps(y_predset_np, y_true_flat).mean() + elif metric == "miscoverage_ps": + output[metric] = pset.miscoverage_ps(y_predset_np, y_true_flat) + elif metric == "miscoverage_overall_ps": + output[metric] = pset.miscoverage_overall_ps(y_predset_np, y_true_flat) + elif metric == "error_mean_ps": + output[metric] = pset.error_ps(y_predset_np, y_true_flat).mean() + elif metric == "error_ps": + output[metric] = pset.error_ps(y_predset_np, y_true_flat) + elif metric == "error_overall_ps": + output[metric] = pset.error_overall_ps(y_predset_np, y_true_flat) else: raise ValueError(f"Unknown metric for binary classification: {metric}") return output diff --git a/pyhealth/metrics/prediction_set.py b/pyhealth/metrics/prediction_set.py index 2b6f71705..99451dfa4 100644 --- a/pyhealth/metrics/prediction_set.py +++ b/pyhealth/metrics/prediction_set.py @@ -26,8 +26,9 @@ def _missrate(y_pred:np.ndarray, y_true:np.ndarray, ignore_rejected=False): # currently handles multilabel and multiclass K = y_pred.shape[1] if len(y_true.shape) == 1: - y_true, _ = np.zeros((len(y_true),K), dtype=bool), y_true - y_true[np.arange(len(y_true)), _] = 1 + labels = np.asarray(y_true).ravel().astype(np.intp) + y_true = np.zeros((len(labels), K), dtype=bool) + y_true[np.arange(len(labels)), labels] = 1 y_true = y_true.astype(bool) keep_msk = (y_pred.sum(1) == 1) if ignore_rejected else np.ones(len(y_true), dtype=bool) @@ -94,7 +95,7 @@ def miscoverage_overall_ps(y_pred:np.ndarray, y_true:np.ndarray): The 2-th prediction set is {0,1} and the label is 1 (covered). Thus the miscoverage rate is 1/3. """ - assert len(y_true.shape) == 1 + y_true = np.asarray(y_true).ravel().astype(np.intp) truth_pred = y_pred[np.arange(len(y_true)), y_true] return 1 - np.mean(truth_pred) if len(truth_pred) > 0 else 0.0 @@ -114,7 +115,7 @@ def error_overall_ps(y_pred:np.ndarray, y_true:np.ndarray): The 1-th sample is not rejected and incurs on error. The 2-th sample is rejected, thus excluded from the computation. """ - assert len(y_true.shape) == 1 + y_true = np.asarray(y_true).ravel().astype(np.intp) truth_pred = y_pred[np.arange(len(y_true)), y_true] truth_pred = truth_pred[y_pred.sum(1) == 1] return 1 - np.mean(truth_pred) if len(truth_pred) > 0 else 0.0 diff --git a/pyhealth/models/tfm_tokenizer.py b/pyhealth/models/tfm_tokenizer.py index fb8b40661..b9ef2c9d0 100644 --- a/pyhealth/models/tfm_tokenizer.py +++ b/pyhealth/models/tfm_tokenizer.py @@ -418,6 +418,14 @@ def tokenize(self, x, x_temporal): x_temporal = self.temporal_patch_embedding(x_temporal) x_temporal = rearrange(x_temporal, "B E T -> B T E") + # Align time dimensions (can differ with variable-length batching or STFT center) + T_f, T_t = x.size(1), x_temporal.size(1) + if T_f != T_t: + if T_f < T_t: + x = F.pad(x, (0, 0, 0, T_t - T_f), value=0) + else: + x_temporal = F.pad(x_temporal, (0, 0, 0, T_f - T_t), value=0) + x = torch.cat((x, x_temporal), dim=-1) x = self.trans_temporal_encoder(x) @@ -781,6 +789,10 @@ def __init__( def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation. + Accepts either per-channel (TFM-Tokenizer standard) or legacy single-stream: + - Per-channel: stft (B, C, F, T), signal (B, C, T) -> tokenizer sees (B*C, F, T) / (B*C, T), classifier gets (B, C, T). + - Legacy: stft (B, F, T), signal (B, T) -> classifier gets (B, 1, T). + Args: **kwargs: keyword arguments containing 'stft', 'signal', and label key. @@ -793,28 +805,45 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: if stft is None or signal is None: raise ValueError("Both 'stft' and 'signal' must be provided in inputs") - stft = stft.to(self.device) - signal = signal.to(self.device) + stft = stft.to(self.device, dtype=torch.float32) + signal = signal.to(self.device, dtype=torch.float32) + + per_channel = stft.dim() == 4 + if per_channel: + B, C, n_freq, T = stft.shape + stft_flat = rearrange(stft, "B C F T -> (B C) F T") + signal_flat = rearrange(signal, "B C T -> (B C) T") + else: + stft_flat = stft + signal_flat = signal + + reconstructed, tokens, quant_out, quant_in = self.tokenizer(stft_flat, signal_flat) - reconstructed, tokens, quant_out, quant_in = self.tokenizer(stft, signal) + if per_channel: + recon_loss = F.mse_loss(reconstructed, stft_flat) + tokens_reshaped = rearrange(tokens, "(B C) T -> B C T", B=B, C=C) + else: + recon_loss = F.mse_loss(reconstructed, stft_flat) + tokens_reshaped = tokens.unsqueeze(1) - recon_loss = F.mse_loss(reconstructed, stft) vq_loss, _, _ = self.tokenizer.vec_quantizer_loss(quant_in, quant_out) results = { "recon_loss": recon_loss, "vq_loss": vq_loss, - "tokens": tokens, + "tokens": tokens_reshaped, "embeddings": quant_out, } + if kwargs.get("embed", False): + if per_channel: + results["embed"] = quant_out.reshape(B, C, -1, quant_out.size(-1)).mean(dim=1) + else: + results["embed"] = quant_out.mean(dim=1) if self.use_classifier and len(self.label_keys) > 0: label_key = self.label_keys[0] y_true = kwargs[label_key].to(self.device) - # Reshape tokens to (B, C, T) for multi-channel classifier - # tokens shape: (B, T) -> (B, 1, T) - tokens_reshaped = tokens.unsqueeze(1) logits = self.classifier(tokens_reshaped) loss_fn = self.get_loss_function() cls_loss = loss_fn(logits, y_true) @@ -837,12 +866,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: def get_embeddings(self, dataloader) -> torch.Tensor: """Extract continuous embeddings for all samples in a dataloader. - - Args: - dataloader: PyHealth dataloader. - - Returns: - tensor of shape (n_samples, seq_len, emb_size). + With per-channel input (stft 4D, signal 3D), returns (n_samples, seq_len, emb_size) by mean-pooling over channels. """ self.eval() all_embeddings = [] @@ -851,19 +875,23 @@ def get_embeddings(self, dataloader) -> torch.Tensor: for batch in dataloader: stft = batch.get("stft").to(self.device) signal = batch.get("signal").to(self.device) - _, _, quant_out, _ = self.tokenizer(stft, signal) + per_channel = stft.dim() == 4 + if per_channel: + B, C, n_freq, T = stft.shape + stft_flat = rearrange(stft, "B C F T -> (B C) F T") + signal_flat = rearrange(signal, "B C T -> (B C) T") + else: + stft_flat, signal_flat = stft, signal + _, _, quant_out, _ = self.tokenizer(stft_flat, signal_flat) + if per_channel: + quant_out = quant_out.reshape(B, C, -1, quant_out.size(-1)).mean(dim=1) all_embeddings.append(quant_out.cpu()) return torch.cat(all_embeddings, dim=0) def get_tokens(self, dataloader) -> torch.Tensor: """Extract discrete tokens for all samples in a dataloader. - - Args: - dataloader: PyHealth dataloader. - - Returns: - tensor of shape (n_samples, seq_len). + With per-channel input, returns (n_samples, n_channels, seq_len). """ self.eval() all_tokens = [] @@ -872,7 +900,16 @@ def get_tokens(self, dataloader) -> torch.Tensor: for batch in dataloader: stft = batch.get("stft").to(self.device) signal = batch.get("signal").to(self.device) - _, tokens, _, _ = self.tokenizer(stft, signal) + per_channel = stft.dim() == 4 + if per_channel: + B, C, n_freq, T = stft.shape + stft_flat = rearrange(stft, "B C F T -> (B C) F T") + signal_flat = rearrange(signal, "B C T -> (B C) T") + else: + stft_flat, signal_flat = stft, signal + _, tokens, _, _ = self.tokenizer(stft_flat, signal_flat) + if per_channel: + tokens = rearrange(tokens, "(B C) T -> B C T", B=B, C=C) all_tokens.append(tokens.cpu()) return torch.cat(all_tokens, dim=0) diff --git a/pyhealth/utils.py b/pyhealth/utils.py index b4af8980a..efbee3977 100644 --- a/pyhealth/utils.py +++ b/pyhealth/utils.py @@ -21,8 +21,7 @@ def set_seed(seed): def create_directory(directory): - if not os.path.exists(directory): - os.makedirs(directory) + os.makedirs(directory, exist_ok=True) def load_pickle(filename): diff --git a/set_size_plot.png b/set_size_plot.png new file mode 100644 index 000000000..61c638378 Binary files /dev/null and b/set_size_plot.png differ