From 32ea8f069a7836b90f817558ab0babbd58151644 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 31 Mar 2026 17:40:14 +0200 Subject: [PATCH 1/3] feat: add EEG-GCNN dataset and neurological disease detection task MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements the EEG-GCNN paper (Wagh & Varatharajah, ML4H @ NeurIPS 2020) as a PyHealth 2.0 contribution with dataset, task, tests, docs, and ablation study script. New files: - pyhealth/datasets/eeg_gcnn.py — EEGGCNNDataset (TUAB normal + MPI LEMON) - pyhealth/datasets/configs/eeg_gcnn.yaml — YAML config - pyhealth/tasks/eeg_gcnn_nd_detection.py — EEGGCNNDiseaseDetection task (PSD features, spatial/functional/combined adjacency, configurable bands) - tests/test_eeg_gcnn.py — 23 tests using synthetic data - examples/eeg_gcnn_nd_detection_gcn.py — 3 ablation experiments - RST docs for dataset and task Co-Authored-By: Claude Opus 4.6 --- docs/api/datasets.rst | 1 + .../pyhealth.datasets.EEGGCNNDataset.rst | 12 + docs/api/tasks.rst | 1 + .../pyhealth.tasks.eeg_gcnn_nd_detection.rst | 16 + examples/eeg_gcnn_nd_detection_gcn.py | 284 +++++++++++++ pyhealth/datasets/__init__.py | 1 + pyhealth/datasets/configs/eeg_gcnn.yaml | 20 + pyhealth/datasets/eeg_gcnn.py | 236 +++++++++++ pyhealth/tasks/__init__.py | 1 + pyhealth/tasks/eeg_gcnn_nd_detection.py | 359 +++++++++++++++++ tests/test_eeg_gcnn.py | 372 ++++++++++++++++++ 11 files changed, 1303 insertions(+) create mode 100644 docs/api/datasets/pyhealth.datasets.EEGGCNNDataset.rst create mode 100644 docs/api/tasks/pyhealth.tasks.eeg_gcnn_nd_detection.rst create mode 100644 examples/eeg_gcnn_nd_detection_gcn.py create mode 100644 pyhealth/datasets/configs/eeg_gcnn.yaml create mode 100644 pyhealth/datasets/eeg_gcnn.py create mode 100644 pyhealth/tasks/eeg_gcnn_nd_detection.py create mode 100644 tests/test_eeg_gcnn.py diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index b02439d26..165dd50a9 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -230,6 +230,7 @@ Available Datasets datasets/pyhealth.datasets.ISRUCDataset datasets/pyhealth.datasets.MIMICExtractDataset datasets/pyhealth.datasets.OMOPDataset + datasets/pyhealth.datasets.EEGGCNNDataset datasets/pyhealth.datasets.DREAMTDataset datasets/pyhealth.datasets.SHHSDataset datasets/pyhealth.datasets.SleepEDFDataset diff --git a/docs/api/datasets/pyhealth.datasets.EEGGCNNDataset.rst b/docs/api/datasets/pyhealth.datasets.EEGGCNNDataset.rst new file mode 100644 index 000000000..512816f7a --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.EEGGCNNDataset.rst @@ -0,0 +1,12 @@ +pyhealth.datasets.EEGGCNNDataset +=================================== + +Dataset for the EEG-GCNN paper (Wagh & Varatharajah, ML4H @ NeurIPS 2020). +Pools the TUAB normal-subset (patient EEGs) and MPI LEMON (healthy controls). + +Paper: https://proceedings.mlr.press/v136/wagh20a.html + +.. autoclass:: pyhealth.datasets.EEGGCNNDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index d85d04bc3..852d6b6d0 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -212,6 +212,7 @@ Available Tasks COVID-19 CXR Classification DKA Prediction (MIMIC-IV) Drug Recommendation + EEG-GCNN Disease Detection EEG Abnormal EEG Events Length of Stay Prediction diff --git a/docs/api/tasks/pyhealth.tasks.eeg_gcnn_nd_detection.rst b/docs/api/tasks/pyhealth.tasks.eeg_gcnn_nd_detection.rst new file mode 100644 index 000000000..de85c8213 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.eeg_gcnn_nd_detection.rst @@ -0,0 +1,16 @@ +pyhealth.tasks.eeg_gcnn_nd_detection +========================================= + +Neurological disease detection task from the EEG-GCNN paper (Wagh & Varatharajah, ML4H @ NeurIPS 2020). + +Binary classification: patient-normal (TUAB) vs healthy-control (LEMON). + +The task extracts PSD band-power features and graph adjacency matrices from EEG recordings, +supporting configurable adjacency types, frequency bands, and connectivity measures for ablation studies. + +Paper: https://proceedings.mlr.press/v136/wagh20a.html + +.. autoclass:: pyhealth.tasks.eeg_gcnn_nd_detection.EEGGCNNDiseaseDetection + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/eeg_gcnn_nd_detection_gcn.py b/examples/eeg_gcnn_nd_detection_gcn.py new file mode 100644 index 000000000..23fa27d09 --- /dev/null +++ b/examples/eeg_gcnn_nd_detection_gcn.py @@ -0,0 +1,284 @@ +"""EEG-GCNN Neurological Disease Detection — Ablation Study. + +This script reproduces the ablation experiments from: + + Wagh, N. & Varatharajah, Y. (2020). EEG-GCNN: Augmenting + Electroencephalogram-based Neurological Disease Diagnosis using a + Domain-guided Graph Convolutional Neural Network. ML4H @ NeurIPS 2020. + https://proceedings.mlr.press/v136/wagh20a.html + +Three ablation experiments are included: + 1. Adjacency type: combined vs spatial vs functional vs none + 2. Frequency band ablation: individual bands & progressive combinations + 3. Connectivity measure: coherence vs WPLI + +Usage: + python examples/eeg_gcnn_nd_detection_gcn.py --root /path/to/eeg-gcnn-data +""" + +import argparse +import json +import logging +from collections import defaultdict +from typing import Dict, List, Tuple + +import numpy as np +import torch +from sklearn.metrics import roc_auc_score +from torch.utils.data import DataLoader + +from pyhealth.datasets import EEGGCNNDataset, get_dataloader, split_by_patient +from pyhealth.models.gnn import GCN +from pyhealth.tasks import EEGGCNNDiseaseDetection +from pyhealth.tasks.eeg_gcnn_nd_detection import DEFAULT_BANDS + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +# ------------------------------------------------------------------- +# Training / evaluation helpers +# ------------------------------------------------------------------- + +def train_one_epoch( + model: torch.nn.Module, + loader: DataLoader, + optimizer: torch.optim.Optimizer, + device: torch.device, +) -> float: + """Train for one epoch, return average loss.""" + model.train() + total_loss = 0.0 + n_batches = 0 + for batch in loader: + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v + for k, v in batch.items()} + optimizer.zero_grad() + output = model(**batch) + loss = output["loss"] + loss.backward() + optimizer.step() + total_loss += loss.item() + n_batches += 1 + return total_loss / max(n_batches, 1) + + +@torch.no_grad() +def evaluate( + model: torch.nn.Module, + loader: DataLoader, + device: torch.device, +) -> Dict[str, float]: + """Evaluate the model and return AUC and loss.""" + model.eval() + all_probs, all_labels = [], [] + total_loss = 0.0 + n_batches = 0 + for batch in loader: + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v + for k, v in batch.items()} + output = model(**batch) + total_loss += output["loss"].item() + n_batches += 1 + probs = output["y_prob"].cpu().numpy() + labels = output["y_true"].cpu().numpy() + all_probs.append(probs) + all_labels.append(labels) + + all_probs = np.concatenate(all_probs) + all_labels = np.concatenate(all_labels) + try: + auc = roc_auc_score(all_labels, all_probs[:, 1]) + except ValueError: + auc = 0.5 + return { + "auc": auc, + "loss": total_loss / max(n_batches, 1), + } + + +def run_experiment( + dataset: EEGGCNNDataset, + task: EEGGCNNDiseaseDetection, + epochs: int = 20, + batch_size: int = 32, + lr: float = 1e-3, +) -> Dict[str, float]: + """Run a single experiment: set_task, split, train, evaluate.""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + sample_dataset = dataset.set_task(task) + train_ds, val_ds, test_ds = split_by_patient( + sample_dataset, [0.7, 0.1, 0.2] + ) + + train_loader = get_dataloader(train_ds, batch_size=batch_size, shuffle=True) + val_loader = get_dataloader(val_ds, batch_size=batch_size, shuffle=False) + test_loader = get_dataloader(test_ds, batch_size=batch_size, shuffle=False) + + model = GCN( + dataset=sample_dataset, + embedding_dim=64, + nhid=32, + dropout=0.5, + num_layers=2, + ).to(device) + + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + + best_val_auc = 0.0 + best_state = None + for epoch in range(epochs): + train_loss = train_one_epoch(model, train_loader, optimizer, device) + val_metrics = evaluate(model, val_loader, device) + logger.info( + "Epoch %d/%d — train_loss=%.4f val_auc=%.4f", + epoch + 1, epochs, train_loss, val_metrics["auc"], + ) + if val_metrics["auc"] > best_val_auc: + best_val_auc = val_metrics["auc"] + best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} + + if best_state is not None: + model.load_state_dict(best_state) + test_metrics = evaluate(model, test_loader, device) + logger.info("Test AUC: %.4f", test_metrics["auc"]) + return test_metrics + + +# ------------------------------------------------------------------- +# Ablation experiments +# ------------------------------------------------------------------- + +def ablation_adjacency( + dataset: EEGGCNNDataset, **kwargs +) -> Dict[str, Dict[str, float]]: + """Experiment 1: Adjacency type ablation.""" + results = {} + for adj_type in ("combined", "spatial", "functional", "none"): + logger.info("=== Adjacency type: %s ===", adj_type) + task = EEGGCNNDiseaseDetection(adjacency_type=adj_type) + results[adj_type] = run_experiment(dataset, task, **kwargs) + return results + + +def ablation_frequency_bands( + dataset: EEGGCNNDataset, **kwargs +) -> Dict[str, Dict[str, float]]: + """Experiment 2: Frequency band ablation. + + Tests individual bands and progressive combinations. + """ + band_names = list(DEFAULT_BANDS.keys()) + results = {} + + # Individual bands + for name in band_names: + logger.info("=== Single band: %s ===", name) + task = EEGGCNNDiseaseDetection( + bands={name: DEFAULT_BANDS[name]} + ) + results[name] = run_experiment(dataset, task, **kwargs) + + # Progressive combinations: delta, delta+theta, delta+theta+alpha, ... + for k in range(2, len(band_names) + 1): + combo_names = band_names[:k] + combo_key = "+".join(combo_names) + logger.info("=== Band combination: %s ===", combo_key) + combo_bands = {n: DEFAULT_BANDS[n] for n in combo_names} + task = EEGGCNNDiseaseDetection(bands=combo_bands) + results[combo_key] = run_experiment(dataset, task, **kwargs) + + return results + + +def ablation_connectivity( + dataset: EEGGCNNDataset, **kwargs +) -> Dict[str, Dict[str, float]]: + """Experiment 3: Connectivity measure ablation.""" + results = {} + for measure in ("coherence", "wpli"): + logger.info("=== Connectivity: %s ===", measure) + task = EEGGCNNDiseaseDetection( + adjacency_type="functional", + connectivity_measure=measure, + ) + results[measure] = run_experiment(dataset, task, **kwargs) + return results + + +# ------------------------------------------------------------------- +# Main +# ------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="EEG-GCNN ablation study" + ) + parser.add_argument( + "--root", type=str, required=True, + help="Root directory of EEG-GCNN data (TUAB + LEMON)", + ) + parser.add_argument("--epochs", type=int, default=20) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument( + "--experiment", + type=str, + default="all", + choices=["all", "adjacency", "bands", "connectivity"], + help="Which ablation experiment to run", + ) + args = parser.parse_args() + + dataset = EEGGCNNDataset(root=args.root) + dataset.stats() + + train_kwargs = dict( + epochs=args.epochs, + batch_size=args.batch_size, + lr=args.lr, + ) + + all_results = {} + + if args.experiment in ("all", "adjacency"): + logger.info("\n" + "=" * 60) + logger.info("EXPERIMENT 1: Adjacency Type Ablation") + logger.info("=" * 60) + all_results["adjacency"] = ablation_adjacency(dataset, **train_kwargs) + + if args.experiment in ("all", "bands"): + logger.info("\n" + "=" * 60) + logger.info("EXPERIMENT 2: Frequency Band Ablation") + logger.info("=" * 60) + all_results["bands"] = ablation_frequency_bands( + dataset, **train_kwargs + ) + + if args.experiment in ("all", "connectivity"): + logger.info("\n" + "=" * 60) + logger.info("EXPERIMENT 3: Connectivity Measure Ablation") + logger.info("=" * 60) + all_results["connectivity"] = ablation_connectivity( + dataset, **train_kwargs + ) + + # Print summary + logger.info("\n" + "=" * 60) + logger.info("ABLATION RESULTS SUMMARY") + logger.info("=" * 60) + for exp_name, exp_results in all_results.items(): + logger.info("\n--- %s ---", exp_name) + for config, metrics in exp_results.items(): + logger.info(" %-30s AUC=%.4f", config, metrics["auc"]) + + # Save results to JSON + output_path = "eeg_gcnn_ablation_results.json" + with open(output_path, "w") as f: + json.dump(all_results, f, indent=2) + logger.info("\nResults saved to %s", output_path) + + +if __name__ == "__main__": + main() diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 7ac05f259..a5e54f58c 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -77,6 +77,7 @@ def __init__(self, *args, **kwargs): split_by_visit, split_by_visit_conformal, ) +from .eeg_gcnn import EEGGCNNDataset from .tuab import TUABDataset from .tuev import TUEVDataset from .utils import ( diff --git a/pyhealth/datasets/configs/eeg_gcnn.yaml b/pyhealth/datasets/configs/eeg_gcnn.yaml new file mode 100644 index 000000000..210cb3908 --- /dev/null +++ b/pyhealth/datasets/configs/eeg_gcnn.yaml @@ -0,0 +1,20 @@ +version: "1.0.0" +tables: + tuab: + file_path: "eeg_gcnn-tuab-pyhealth.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "record_id" + - "signal_file" + - "source" + - "label" + lemon: + file_path: "eeg_gcnn-lemon-pyhealth.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "record_id" + - "signal_file" + - "source" + - "label" diff --git a/pyhealth/datasets/eeg_gcnn.py b/pyhealth/datasets/eeg_gcnn.py new file mode 100644 index 000000000..746ae422b --- /dev/null +++ b/pyhealth/datasets/eeg_gcnn.py @@ -0,0 +1,236 @@ +import logging +from pathlib import Path +from typing import Optional + +import pandas as pd + +from .base_dataset import BaseDataset + +logger = logging.getLogger(__name__) + + +class EEGGCNNDataset(BaseDataset): + """EEG-GCNN dataset pooling TUAB normal-subset and MPI LEMON controls. + + This dataset supports the EEG-GCNN paper (Wagh & Varatharajah, ML4H @ + NeurIPS 2020) which distinguishes "normal-appearing" patient EEGs (from + TUAB) from truly healthy EEGs (from MPI LEMON). + + **TUAB (normal subset):** The Temple University EEG Abnormal Corpus + provides EDF recordings labelled normal/abnormal. Only the *normal* + recordings are used here — these are the "patient" class (label 0). + + **MPI LEMON:** The Leipzig Study for Mind-Body-Emotion Interactions + provides BrainVision EEG recordings from healthy controls — these form + the "healthy" class (label 1). + + Paper: + Wagh, N. & Varatharajah, Y. (2020). EEG-GCNN: Augmenting + Electroencephalogram-based Neurological Disease Diagnosis using a + Domain-guided Graph Convolutional Neural Network. *Proceedings of + Machine Learning for Health (ML4H) at NeurIPS 2020*, PMLR 136. + https://proceedings.mlr.press/v136/wagh20a.html + + Authors' code: https://github.com/neerajwagh/eeg-gcnn + + Args: + root: Root directory containing TUAB and/or LEMON data. + Expected structure for TUAB:: + + /train/normal/01_tcp_ar//*.edf + /eval/normal/01_tcp_ar//*.edf + + Expected structure for LEMON:: + + /lemon//*.vhdr + + dataset_name: Name of the dataset. Defaults to ``"eeg_gcnn"``. + config_path: Path to the YAML config. Defaults to the built-in + ``eeg_gcnn.yaml``. + subset: Which data source(s) to load. One of ``"tuab"``, + ``"lemon"``, or ``"both"`` (default). + dev: If ``True``, limit to a small subset for quick iteration. + + Attributes: + task: Optional task name after ``set_task()`` is called. + samples: Sample list after task is set. + patient_to_index: Maps patient IDs to sample indices. + visit_to_index: Maps visit/record IDs to sample indices. + + Examples: + >>> from pyhealth.datasets import EEGGCNNDataset + >>> from pyhealth.tasks import EEGGCNNDiseaseDetection + >>> dataset = EEGGCNNDataset( + ... root="/data/eeg-gcnn/", + ... ) + >>> dataset.stats() + >>> sample_dataset = dataset.set_task(EEGGCNNDiseaseDetection()) + >>> sample = sample_dataset[0] + >>> print(sample["psd_features"].shape) # (8, 6) + """ + + def __init__( + self, + root: str, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + subset: Optional[str] = "both", + **kwargs, + ) -> None: + if config_path is None: + config_path = ( + Path(__file__).parent / "configs" / "eeg_gcnn.yaml" + ) + + self.root = root + + if subset == "tuab": + tables = ["tuab"] + elif subset == "lemon": + tables = ["lemon"] + elif subset == "both": + tables = ["tuab", "lemon"] + else: + raise ValueError( + "subset must be one of 'tuab', 'lemon', or 'both'" + ) + + self.prepare_metadata() + + root_path = Path(root) + cache_dir = Path.home() / ".cache" / "pyhealth" / "eeg_gcnn" + + use_cache = False + for table in tables: + shared_csv = root_path / f"eeg_gcnn-{table}-pyhealth.csv" + cache_csv = cache_dir / f"eeg_gcnn-{table}-pyhealth.csv" + if not shared_csv.exists() and cache_csv.exists(): + use_cache = True + break + + if use_cache: + logger.info("Using cached metadata from %s", cache_dir) + root = str(cache_dir) + + super().__init__( + root=root, + tables=tables, + dataset_name=dataset_name or "eeg_gcnn", + config_path=config_path, + **kwargs, + ) + + def prepare_metadata(self) -> None: + """Build and save metadata CSVs for TUAB normal subset and LEMON. + + Writes: + - ``/eeg_gcnn-tuab-pyhealth.csv`` + - ``/eeg_gcnn-lemon-pyhealth.csv`` + + TUAB filenames: ``__.edf`` + LEMON filenames: ``sub-.vhdr`` (BrainVision header) + """ + root = Path(self.root) + cache_dir = Path.home() / ".cache" / "pyhealth" / "eeg_gcnn" + + # --- TUAB normal subset --- + shared_csv = root / "eeg_gcnn-tuab-pyhealth.csv" + cache_csv = cache_dir / "eeg_gcnn-tuab-pyhealth.csv" + + if not shared_csv.exists() and not cache_csv.exists(): + tuab_rows = [] + for split in ("train", "eval"): + normal_dir = root / split / "normal" / "01_tcp_ar" + if not normal_dir.is_dir(): + logger.debug( + "TUAB normal dir not found: %s", normal_dir + ) + continue + for edf in sorted(normal_dir.rglob("*.edf")): + parts = edf.stem.split("_") + patient_id = f"tuab_{parts[0]}" + record_id = parts[1] if len(parts) > 1 else "0" + tuab_rows.append( + { + "patient_id": patient_id, + "record_id": record_id, + "signal_file": str(edf), + "source": "tuab", + "label": 0, + } + ) + + if tuab_rows: + df = pd.DataFrame(tuab_rows) + df.sort_values( + ["patient_id", "record_id"], + inplace=True, + na_position="last", + ) + df.reset_index(drop=True, inplace=True) + self._write_csv(df, shared_csv, cache_dir, "tuab") + + # --- LEMON healthy controls --- + shared_csv = root / "eeg_gcnn-lemon-pyhealth.csv" + cache_csv = cache_dir / "eeg_gcnn-lemon-pyhealth.csv" + + if not shared_csv.exists() and not cache_csv.exists(): + lemon_rows = [] + lemon_dir = root / "lemon" + if lemon_dir.is_dir(): + for subject_dir in sorted(lemon_dir.iterdir()): + if not subject_dir.is_dir(): + continue + for vhdr in sorted(subject_dir.glob("*.vhdr")): + patient_id = f"lemon_{subject_dir.name}" + record_id = vhdr.stem + lemon_rows.append( + { + "patient_id": patient_id, + "record_id": record_id, + "signal_file": str(vhdr), + "source": "lemon", + "label": 1, + } + ) + + if lemon_rows: + df = pd.DataFrame(lemon_rows) + df.sort_values( + ["patient_id", "record_id"], + inplace=True, + na_position="last", + ) + df.reset_index(drop=True, inplace=True) + self._write_csv(df, shared_csv, cache_dir, "lemon") + + @staticmethod + def _write_csv( + df: "pd.DataFrame", + shared_path: Path, + cache_dir: Path, + table_name: str, + ) -> None: + """Write CSV to shared location, falling back to cache.""" + try: + shared_path.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(shared_path, index=False) + logger.info("Wrote %s metadata to %s", table_name, shared_path) + except (PermissionError, OSError): + cache_dir.mkdir(parents=True, exist_ok=True) + cache_path = cache_dir / shared_path.name + df.to_csv(cache_path, index=False) + logger.info( + "Wrote %s metadata to cache: %s", table_name, cache_path + ) + + @property + def default_task(self): + """Returns the default task for the EEG-GCNN dataset. + + Returns: + EEGGCNNDiseaseDetection: The default task instance. + """ + from pyhealth.tasks import EEGGCNNDiseaseDetection + + return EEGGCNNDiseaseDetection() diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 2f4294a19..f2bea8645 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -23,6 +23,7 @@ ) from .EEG_abnormal import EEG_isAbnormal_fn from .EEG_events import EEG_events_fn +from .eeg_gcnn_nd_detection import EEGGCNNDiseaseDetection from .in_hospital_mortality_mimic4 import InHospitalMortalityMIMIC4 from .length_of_stay_prediction import ( LengthOfStayPredictioneICU, diff --git a/pyhealth/tasks/eeg_gcnn_nd_detection.py b/pyhealth/tasks/eeg_gcnn_nd_detection.py new file mode 100644 index 000000000..e463cff77 --- /dev/null +++ b/pyhealth/tasks/eeg_gcnn_nd_detection.py @@ -0,0 +1,359 @@ +"""EEG-GCNN neurological disease detection task. + +Implements the preprocessing and feature-extraction pipeline from: + + Wagh, N. & Varatharajah, Y. (2020). EEG-GCNN: Augmenting + Electroencephalogram-based Neurological Disease Diagnosis using a + Domain-guided Graph Convolutional Neural Network. ML4H @ NeurIPS 2020. + https://proceedings.mlr.press/v136/wagh20a.html +""" + +import logging +from typing import Any, Dict, List, Optional, Tuple + +import mne +import numpy as np +import torch + +from pyhealth.tasks.base_task import BaseTask + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Constants from the EEG-GCNN paper (Section 4) +# --------------------------------------------------------------------------- + +# 8 bipolar channels used in the paper. +# Each entry is (anode_ref_name, cathode_ref_name, canonical_name). +BIPOLAR_CHANNELS: List[Tuple[str, str, str]] = [ + ("EEG F7-REF", "EEG F3-REF", "F7-F3"), + ("EEG F8-REF", "EEG F4-REF", "F8-F4"), + ("EEG T3-REF", "EEG C3-REF", "T7-C3"), + ("EEG T4-REF", "EEG C4-REF", "T8-C4"), + ("EEG T5-REF", "EEG P3-REF", "P7-P3"), + ("EEG T6-REF", "EEG P4-REF", "P8-P4"), + ("EEG O1-REF", "EEG P3-REF", "O1-P3"), + ("EEG O2-REF", "EEG P4-REF", "O2-P4"), +] + +# 3D MNI coordinates for the 8 bipolar channel mid-points (approximate). +# Used to build the spatial adjacency matrix. +CHANNEL_POSITIONS_MNI: np.ndarray = np.array( + [ + [-0.054, 0.044, 0.038], # F7-F3 + [0.054, 0.044, 0.038], # F8-F4 + [-0.069, -0.014, 0.034], # T7-C3 + [0.069, -0.014, 0.034], # T8-C4 + [-0.059, -0.067, 0.034], # P7-P3 + [0.059, -0.067, 0.034], # P8-P4 + [-0.037, -0.094, 0.020], # O1-P3 + [0.037, -0.094, 0.020], # O2-P4 + ], + dtype=np.float64, +) + +# Frequency bands (Hz) from the paper. +DEFAULT_BANDS: Dict[str, Tuple[float, float]] = { + "delta": (0.5, 4.0), + "theta": (4.0, 8.0), + "alpha": (8.0, 12.0), + "lower_beta": (12.0, 20.0), + "higher_beta": (20.0, 30.0), + "gamma": (30.0, 50.0), +} + +NUM_CHANNELS = 8 + + +class EEGGCNNDiseaseDetection(BaseTask): + """Binary classification: patient-normal (0) vs healthy-control (1). + + For each EEG recording the task: + + 1. Reads the EDF / BrainVision file via MNE. + 2. Resamples to ``resample_rate`` Hz, applies a 1 Hz high-pass and a + ``notch_freq`` Hz notch filter. + 3. Computes the 8 bipolar channels defined in the paper. + 4. Segments the signal into non-overlapping ``window_sec``-second windows. + 5. For each window, extracts PSD band-power features (8 channels x + ``len(bands)`` bands) via Welch's method. + 6. Computes an 8x8 graph adjacency matrix (spatial, functional, combined, + or identity) to accompany each sample. + + Args: + resample_rate: Target sampling rate in Hz. Default ``250``. + highpass_freq: High-pass filter cutoff in Hz. Default ``1.0``. + notch_freq: Notch filter frequency in Hz. Default ``50.0``. + window_sec: Window length in seconds. Default ``10``. + bands: Frequency bands to extract. Defaults to all 6 paper bands. + Pass a subset dict to run a band-ablation experiment. + adjacency_type: One of ``"combined"`` (default), ``"spatial"``, + ``"functional"``, or ``"none"`` (identity matrix). + connectivity_measure: ``"coherence"`` (default) or ``"wpli"``. + + Examples: + >>> from pyhealth.datasets import EEGGCNNDataset + >>> from pyhealth.tasks import EEGGCNNDiseaseDetection + >>> dataset = EEGGCNNDataset(root="/data/eeg-gcnn/") + >>> task = EEGGCNNDiseaseDetection(adjacency_type="spatial") + >>> sample_dataset = dataset.set_task(task) + >>> sample = sample_dataset[0] + >>> print(sample["psd_features"].shape) # (8, 6) + >>> print(sample["adjacency"].shape) # (8, 8) + """ + + task_name: str = "eeg_gcnn_nd_detection" + input_schema: Dict[str, str] = { + "psd_features": "tensor", + "adjacency": "tensor", + } + output_schema: Dict[str, str] = {"label": "binary"} + + def __init__( + self, + resample_rate: int = 250, + highpass_freq: float = 1.0, + notch_freq: float = 50.0, + window_sec: int = 10, + bands: Optional[Dict[str, Tuple[float, float]]] = None, + adjacency_type: str = "combined", + connectivity_measure: str = "coherence", + ) -> None: + super().__init__() + self.resample_rate = resample_rate + self.highpass_freq = highpass_freq + self.notch_freq = notch_freq + self.window_sec = window_sec + self.bands = bands if bands is not None else dict(DEFAULT_BANDS) + if adjacency_type not in ("combined", "spatial", "functional", "none"): + raise ValueError( + f"adjacency_type must be 'combined', 'spatial', " + f"'functional', or 'none', got '{adjacency_type}'" + ) + self.adjacency_type = adjacency_type + if connectivity_measure not in ("coherence", "wpli"): + raise ValueError( + f"connectivity_measure must be 'coherence' or 'wpli', " + f"got '{connectivity_measure}'" + ) + self.connectivity_measure = connectivity_measure + + # Pre-compute spatial adjacency (constant across recordings). + self._spatial_adj = self._build_spatial_adjacency() + + # ------------------------------------------------------------------ + # Signal I/O and preprocessing + # ------------------------------------------------------------------ + + def _read_eeg(self, filepath: str) -> mne.io.BaseRaw: + """Read an EEG file (EDF or BrainVision) into MNE Raw.""" + filepath_lower = filepath.lower() + if filepath_lower.endswith(".edf"): + raw = mne.io.read_raw_edf(filepath, preload=True, verbose="error") + elif filepath_lower.endswith(".vhdr"): + raw = mne.io.read_raw_brainvision( + filepath, preload=True, verbose="error" + ) + else: + raise ValueError(f"Unsupported EEG format: {filepath}") + return raw + + def _preprocess(self, raw: mne.io.BaseRaw) -> mne.io.BaseRaw: + """Resample, high-pass, and notch-filter.""" + raw.filter( + l_freq=self.highpass_freq, + h_freq=None, + verbose="error", + ) + raw.notch_filter(self.notch_freq, verbose="error") + raw.resample(self.resample_rate, verbose="error") + return raw + + @staticmethod + def _compute_bipolar(raw: mne.io.BaseRaw) -> np.ndarray: + """Compute the 8 bipolar channels from reference montage. + + Returns: + np.ndarray of shape ``(8, n_samples)``. + """ + data = raw.get_data() + ch_map = { + name: idx for idx, name in enumerate(raw.ch_names) + } + bipolar = np.zeros((NUM_CHANNELS, data.shape[1])) + for i, (anode, cathode, _) in enumerate(BIPOLAR_CHANNELS): + bipolar[i] = data[ch_map[anode]] - data[ch_map[cathode]] + return bipolar + + # ------------------------------------------------------------------ + # Feature extraction + # ------------------------------------------------------------------ + + def _extract_psd_features( + self, window: np.ndarray + ) -> np.ndarray: + """Extract PSD band-power features for one window. + + Args: + window: shape ``(8, n_samples)``. + + Returns: + np.ndarray of shape ``(8, n_bands)`` — log10 average PSD per + channel per band via Welch's method. + """ + from scipy.signal import welch + + fs = self.resample_rate + band_list = list(self.bands.values()) + n_bands = len(band_list) + features = np.zeros((NUM_CHANNELS, n_bands)) + + freqs, pxx = welch(window, fs=fs, nperseg=min(fs * 2, window.shape[1])) + + for b_idx, (fmin, fmax) in enumerate(band_list): + band_mask = (freqs >= fmin) & (freqs < fmax) + if band_mask.any(): + features[:, b_idx] = np.log10( + pxx[:, band_mask].mean(axis=1) + 1e-10 + ) + + return features + + # ------------------------------------------------------------------ + # Adjacency matrices + # ------------------------------------------------------------------ + + @staticmethod + def _build_spatial_adjacency() -> np.ndarray: + """Build spatial adjacency from inverse Euclidean distance. + + Returns: + np.ndarray of shape ``(8, 8)`` with self-loops set to 1. + """ + pos = CHANNEL_POSITIONS_MNI + n = pos.shape[0] + dist = np.zeros((n, n)) + for i in range(n): + for j in range(n): + dist[i, j] = np.linalg.norm(pos[i] - pos[j]) + # Inverse distance (avoid div-by-zero on diagonal) + with np.errstate(divide="ignore"): + adj = np.where(dist > 0, 1.0 / dist, 0.0) + # Row-normalise + row_sum = adj.sum(axis=1, keepdims=True) + adj = adj / (row_sum + 1e-10) + # Self-loops + np.fill_diagonal(adj, 1.0) + return adj + + def _build_functional_adjacency( + self, window: np.ndarray + ) -> np.ndarray: + """Build functional adjacency from inter-channel connectivity. + + Args: + window: shape ``(8, n_samples)``. + + Returns: + np.ndarray of shape ``(8, 8)``. + """ + from mne_connectivity import spectral_connectivity_epochs + + fs = self.resample_rate + # spectral_connectivity_epochs expects (n_epochs, n_channels, n_times) + data_3d = window[np.newaxis, :, :] + + conn = spectral_connectivity_epochs( + data_3d, + method=self.connectivity_measure, + sfreq=fs, + fmin=0.5, + fmax=50.0, + verbose="error", + ) + conn_data = conn.get_data(output="dense") + # Average across frequencies → (8, 8) + adj = conn_data.mean(axis=-1) + adj = np.abs(adj) + # Symmetrise and set diagonal to 1 + adj = (adj + adj.T) / 2.0 + np.fill_diagonal(adj, 1.0) + return adj + + def _build_adjacency(self, window: np.ndarray) -> np.ndarray: + """Build the adjacency matrix according to ``adjacency_type``. + + Args: + window: shape ``(8, n_samples)``. + + Returns: + np.ndarray of shape ``(8, 8)``. + """ + if self.adjacency_type == "none": + return np.eye(NUM_CHANNELS) + elif self.adjacency_type == "spatial": + return self._spatial_adj.copy() + elif self.adjacency_type == "functional": + return self._build_functional_adjacency(window) + else: # combined + spatial = self._spatial_adj + functional = self._build_functional_adjacency(window) + combined = (spatial + functional) / 2.0 + np.fill_diagonal(combined, 1.0) + return combined + + # ------------------------------------------------------------------ + # Main entry point + # ------------------------------------------------------------------ + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process one patient and return a list of sample dicts. + + Each sample contains: + - ``patient_id``: str + - ``psd_features``: torch.FloatTensor of shape ``(8, n_bands)`` + - ``adjacency``: torch.FloatTensor of shape ``(8, 8)`` + - ``label``: int (0 = patient-normal, 1 = healthy-control) + """ + pid = patient.patient_id + samples: List[Dict[str, Any]] = [] + fs = self.resample_rate + win_samples = int(self.window_sec * fs) + + for table in ("tuab", "lemon"): + events = patient.get_events(table) + for event in events: + filepath = event.signal_file + label = int(event.label) + + try: + raw = self._read_eeg(filepath) + raw = self._preprocess(raw) + bipolar = self._compute_bipolar(raw) + except (ValueError, KeyError, RuntimeError) as exc: + logger.warning( + "Skipping %s for patient %s: %s", + filepath, pid, exc, + ) + continue + + n_windows = bipolar.shape[1] // win_samples + for w in range(n_windows): + start = w * win_samples + end = start + win_samples + window = bipolar[:, start:end] + + psd_feat = self._extract_psd_features(window) + adj = self._build_adjacency(window) + + samples.append( + { + "patient_id": pid, + "signal_file": filepath, + "psd_features": torch.FloatTensor(psd_feat), + "adjacency": torch.FloatTensor(adj), + "label": label, + } + ) + + return samples diff --git a/tests/test_eeg_gcnn.py b/tests/test_eeg_gcnn.py new file mode 100644 index 000000000..f39559278 --- /dev/null +++ b/tests/test_eeg_gcnn.py @@ -0,0 +1,372 @@ +"""Tests for the EEG-GCNN dataset and task classes. + +All tests use synthetic data — no real EEG files are required. +Tests complete in milliseconds. +""" + +import os +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import torch + +from pyhealth.tasks.eeg_gcnn_nd_detection import ( + BIPOLAR_CHANNELS, + DEFAULT_BANDS, + NUM_CHANNELS, + EEGGCNNDiseaseDetection, +) + + +# --------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------- + +@pytest.fixture +def task(): + """Default task instance.""" + return EEGGCNNDiseaseDetection() + + +@pytest.fixture +def task_spatial(): + """Task with spatial-only adjacency.""" + return EEGGCNNDiseaseDetection(adjacency_type="spatial") + + +@pytest.fixture +def task_none(): + """Task with identity adjacency.""" + return EEGGCNNDiseaseDetection(adjacency_type="none") + + +@pytest.fixture +def synthetic_bipolar_window(): + """Synthetic 10-second bipolar window at 250 Hz (8, 2500).""" + rng = np.random.RandomState(42) + return rng.randn(NUM_CHANNELS, 2500) + + +@pytest.fixture +def synthetic_bipolar_signal(): + """Synthetic 30-second bipolar signal at 250 Hz (8, 7500).""" + rng = np.random.RandomState(42) + return rng.randn(NUM_CHANNELS, 7500) + + +def _make_mock_raw(n_seconds=30, sfreq=250): + """Create a mock MNE Raw object with all required TUAB channels.""" + n_samples = int(n_seconds * sfreq) + rng = np.random.RandomState(0) + + ch_names_needed = set() + for anode, cathode, _ in BIPOLAR_CHANNELS: + ch_names_needed.add(anode) + ch_names_needed.add(cathode) + ch_names = sorted(ch_names_needed) + n_ch = len(ch_names) + + data = rng.randn(n_ch, n_samples) * 1e-5 # realistic EEG scale + + raw = MagicMock() + raw.ch_names = ch_names + raw.get_data.return_value = data + raw.filter.return_value = raw + raw.notch_filter.return_value = raw + raw.resample.return_value = raw + return raw + + +# --------------------------------------------------------------- +# Task initialization tests +# --------------------------------------------------------------- + +class TestTaskInit: + + def test_default_params(self, task): + assert task.resample_rate == 250 + assert task.highpass_freq == 1.0 + assert task.notch_freq == 50.0 + assert task.window_sec == 10 + assert task.adjacency_type == "combined" + assert task.connectivity_measure == "coherence" + assert len(task.bands) == 6 + + def test_custom_bands(self): + custom = {"alpha": (8.0, 12.0), "theta": (4.0, 8.0)} + task = EEGGCNNDiseaseDetection(bands=custom) + assert len(task.bands) == 2 + + def test_invalid_adjacency_type(self): + with pytest.raises(ValueError, match="adjacency_type"): + EEGGCNNDiseaseDetection(adjacency_type="invalid") + + def test_invalid_connectivity(self): + with pytest.raises(ValueError, match="connectivity_measure"): + EEGGCNNDiseaseDetection(connectivity_measure="plv") + + def test_task_schemas(self, task): + assert "psd_features" in task.input_schema + assert "adjacency" in task.input_schema + assert "label" in task.output_schema + assert task.output_schema["label"] == "binary" + assert task.task_name == "eeg_gcnn_nd_detection" + + +# --------------------------------------------------------------- +# PSD feature extraction tests +# --------------------------------------------------------------- + +class TestPSDExtraction: + + def test_shape(self, task, synthetic_bipolar_window): + psd = task._extract_psd_features(synthetic_bipolar_window) + assert psd.shape == (NUM_CHANNELS, len(DEFAULT_BANDS)) + + def test_finite_values(self, task, synthetic_bipolar_window): + psd = task._extract_psd_features(synthetic_bipolar_window) + assert np.all(np.isfinite(psd)) + + def test_custom_bands_shape(self, synthetic_bipolar_window): + custom = {"alpha": (8.0, 12.0)} + task = EEGGCNNDiseaseDetection(bands=custom) + psd = task._extract_psd_features(synthetic_bipolar_window) + assert psd.shape == (NUM_CHANNELS, 1) + + +# --------------------------------------------------------------- +# Adjacency matrix tests +# --------------------------------------------------------------- + +class TestAdjacency: + + def test_spatial_shape(self, task): + adj = task._build_spatial_adjacency() + assert adj.shape == (NUM_CHANNELS, NUM_CHANNELS) + + def test_spatial_diagonal(self, task): + adj = task._build_spatial_adjacency() + np.testing.assert_array_equal(np.diag(adj), np.ones(NUM_CHANNELS)) + + def test_spatial_positive(self, task): + adj = task._build_spatial_adjacency() + assert np.all(adj >= 0) + + def test_none_is_identity(self, task_none, synthetic_bipolar_window): + adj = task_none._build_adjacency(synthetic_bipolar_window) + np.testing.assert_array_equal(adj, np.eye(NUM_CHANNELS)) + + def test_spatial_adjacency_type( + self, task_spatial, synthetic_bipolar_window + ): + adj = task_spatial._build_adjacency(synthetic_bipolar_window) + assert adj.shape == (NUM_CHANNELS, NUM_CHANNELS) + # Should NOT be identity — off-diagonal elements > 0 + assert np.any(adj[0, 1:] > 0) + + +# --------------------------------------------------------------- +# Bipolar channel computation tests +# --------------------------------------------------------------- + +class TestBipolarComputation: + + def test_compute_bipolar(self): + raw = _make_mock_raw(n_seconds=10, sfreq=250) + bipolar = EEGGCNNDiseaseDetection._compute_bipolar(raw) + assert bipolar.shape == (NUM_CHANNELS, 2500) + + def test_bipolar_is_difference(self): + raw = _make_mock_raw(n_seconds=1, sfreq=250) + data = raw.get_data() + ch_map = {name: idx for idx, name in enumerate(raw.ch_names)} + bipolar = EEGGCNNDiseaseDetection._compute_bipolar(raw) + + anode, cathode, _ = BIPOLAR_CHANNELS[0] + expected = data[ch_map[anode]] - data[ch_map[cathode]] + np.testing.assert_array_almost_equal(bipolar[0], expected) + + +# --------------------------------------------------------------- +# End-to-end __call__ test with mocked I/O +# --------------------------------------------------------------- + +class TestTaskCall: + + def test_call_produces_samples(self): + """Verify __call__ produces correct samples with mocked EEG I/O.""" + task = EEGGCNNDiseaseDetection( + adjacency_type="none", + window_sec=10, + ) + + mock_raw = _make_mock_raw(n_seconds=30, sfreq=250) + + # Build a mock patient with one TUAB event + event = MagicMock() + event.signal_file = "/fake/path.edf" + event.label = 0 + + patient = MagicMock() + patient.patient_id = "test_001" + patient.get_events.side_effect = ( + lambda table: [event] if table == "tuab" else [] + ) + + with patch.object(task, "_read_eeg", return_value=mock_raw), \ + patch.object(task, "_preprocess", return_value=mock_raw): + samples = task(patient) + + # 30s / 10s = 3 windows + assert len(samples) == 3 + for s in samples: + assert s["patient_id"] == "test_001" + assert isinstance(s["psd_features"], torch.Tensor) + assert s["psd_features"].shape == (8, 6) + assert isinstance(s["adjacency"], torch.Tensor) + assert s["adjacency"].shape == (8, 8) + assert s["label"] == 0 + + def test_call_skips_bad_file(self): + """Verify __call__ gracefully skips unreadable files.""" + task = EEGGCNNDiseaseDetection(adjacency_type="none") + + event = MagicMock() + event.signal_file = "/bad/path.edf" + event.label = 0 + + patient = MagicMock() + patient.patient_id = "test_002" + patient.get_events.side_effect = ( + lambda table: [event] if table == "tuab" else [] + ) + + with patch.object( + task, "_read_eeg", side_effect=ValueError("corrupt file") + ): + samples = task(patient) + + assert len(samples) == 0 + + def test_call_both_sources(self): + """Verify samples from both TUAB and LEMON are collected.""" + task = EEGGCNNDiseaseDetection( + adjacency_type="none", window_sec=10 + ) + + mock_raw = _make_mock_raw(n_seconds=10, sfreq=250) + + tuab_event = MagicMock() + tuab_event.signal_file = "/fake/tuab.edf" + tuab_event.label = 0 + + lemon_event = MagicMock() + lemon_event.signal_file = "/fake/lemon.vhdr" + lemon_event.label = 1 + + patient = MagicMock() + patient.patient_id = "test_003" + + def get_events(table): + if table == "tuab": + return [tuab_event] + elif table == "lemon": + return [lemon_event] + return [] + + patient.get_events.side_effect = get_events + + with patch.object(task, "_read_eeg", return_value=mock_raw), \ + patch.object(task, "_preprocess", return_value=mock_raw): + samples = task(patient) + + assert len(samples) == 2 + labels = {s["label"] for s in samples} + assert labels == {0, 1} + + +# --------------------------------------------------------------- +# Dataset metadata tests +# --------------------------------------------------------------- + +class TestDatasetMetadata: + + def test_prepare_tuab_csv(self): + """Verify TUAB CSV is generated from synthetic directory structure.""" + from pyhealth.datasets.eeg_gcnn import EEGGCNNDataset + + with tempfile.TemporaryDirectory() as tmpdir: + # Create synthetic TUAB directory structure + subj_dir = ( + Path(tmpdir) + / "train" + / "normal" + / "01_tcp_ar" + / "000" + / "00000001" + ) + subj_dir.mkdir(parents=True) + (subj_dir / "00000001_00000001_01.edf").touch() + (subj_dir / "00000001_00000002_01.edf").touch() + + # Instantiate dataset just to trigger prepare_metadata + # We can't fully instantiate BaseDataset without CSV content, + # so we test prepare_metadata directly. + ds = EEGGCNNDataset.__new__(EEGGCNNDataset) + ds.root = tmpdir + ds.prepare_metadata() + + csv_path = Path(tmpdir) / "eeg_gcnn-tuab-pyhealth.csv" + assert csv_path.exists() + + import pandas as pd + + df = pd.read_csv(csv_path) + assert len(df) == 2 + assert "patient_id" in df.columns + assert "signal_file" in df.columns + assert "label" in df.columns + assert all(df["label"] == 0) + assert all(df["source"] == "tuab") + + def test_prepare_lemon_csv(self): + """Verify LEMON CSV is generated from synthetic directory structure.""" + from pyhealth.datasets.eeg_gcnn import EEGGCNNDataset + + with tempfile.TemporaryDirectory() as tmpdir: + subj_dir = Path(tmpdir) / "lemon" / "sub-010002" + subj_dir.mkdir(parents=True) + (subj_dir / "sub-010002.vhdr").touch() + + ds = EEGGCNNDataset.__new__(EEGGCNNDataset) + ds.root = tmpdir + ds.prepare_metadata() + + csv_path = Path(tmpdir) / "eeg_gcnn-lemon-pyhealth.csv" + assert csv_path.exists() + + import pandas as pd + + df = pd.read_csv(csv_path) + assert len(df) == 1 + assert df.iloc[0]["label"] == 1 + assert df.iloc[0]["source"] == "lemon" + + +# --------------------------------------------------------------- +# Constants / module-level tests +# --------------------------------------------------------------- + +class TestConstants: + + def test_bipolar_channel_count(self): + assert len(BIPOLAR_CHANNELS) == 8 + + def test_default_bands_count(self): + assert len(DEFAULT_BANDS) == 6 + + def test_num_channels(self): + assert NUM_CHANNELS == 8 From 4ce853eff41f1b2b9323c3f0ad0fac6a96f26e4d Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 1 Apr 2026 12:47:21 +0200 Subject: [PATCH 2/3] fix: map connectivity_measure to mne_connectivity method names "coherence" must be passed as "coh" to spectral_connectivity_epochs. Co-Authored-By: Claude Opus 4.6 --- pyhealth/tasks/eeg_gcnn_nd_detection.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyhealth/tasks/eeg_gcnn_nd_detection.py b/pyhealth/tasks/eeg_gcnn_nd_detection.py index e463cff77..26db50b38 100644 --- a/pyhealth/tasks/eeg_gcnn_nd_detection.py +++ b/pyhealth/tasks/eeg_gcnn_nd_detection.py @@ -263,9 +263,10 @@ def _build_functional_adjacency( # spectral_connectivity_epochs expects (n_epochs, n_channels, n_times) data_3d = window[np.newaxis, :, :] + method_map = {"coherence": "coh", "wpli": "wpli"} conn = spectral_connectivity_epochs( data_3d, - method=self.connectivity_measure, + method=method_map[self.connectivity_measure], sfreq=fs, fmin=0.5, fmax=50.0, From c13efd656317248ba774ed8d40ec8c215cff0c14 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 1 Apr 2026 13:02:46 +0200 Subject: [PATCH 3/3] feat: add --demo mode to ablation script for synthetic data runs The ablation script can now run without real TUAB/LEMON data using `--demo`. Generates 40 synthetic patients with reproducible random PSD features and adjacency matrices, runs all 3 experiments through the full GCN training pipeline. Also fixes a bug where roc_auc_score failed on single-column y_prob output from binary classification. Co-Authored-By: Claude Opus 4.6 --- examples/eeg_gcnn_nd_detection_gcn.py | 126 +++++++++++++++++++++++--- 1 file changed, 115 insertions(+), 11 deletions(-) diff --git a/examples/eeg_gcnn_nd_detection_gcn.py b/examples/eeg_gcnn_nd_detection_gcn.py index 23fa27d09..f6113d8c0 100644 --- a/examples/eeg_gcnn_nd_detection_gcn.py +++ b/examples/eeg_gcnn_nd_detection_gcn.py @@ -12,13 +12,17 @@ 2. Frequency band ablation: individual bands & progressive combinations 3. Connectivity measure: coherence vs WPLI -Usage: - python examples/eeg_gcnn_nd_detection_gcn.py --root /path/to/eeg-gcnn-data +Usage (with real data): + python examples/eeg_gcnn_nd_detection_gcn.py --root /path/to/data + +Usage (demo mode — synthetic data, no downloads needed): + python examples/eeg_gcnn_nd_detection_gcn.py --demo """ import argparse import json import logging +import warnings from collections import defaultdict from typing import Dict, List, Tuple @@ -27,15 +31,88 @@ from sklearn.metrics import roc_auc_score from torch.utils.data import DataLoader -from pyhealth.datasets import EEGGCNNDataset, get_dataloader, split_by_patient +from pyhealth.datasets import ( + EEGGCNNDataset, + create_sample_dataset, + get_dataloader, + split_by_patient, +) from pyhealth.models.gnn import GCN from pyhealth.tasks import EEGGCNNDiseaseDetection -from pyhealth.tasks.eeg_gcnn_nd_detection import DEFAULT_BANDS +from pyhealth.tasks.eeg_gcnn_nd_detection import ( + DEFAULT_BANDS, + NUM_CHANNELS, +) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +# ------------------------------------------------------------------- +# Synthetic demo data +# ------------------------------------------------------------------- + +def generate_demo_samples( + n_patients: int = 40, + windows_per_patient: int = 5, + n_bands: int = 6, + seed: int = 42, +) -> List[Dict]: + """Generate synthetic samples that mirror real task output. + + Creates ``n_patients`` patients (half label-0, half label-1) each + with ``windows_per_patient`` 10-second windows. PSD features and + adjacency matrices are random but reproducible. + + Args: + n_patients: Total number of synthetic patients. + windows_per_patient: Windows (samples) per patient. + n_bands: Number of frequency bands in PSD features. + seed: Random seed for reproducibility. + + Returns: + List of sample dicts with keys ``patient_id``, + ``psd_features``, ``adjacency``, and ``label``. + """ + rng = np.random.RandomState(seed) + samples = [] + for p in range(n_patients): + pid = f"demo_{p:03d}" + label = 0 if p < n_patients // 2 else 1 + for _ in range(windows_per_patient): + psd = rng.randn(NUM_CHANNELS, n_bands).astype(np.float32) + # Shift class-1 features slightly so the model can learn + if label == 1: + psd += 0.5 + adj = np.eye(NUM_CHANNELS, dtype=np.float32) + off = rng.uniform(0.1, 0.5, (NUM_CHANNELS, NUM_CHANNELS)) + off = (off + off.T) / 2.0 + np.fill_diagonal(off, 0.0) + adj = adj + off.astype(np.float32) + samples.append( + { + "patient_id": pid, + "psd_features": torch.FloatTensor(psd), + "adjacency": torch.FloatTensor(adj), + "label": label, + } + ) + return samples + + +def build_demo_dataset(task: EEGGCNNDiseaseDetection): + """Wrap synthetic samples in a SampleDataset compatible with GCN.""" + n_bands = len(task.bands) + samples = generate_demo_samples(n_bands=n_bands) + return create_sample_dataset( + samples=samples, + input_schema=task.input_schema, + output_schema=task.output_schema, + dataset_name="eeg_gcnn_demo", + task_name=task.task_name, + ) + + # ------------------------------------------------------------------- # Training / evaluation helpers # ------------------------------------------------------------------- @@ -88,8 +165,15 @@ def evaluate( all_probs = np.concatenate(all_probs) all_labels = np.concatenate(all_labels) try: - auc = roc_auc_score(all_labels, all_probs[:, 1]) - except ValueError: + # Binary output may be shape (N,1) or (N,2) + if all_probs.ndim == 2 and all_probs.shape[1] >= 2: + scores = all_probs[:, 1] + else: + scores = all_probs.ravel() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + auc = float(roc_auc_score(all_labels, scores)) + except (ValueError, IndexError): auc = 0.5 return { "auc": auc, @@ -103,11 +187,15 @@ def run_experiment( epochs: int = 20, batch_size: int = 32, lr: float = 1e-3, + demo: bool = False, ) -> Dict[str, float]: """Run a single experiment: set_task, split, train, evaluate.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - sample_dataset = dataset.set_task(task) + if demo: + sample_dataset = build_demo_dataset(task) + else: + sample_dataset = dataset.set_task(task) train_ds, val_ds, test_ds = split_by_patient( sample_dataset, [0.7, 0.1, 0.2] ) @@ -180,7 +268,7 @@ def ablation_frequency_bands( ) results[name] = run_experiment(dataset, task, **kwargs) - # Progressive combinations: delta, delta+theta, delta+theta+alpha, ... + # Progressive combinations for k in range(2, len(band_names) + 1): combo_names = band_names[:k] combo_key = "+".join(combo_names) @@ -216,9 +304,13 @@ def main(): description="EEG-GCNN ablation study" ) parser.add_argument( - "--root", type=str, required=True, + "--root", type=str, default=None, help="Root directory of EEG-GCNN data (TUAB + LEMON)", ) + parser.add_argument( + "--demo", action="store_true", + help="Run with synthetic data (no downloads needed)", + ) parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--lr", type=float, default=1e-3) @@ -231,13 +323,25 @@ def main(): ) args = parser.parse_args() - dataset = EEGGCNNDataset(root=args.root) - dataset.stats() + if not args.demo and args.root is None: + parser.error("--root is required unless --demo is set") + + dataset = None + if not args.demo: + dataset = EEGGCNNDataset(root=args.root) + dataset.stats() + + if args.demo: + logger.info( + "Running in DEMO mode with synthetic data " + "(results are illustrative, not meaningful)" + ) train_kwargs = dict( epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, + demo=args.demo, ) all_results = {}