diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index b02439d26..165dd50a9 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -230,6 +230,7 @@ Available Datasets datasets/pyhealth.datasets.ISRUCDataset datasets/pyhealth.datasets.MIMICExtractDataset datasets/pyhealth.datasets.OMOPDataset + datasets/pyhealth.datasets.EEGGCNNDataset datasets/pyhealth.datasets.DREAMTDataset datasets/pyhealth.datasets.SHHSDataset datasets/pyhealth.datasets.SleepEDFDataset diff --git a/docs/api/datasets/pyhealth.datasets.EEGGCNNDataset.rst b/docs/api/datasets/pyhealth.datasets.EEGGCNNDataset.rst new file mode 100644 index 000000000..512816f7a --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.EEGGCNNDataset.rst @@ -0,0 +1,12 @@ +pyhealth.datasets.EEGGCNNDataset +=================================== + +Dataset for the EEG-GCNN paper (Wagh & Varatharajah, ML4H @ NeurIPS 2020). +Pools the TUAB normal-subset (patient EEGs) and MPI LEMON (healthy controls). + +Paper: https://proceedings.mlr.press/v136/wagh20a.html + +.. autoclass:: pyhealth.datasets.EEGGCNNDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..852d6b6d0 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -212,6 +212,9 @@ Available Tasks COVID-19 CXR Classification DKA Prediction (MIMIC-IV) Drug Recommendation + EEG-GCNN Disease Detection + EEG Abnormal + EEG Events Length of Stay Prediction Medical Transcriptions Classification Mortality Prediction (Next Visit) diff --git a/docs/api/tasks/pyhealth.tasks.eeg_gcnn_nd_detection.rst b/docs/api/tasks/pyhealth.tasks.eeg_gcnn_nd_detection.rst new file mode 100644 index 000000000..de85c8213 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.eeg_gcnn_nd_detection.rst @@ -0,0 +1,16 @@ +pyhealth.tasks.eeg_gcnn_nd_detection +========================================= + +Neurological disease detection task from the EEG-GCNN paper (Wagh & Varatharajah, ML4H @ NeurIPS 2020). + +Binary classification: patient-normal (TUAB) vs healthy-control (LEMON). + +The task extracts PSD band-power features and graph adjacency matrices from EEG recordings, +supporting configurable adjacency types, frequency bands, and connectivity measures for ablation studies. + +Paper: https://proceedings.mlr.press/v136/wagh20a.html + +.. autoclass:: pyhealth.tasks.eeg_gcnn_nd_detection.EEGGCNNDiseaseDetection + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/eeg_gcnn_combined_pipeline.py b/examples/eeg_gcnn_combined_pipeline.py new file mode 100644 index 000000000..9590a515e --- /dev/null +++ b/examples/eeg_gcnn_combined_pipeline.py @@ -0,0 +1,660 @@ +"""Combined EEG-GCNN Pipeline — Dataset/Task + Model Integration. + +Self-contained end-to-end pipeline combining both team contributions: + - Dataset & Task (Option 1 — jburhan): + EEG preprocessing, PSD feature extraction, graph adjacency matrices + - Model (Option 2 — racoffey): + EEGGraphConvNet (GCN, paper baseline) and EEGGATConvNet (GAT, novel) + +This script runs standalone — no PyHealth installation required for demo +mode. For real-data mode, PyHealth with the EEG-GCNN contribution must +be installed (pip install -e . from the PyHealth fork). + +Paper: + Wagh, N. & Varatharajah, Y. (2020). EEG-GCNN: Augmenting + Electroencephalogram-based Neurological Disease Diagnosis using a + Domain-guided Graph Convolutional Neural Network. ML4H @ NeurIPS 2020. + https://proceedings.mlr.press/v136/wagh20a.html + +Usage (demo mode — synthetic data, no downloads needed): + python eeg_gcnn_combined_pipeline.py --demo + +Usage (demo with GAT model): + python eeg_gcnn_combined_pipeline.py --demo --model gat + +Usage (demo, single experiment): + python eeg_gcnn_combined_pipeline.py --demo --experiment model + +Usage (with real EEG data — requires PyHealth with EEG-GCNN contribution): + python eeg_gcnn_combined_pipeline.py --root /path/to/eeg-gcnn-data + +Dependencies (demo mode): + torch, torch_geometric, numpy, scikit-learn +""" + +import argparse +import json +import logging +import warnings +from collections import defaultdict +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from sklearn.metrics import roc_auc_score +from torch.utils.data import DataLoader, Dataset + +from torch_geometric.nn import ( + BatchNorm, + GATConv, + GCNConv, + global_add_pool, +) +from torch_geometric.utils import dense_to_sparse + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +# =================================================================== +# Constants (from EEG-GCNN paper, Section 4) +# =================================================================== + +NUM_CHANNELS = 8 +NUM_NODES = 8 + +DEFAULT_BANDS: Dict[str, Tuple[float, float]] = { + "delta": (0.5, 4.0), + "theta": (4.0, 8.0), + "alpha": (8.0, 12.0), + "lower_beta": (12.0, 20.0), + "higher_beta": (20.0, 30.0), + "gamma": (30.0, 50.0), +} + + +# =================================================================== +# Models (Option 2 contribution — racoffey) +# =================================================================== + + +class EEGGraphConvNet(nn.Module): + """Deep EEG-GCNN model using GCNConv layers. + + Follows the architecture from the paper: 4 GCN layers with + leaky ReLU, global add pooling, and a 3-layer FC classifier. + + Args: + num_node_features: Number of PSD bands per node. Default: 6. + output_size: Number of output classes. Default: 1 (binary). + """ + + def __init__(self, num_node_features: int = 6, output_size: int = 1): + super().__init__() + + self.conv1 = GCNConv(num_node_features, 16, improved=True, + cached=False, normalize=False) + self.conv2 = GCNConv(16, 32, improved=True, cached=False, + normalize=False) + self.conv3 = GCNConv(32, 64, improved=True, cached=False, + normalize=False) + self.conv4 = GCNConv(64, 50, improved=True, cached=False, + normalize=False) + self.conv4_bn = BatchNorm(50) + + self.fc1 = nn.Linear(50, 30) + self.fc2 = nn.Linear(30, 20) + self.fc3 = nn.Linear(20, output_size) + + for fc in [self.fc1, self.fc2, self.fc3]: + nn.init.xavier_normal_(fc.weight, gain=1) + + def forward(self, x, edge_index, edge_weight, batch): + x = F.leaky_relu(self.conv1(x, edge_index, edge_weight)) + x = F.leaky_relu(self.conv2(x, edge_index, edge_weight)) + x = F.leaky_relu(self.conv3(x, edge_index, edge_weight)) + x = F.leaky_relu(self.conv4_bn( + self.conv4(x, edge_index, edge_weight))) + out = global_add_pool(x, batch=batch) + out = F.leaky_relu(self.fc1(out), negative_slope=0.01) + out = F.dropout(out, p=0.2, training=self.training) + out = F.leaky_relu(self.fc2(out), negative_slope=0.01) + return self.fc3(out) + + +class EEGGATConvNet(nn.Module): + """EEG Graph Attention Network — novel model contribution. + + Replaces GCNConv with multi-head GATConv layers to learn + attention-weighted edge importance, rather than relying on + fixed adjacency weights. + + Args: + num_node_features: Number of PSD bands per node. Default: 6. + output_size: Number of output classes. Default: 1 (binary). + """ + + def __init__(self, num_node_features: int = 6, output_size: int = 1): + super().__init__() + + self.conv1 = GATConv(num_node_features, 16, heads=4, concat=True, + negative_slope=0.2, dropout=0.0, + add_self_loops=False, edge_dim=1) + self.conv2 = GATConv(64, 32, heads=4, concat=True, + negative_slope=0.2, dropout=0.0, + add_self_loops=False, edge_dim=1) + self.conv3 = GATConv(128, 16, heads=4, concat=True, + negative_slope=0.2, dropout=0.0, + add_self_loops=False, edge_dim=1) + self.conv4 = GATConv(64, 50, heads=1, concat=False, + negative_slope=0.2, dropout=0.0, + add_self_loops=False, edge_dim=1) + self.conv4_bn = BatchNorm(50) + + self.fc1 = nn.Linear(50, 30) + self.fc2 = nn.Linear(30, 20) + self.fc3 = nn.Linear(20, output_size) + + for fc in [self.fc1, self.fc2, self.fc3]: + nn.init.xavier_normal_(fc.weight, gain=1) + + def forward(self, x, edge_index, edge_weight, batch): + edge_attr = edge_weight.unsqueeze(-1) if edge_weight.dim() == 1 \ + else edge_weight + x = F.leaky_relu(self.conv1(x, edge_index, edge_attr=edge_attr)) + x = F.leaky_relu(self.conv2(x, edge_index, edge_attr=edge_attr)) + x = F.leaky_relu(self.conv3(x, edge_index, edge_attr=edge_attr)) + x = F.leaky_relu(self.conv4_bn( + self.conv4(x, edge_index, edge_attr=edge_attr))) + out = global_add_pool(x, batch=batch) + out = F.leaky_relu(self.fc1(out), negative_slope=0.01) + out = F.dropout(out, p=0.2, training=self.training) + out = F.leaky_relu(self.fc2(out), negative_slope=0.01) + return self.fc3(out) + + +# =================================================================== +# Graph batching helper +# =================================================================== + +def build_graph_batch(node_features_batch, adj_matrix_batch, device): + """Convert batched dense tensors into a single PyG-style graph batch. + + Args: + node_features_batch: (B, 8, n_bands) tensor + adj_matrix_batch: (B, 8, 8) tensor + device: torch device + + Returns: + (x, edge_index, edge_weight, batch) tensors on device + """ + all_edge_index = [] + all_edge_weight = [] + all_x = [] + batch_ids = [] + + offset = 0 + for i in range(node_features_batch.shape[0]): + x = node_features_batch[i].float() + adj = adj_matrix_batch[i].float() + ei, ew = dense_to_sparse(adj) + all_edge_index.append(ei + offset) + all_edge_weight.append(ew) + all_x.append(x) + batch_ids.extend([i] * x.shape[0]) + offset += x.shape[0] + + return ( + torch.cat(all_x, dim=0).to(device), + torch.cat(all_edge_index, dim=1).to(device), + torch.cat(all_edge_weight).to(device), + torch.tensor(batch_ids, dtype=torch.long, device=device), + ) + + +# =================================================================== +# Lightweight dataset utilities (standalone, no PyHealth needed) +# =================================================================== + +class DictDataset(Dataset): + """Map-style dataset wrapping a list of sample dicts.""" + + def __init__(self, samples: List[Dict]): + self._samples = samples + + def __len__(self): + return len(self._samples) + + def __getitem__(self, idx): + return self._samples[idx] + + +def collate_dict_batch(batch: List[Dict]) -> Dict: + """Collate a list of sample dicts into a batched dict. + + Stacks tensors, collects strings into lists, and stacks ints/floats. + """ + keys = batch[0].keys() + collated = {} + for k in keys: + vals = [s[k] for s in batch] + if isinstance(vals[0], torch.Tensor): + collated[k] = torch.stack(vals) + elif isinstance(vals[0], (int, float)): + collated[k] = torch.tensor(vals, dtype=torch.float32) + else: + collated[k] = vals + return collated + + +def split_by_patient( + samples: List[Dict], ratios: List[float], seed: int = 42, +) -> Tuple[List[Dict], List[Dict], List[Dict]]: + """Split samples into train/val/test by patient_id. + + Ensures all windows from the same patient stay in the same split. + + Args: + samples: List of sample dicts with "patient_id" key. + ratios: [train, val, test] fractions summing to 1.0. + seed: Random seed for reproducibility. + + Returns: + (train_samples, val_samples, test_samples) + """ + rng = np.random.RandomState(seed) + patient_ids = sorted(set(s["patient_id"] for s in samples)) + rng.shuffle(patient_ids) + + n = len(patient_ids) + n_train = int(n * ratios[0]) + n_val = int(n * ratios[1]) + + train_pids = set(patient_ids[:n_train]) + val_pids = set(patient_ids[n_train:n_train + n_val]) + test_pids = set(patient_ids[n_train + n_val:]) + + train = [s for s in samples if s["patient_id"] in train_pids] + val = [s for s in samples if s["patient_id"] in val_pids] + test = [s for s in samples if s["patient_id"] in test_pids] + return train, val, test + + +def make_loader( + samples: List[Dict], batch_size: int, shuffle: bool = False, +) -> DataLoader: + """Create a DataLoader from a list of sample dicts.""" + return DataLoader( + DictDataset(samples), + batch_size=batch_size, + shuffle=shuffle, + collate_fn=collate_dict_batch, + ) + + +# =================================================================== +# Synthetic demo data (Option 1 — same schema as real task output) +# =================================================================== + +def generate_demo_samples( + n_patients: int = 40, + windows_per_patient: int = 5, + n_bands: int = 6, + adjacency_type: str = "combined", + seed: int = 42, +) -> List[Dict]: + """Generate synthetic samples matching real task output schema. + + Creates ``n_patients`` patients (half label-0, half label-1) each + with ``windows_per_patient`` windows. PSD features and adjacency + matrices are random but reproducible. + + The output schema matches what the real EEGGCNNDiseaseDetection task + produces: + - node_features: (8, n_bands) float32 tensor + - adj_matrix: (8, 8) float32 tensor + - label: int (0 = patient-normal, 1 = healthy-control) + """ + rng = np.random.RandomState(seed) + samples = [] + for p in range(n_patients): + pid = f"demo_{p:03d}" + label = 0 if p < n_patients // 2 else 1 + for _ in range(windows_per_patient): + psd = rng.randn(NUM_CHANNELS, n_bands).astype(np.float32) + # Shift class-1 features so the model can learn separation + if label == 1: + psd += 0.5 + + # Build adjacency based on type + if adjacency_type == "none": + adj = np.eye(NUM_CHANNELS, dtype=np.float32) + else: + adj = np.eye(NUM_CHANNELS, dtype=np.float32) + off = rng.uniform(0.1, 0.5, (NUM_CHANNELS, NUM_CHANNELS)) + off = (off + off.T) / 2.0 + np.fill_diagonal(off, 0.0) + adj = adj + off.astype(np.float32) + + samples.append({ + "patient_id": pid, + "node_features": torch.FloatTensor(psd), + "adj_matrix": torch.FloatTensor(adj), + "label": label, + }) + return samples + + +# =================================================================== +# Training / evaluation +# =================================================================== + +def train_one_epoch(model, loader, optimizer, loss_fn, device): + """Train for one epoch, return average loss.""" + model.train() + total_loss = 0.0 + n_batches = 0 + for batch in loader: + node_features = batch["node_features"].to(device) + adj_matrix = batch["adj_matrix"].to(device) + labels = batch["label"].to(device).float() + + x, edge_index, edge_weight, batch_ids = build_graph_batch( + node_features, adj_matrix, device + ) + + optimizer.zero_grad() + logits = model(x, edge_index, edge_weight, batch_ids) + loss = loss_fn(logits.squeeze(-1), labels.squeeze(-1)) + loss.backward() + optimizer.step() + total_loss += loss.item() + n_batches += 1 + return total_loss / max(n_batches, 1) + + +@torch.no_grad() +def evaluate(model, loader, loss_fn, device): + """Evaluate model, return AUC and loss.""" + model.eval() + all_probs, all_labels = [], [] + total_loss = 0.0 + n_batches = 0 + for batch in loader: + node_features = batch["node_features"].to(device) + adj_matrix = batch["adj_matrix"].to(device) + labels = batch["label"].to(device).float() + + x, edge_index, edge_weight, batch_ids = build_graph_batch( + node_features, adj_matrix, device + ) + + logits = model(x, edge_index, edge_weight, batch_ids) + loss = loss_fn(logits.squeeze(-1), labels.squeeze(-1)) + total_loss += loss.item() + n_batches += 1 + + probs = torch.sigmoid(logits.squeeze(-1)).cpu().numpy() + all_probs.append(probs) + all_labels.append(labels.squeeze(-1).cpu().numpy()) + + all_probs = np.concatenate(all_probs) + all_labels = np.concatenate(all_labels) + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + auc = float(roc_auc_score(all_labels, all_probs)) + except (ValueError, IndexError): + auc = 0.5 + return { + "auc": auc, + "loss": total_loss / max(n_batches, 1), + } + + +def run_experiment( + task_config: Dict[str, Any], + model_type: str = "gcn", + epochs: int = 20, + batch_size: int = 32, + lr: float = 1e-3, + demo: bool = True, + dataset: Any = None, +) -> Dict[str, float]: + """Run a single experiment: generate data, split, train, evaluate. + + Args: + task_config: Dict with keys like adjacency_type, bands, + connectivity_measure — mirrors EEGGCNNDiseaseDetection init args. + model_type: "gcn" or "gat". + epochs: Training epochs. + batch_size: Batch size. + lr: Learning rate. + demo: If True, use synthetic data. If False, use real data. + dataset: PyHealth EEGGCNNDataset instance (only needed if demo=False). + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + bands = task_config.get("bands", DEFAULT_BANDS) + n_bands = len(bands) + adjacency_type = task_config.get("adjacency_type", "combined") + + if demo: + samples = generate_demo_samples( + n_bands=n_bands, adjacency_type=adjacency_type, + ) + else: + # Real data path — requires PyHealth with EEG-GCNN contribution + from pyhealth.tasks import EEGGCNNDiseaseDetection + task = EEGGCNNDiseaseDetection(**task_config) + sample_dataset = dataset.set_task(task) + samples = [sample_dataset[i] for i in range(len(sample_dataset))] + + train_samples, val_samples, test_samples = split_by_patient( + samples, [0.7, 0.1, 0.2] + ) + + train_loader = make_loader(train_samples, batch_size, shuffle=True) + val_loader = make_loader(val_samples, batch_size) + test_loader = make_loader(test_samples, batch_size) + + if model_type == "gat": + model = EEGGATConvNet( + num_node_features=n_bands, output_size=1 + ).to(device) + else: + model = EEGGraphConvNet( + num_node_features=n_bands, output_size=1 + ).to(device) + + loss_fn = nn.BCEWithLogitsLoss() + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + + best_val_auc = 0.0 + best_state = None + for epoch in range(epochs): + train_loss = train_one_epoch( + model, train_loader, optimizer, loss_fn, device) + val_metrics = evaluate(model, val_loader, loss_fn, device) + logger.info( + "Epoch %d/%d — train_loss=%.4f val_auc=%.4f", + epoch + 1, epochs, train_loss, val_metrics["auc"], + ) + if val_metrics["auc"] > best_val_auc: + best_val_auc = val_metrics["auc"] + best_state = { + k: v.cpu().clone() for k, v in model.state_dict().items() + } + + if best_state is not None: + model.load_state_dict(best_state) + test_metrics = evaluate(model, test_loader, loss_fn, device) + logger.info("Test AUC: %.4f", test_metrics["auc"]) + return test_metrics + + +# =================================================================== +# Ablation experiments +# =================================================================== + +def ablation_adjacency(**kwargs) -> Dict[str, Dict[str, float]]: + """Experiment 1: Adjacency type ablation.""" + results = {} + for adj_type in ("combined", "spatial", "functional", "none"): + logger.info("=== Adjacency type: %s ===", adj_type) + results[adj_type] = run_experiment( + task_config={"adjacency_type": adj_type}, **kwargs) + return results + + +def ablation_frequency_bands(**kwargs) -> Dict[str, Dict[str, float]]: + """Experiment 2: Frequency band ablation. + + Tests individual bands and progressive combinations. + """ + band_names = list(DEFAULT_BANDS.keys()) + results = {} + + # Individual bands + for name in band_names: + logger.info("=== Single band: %s ===", name) + results[name] = run_experiment( + task_config={"bands": {name: DEFAULT_BANDS[name]}}, **kwargs) + + # Progressive combinations + for k in range(2, len(band_names) + 1): + combo_names = band_names[:k] + combo_key = "+".join(combo_names) + logger.info("=== Band combination: %s ===", combo_key) + combo_bands = {n: DEFAULT_BANDS[n] for n in combo_names} + results[combo_key] = run_experiment( + task_config={"bands": combo_bands}, **kwargs) + + return results + + +def ablation_connectivity(**kwargs) -> Dict[str, Dict[str, float]]: + """Experiment 3: Connectivity measure ablation.""" + results = {} + for measure in ("coherence", "wpli"): + logger.info("=== Connectivity: %s ===", measure) + results[measure] = run_experiment( + task_config={ + "adjacency_type": "functional", + "connectivity_measure": measure, + }, **kwargs) + return results + + +def ablation_model_comparison(**kwargs) -> Dict[str, Dict[str, float]]: + """Experiment 4: Model comparison — GCN vs GAT (novel contribution).""" + results = {} + for model_type in ("gcn", "gat"): + logger.info("=== Model: %s ===", model_type.upper()) + results[model_type] = run_experiment( + task_config={}, model_type=model_type, **kwargs) + return results + + +# =================================================================== +# Main +# =================================================================== + +def main(): + parser = argparse.ArgumentParser( + description="EEG-GCNN Combined Pipeline — Dataset/Task + Model" + ) + parser.add_argument( + "--root", type=str, default=None, + help="Root directory of EEG-GCNN data (TUAB + LEMON)", + ) + parser.add_argument( + "--demo", action="store_true", + help="Run with synthetic data (no downloads needed)", + ) + parser.add_argument( + "--model", type=str, default="gcn", choices=["gcn", "gat"], + help="Model architecture: gcn (paper default) or gat (novel)", + ) + parser.add_argument("--epochs", type=int, default=20) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument( + "--experiment", type=str, default="all", + choices=["all", "adjacency", "bands", "connectivity", "model"], + help="Which ablation experiment to run", + ) + args = parser.parse_args() + + if not args.demo and args.root is None: + parser.error("--root is required unless --demo is set") + + dataset = None + if not args.demo: + from pyhealth.datasets import EEGGCNNDataset + dataset = EEGGCNNDataset(root=args.root) + dataset.stats() + + if args.demo: + logger.info( + "Running in DEMO mode with synthetic data " + "(results are illustrative, not meaningful)" + ) + + run_kwargs = dict( + epochs=args.epochs, + batch_size=args.batch_size, + lr=args.lr, + model_type=args.model, + demo=args.demo, + dataset=dataset, + ) + + all_results = {} + + if args.experiment in ("all", "adjacency"): + logger.info("\n" + "=" * 60) + logger.info("EXPERIMENT 1: Adjacency Type Ablation") + logger.info("=" * 60) + all_results["adjacency"] = ablation_adjacency(**run_kwargs) + + if args.experiment in ("all", "bands"): + logger.info("\n" + "=" * 60) + logger.info("EXPERIMENT 2: Frequency Band Ablation") + logger.info("=" * 60) + all_results["bands"] = ablation_frequency_bands(**run_kwargs) + + if args.experiment in ("all", "connectivity"): + logger.info("\n" + "=" * 60) + logger.info("EXPERIMENT 3: Connectivity Measure Ablation") + logger.info("=" * 60) + all_results["connectivity"] = ablation_connectivity(**run_kwargs) + + if args.experiment in ("all", "model"): + logger.info("\n" + "=" * 60) + logger.info("EXPERIMENT 4: Model Comparison (GCN vs GAT)") + logger.info("=" * 60) + # Model comparison tests both; remove model_type from kwargs + kw = dict(run_kwargs) + kw.pop("model_type") + all_results["model"] = ablation_model_comparison(**kw) + + # Print summary + logger.info("\n" + "=" * 60) + logger.info("ABLATION RESULTS SUMMARY") + logger.info("=" * 60) + for exp_name, exp_results in all_results.items(): + logger.info("\n--- %s ---", exp_name) + for config, metrics in exp_results.items(): + logger.info(" %-30s AUC=%.4f", config, metrics["auc"]) + + output_path = "eeg_gcnn_combined_results.json" + with open(output_path, "w") as f: + json.dump(all_results, f, indent=2) + logger.info("\nResults saved to %s", output_path) + + +if __name__ == "__main__": + main() diff --git a/examples/eeg_gcnn_nd_detection_gcn.py b/examples/eeg_gcnn_nd_detection_gcn.py new file mode 100644 index 000000000..12c8deb4e --- /dev/null +++ b/examples/eeg_gcnn_nd_detection_gcn.py @@ -0,0 +1,445 @@ +"""EEG-GCNN Neurological Disease Detection — Ablation Study. + +This script reproduces the ablation experiments from: + + Wagh, N. & Varatharajah, Y. (2020). EEG-GCNN: Augmenting + Electroencephalogram-based Neurological Disease Diagnosis using a + Domain-guided Graph Convolutional Neural Network. ML4H @ NeurIPS 2020. + https://proceedings.mlr.press/v136/wagh20a.html + +Three ablation experiments are included: + 1. Adjacency type: combined vs spatial vs functional vs none + 2. Frequency band ablation: individual bands & progressive combinations + 3. Connectivity measure: coherence vs WPLI + +Usage (with real data): + python examples/eeg_gcnn_nd_detection_gcn.py --root /path/to/data + +Usage (demo mode — synthetic data, no downloads needed): + python examples/eeg_gcnn_nd_detection_gcn.py --demo +""" + +import argparse +import json +import logging +import warnings +from collections import defaultdict +from typing import Dict, List, Tuple + +import numpy as np +import torch +from sklearn.metrics import ( + balanced_accuracy_score, + f1_score, + precision_score, + recall_score, + roc_auc_score, + roc_curve, +) +from torch.utils.data import DataLoader + +from pyhealth.datasets import ( + EEGGCNNDataset, + create_sample_dataset, + get_dataloader, + split_by_patient, +) +from pyhealth.models.gnn import GCN +from pyhealth.tasks import EEGGCNNDiseaseDetection +from pyhealth.tasks.eeg_gcnn_nd_detection import ( + DEFAULT_BANDS, + NUM_CHANNELS, +) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +# ------------------------------------------------------------------- +# Synthetic demo data +# ------------------------------------------------------------------- + +def generate_demo_samples( + n_patients: int = 40, + windows_per_patient: int = 5, + n_bands: int = 6, + seed: int = 42, +) -> List[Dict]: + """Generate synthetic samples that mirror real task output. + + Creates ``n_patients`` patients (half label-0, half label-1) each + with ``windows_per_patient`` 10-second windows. PSD features and + adjacency matrices are random but reproducible. + + Args: + n_patients: Total number of synthetic patients. + windows_per_patient: Windows (samples) per patient. + n_bands: Number of frequency bands in PSD features. + seed: Random seed for reproducibility. + + Returns: + List of sample dicts with keys ``patient_id``, + ``node_features``, ``adj_matrix``, and ``label``. + """ + rng = np.random.RandomState(seed) + samples = [] + for p in range(n_patients): + pid = f"demo_{p:03d}" + label = 0 if p < n_patients // 2 else 1 + for _ in range(windows_per_patient): + psd = rng.randn(NUM_CHANNELS, n_bands).astype(np.float32) + # Shift class-1 features slightly so the model can learn + if label == 1: + psd += 0.5 + adj = np.eye(NUM_CHANNELS, dtype=np.float32) + off = rng.uniform(0.1, 0.5, (NUM_CHANNELS, NUM_CHANNELS)) + off = (off + off.T) / 2.0 + np.fill_diagonal(off, 0.0) + adj = adj + off.astype(np.float32) + samples.append( + { + "patient_id": pid, + "node_features": torch.FloatTensor(psd), + "adj_matrix": torch.FloatTensor(adj), + "label": label, + } + ) + return samples + + +def build_demo_dataset(task: EEGGCNNDiseaseDetection): + """Wrap synthetic samples in a SampleDataset compatible with GCN.""" + n_bands = len(task.bands) + samples = generate_demo_samples(n_bands=n_bands) + return create_sample_dataset( + samples=samples, + input_schema=task.input_schema, + output_schema=task.output_schema, + dataset_name="eeg_gcnn_demo", + task_name=task.task_name, + ) + + +# ------------------------------------------------------------------- +# Training / evaluation helpers +# ------------------------------------------------------------------- + +def train_one_epoch( + model: torch.nn.Module, + loader: DataLoader, + optimizer: torch.optim.Optimizer, + device: torch.device, +) -> float: + """Train for one epoch, return average loss.""" + model.train() + total_loss = 0.0 + n_batches = 0 + for batch in loader: + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v + for k, v in batch.items()} + optimizer.zero_grad() + output = model(**batch) + loss = output["loss"] + loss.backward() + optimizer.step() + total_loss += loss.item() + n_batches += 1 + return total_loss / max(n_batches, 1) + + +@torch.no_grad() +def evaluate( + model: torch.nn.Module, + loader: DataLoader, + device: torch.device, +) -> Dict[str, float]: + """Evaluate the model; return AUC, Precision, Recall, F1, Balanced Acc. + + Threshold is chosen via Youden's J statistic (maximises sensitivity + + specificity), matching the evaluation protocol in the paper (Section 6). + Precision, Recall, F1, and Balanced Accuracy are all computed at the + subject level treating the patient class as positive, consistent with + Table 2 of Wagh & Varatharajah (2020). + """ + model.eval() + all_probs, all_labels = [], [] + total_loss = 0.0 + n_batches = 0 + for batch in loader: + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v + for k, v in batch.items()} + output = model(**batch) + total_loss += output["loss"].item() + n_batches += 1 + probs = output["y_prob"].cpu().numpy() + labels = output["y_true"].cpu().numpy() + all_probs.append(probs) + all_labels.append(labels) + + all_probs = np.concatenate(all_probs) + all_labels = np.concatenate(all_labels) + + # Binary output may be shape (N,1) or (N,2) + if all_probs.ndim == 2 and all_probs.shape[1] >= 2: + scores = all_probs[:, 1] + else: + scores = all_probs.ravel() + + fallback = { + "auc": 0.5, "precision": 0.0, "recall": 0.0, + "f1": 0.0, "balanced_accuracy": 0.5, "threshold": 0.5, + "loss": total_loss / max(n_batches, 1), + } + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + auc = float(roc_auc_score(all_labels, scores)) + + # Youden's J: threshold that maximises (sensitivity + specificity - 1) + fpr, tpr, thresholds = roc_curve(all_labels, scores) + j_idx = int(np.argmax(tpr - fpr)) + threshold = float(thresholds[j_idx]) + + preds = (scores >= threshold).astype(int) + precision = float(precision_score(all_labels, preds, zero_division=0)) + recall = float(recall_score(all_labels, preds, zero_division=0)) + f1 = float(f1_score(all_labels, preds, zero_division=0)) + bal_acc = float(balanced_accuracy_score(all_labels, preds)) + except (ValueError, IndexError): + return fallback + + return { + "auc": auc, + "precision": precision, + "recall": recall, + "f1": f1, + "balanced_accuracy": bal_acc, + "threshold": threshold, + "loss": total_loss / max(n_batches, 1), + } + + +def run_experiment( + dataset: EEGGCNNDataset, + task: EEGGCNNDiseaseDetection, + epochs: int = 20, + batch_size: int = 32, + lr: float = 1e-3, + demo: bool = False, +) -> Dict[str, float]: + """Run a single experiment: set_task, split, train, evaluate.""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if demo: + sample_dataset = build_demo_dataset(task) + else: + sample_dataset = dataset.set_task(task) + train_ds, val_ds, test_ds = split_by_patient( + sample_dataset, [0.7, 0.1, 0.2] + ) + + train_loader = get_dataloader(train_ds, batch_size=batch_size, shuffle=True) + val_loader = get_dataloader(val_ds, batch_size=batch_size, shuffle=False) + test_loader = get_dataloader(test_ds, batch_size=batch_size, shuffle=False) + + model = GCN( + dataset=sample_dataset, + embedding_dim=64, + nhid=32, + dropout=0.5, + num_layers=2, + ).to(device) + + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + + best_val_auc = 0.0 + best_state = None + for epoch in range(epochs): + train_loss = train_one_epoch(model, train_loader, optimizer, device) + val_metrics = evaluate(model, val_loader, device) + logger.info( + "Epoch %d/%d — train_loss=%.4f val_auc=%.4f", + epoch + 1, epochs, train_loss, val_metrics["auc"], + ) + if val_metrics["auc"] > best_val_auc: + best_val_auc = val_metrics["auc"] + best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} + + if best_state is not None: + model.load_state_dict(best_state) + test_metrics = evaluate(model, test_loader, device) + logger.info( + "Test — AUC=%.4f Prec=%.4f Rec=%.4f F1=%.4f BalAcc=%.4f", + test_metrics["auc"], + test_metrics.get("precision", float("nan")), + test_metrics.get("recall", float("nan")), + test_metrics.get("f1", float("nan")), + test_metrics.get("balanced_accuracy", float("nan")), + ) + return test_metrics + + +# ------------------------------------------------------------------- +# Ablation experiments +# ------------------------------------------------------------------- + +def ablation_adjacency( + dataset: EEGGCNNDataset, **kwargs +) -> Dict[str, Dict[str, float]]: + """Experiment 1: Adjacency type ablation.""" + results = {} + for adj_type in ("combined", "spatial", "functional", "none"): + logger.info("=== Adjacency type: %s ===", adj_type) + task = EEGGCNNDiseaseDetection(adjacency_type=adj_type) + results[adj_type] = run_experiment(dataset, task, **kwargs) + return results + + +def ablation_frequency_bands( + dataset: EEGGCNNDataset, **kwargs +) -> Dict[str, Dict[str, float]]: + """Experiment 2: Frequency band ablation. + + Tests individual bands and progressive combinations. + """ + band_names = list(DEFAULT_BANDS.keys()) + results = {} + + # Individual bands + for name in band_names: + logger.info("=== Single band: %s ===", name) + task = EEGGCNNDiseaseDetection( + bands={name: DEFAULT_BANDS[name]} + ) + results[name] = run_experiment(dataset, task, **kwargs) + + # Progressive combinations + for k in range(2, len(band_names) + 1): + combo_names = band_names[:k] + combo_key = "+".join(combo_names) + logger.info("=== Band combination: %s ===", combo_key) + combo_bands = {n: DEFAULT_BANDS[n] for n in combo_names} + task = EEGGCNNDiseaseDetection(bands=combo_bands) + results[combo_key] = run_experiment(dataset, task, **kwargs) + + return results + + +def ablation_connectivity( + dataset: EEGGCNNDataset, **kwargs +) -> Dict[str, Dict[str, float]]: + """Experiment 3: Connectivity measure ablation.""" + results = {} + for measure in ("coherence", "wpli"): + logger.info("=== Connectivity: %s ===", measure) + task = EEGGCNNDiseaseDetection( + adjacency_type="functional", + connectivity_measure=measure, + ) + results[measure] = run_experiment(dataset, task, **kwargs) + return results + + +# ------------------------------------------------------------------- +# Main +# ------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="EEG-GCNN ablation study" + ) + parser.add_argument( + "--root", type=str, default=None, + help="Root directory of EEG-GCNN data (TUAB + LEMON)", + ) + parser.add_argument( + "--demo", action="store_true", + help="Run with synthetic data (no downloads needed)", + ) + parser.add_argument("--epochs", type=int, default=20) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument( + "--experiment", + type=str, + default="all", + choices=["all", "adjacency", "bands", "connectivity"], + help="Which ablation experiment to run", + ) + args = parser.parse_args() + + if not args.demo and args.root is None: + parser.error("--root is required unless --demo is set") + + dataset = None + if not args.demo: + dataset = EEGGCNNDataset(root=args.root) + dataset.stats() + + if args.demo: + logger.info( + "Running in DEMO mode with synthetic data " + "(results are illustrative, not meaningful)" + ) + + train_kwargs = dict( + epochs=args.epochs, + batch_size=args.batch_size, + lr=args.lr, + demo=args.demo, + ) + + all_results = {} + + if args.experiment in ("all", "adjacency"): + logger.info("\n" + "=" * 60) + logger.info("EXPERIMENT 1: Adjacency Type Ablation") + logger.info("=" * 60) + all_results["adjacency"] = ablation_adjacency(dataset, **train_kwargs) + + if args.experiment in ("all", "bands"): + logger.info("\n" + "=" * 60) + logger.info("EXPERIMENT 2: Frequency Band Ablation") + logger.info("=" * 60) + all_results["bands"] = ablation_frequency_bands( + dataset, **train_kwargs + ) + + if args.experiment in ("all", "connectivity"): + logger.info("\n" + "=" * 60) + logger.info("EXPERIMENT 3: Connectivity Measure Ablation") + logger.info("=" * 60) + all_results["connectivity"] = ablation_connectivity( + dataset, **train_kwargs + ) + + # Print summary (mirrors Table 2 of Wagh & Varatharajah, 2020) + logger.info("\n" + "=" * 60) + logger.info("ABLATION RESULTS SUMMARY") + logger.info( + " %-30s %6s %6s %6s %6s %8s", + "Config", "AUC", "Prec", "Recall", "F1", "BalAcc", + ) + logger.info("=" * 60) + for exp_name, exp_results in all_results.items(): + logger.info("\n--- %s ---", exp_name) + for config, metrics in exp_results.items(): + logger.info( + " %-30s %.4f %.4f %.4f %.4f %.4f", + config, + metrics.get("auc", float("nan")), + metrics.get("precision", float("nan")), + metrics.get("recall", float("nan")), + metrics.get("f1", float("nan")), + metrics.get("balanced_accuracy", float("nan")), + ) + + # Save results to JSON + output_path = "eeg_gcnn_ablation_results.json" + with open(output_path, "w") as f: + json.dump(all_results, f, indent=2) + logger.info("\nResults saved to %s", output_path) + + +if __name__ == "__main__": + main() diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 54e77670c..fa2ac5b4a 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -80,6 +80,7 @@ def __init__(self, *args, **kwargs): split_by_visit, split_by_visit_conformal, ) +from .eeg_gcnn import EEGGCNNDataset from .tuab import TUABDataset from .tuev import TUEVDataset from .utils import ( diff --git a/pyhealth/datasets/configs/eeg_gcnn.yaml b/pyhealth/datasets/configs/eeg_gcnn.yaml new file mode 100644 index 000000000..210cb3908 --- /dev/null +++ b/pyhealth/datasets/configs/eeg_gcnn.yaml @@ -0,0 +1,20 @@ +version: "1.0.0" +tables: + tuab: + file_path: "eeg_gcnn-tuab-pyhealth.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "record_id" + - "signal_file" + - "source" + - "label" + lemon: + file_path: "eeg_gcnn-lemon-pyhealth.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "record_id" + - "signal_file" + - "source" + - "label" diff --git a/pyhealth/datasets/eeg_gcnn.py b/pyhealth/datasets/eeg_gcnn.py new file mode 100644 index 000000000..746ae422b --- /dev/null +++ b/pyhealth/datasets/eeg_gcnn.py @@ -0,0 +1,236 @@ +import logging +from pathlib import Path +from typing import Optional + +import pandas as pd + +from .base_dataset import BaseDataset + +logger = logging.getLogger(__name__) + + +class EEGGCNNDataset(BaseDataset): + """EEG-GCNN dataset pooling TUAB normal-subset and MPI LEMON controls. + + This dataset supports the EEG-GCNN paper (Wagh & Varatharajah, ML4H @ + NeurIPS 2020) which distinguishes "normal-appearing" patient EEGs (from + TUAB) from truly healthy EEGs (from MPI LEMON). + + **TUAB (normal subset):** The Temple University EEG Abnormal Corpus + provides EDF recordings labelled normal/abnormal. Only the *normal* + recordings are used here — these are the "patient" class (label 0). + + **MPI LEMON:** The Leipzig Study for Mind-Body-Emotion Interactions + provides BrainVision EEG recordings from healthy controls — these form + the "healthy" class (label 1). + + Paper: + Wagh, N. & Varatharajah, Y. (2020). EEG-GCNN: Augmenting + Electroencephalogram-based Neurological Disease Diagnosis using a + Domain-guided Graph Convolutional Neural Network. *Proceedings of + Machine Learning for Health (ML4H) at NeurIPS 2020*, PMLR 136. + https://proceedings.mlr.press/v136/wagh20a.html + + Authors' code: https://github.com/neerajwagh/eeg-gcnn + + Args: + root: Root directory containing TUAB and/or LEMON data. + Expected structure for TUAB:: + + /train/normal/01_tcp_ar//*.edf + /eval/normal/01_tcp_ar//*.edf + + Expected structure for LEMON:: + + /lemon//*.vhdr + + dataset_name: Name of the dataset. Defaults to ``"eeg_gcnn"``. + config_path: Path to the YAML config. Defaults to the built-in + ``eeg_gcnn.yaml``. + subset: Which data source(s) to load. One of ``"tuab"``, + ``"lemon"``, or ``"both"`` (default). + dev: If ``True``, limit to a small subset for quick iteration. + + Attributes: + task: Optional task name after ``set_task()`` is called. + samples: Sample list after task is set. + patient_to_index: Maps patient IDs to sample indices. + visit_to_index: Maps visit/record IDs to sample indices. + + Examples: + >>> from pyhealth.datasets import EEGGCNNDataset + >>> from pyhealth.tasks import EEGGCNNDiseaseDetection + >>> dataset = EEGGCNNDataset( + ... root="/data/eeg-gcnn/", + ... ) + >>> dataset.stats() + >>> sample_dataset = dataset.set_task(EEGGCNNDiseaseDetection()) + >>> sample = sample_dataset[0] + >>> print(sample["psd_features"].shape) # (8, 6) + """ + + def __init__( + self, + root: str, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + subset: Optional[str] = "both", + **kwargs, + ) -> None: + if config_path is None: + config_path = ( + Path(__file__).parent / "configs" / "eeg_gcnn.yaml" + ) + + self.root = root + + if subset == "tuab": + tables = ["tuab"] + elif subset == "lemon": + tables = ["lemon"] + elif subset == "both": + tables = ["tuab", "lemon"] + else: + raise ValueError( + "subset must be one of 'tuab', 'lemon', or 'both'" + ) + + self.prepare_metadata() + + root_path = Path(root) + cache_dir = Path.home() / ".cache" / "pyhealth" / "eeg_gcnn" + + use_cache = False + for table in tables: + shared_csv = root_path / f"eeg_gcnn-{table}-pyhealth.csv" + cache_csv = cache_dir / f"eeg_gcnn-{table}-pyhealth.csv" + if not shared_csv.exists() and cache_csv.exists(): + use_cache = True + break + + if use_cache: + logger.info("Using cached metadata from %s", cache_dir) + root = str(cache_dir) + + super().__init__( + root=root, + tables=tables, + dataset_name=dataset_name or "eeg_gcnn", + config_path=config_path, + **kwargs, + ) + + def prepare_metadata(self) -> None: + """Build and save metadata CSVs for TUAB normal subset and LEMON. + + Writes: + - ``/eeg_gcnn-tuab-pyhealth.csv`` + - ``/eeg_gcnn-lemon-pyhealth.csv`` + + TUAB filenames: ``__.edf`` + LEMON filenames: ``sub-.vhdr`` (BrainVision header) + """ + root = Path(self.root) + cache_dir = Path.home() / ".cache" / "pyhealth" / "eeg_gcnn" + + # --- TUAB normal subset --- + shared_csv = root / "eeg_gcnn-tuab-pyhealth.csv" + cache_csv = cache_dir / "eeg_gcnn-tuab-pyhealth.csv" + + if not shared_csv.exists() and not cache_csv.exists(): + tuab_rows = [] + for split in ("train", "eval"): + normal_dir = root / split / "normal" / "01_tcp_ar" + if not normal_dir.is_dir(): + logger.debug( + "TUAB normal dir not found: %s", normal_dir + ) + continue + for edf in sorted(normal_dir.rglob("*.edf")): + parts = edf.stem.split("_") + patient_id = f"tuab_{parts[0]}" + record_id = parts[1] if len(parts) > 1 else "0" + tuab_rows.append( + { + "patient_id": patient_id, + "record_id": record_id, + "signal_file": str(edf), + "source": "tuab", + "label": 0, + } + ) + + if tuab_rows: + df = pd.DataFrame(tuab_rows) + df.sort_values( + ["patient_id", "record_id"], + inplace=True, + na_position="last", + ) + df.reset_index(drop=True, inplace=True) + self._write_csv(df, shared_csv, cache_dir, "tuab") + + # --- LEMON healthy controls --- + shared_csv = root / "eeg_gcnn-lemon-pyhealth.csv" + cache_csv = cache_dir / "eeg_gcnn-lemon-pyhealth.csv" + + if not shared_csv.exists() and not cache_csv.exists(): + lemon_rows = [] + lemon_dir = root / "lemon" + if lemon_dir.is_dir(): + for subject_dir in sorted(lemon_dir.iterdir()): + if not subject_dir.is_dir(): + continue + for vhdr in sorted(subject_dir.glob("*.vhdr")): + patient_id = f"lemon_{subject_dir.name}" + record_id = vhdr.stem + lemon_rows.append( + { + "patient_id": patient_id, + "record_id": record_id, + "signal_file": str(vhdr), + "source": "lemon", + "label": 1, + } + ) + + if lemon_rows: + df = pd.DataFrame(lemon_rows) + df.sort_values( + ["patient_id", "record_id"], + inplace=True, + na_position="last", + ) + df.reset_index(drop=True, inplace=True) + self._write_csv(df, shared_csv, cache_dir, "lemon") + + @staticmethod + def _write_csv( + df: "pd.DataFrame", + shared_path: Path, + cache_dir: Path, + table_name: str, + ) -> None: + """Write CSV to shared location, falling back to cache.""" + try: + shared_path.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(shared_path, index=False) + logger.info("Wrote %s metadata to %s", table_name, shared_path) + except (PermissionError, OSError): + cache_dir.mkdir(parents=True, exist_ok=True) + cache_path = cache_dir / shared_path.name + df.to_csv(cache_path, index=False) + logger.info( + "Wrote %s metadata to cache: %s", table_name, cache_path + ) + + @property + def default_task(self): + """Returns the default task for the EEG-GCNN dataset. + + Returns: + EEGGCNNDiseaseDetection: The default task instance. + """ + from pyhealth.tasks import EEGGCNNDiseaseDetection + + return EEGGCNNDiseaseDetection() diff --git a/pyhealth/tasks/EEG_abnormal.py b/pyhealth/tasks/EEG_abnormal.py new file mode 100644 index 000000000..5ce4d4b54 --- /dev/null +++ b/pyhealth/tasks/EEG_abnormal.py @@ -0,0 +1,165 @@ +import os +import pickle +import mne + + +def EEG_isAbnormal_fn(record): + """Processes a single patient for the abnormal EEG detection task on TUAB. + + Abnormal EEG detection aims at determining whether a EEG is abnormal. + + Args: + record: a singleton list of one subject from the TUABDataset. + The (single) record is a dictionary with the following keys: + load_from_path, patient_id, visit_id, signal_file, label_file, save_to_path + + Returns: + samples: a list of samples, each sample is a dict with patient_id, visit_id, record_id, + and epoch_path (the path to the saved epoch {"signal": signal, "label": label} as key. + + Note that we define the task as a binary classification task. + + Examples: + >>> from pyhealth.datasets import TUABDataset + >>> isabnormal = TUABDataset( + ... root="/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf/", download=True, + ... ) + >>> from pyhealth.tasks import EEG_isabnormal_fn + >>> EEG_abnormal_ds = isabnormal.set_task(EEG_isAbnormal_fn) + >>> EEG_abnormal_ds.samples[0] + { + 'patient_id': 'aaaaamye', + 'visit_id': 's001', + 'record_id': '1', + 'epoch_path': '/home/zhenlin4/.cache/pyhealth/datasets/832afe6e6e8a5c9ea5505b47e7af8125/10-1/1/0.pkl', + 'label': 1 + } + """ + + samples = [] + for visit in record: + root, pid, visit_id, signal, label, save_path = ( + visit["load_from_path"], + visit["patient_id"], + visit["visit_id"], + visit["signal_file"], + visit["label_file"], + visit["save_to_path"], + ) + + raw = mne.io.read_raw_edf(os.path.join(root, signal), preload=True) + raw.resample(200) + ch_name = raw.ch_names + raw_data = raw.get_data() + channeled_data = raw_data.copy()[:16] + try: + channeled_data[0] = ( + raw_data[ch_name.index("EEG FP1-REF")] + - raw_data[ch_name.index("EEG F7-REF")] + ) + channeled_data[1] = ( + raw_data[ch_name.index("EEG F7-REF")] + - raw_data[ch_name.index("EEG T3-REF")] + ) + channeled_data[2] = ( + raw_data[ch_name.index("EEG T3-REF")] + - raw_data[ch_name.index("EEG T5-REF")] + ) + channeled_data[3] = ( + raw_data[ch_name.index("EEG T5-REF")] + - raw_data[ch_name.index("EEG O1-REF")] + ) + channeled_data[4] = ( + raw_data[ch_name.index("EEG FP2-REF")] + - raw_data[ch_name.index("EEG F8-REF")] + ) + channeled_data[5] = ( + raw_data[ch_name.index("EEG F8-REF")] + - raw_data[ch_name.index("EEG T4-REF")] + ) + channeled_data[6] = ( + raw_data[ch_name.index("EEG T4-REF")] + - raw_data[ch_name.index("EEG T6-REF")] + ) + channeled_data[7] = ( + raw_data[ch_name.index("EEG T6-REF")] + - raw_data[ch_name.index("EEG O2-REF")] + ) + channeled_data[8] = ( + raw_data[ch_name.index("EEG FP1-REF")] + - raw_data[ch_name.index("EEG F3-REF")] + ) + channeled_data[9] = ( + raw_data[ch_name.index("EEG F3-REF")] + - raw_data[ch_name.index("EEG C3-REF")] + ) + channeled_data[10] = ( + raw_data[ch_name.index("EEG C3-REF")] + - raw_data[ch_name.index("EEG P3-REF")] + ) + channeled_data[11] = ( + raw_data[ch_name.index("EEG P3-REF")] + - raw_data[ch_name.index("EEG O1-REF")] + ) + channeled_data[12] = ( + raw_data[ch_name.index("EEG FP2-REF")] + - raw_data[ch_name.index("EEG F4-REF")] + ) + channeled_data[13] = ( + raw_data[ch_name.index("EEG F4-REF")] + - raw_data[ch_name.index("EEG C4-REF")] + ) + channeled_data[14] = ( + raw_data[ch_name.index("EEG C4-REF")] + - raw_data[ch_name.index("EEG P4-REF")] + ) + channeled_data[15] = ( + raw_data[ch_name.index("EEG P4-REF")] + - raw_data[ch_name.index("EEG O2-REF")] + ) + except: + with open("tuab-process-error-files.txt", "a") as f: + f.write(os.path.join(root, signal) + "\n") + continue + + # get the label + data_field = pid.split("_")[0] + if data_field == "0" or data_field == "2": + label = 1 + else: + label = 0 + + # load data + for i in range(channeled_data.shape[1] // 2000): + dump_path = os.path.join( + save_path, pid + "_" + visit_id + "_" + str(i) + ".pkl" + ) + pickle.dump( + {"signal": channeled_data[:, i * 2000 : (i + 1) * 2000], "label": label}, + open(dump_path, "wb"), + ) + + samples.append( + { + "patient_id": pid, + "visit_id": visit_id, + "record_id": i, + "epoch_path": dump_path, + "label": label, + } + ) + + return samples + + +if __name__ == "__main__": + from pyhealth.datasets import TUABDataset + + dataset = TUABDataset( + root="/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf/", + dev=True, + refresh_cache=True, + ) + EEG_abnormal_ds = dataset.set_task(EEG_isAbnormal_fn) + print(EEG_abnormal_ds.samples[0]) + print(EEG_abnormal_ds.input_info) diff --git a/pyhealth/tasks/EEG_events.py b/pyhealth/tasks/EEG_events.py new file mode 100644 index 000000000..5993413a9 --- /dev/null +++ b/pyhealth/tasks/EEG_events.py @@ -0,0 +1,203 @@ +import os +import pickle +import mne +import numpy as np + + +def EEG_events_fn(record): + """Processes a single patient for the EEG events task on TUEV. + + This task aims at annotating of EEG segments as one of six classes: (1) spike and sharp wave (SPSW), (2) generalized periodic epileptiform discharges (GPED), (3) periodic lateralized epileptiform discharges (PLED), (4) eye movement (EYEM), (5) artifact (ARTF) and (6) background (BCKG). + + Args: + record: a singleton list of one subject from the TUEVDataset. + The (single) record is a dictionary with the following keys: + load_from_path, patient_id, visit_id, signal_file, label_file, save_to_path + + Returns: + samples: a list of samples, each sample is a dict with patient_id, visit_id, record_id, label, offending_channel, + and epoch_path (the path to the saved epoch {"signal": signal, "label": label} as key. + + Note that we define the task as a multiclass classification task. + + Examples: + >>> from pyhealth.datasets import TUEVDataset + >>> EEGevents = TUEVDataset( + ... root="/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf/", download=True, + ... ) + >>> from pyhealth.tasks import EEG_events_fn + >>> EEG_events_ds = EEGevents.set_task(EEG_events_fn) + >>> EEG_events_ds.samples[0] + { + 'patient_id': '0_00002265', + 'visit_id': '00000001', + 'record_id': 0, + 'epoch_path': '/Users/liyanjing/.cache/pyhealth/datasets/d8f3cb92cc444d481444d3414fb5240c/0_00002265_00000001_0.pkl', + 'label': 6, + 'offending_channel': array([4.]) + } + """ + + samples = [] + for visit in record: + root, pid, visit_id, signal, label, save_path = ( + visit["load_from_path"], + visit["patient_id"], + visit["visit_id"], + visit["signal_file"], + visit["label_file"], + visit["save_to_path"], + ) + + + # load data + try: + [signals, times, event, Rawdata] = readEDF( + os.path.join(root, signal) + ) # event is the .rec file in the form of an array + signals = convert_signals(signals, Rawdata) + except (ValueError, KeyError): + print("something funky happened in " + os.path.join(root, signal)) + continue + signals, offending_channels, labels = BuildEvents(signals, times, event) + + for idx, (signal, offending_channel, label) in enumerate( + zip(signals, offending_channels, labels) + ): + dump_path = os.path.join( + save_path, pid + "_" + visit_id + "_" + str(idx) + ".pkl" + ) + + pickle.dump( + {"signal": signal, "label": int(label[0])}, + open(dump_path, "wb"), + ) + + samples.append( + { + "patient_id": pid, + "visit_id": visit_id, + "record_id": idx, + "epoch_path": dump_path, + "label": int(label[0]), + "offending_channel": offending_channel, + } + ) + + return samples + +def BuildEvents(signals, times, EventData): + [numEvents, z] = EventData.shape # numEvents is equal to # of rows of the .rec file + fs = 250.0 + [numChan, numPoints] = signals.shape + + features = np.zeros([numEvents, numChan, int(fs) * 5]) + offending_channel = np.zeros([numEvents, 1]) # channel that had the detected thing + labels = np.zeros([numEvents, 1]) + offset = signals.shape[1] + signals = np.concatenate([signals, signals, signals], axis=1) + for i in range(numEvents): # for each event + chan = int(EventData[i, 0]) # chan is channel + start = np.where((times) >= EventData[i, 1])[0][0] + end = np.where((times) >= EventData[i, 2])[0][0] + features[i, :] = signals[ + :, offset + start - 2 * int(fs) : offset + end + 2 * int(fs) + ] + offending_channel[i, :] = int(chan) + labels[i, :] = int(EventData[i, 3]) + return [features, offending_channel, labels] + + +def convert_signals(signals, Rawdata): + signal_names = { + k: v + for (k, v) in zip( + Rawdata.info["ch_names"], list(range(len(Rawdata.info["ch_names"]))) + ) + } + new_signals = np.vstack( + ( + signals[signal_names["EEG FP1-REF"]] + - signals[signal_names["EEG F7-REF"]], # 0 + ( + signals[signal_names["EEG F7-REF"]] + - signals[signal_names["EEG T3-REF"]] + ), # 1 + ( + signals[signal_names["EEG T3-REF"]] + - signals[signal_names["EEG T5-REF"]] + ), # 2 + ( + signals[signal_names["EEG T5-REF"]] + - signals[signal_names["EEG O1-REF"]] + ), # 3 + ( + signals[signal_names["EEG FP2-REF"]] + - signals[signal_names["EEG F8-REF"]] + ), # 4 + ( + signals[signal_names["EEG F8-REF"]] + - signals[signal_names["EEG T4-REF"]] + ), # 5 + ( + signals[signal_names["EEG T4-REF"]] + - signals[signal_names["EEG T6-REF"]] + ), # 6 + ( + signals[signal_names["EEG T6-REF"]] + - signals[signal_names["EEG O2-REF"]] + ), # 7 + ( + signals[signal_names["EEG FP1-REF"]] + - signals[signal_names["EEG F3-REF"]] + ), # 14 + ( + signals[signal_names["EEG F3-REF"]] + - signals[signal_names["EEG C3-REF"]] + ), # 15 + ( + signals[signal_names["EEG C3-REF"]] + - signals[signal_names["EEG P3-REF"]] + ), # 16 + ( + signals[signal_names["EEG P3-REF"]] + - signals[signal_names["EEG O1-REF"]] + ), # 17 + ( + signals[signal_names["EEG FP2-REF"]] + - signals[signal_names["EEG F4-REF"]] + ), # 18 + ( + signals[signal_names["EEG F4-REF"]] + - signals[signal_names["EEG C4-REF"]] + ), # 19 + ( + signals[signal_names["EEG C4-REF"]] + - signals[signal_names["EEG P4-REF"]] + ), # 20 + (signals[signal_names["EEG P4-REF"]] - signals[signal_names["EEG O2-REF"]]), + ) + ) # 21 + return new_signals + + +def readEDF(fileName): + Rawdata = mne.io.read_raw_edf(fileName) + signals, times = Rawdata[:] + RecFile = fileName[0:-3] + "rec" + eventData = np.genfromtxt(RecFile, delimiter=",") + Rawdata.close() + return [signals, times, eventData, Rawdata] + + +if __name__ == "__main__": + from pyhealth.datasets import TUEVDataset + + dataset = TUEVDataset( + root="/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf/", + dev=True, + refresh_cache=True, + ) + EEG_events_ds = dataset.set_task(EEG_events_fn) + print(EEG_events_ds.samples[0]) + print(EEG_events_ds.input_info) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..f2bea8645 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -21,6 +21,9 @@ drug_recommendation_mimic4_fn, drug_recommendation_omop_fn, ) +from .EEG_abnormal import EEG_isAbnormal_fn +from .EEG_events import EEG_events_fn +from .eeg_gcnn_nd_detection import EEGGCNNDiseaseDetection from .in_hospital_mortality_mimic4 import InHospitalMortalityMIMIC4 from .length_of_stay_prediction import ( LengthOfStayPredictioneICU, diff --git a/pyhealth/tasks/eeg_gcnn_nd_detection.py b/pyhealth/tasks/eeg_gcnn_nd_detection.py new file mode 100644 index 000000000..16302623a --- /dev/null +++ b/pyhealth/tasks/eeg_gcnn_nd_detection.py @@ -0,0 +1,360 @@ +"""EEG-GCNN neurological disease detection task. + +Implements the preprocessing and feature-extraction pipeline from: + + Wagh, N. & Varatharajah, Y. (2020). EEG-GCNN: Augmenting + Electroencephalogram-based Neurological Disease Diagnosis using a + Domain-guided Graph Convolutional Neural Network. ML4H @ NeurIPS 2020. + https://proceedings.mlr.press/v136/wagh20a.html +""" + +import logging +from typing import Any, Dict, List, Optional, Tuple + +import mne +import numpy as np +import torch + +from pyhealth.tasks.base_task import BaseTask + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Constants from the EEG-GCNN paper (Section 4) +# --------------------------------------------------------------------------- + +# 8 bipolar channels used in the paper. +# Each entry is (anode_ref_name, cathode_ref_name, canonical_name). +BIPOLAR_CHANNELS: List[Tuple[str, str, str]] = [ + ("EEG F7-REF", "EEG F3-REF", "F7-F3"), + ("EEG F8-REF", "EEG F4-REF", "F8-F4"), + ("EEG T3-REF", "EEG C3-REF", "T7-C3"), + ("EEG T4-REF", "EEG C4-REF", "T8-C4"), + ("EEG T5-REF", "EEG P3-REF", "P7-P3"), + ("EEG T6-REF", "EEG P4-REF", "P8-P4"), + ("EEG O1-REF", "EEG P3-REF", "O1-P3"), + ("EEG O2-REF", "EEG P4-REF", "O2-P4"), +] + +# 3D MNI coordinates for the 8 bipolar channel mid-points (approximate). +# Used to build the spatial adjacency matrix. +CHANNEL_POSITIONS_MNI: np.ndarray = np.array( + [ + [-0.054, 0.044, 0.038], # F7-F3 + [0.054, 0.044, 0.038], # F8-F4 + [-0.069, -0.014, 0.034], # T7-C3 + [0.069, -0.014, 0.034], # T8-C4 + [-0.059, -0.067, 0.034], # P7-P3 + [0.059, -0.067, 0.034], # P8-P4 + [-0.037, -0.094, 0.020], # O1-P3 + [0.037, -0.094, 0.020], # O2-P4 + ], + dtype=np.float64, +) + +# Frequency bands (Hz) from the paper. +DEFAULT_BANDS: Dict[str, Tuple[float, float]] = { + "delta": (0.5, 4.0), + "theta": (4.0, 8.0), + "alpha": (8.0, 12.0), + "lower_beta": (12.0, 20.0), + "higher_beta": (20.0, 30.0), + "gamma": (30.0, 50.0), +} + +NUM_CHANNELS = 8 + + +class EEGGCNNDiseaseDetection(BaseTask): + """Binary classification: patient-normal (0) vs healthy-control (1). + + For each EEG recording the task: + + 1. Reads the EDF / BrainVision file via MNE. + 2. Resamples to ``resample_rate`` Hz, applies a 1 Hz high-pass and a + ``notch_freq`` Hz notch filter. + 3. Computes the 8 bipolar channels defined in the paper. + 4. Segments the signal into non-overlapping ``window_sec``-second windows. + 5. For each window, extracts PSD band-power features (8 channels x + ``len(bands)`` bands) via Welch's method. + 6. Computes an 8x8 graph adjacency matrix (spatial, functional, combined, + or identity) to accompany each sample. + + Args: + resample_rate: Target sampling rate in Hz. Default ``250``. + highpass_freq: High-pass filter cutoff in Hz. Default ``1.0``. + notch_freq: Notch filter frequency in Hz. Default ``50.0``. + window_sec: Window length in seconds. Default ``10``. + bands: Frequency bands to extract. Defaults to all 6 paper bands. + Pass a subset dict to run a band-ablation experiment. + adjacency_type: One of ``"combined"`` (default), ``"spatial"``, + ``"functional"``, or ``"none"`` (identity matrix). + connectivity_measure: ``"coherence"`` (default) or ``"wpli"``. + + Examples: + >>> from pyhealth.datasets import EEGGCNNDataset + >>> from pyhealth.tasks import EEGGCNNDiseaseDetection + >>> dataset = EEGGCNNDataset(root="/data/eeg-gcnn/") + >>> task = EEGGCNNDiseaseDetection(adjacency_type="spatial") + >>> sample_dataset = dataset.set_task(task) + >>> sample = sample_dataset[0] + >>> print(sample["node_features"].shape) # (8, 6) + >>> print(sample["adj_matrix"].shape) # (8, 8) + """ + + task_name: str = "eeg_gcnn_nd_detection" + input_schema: Dict[str, str] = { + "node_features": "tensor", + "adj_matrix": "tensor", + } + output_schema: Dict[str, str] = {"label": "binary"} + + def __init__( + self, + resample_rate: int = 250, + highpass_freq: float = 1.0, + notch_freq: float = 50.0, + window_sec: int = 10, + bands: Optional[Dict[str, Tuple[float, float]]] = None, + adjacency_type: str = "combined", + connectivity_measure: str = "coherence", + ) -> None: + super().__init__() + self.resample_rate = resample_rate + self.highpass_freq = highpass_freq + self.notch_freq = notch_freq + self.window_sec = window_sec + self.bands = bands if bands is not None else dict(DEFAULT_BANDS) + if adjacency_type not in ("combined", "spatial", "functional", "none"): + raise ValueError( + f"adjacency_type must be 'combined', 'spatial', " + f"'functional', or 'none', got '{adjacency_type}'" + ) + self.adjacency_type = adjacency_type + if connectivity_measure not in ("coherence", "wpli"): + raise ValueError( + f"connectivity_measure must be 'coherence' or 'wpli', " + f"got '{connectivity_measure}'" + ) + self.connectivity_measure = connectivity_measure + + # Pre-compute spatial adjacency (constant across recordings). + self._spatial_adj = self._build_spatial_adjacency() + + # ------------------------------------------------------------------ + # Signal I/O and preprocessing + # ------------------------------------------------------------------ + + def _read_eeg(self, filepath: str) -> mne.io.BaseRaw: + """Read an EEG file (EDF or BrainVision) into MNE Raw.""" + filepath_lower = filepath.lower() + if filepath_lower.endswith(".edf"): + raw = mne.io.read_raw_edf(filepath, preload=True, verbose="error") + elif filepath_lower.endswith(".vhdr"): + raw = mne.io.read_raw_brainvision( + filepath, preload=True, verbose="error" + ) + else: + raise ValueError(f"Unsupported EEG format: {filepath}") + return raw + + def _preprocess(self, raw: mne.io.BaseRaw) -> mne.io.BaseRaw: + """Resample, high-pass, and notch-filter.""" + raw.filter( + l_freq=self.highpass_freq, + h_freq=None, + verbose="error", + ) + raw.notch_filter(self.notch_freq, verbose="error") + raw.resample(self.resample_rate, verbose="error") + return raw + + @staticmethod + def _compute_bipolar(raw: mne.io.BaseRaw) -> np.ndarray: + """Compute the 8 bipolar channels from reference montage. + + Returns: + np.ndarray of shape ``(8, n_samples)``. + """ + data = raw.get_data() + ch_map = { + name: idx for idx, name in enumerate(raw.ch_names) + } + bipolar = np.zeros((NUM_CHANNELS, data.shape[1])) + for i, (anode, cathode, _) in enumerate(BIPOLAR_CHANNELS): + bipolar[i] = data[ch_map[anode]] - data[ch_map[cathode]] + return bipolar + + # ------------------------------------------------------------------ + # Feature extraction + # ------------------------------------------------------------------ + + def _extract_psd_features( + self, window: np.ndarray + ) -> np.ndarray: + """Extract PSD band-power features for one window. + + Args: + window: shape ``(8, n_samples)``. + + Returns: + np.ndarray of shape ``(8, n_bands)`` — log10 average PSD per + channel per band via Welch's method. + """ + from scipy.signal import welch + + fs = self.resample_rate + band_list = list(self.bands.values()) + n_bands = len(band_list) + features = np.zeros((NUM_CHANNELS, n_bands)) + + freqs, pxx = welch(window, fs=fs, nperseg=min(fs * 2, window.shape[1])) + + for b_idx, (fmin, fmax) in enumerate(band_list): + band_mask = (freqs >= fmin) & (freqs < fmax) + if band_mask.any(): + features[:, b_idx] = np.log10( + pxx[:, band_mask].mean(axis=1) + 1e-10 + ) + + return features + + # ------------------------------------------------------------------ + # Adjacency matrices + # ------------------------------------------------------------------ + + @staticmethod + def _build_spatial_adjacency() -> np.ndarray: + """Build spatial adjacency from inverse Euclidean distance. + + Returns: + np.ndarray of shape ``(8, 8)`` with self-loops set to 1. + """ + pos = CHANNEL_POSITIONS_MNI + n = pos.shape[0] + dist = np.zeros((n, n)) + for i in range(n): + for j in range(n): + dist[i, j] = np.linalg.norm(pos[i] - pos[j]) + # Inverse distance (avoid div-by-zero on diagonal) + with np.errstate(divide="ignore"): + adj = np.where(dist > 0, 1.0 / dist, 0.0) + # Row-normalise + row_sum = adj.sum(axis=1, keepdims=True) + adj = adj / (row_sum + 1e-10) + # Self-loops + np.fill_diagonal(adj, 1.0) + return adj + + def _build_functional_adjacency( + self, window: np.ndarray + ) -> np.ndarray: + """Build functional adjacency from inter-channel connectivity. + + Args: + window: shape ``(8, n_samples)``. + + Returns: + np.ndarray of shape ``(8, 8)``. + """ + from mne_connectivity import spectral_connectivity_epochs + + fs = self.resample_rate + # spectral_connectivity_epochs expects (n_epochs, n_channels, n_times) + data_3d = window[np.newaxis, :, :] + + method_map = {"coherence": "coh", "wpli": "wpli"} + conn = spectral_connectivity_epochs( + data_3d, + method=method_map[self.connectivity_measure], + sfreq=fs, + fmin=0.5, + fmax=50.0, + verbose="error", + ) + conn_data = conn.get_data(output="dense") + # Average across frequencies → (8, 8) + adj = conn_data.mean(axis=-1) + adj = np.abs(adj) + # Symmetrise and set diagonal to 1 + adj = (adj + adj.T) / 2.0 + np.fill_diagonal(adj, 1.0) + return adj + + def _build_adjacency(self, window: np.ndarray) -> np.ndarray: + """Build the adjacency matrix according to ``adjacency_type``. + + Args: + window: shape ``(8, n_samples)``. + + Returns: + np.ndarray of shape ``(8, 8)``. + """ + if self.adjacency_type == "none": + return np.eye(NUM_CHANNELS) + elif self.adjacency_type == "spatial": + return self._spatial_adj.copy() + elif self.adjacency_type == "functional": + return self._build_functional_adjacency(window) + else: # combined + spatial = self._spatial_adj + functional = self._build_functional_adjacency(window) + combined = (spatial + functional) / 2.0 + np.fill_diagonal(combined, 1.0) + return combined + + # ------------------------------------------------------------------ + # Main entry point + # ------------------------------------------------------------------ + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process one patient and return a list of sample dicts. + + Each sample contains: + - ``patient_id``: str + - ``node_features``: torch.FloatTensor of shape ``(8, n_bands)`` + - ``adj_matrix``: torch.FloatTensor of shape ``(8, 8)`` + - ``label``: int (0 = patient-normal, 1 = healthy-control) + """ + pid = patient.patient_id + samples: List[Dict[str, Any]] = [] + fs = self.resample_rate + win_samples = int(self.window_sec * fs) + + for table in ("tuab", "lemon"): + events = patient.get_events(table) + for event in events: + filepath = event.signal_file + label = int(event.label) + + try: + raw = self._read_eeg(filepath) + raw = self._preprocess(raw) + bipolar = self._compute_bipolar(raw) + except (ValueError, KeyError, RuntimeError) as exc: + logger.warning( + "Skipping %s for patient %s: %s", + filepath, pid, exc, + ) + continue + + n_windows = bipolar.shape[1] // win_samples + for w in range(n_windows): + start = w * win_samples + end = start + win_samples + window = bipolar[:, start:end] + + psd_feat = self._extract_psd_features(window) + adj = self._build_adjacency(window) + + samples.append( + { + "patient_id": pid, + "signal_file": filepath, + "node_features": torch.FloatTensor(psd_feat), + "adj_matrix": torch.FloatTensor(adj), + "label": label, + } + ) + + return samples diff --git a/tests/test_eeg_gcnn.py b/tests/test_eeg_gcnn.py new file mode 100644 index 000000000..9e1370848 --- /dev/null +++ b/tests/test_eeg_gcnn.py @@ -0,0 +1,491 @@ +"""Tests for the EEG-GCNN dataset and task classes. + +All tests use synthetic data — no real EEG files are required. +Tests complete in milliseconds. +""" + +import sys +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import torch + +from pyhealth.tasks.eeg_gcnn_nd_detection import ( + BIPOLAR_CHANNELS, + DEFAULT_BANDS, + NUM_CHANNELS, + EEGGCNNDiseaseDetection, +) + + +# --------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------- + +@pytest.fixture +def task(): + """Default task instance.""" + return EEGGCNNDiseaseDetection() + + +@pytest.fixture +def task_spatial(): + """Task with spatial-only adjacency.""" + return EEGGCNNDiseaseDetection(adjacency_type="spatial") + + +@pytest.fixture +def task_none(): + """Task with identity adjacency.""" + return EEGGCNNDiseaseDetection(adjacency_type="none") + + +@pytest.fixture +def synthetic_bipolar_window(): + """Synthetic 10-second bipolar window at 250 Hz (8, 2500).""" + rng = np.random.RandomState(42) + return rng.randn(NUM_CHANNELS, 2500) + + +@pytest.fixture +def synthetic_bipolar_signal(): + """Synthetic 30-second bipolar signal at 250 Hz (8, 7500).""" + rng = np.random.RandomState(42) + return rng.randn(NUM_CHANNELS, 7500) + + +def _make_mock_raw(n_seconds=30, sfreq=250): + """Create a mock MNE Raw object with all required TUAB channels.""" + n_samples = int(n_seconds * sfreq) + rng = np.random.RandomState(0) + + ch_names_needed = set() + for anode, cathode, _ in BIPOLAR_CHANNELS: + ch_names_needed.add(anode) + ch_names_needed.add(cathode) + ch_names = sorted(ch_names_needed) + n_ch = len(ch_names) + + data = rng.randn(n_ch, n_samples) * 1e-5 # realistic EEG scale + + raw = MagicMock() + raw.ch_names = ch_names + raw.get_data.return_value = data + raw.filter.return_value = raw + raw.notch_filter.return_value = raw + raw.resample.return_value = raw + return raw + + +# --------------------------------------------------------------- +# Task initialization tests +# --------------------------------------------------------------- + +class TestTaskInit: + + def test_default_params(self, task): + assert task.resample_rate == 250 + assert task.highpass_freq == 1.0 + assert task.notch_freq == 50.0 + assert task.window_sec == 10 + assert task.adjacency_type == "combined" + assert task.connectivity_measure == "coherence" + assert len(task.bands) == 6 + + def test_custom_bands(self): + custom = {"alpha": (8.0, 12.0), "theta": (4.0, 8.0)} + task = EEGGCNNDiseaseDetection(bands=custom) + assert len(task.bands) == 2 + + def test_invalid_adjacency_type(self): + with pytest.raises(ValueError, match="adjacency_type"): + EEGGCNNDiseaseDetection(adjacency_type="invalid") + + def test_invalid_connectivity(self): + with pytest.raises(ValueError, match="connectivity_measure"): + EEGGCNNDiseaseDetection(connectivity_measure="plv") + + def test_task_schemas(self, task): + assert "node_features" in task.input_schema + assert "adj_matrix" in task.input_schema + assert "label" in task.output_schema + assert task.output_schema["label"] == "binary" + assert task.task_name == "eeg_gcnn_nd_detection" + + +# --------------------------------------------------------------- +# PSD feature extraction tests +# --------------------------------------------------------------- + +class TestPSDExtraction: + + def test_shape(self, task, synthetic_bipolar_window): + psd = task._extract_psd_features(synthetic_bipolar_window) + assert psd.shape == (NUM_CHANNELS, len(DEFAULT_BANDS)) + + def test_finite_values(self, task, synthetic_bipolar_window): + psd = task._extract_psd_features(synthetic_bipolar_window) + assert np.all(np.isfinite(psd)) + + def test_custom_bands_shape(self, synthetic_bipolar_window): + custom = {"alpha": (8.0, 12.0)} + task = EEGGCNNDiseaseDetection(bands=custom) + psd = task._extract_psd_features(synthetic_bipolar_window) + assert psd.shape == (NUM_CHANNELS, 1) + + +# --------------------------------------------------------------- +# Adjacency matrix tests +# --------------------------------------------------------------- + +class TestAdjacency: + + def test_spatial_shape(self, task): + adj = task._build_spatial_adjacency() + assert adj.shape == (NUM_CHANNELS, NUM_CHANNELS) + + def test_spatial_diagonal(self, task): + adj = task._build_spatial_adjacency() + np.testing.assert_array_equal(np.diag(adj), np.ones(NUM_CHANNELS)) + + def test_spatial_positive(self, task): + adj = task._build_spatial_adjacency() + assert np.all(adj >= 0) + + def test_none_is_identity(self, task_none, synthetic_bipolar_window): + adj = task_none._build_adjacency(synthetic_bipolar_window) + np.testing.assert_array_equal(adj, np.eye(NUM_CHANNELS)) + + def test_spatial_adjacency_type( + self, task_spatial, synthetic_bipolar_window + ): + adj = task_spatial._build_adjacency(synthetic_bipolar_window) + assert adj.shape == (NUM_CHANNELS, NUM_CHANNELS) + # Should NOT be identity — off-diagonal elements > 0 + assert np.any(adj[0, 1:] > 0) + + +# --------------------------------------------------------------- +# Bipolar channel computation tests +# --------------------------------------------------------------- + +class TestBipolarComputation: + + def test_compute_bipolar(self): + raw = _make_mock_raw(n_seconds=10, sfreq=250) + bipolar = EEGGCNNDiseaseDetection._compute_bipolar(raw) + assert bipolar.shape == (NUM_CHANNELS, 2500) + + def test_bipolar_is_difference(self): + raw = _make_mock_raw(n_seconds=1, sfreq=250) + data = raw.get_data() + ch_map = {name: idx for idx, name in enumerate(raw.ch_names)} + bipolar = EEGGCNNDiseaseDetection._compute_bipolar(raw) + + anode, cathode, _ = BIPOLAR_CHANNELS[0] + expected = data[ch_map[anode]] - data[ch_map[cathode]] + np.testing.assert_array_almost_equal(bipolar[0], expected) + + +# --------------------------------------------------------------- +# End-to-end __call__ test with mocked I/O +# --------------------------------------------------------------- + +class TestTaskCall: + + def test_call_produces_samples(self): + """Verify __call__ produces correct samples with mocked EEG I/O.""" + task = EEGGCNNDiseaseDetection( + adjacency_type="none", + window_sec=10, + ) + + mock_raw = _make_mock_raw(n_seconds=30, sfreq=250) + + # Build a mock patient with one TUAB event + event = MagicMock() + event.signal_file = "/fake/path.edf" + event.label = 0 + + patient = MagicMock() + patient.patient_id = "test_001" + patient.get_events.side_effect = ( + lambda table: [event] if table == "tuab" else [] + ) + + with patch.object(task, "_read_eeg", return_value=mock_raw), \ + patch.object(task, "_preprocess", return_value=mock_raw): + samples = task(patient) + + # 30s / 10s = 3 windows + assert len(samples) == 3 + for s in samples: + assert s["patient_id"] == "test_001" + assert isinstance(s["node_features"], torch.Tensor) + assert s["node_features"].shape == (8, 6) + assert isinstance(s["adj_matrix"], torch.Tensor) + assert s["adj_matrix"].shape == (8, 8) + assert s["label"] == 0 + + def test_call_skips_bad_file(self): + """Verify __call__ gracefully skips unreadable files.""" + task = EEGGCNNDiseaseDetection(adjacency_type="none") + + event = MagicMock() + event.signal_file = "/bad/path.edf" + event.label = 0 + + patient = MagicMock() + patient.patient_id = "test_002" + patient.get_events.side_effect = ( + lambda table: [event] if table == "tuab" else [] + ) + + with patch.object( + task, "_read_eeg", side_effect=ValueError("corrupt file") + ): + samples = task(patient) + + assert len(samples) == 0 + + def test_call_both_sources(self): + """Verify samples from both TUAB and LEMON are collected.""" + task = EEGGCNNDiseaseDetection( + adjacency_type="none", window_sec=10 + ) + + mock_raw = _make_mock_raw(n_seconds=10, sfreq=250) + + tuab_event = MagicMock() + tuab_event.signal_file = "/fake/tuab.edf" + tuab_event.label = 0 + + lemon_event = MagicMock() + lemon_event.signal_file = "/fake/lemon.vhdr" + lemon_event.label = 1 + + patient = MagicMock() + patient.patient_id = "test_003" + + def get_events(table): + if table == "tuab": + return [tuab_event] + elif table == "lemon": + return [lemon_event] + return [] + + patient.get_events.side_effect = get_events + + with patch.object(task, "_read_eeg", return_value=mock_raw), \ + patch.object(task, "_preprocess", return_value=mock_raw): + samples = task(patient) + + assert len(samples) == 2 + labels = {s["label"] for s in samples} + assert labels == {0, 1} + + +# --------------------------------------------------------------- +# Dataset metadata tests +# --------------------------------------------------------------- + +class TestDatasetMetadata: + + def test_prepare_tuab_csv(self): + """Verify TUAB CSV is generated from synthetic directory structure.""" + from pyhealth.datasets.eeg_gcnn import EEGGCNNDataset + + with tempfile.TemporaryDirectory() as tmpdir: + # Create synthetic TUAB directory structure + subj_dir = ( + Path(tmpdir) + / "train" + / "normal" + / "01_tcp_ar" + / "000" + / "00000001" + ) + subj_dir.mkdir(parents=True) + (subj_dir / "00000001_00000001_01.edf").touch() + (subj_dir / "00000001_00000002_01.edf").touch() + + # Instantiate dataset just to trigger prepare_metadata + # We can't fully instantiate BaseDataset without CSV content, + # so we test prepare_metadata directly. + ds = EEGGCNNDataset.__new__(EEGGCNNDataset) + ds.root = tmpdir + ds.prepare_metadata() + + csv_path = Path(tmpdir) / "eeg_gcnn-tuab-pyhealth.csv" + assert csv_path.exists() + + import pandas as pd + + df = pd.read_csv(csv_path) + assert len(df) == 2 + assert "patient_id" in df.columns + assert "signal_file" in df.columns + assert "label" in df.columns + assert all(df["label"] == 0) + assert all(df["source"] == "tuab") + + def test_prepare_lemon_csv(self): + """Verify LEMON CSV is generated from synthetic directory structure.""" + from pyhealth.datasets.eeg_gcnn import EEGGCNNDataset + + with tempfile.TemporaryDirectory() as tmpdir: + subj_dir = Path(tmpdir) / "lemon" / "sub-010002" + subj_dir.mkdir(parents=True) + (subj_dir / "sub-010002.vhdr").touch() + + ds = EEGGCNNDataset.__new__(EEGGCNNDataset) + ds.root = tmpdir + ds.prepare_metadata() + + csv_path = Path(tmpdir) / "eeg_gcnn-lemon-pyhealth.csv" + assert csv_path.exists() + + import pandas as pd + + df = pd.read_csv(csv_path) + assert len(df) == 1 + assert df.iloc[0]["label"] == 1 + assert df.iloc[0]["source"] == "lemon" + + +# --------------------------------------------------------------- +# Constants / module-level tests +# --------------------------------------------------------------- + +class TestConstants: + + def test_bipolar_channel_count(self): + assert len(BIPOLAR_CHANNELS) == 8 + + def test_default_bands_count(self): + assert len(DEFAULT_BANDS) == 6 + + def test_num_channels(self): + assert NUM_CHANNELS == 8 + + +# --------------------------------------------------------------- +# Functional and combined adjacency tests +# (mne_connectivity is mocked so no real EEG dependency) +# --------------------------------------------------------------- + +def _make_mock_mne_connectivity(n_channels: int = 8, n_freqs: int = 10): + """Return a (mock mne_connectivity module, mock connectivity result). + + ``conn.get_data(output="dense")`` returns shape + ``(n_channels, n_channels, n_freqs)`` with values in [0, 1]. + """ + rng = np.random.RandomState(7) + conn_data = rng.uniform(0.0, 1.0, (n_channels, n_channels, n_freqs)) + + mock_conn_result = MagicMock() + mock_conn_result.get_data.return_value = conn_data + + mock_mne_conn_module = MagicMock() + mock_mne_conn_module.spectral_connectivity_epochs.return_value = ( + mock_conn_result + ) + return mock_mne_conn_module, mock_conn_result + + +class TestFunctionalAdjacency: + """Tests for _build_functional_adjacency and combined adjacency mode. + + mne_connectivity is injected via sys.modules so the tests run in + milliseconds with no network or file I/O. + """ + + @pytest.fixture(autouse=True) + def _patch_mne_conn(self, synthetic_bipolar_window): + """Inject a mock mne_connectivity for every test in this class.""" + mock_module, self._mock_result = _make_mock_mne_connectivity() + with patch.dict(sys.modules, {"mne_connectivity": mock_module}): + self._mock_module = mock_module + yield + + # --- _build_functional_adjacency --- + + def test_functional_shape(self, synthetic_bipolar_window): + task = EEGGCNNDiseaseDetection(adjacency_type="functional") + adj = task._build_functional_adjacency(synthetic_bipolar_window) + assert adj.shape == (NUM_CHANNELS, NUM_CHANNELS) + + def test_functional_diagonal_is_one(self, synthetic_bipolar_window): + task = EEGGCNNDiseaseDetection(adjacency_type="functional") + adj = task._build_functional_adjacency(synthetic_bipolar_window) + np.testing.assert_array_equal(np.diag(adj), np.ones(NUM_CHANNELS)) + + def test_functional_symmetric(self, synthetic_bipolar_window): + task = EEGGCNNDiseaseDetection(adjacency_type="functional") + adj = task._build_functional_adjacency(synthetic_bipolar_window) + np.testing.assert_array_almost_equal(adj, adj.T) + + def test_functional_values_finite(self, synthetic_bipolar_window): + task = EEGGCNNDiseaseDetection(adjacency_type="functional") + adj = task._build_functional_adjacency(synthetic_bipolar_window) + assert np.all(np.isfinite(adj)) + + def test_functional_coherence_method_passed(self, synthetic_bipolar_window): + """spectral_connectivity_epochs must be called with method='coh'.""" + task = EEGGCNNDiseaseDetection( + adjacency_type="functional", connectivity_measure="coherence" + ) + task._build_functional_adjacency(synthetic_bipolar_window) + call_kwargs = ( + self._mock_module.spectral_connectivity_epochs.call_args[1] + ) + assert call_kwargs["method"] == "coh" + + def test_functional_wpli_method_passed(self, synthetic_bipolar_window): + """spectral_connectivity_epochs must be called with method='wpli'.""" + task = EEGGCNNDiseaseDetection( + adjacency_type="functional", connectivity_measure="wpli" + ) + task._build_functional_adjacency(synthetic_bipolar_window) + call_kwargs = ( + self._mock_module.spectral_connectivity_epochs.call_args[1] + ) + assert call_kwargs["method"] == "wpli" + + # --- combined adjacency via _build_adjacency --- + + def test_combined_shape(self, synthetic_bipolar_window): + task = EEGGCNNDiseaseDetection(adjacency_type="combined") + adj = task._build_adjacency(synthetic_bipolar_window) + assert adj.shape == (NUM_CHANNELS, NUM_CHANNELS) + + def test_combined_diagonal_is_one(self, synthetic_bipolar_window): + task = EEGGCNNDiseaseDetection(adjacency_type="combined") + adj = task._build_adjacency(synthetic_bipolar_window) + np.testing.assert_array_equal(np.diag(adj), np.ones(NUM_CHANNELS)) + + def test_combined_is_mean_of_spatial_and_functional( + self, synthetic_bipolar_window + ): + """Combined adjacency = (spatial + functional) / 2, diag set to 1.""" + task = EEGGCNNDiseaseDetection(adjacency_type="combined") + combined = task._build_adjacency(synthetic_bipolar_window) + + functional = task._build_functional_adjacency(synthetic_bipolar_window) + spatial = task._spatial_adj + + expected = (spatial + functional) / 2.0 + np.fill_diagonal(expected, 1.0) + np.testing.assert_array_almost_equal(combined, expected) + + def test_combined_off_diagonal_differs_from_identity( + self, synthetic_bipolar_window + ): + """Combined adjacency must not collapse to identity matrix.""" + task = EEGGCNNDiseaseDetection(adjacency_type="combined") + adj = task._build_adjacency(synthetic_bipolar_window) + off_diag = adj[~np.eye(NUM_CHANNELS, dtype=bool)] + assert np.any(off_diag != 0.0)