Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ resource*
debug_entry*
playground.ipynb

# External TFM-Tokenizer (do not commit)
TFM-Tokenizer/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
Binary file added contrawr_coverage_1x2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contrawr_coverage_solid.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contrawr_results_1x2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contrawr_setsize_1x2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contrawr_setsize_solid.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added coverage_plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
41 changes: 41 additions & 0 deletions examples/conformal_eeg/RUN_TFM_32_COMMANDS.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# TFM: 32 jobs in parallel (same style as ContraWR). Run from repo root.
# Logs go to logs/ (bare filenames are written under repo root logs/).
# Default tokenizer + TUEV/TUAB classifier paths are set in the scripts; no need to pass them.

# --- Events (TUEV), 16 jobs ---
CUDA_VISIBLE_DEVICES=0 python examples/conformal_eeg/tuev_naive_cp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.2 --log-file naive_cp_alpha02_events_seeds5.log &
CUDA_VISIBLE_DEVICES=1 python examples/conformal_eeg/tuev_naive_cp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.1 --log-file naive_cp_alpha01_events_seeds5.log &
CUDA_VISIBLE_DEVICES=2 python examples/conformal_eeg/tuev_naive_cp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.05 --log-file naive_cp_alpha005_events_seeds5.log &
CUDA_VISIBLE_DEVICES=3 python examples/conformal_eeg/tuev_naive_cp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.01 --log-file naive_cp_alpha001_events_seeds5.log &
CUDA_VISIBLE_DEVICES=4 python examples/conformal_eeg/tuev_kde_cp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.2 --log-file kde_cp_alpha02_events_seeds5.log &
CUDA_VISIBLE_DEVICES=5 python examples/conformal_eeg/tuev_kde_cp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.1 --log-file kde_cp_alpha01_events_seeds5.log &
CUDA_VISIBLE_DEVICES=6 python examples/conformal_eeg/tuev_kde_cp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.05 --log-file kde_cp_alpha005_events_seeds5.log &
CUDA_VISIBLE_DEVICES=7 python examples/conformal_eeg/tuev_kde_cp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.01 --log-file kde_cp_alpha001_events_seeds5.log &
CUDA_VISIBLE_DEVICES=0 python examples/conformal_eeg/tuev_kmeans_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.2 --log-file kmeans_cp_alpha02_events_seeds5.log &
CUDA_VISIBLE_DEVICES=1 python examples/conformal_eeg/tuev_kmeans_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.1 --log-file kmeans_cp_alpha01_events_seeds5.log &
CUDA_VISIBLE_DEVICES=2 python examples/conformal_eeg/tuev_kmeans_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.05 --log-file kmeans_cp_alpha005_events_seeds5.log &
CUDA_VISIBLE_DEVICES=3 python examples/conformal_eeg/tuev_kmeans_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.01 --log-file kmeans_cp_alpha001_events_seeds5.log &
CUDA_VISIBLE_DEVICES=4 python examples/conformal_eeg/tuev_ncp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.2 --log-file ncp_alpha02_events_seeds5.log &
CUDA_VISIBLE_DEVICES=5 python examples/conformal_eeg/tuev_ncp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.1 --log-file ncp_alpha01_events_seeds5.log &
CUDA_VISIBLE_DEVICES=6 python examples/conformal_eeg/tuev_ncp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.05 --log-file ncp_alpha005_events_seeds5.log &
CUDA_VISIBLE_DEVICES=7 python examples/conformal_eeg/tuev_ncp_conformal.py --dataset tuev --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.01 --log-file ncp_alpha001_events_seeds5.log &

# --- Abnormal (TUAB), 16 jobs. Use --dataset tuab so task and default classifier path are correct. ---
CUDA_VISIBLE_DEVICES=0 python examples/conformal_eeg/tuev_naive_cp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.2 --log-file naive_cp_alpha02_abnormal_seeds5.log &
CUDA_VISIBLE_DEVICES=1 python examples/conformal_eeg/tuev_naive_cp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.1 --log-file naive_cp_alpha01_abnormal_seeds5.log &
CUDA_VISIBLE_DEVICES=2 python examples/conformal_eeg/tuev_naive_cp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.05 --log-file naive_cp_alpha005_abnormal_seeds5.log &
CUDA_VISIBLE_DEVICES=3 python examples/conformal_eeg/tuev_naive_cp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.01 --log-file naive_cp_alpha001_abnormal_seeds5.log &
CUDA_VISIBLE_DEVICES=4 python examples/conformal_eeg/tuev_kde_cp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.2 --log-file kde_cp_alpha02_abnormal_seeds5.log &
CUDA_VISIBLE_DEVICES=5 python examples/conformal_eeg/tuev_kde_cp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.1 --log-file kde_cp_alpha01_abnormal_seeds5.log &
CUDA_VISIBLE_DEVICES=6 python examples/conformal_eeg/tuev_kde_cp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.05 --log-file kde_cp_alpha005_abnormal_seeds5.log &
CUDA_VISIBLE_DEVICES=7 python examples/conformal_eeg/tuev_kde_cp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.01 --log-file kde_cp_alpha001_abnormal_seeds5.log &
CUDA_VISIBLE_DEVICES=0 python examples/conformal_eeg/tuev_kmeans_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.2 --log-file kmeans_cp_alpha02_abnormal_seeds5.log &
CUDA_VISIBLE_DEVICES=1 python examples/conformal_eeg/tuev_kmeans_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.1 --log-file kmeans_cp_alpha01_abnormal_seeds5.log &
CUDA_VISIBLE_DEVICES=2 python examples/conformal_eeg/tuev_kmeans_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.05 --log-file kmeans_cp_alpha005_abnormal_seeds5.log &
CUDA_VISIBLE_DEVICES=3 python examples/conformal_eeg/tuev_kmeans_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.01 --log-file kmeans_cp_alpha001_abnormal_seeds5.log &
CUDA_VISIBLE_DEVICES=4 python examples/conformal_eeg/tuev_ncp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.2 --log-file ncp_alpha02_abnormal_seeds5.log &
CUDA_VISIBLE_DEVICES=5 python examples/conformal_eeg/tuev_ncp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.1 --log-file ncp_alpha01_abnormal_seeds5.log &
CUDA_VISIBLE_DEVICES=6 python examples/conformal_eeg/tuev_ncp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.05 --log-file ncp_alpha005_abnormal_seeds5.log &
CUDA_VISIBLE_DEVICES=7 python examples/conformal_eeg/tuev_ncp_conformal.py --dataset tuab --model tfm --tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 --alpha 0.01 --log-file ncp_alpha001_abnormal_seeds5.log &

wait
158 changes: 158 additions & 0 deletions examples/conformal_eeg/model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""Shared helpers for TUEV conformal scripts: model choice (ContraWR vs TFM-Tokenizer) and STFT dataset wrapper.

TFM runs use the same experimental protocol as ContraWR (split_seed, run_seeds, alpha,
ratios, fixed test set) so results are directly comparable.

TFM loading options:
- Single checkpoint: --tfm-checkpoint (full model or tokenizer-only; use {seed} for per-seed).
- Two checkpoints (pretrained tokenizer + finetuned classifier): --tfm-tokenizer-checkpoint
and --tfm-classifier-checkpoint (use {seed} in classifier path for per-seed). Use
--tfm-skip-train to run calibration + inference only.

Scripts support --dataset tuev|tuab (same protocol; TUAB uses binary task, TUEV uses multiclass).
"""

from __future__ import annotations

import numpy as np
import torch

from pyhealth.models import ContraWR, TFMTokenizer


RESAMPLING_RATE = 200 # TFM-Tokenizer standard; n_fft=200, hop_length=100, 100 freq bins


def get_stft_torch(X: torch.Tensor, resampling_rate: int = RESAMPLING_RATE) -> torch.Tensor:
"""Per-channel magnitude STFT matching TFM-Tokenizer repo. Input (B, C, T) -> output (B, C, 100, T')."""
B, C, T = X.shape
x_temp = X.reshape(B * C, T)
window = torch.hann_window(resampling_rate, device=X.device, dtype=X.dtype)
stft_complex = torch.stft(
x_temp,
n_fft=resampling_rate,
hop_length=resampling_rate // 2,
window=window,
onesided=True,
return_complex=True,
center=False,
)
# (B*C, n_fft//2+1, T') -> take first 100 freq bins
x_stft_temp = torch.abs(stft_complex)[:, : resampling_rate // 2, :]
x_stft_temp = x_stft_temp.reshape(B, C, resampling_rate // 2, -1)
return x_stft_temp


class AddSTFTDataset:
"""Wraps a TUEV/TUAB task dataset to add per-channel 'stft' for TFM-Tokenizer.
Keeps 'signal' as (C, T); adds 'stft' as (C, 100, T') with n_fft=200, hop_length=100.
Matches the original TFM-Tokenizer training pipeline (16 token sequences per sample).
"""

def __init__(self, base, n_fft: int = 200, hop_length: int = 100):
self._base = base
self.n_fft = n_fft
self.hop_length = hop_length
self.input_schema = {**getattr(base, "input_schema", {}), "stft": "tensor"}
self.output_schema = getattr(base, "output_schema", {})

@property
def output_processors(self):
"""Forward so BaseModel.get_output_size() works (TFMTokenizer/ContraWR)."""
return getattr(self._base, "output_processors", {})

@property
def input_processors(self):
"""Forward in case the model reads input_processors."""
return getattr(self._base, "input_processors", {})

def __len__(self) -> int:
return len(self._base)

def __getitem__(self, i: int):
sample = dict(self._base[i])
signal = sample["signal"]
signal = np.asarray(signal, dtype=np.float32)
if signal.ndim == 1:
signal = signal.reshape(1, -1)
# Normalize by 95th percentile of |signal| per channel (axis=-1), matching TFM training
scale = np.quantile(
np.abs(signal), q=0.95, axis=-1, method="linear", keepdims=True
) + 1e-8
signal = np.asarray(signal / scale, dtype=np.float32)
# signal (C, T) -> tensor (float32 for tokenizer)
signal_t = torch.from_numpy(signal)
sample["signal"] = signal_t
# Per-channel STFT: (1, C, T) -> (1, C, 100, T')
stft = get_stft_torch(signal_t.unsqueeze(0), resampling_rate=self.n_fft)
sample["stft"] = stft.squeeze(0)
return sample

def subset(self, indices):
return AddSTFTDataset(self._base.subset(indices), self.n_fft, self.hop_length)

def set_shuffle(self, shuffle: bool) -> None:
"""Forward to base dataset so get_dataloader() works (pyhealth.datasets.utils)."""
if hasattr(self._base, "set_shuffle"):
self._base.set_shuffle(shuffle)


def _load_tfm_checkpoint(model, checkpoint_path: str, device: str):
"""Load TFM checkpoint: full model state_dict or tokenizer-only (legacy)."""
ckpt = torch.load(checkpoint_path, map_location=device)
state = ckpt.get("model_state_dict", ckpt.get("state_dict", ckpt))
if not isinstance(state, dict):
model.load_pretrained_weights(checkpoint_path, map_location=device)
return
keys = list(state.keys())
if any(str(k).startswith("tokenizer.") or str(k).startswith("classifier.") for k in keys):
model.load_state_dict(state, strict=False)
print(f" Loaded full model from {checkpoint_path}")
else:
model.load_pretrained_weights(checkpoint_path, map_location=device)


def _load_tfm_classifier_checkpoint(model, checkpoint_path: str, device: str):
"""Load classifier-only checkpoint into model.classifier. Handles keys with or without 'classifier.' prefix."""
ckpt = torch.load(checkpoint_path, map_location=device)
state = ckpt.get("model_state_dict", ckpt.get("state_dict", ckpt))
if not isinstance(state, dict):
return
keys = list(state.keys())
if any(str(k).startswith("classifier.") for k in keys):
model.load_state_dict(state, strict=False)
print(f" Loaded classifier from {checkpoint_path}")
else:
model.classifier.load_state_dict(state, strict=False)
print(f" Loaded classifier from {checkpoint_path}")


def get_model(args, sample_dataset, device: str):
"""Build ContraWR or TFMTokenizer from args.model. Use sample_dataset (possibly AddSTFTDataset for TFM).
Loading options (TFM):
- args.tfm_checkpoint: single path (full model or tokenizer-only).
- args.tfm_tokenizer_checkpoint + args.tfm_classifier_checkpoint: pretrained tokenizer + per-seed finetuned classifier.
"""
if getattr(args, "model", "contrawr").lower() == "tfm":
model = TFMTokenizer(
dataset=sample_dataset,
n_freq=100,
emb_size=getattr(args, "tfm_emb_size", 64),
code_book_size=getattr(args, "tfm_code_book_size", 8192),
)
model = model.to(device)
tokenizer_ckpt = getattr(args, "tfm_tokenizer_checkpoint", None)
classifier_ckpt = getattr(args, "tfm_classifier_checkpoint", None)
single_ckpt = getattr(args, "tfm_checkpoint", None)
if tokenizer_ckpt and classifier_ckpt:
model.load_pretrained_weights(tokenizer_ckpt, map_location=device)
_load_tfm_classifier_checkpoint(model, classifier_ckpt, device)
elif single_ckpt:
_load_tfm_checkpoint(model, single_ckpt, device)
if getattr(args, "tfm_freeze_tokenizer", False):
for p in model.tokenizer.parameters():
p.requires_grad = False
return model
else:
model = ContraWR(dataset=sample_dataset, n_fft=getattr(args, "n_fft", 128))
return model.to(device)
98 changes: 98 additions & 0 deletions examples/conformal_eeg/run_tfm_grid_8gpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#!/usr/bin/env bash
# Run full TFM conformal grid on 8 GPUs (4 alphas × 2 datasets × 4 methods = 32 jobs).
# Dynamic assignment: first free GPU takes the next job (faster when runtimes vary).
# Usage: set TOK, TUEV_CLF, TUAB_CLF below, then: bash run_tfm_grid_8gpu.sh
# From repo root: bash examples/conformal_eeg/run_tfm_grid_8gpu.sh
# Run 8 at a time (4 waves): WAVES_OF_8=1 bash examples/conformal_eeg/run_tfm_grid_8gpu.sh

set -e
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR/../.."

# --- PATHS (structure: .../TFM_Tokenizer_multiple_finetuned_on_TUEV/TFM_Tokenizer_multiple_finetuned_on_TUEV_{seed}/best_model.pth) ---
TOK="${TOK:-/srv/local/data/arjunc4/tfm_tokenizer_last.pth}"
# Escape } in {seed} so bash does not close ${VAR:-default} at the first }
TUEV_CLF="${TUEV_CLF:-/srv/local/data/arjunc4/TFM_Tokenizer_multiple_finetuned_on_TUEV/TFM_Tokenizer_multiple_finetuned_on_TUEV_{seed\}/best_model.pth}"
TUAB_CLF="${TUAB_CLF:-/srv/local/data/arjunc4/TFM_Tokenizer_multiple_finetuned_on_TUAB/TFM_Tokenizer_multiple_finetuned_on_TUAB_{seed\}/best_model.pth}"
# Logs go to repo root logs/ (create if missing)
REPO_ROOT="$SCRIPT_DIR/../.."
LOG_DIR="${LOG_DIR:-$REPO_ROOT/logs}"
mkdir -p "$LOG_DIR"

run_one() {
local gpu="$1"
local method="$2"
local dataset="$3"
local alpha="$4"
local script
case "$method" in
naive) script="tuev_naive_cp_conformal.py" ;;
kde) script="tuev_kde_cp_conformal.py" ;;
kmeans) script="tuev_kmeans_conformal.py" ;;
ncp) script="tuev_ncp_conformal.py" ;;
*) echo "Unknown method $method"; exit 1 ;;
esac
local clf_path
[ "$dataset" = tuev ] && clf_path="$TUEV_CLF" || clf_path="$TUAB_CLF"
CUDA_VISIBLE_DEVICES=$gpu python "examples/conformal_eeg/$script" \
--dataset "$dataset" --model tfm --alpha "$alpha" \
--tfm-tokenizer-checkpoint "$TOK" \
--tfm-classifier-checkpoint "$clf_path" \
--tfm-skip-train --seeds 1,2,3,4,5 --split-seed 0 \
--log-file "$LOG_DIR/${method}_${dataset}_alpha${alpha}.log"
}

# Build list of jobs: method dataset alpha (32 jobs)
jobs=()
for method in naive kde kmeans ncp; do
for dataset in tuev tuab; do
for alpha in 0.2 0.1 0.05 0.01; do
jobs+=("$method $dataset $alpha")
done
done
done
total=${#jobs[@]}

if [[ -n "${WAVES_OF_8:-}" ]]; then
# Run 4 waves of 8 jobs: wait for each wave to finish before starting the next
for wave in 0 1 2 3; do
start=$((wave * 8))
echo "=== Wave $((wave + 1))/4 (jobs $((start + 1))–$((start + 8))) ==="
for gpu in 0 1 2 3 4 5 6 7; do
idx=$((start + gpu))
read -r method dataset alpha <<< "${jobs[idx]}"
run_one "$gpu" "$method" "$dataset" "$alpha" &
done
wait
echo "Wave $((wave + 1)) done."
done
echo "All 32 jobs done. Logs in $LOG_DIR"
else
# Dynamic assignment: shared job counter, first free GPU claims next job
NEXT_JOB_FILE=$(mktemp)
LOCK_FILE=$(mktemp)
echo 0 > "$NEXT_JOB_FILE"
cleanup() { rm -f "$NEXT_JOB_FILE" "$LOCK_FILE"; }
trap cleanup EXIT

claim_next_index() {
flock 200 bash -c "read -r idx < \"$NEXT_JOB_FILE\"; echo \$((idx + 1)) > \"$NEXT_JOB_FILE\"; printf '%s\n' \"\$idx\""
} 200>>"$LOCK_FILE"

worker() {
local gpu=$1
while true; do
local job_idx
job_idx=$(claim_next_index)
[[ "$job_idx" -ge "$total" ]] && break
read -r method dataset alpha <<< "${jobs[job_idx]}"
run_one "$gpu" "$method" "$dataset" "$alpha"
done
}

for gpu in 0 1 2 3 4 5 6 7; do
worker "$gpu" &
done
wait
echo "All 32 jobs done. Logs in $LOG_DIR"
fi
Loading