From 8801221a95931487ef8e58d81bfdcae9c81f2d8d Mon Sep 17 00:00:00 2001 From: prachipradhan Date: Mon, 20 Apr 2026 20:50:48 -0400 Subject: [PATCH 1/5] feat: add SleepWakeDetectionDREAMT and SleepStagingDREAMT tasks for DREAMT dataset with feature engineering, 17 unit tests, ablation example script, and docs --- docs/api/tasks.rst | 2 + .../pyhealth.tasks.sleep_wake_dreamt.rst | 12 + examples/dreamt_sleep_wake_detection.py | 219 +++++++++ pyhealth/tasks/__init__.py | 1 + pyhealth/tasks/sleep_wake_dreamt.py | 416 ++++++++++++++++++ tests/test_sleep_wake_dreamt.py | 246 +++++++++++ 6 files changed, 896 insertions(+) create mode 100644 docs/api/tasks/pyhealth.tasks.sleep_wake_dreamt.rst create mode 100644 examples/dreamt_sleep_wake_detection.py create mode 100644 pyhealth/tasks/sleep_wake_dreamt.py create mode 100644 tests/test_sleep_wake_dreamt.py diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..4b1442278 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -222,6 +222,7 @@ Available Tasks Sleep Staging (SleepEDF) Temple University EEG Tasks Sleep Staging v2 + Sleep/Wake Detection & Staging (DREAMT) Benchmark EHRShot ChestX-ray14 Binary Classification ChestX-ray14 Multilabel Classification @@ -229,3 +230,4 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + \ No newline at end of file diff --git a/docs/api/tasks/pyhealth.tasks.sleep_wake_dreamt.rst b/docs/api/tasks/pyhealth.tasks.sleep_wake_dreamt.rst new file mode 100644 index 000000000..a37be266c --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.sleep_wake_dreamt.rst @@ -0,0 +1,12 @@ +rstpyhealth.tasks.sleep\_wake\_dreamt +=================================== + +.. autoclass:: pyhealth.tasks.sleep_wake_dreamt.SleepWakeDetectionDREAMT + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.tasks.sleep_wake_dreamt.SleepStagingDREAMT + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/examples/dreamt_sleep_wake_detection.py b/examples/dreamt_sleep_wake_detection.py new file mode 100644 index 000000000..6bc329553 --- /dev/null +++ b/examples/dreamt_sleep_wake_detection.py @@ -0,0 +1,219 @@ +""" +Sleep/Wake Detection on DREAMT Dataset +======================================= +Reproduces the binary sleep/wake classification pipeline from: + + Wang et al. (2024). Addressing wearable sleep tracking inequity: + a new dataset and novel methods for a population with sleep disorders. + CHIL 2024, PMLR 248:380-396. + + Dataset: https://physionet.org/content/dreamt/2.1.0/ + Original code: https://github.com/WillKeWang/DREAMT_FE + +This example demonstrates: + 1. Loading the DREAMT dataset via PyHealth (or generating synthetic data) + 2. Applying the SleepWakeDetectionDREAMT task to generate epoch samples + 3. Training a baseline LightGBM classifier (epoch-by-epoch) + 4. Ablation: comparing LightGBM vs LightGBM + AHI vs LightGBM + BMI + +Requirements: + pip install lightgbm scikit-learn imbalanced-learn + +Usage: + # Run with synthetic demo data (no PhysioNet access needed): + python examples/dreamt_sleep_wake_detection.py --demo + + # Run with real DREAMT data from PhysioNet: + python examples/dreamt_sleep_wake_detection.py --root /path/to/dreamt/2.1.0 +""" + +import argparse +import numpy as np +from sklearn.model_selection import GroupShuffleSplit +from sklearn.metrics import ( + f1_score, + roc_auc_score, + cohen_kappa_score, + accuracy_score, +) + + +def make_synthetic_samples(n_patients: int = 5, n_epochs_per_patient: int = 40): + """Generate synthetic DREAMT-like epoch samples for demo/testing. + + Simulates the output of SleepWakeDetectionDREAMT.__call__() without + needing real PhysioNet data. Signal values are random floats; labels + are randomly assigned with a 75/25 sleep/wake split matching the + approximate class balance in the real dataset. + + Args: + n_patients: number of synthetic patients to generate + n_epochs_per_patient: number of 30-second epochs per patient + + Returns: + List of sample dicts matching SleepWakeDetectionDREAMT output format + """ + np.random.seed(42) + samples = [] + for i in range(n_patients): + pid = f"S{i+1:03d}" + ahi = np.random.uniform(5, 40) + bmi = np.random.uniform(25, 45) + for epoch_idx in range(n_epochs_per_patient): + label = 1 if np.random.random() < 0.25 else 0 + signal = np.random.randn(96).astype(np.float32) + samples.append({ + "patient_id": pid, + "epoch_index": epoch_idx, + "signal": signal.flatten(), + "ahi": ahi, + "bmi": bmi, + "label": label, + }) + return samples + + +# Helpers + +def samples_to_arrays(samples): + """Convert task samples into numpy arrays for sklearn. + + Args: + samples: list of dicts from SleepWakeDetectionDREAMT + + Returns: + X: feature matrix (n_samples, n_features) + y: binary labels (n_samples,) + groups: patient IDs for participant-level CV (n_samples,) + ahi: AHI values (n_samples,) + bmi: BMI values (n_samples,) + """ + X, y, groups, ahi, bmi = [], [], [], [], [] + for s in samples: + X.append(np.array(s["signal"]).flatten()) + y.append(s["label"]) + groups.append(s["patient_id"]) + ahi.append(s["ahi"]) + bmi.append(s["bmi"]) + return ( + np.array(X), + np.array(y), + np.array(groups), + np.array(ahi), + np.array(bmi), + ) + + +def evaluate(y_true, y_pred, y_prob, label=""): + """Print evaluation metrics matching paper Table 2.""" + print(f"\n{'─'*55}") + print(f" {label}") + print(f"{'─'*55}") + print(f" F1 Score : {f1_score(y_true, y_pred):.3f}") + print(f" AUROC : {roc_auc_score(y_true, y_prob):.3f}") + print(f" Accuracy : {accuracy_score(y_true, y_pred):.3f}") + print(f" Kappa : {cohen_kappa_score(y_true, y_pred):.3f}") + + +def main(root: str = None, demo: bool = False): + try: + import lightgbm as lgb + except ImportError: + print("lightgbm not installed. Run: pip install lightgbm") + return + + if demo: + print("\n[1/4] Generating synthetic DREAMT-like data (demo mode)...") + print(" To use real data: --root /path/to/dreamt/2.1.0") + all_samples = make_synthetic_samples( + n_patients=5, n_epochs_per_patient=40 + ) + else: + print("\n[1/4] Loading real DREAMT dataset from PhysioNet...") + from pyhealth.datasets import DREAMTDataset + from pyhealth.tasks import SleepWakeDetectionDREAMT + dataset = DREAMTDataset(root=root) + task = SleepWakeDetectionDREAMT() + all_samples = [] + for pid in dataset.unique_patient_ids: + patient = dataset.get_patient(pid) + all_samples.extend(task(patient)) + + print(f" Total epochs : {len(all_samples)}") + wake = sum(s["label"] == 1 for s in all_samples) + sleep = sum(s["label"] == 0 for s in all_samples) + print(f" Wake (1) : {wake} ({100*wake/len(all_samples):.1f}%)") + print(f" Sleep (0) : {sleep} ({100*sleep/len(all_samples):.1f}%)") + + print("\n[2/4] Extracting features...") + X, y, groups, ahi, bmi = samples_to_arrays(all_samples) + print(f" Feature matrix: {X.shape}") + + print("\n[3/4] Splitting by participant (80/20)...") + splitter = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=42) + train_idx, test_idx = next(splitter.split(X, y, groups)) + X_train, X_test = X[train_idx], X[test_idx] + y_train, y_test = y[train_idx], y[test_idx] + ahi_train, ahi_test = ahi[train_idx], ahi[test_idx] + bmi_train, bmi_test = bmi[train_idx], bmi[test_idx] + + # SMOTE balancing on training set (paper section 3.1) + try: + from imblearn.over_sampling import SMOTE + X_train, y_train = SMOTE(random_state=42).fit_resample( + X_train, y_train + ) + print(" Applied SMOTE balancing") + except ImportError: + print(" imbalanced-learn not installed, skipping SMOTE") + + # Ablation study + # Ablation A: Baseline LightGBM + clf_a = lgb.LGBMClassifier(n_estimators=200, random_state=42, verbose=-1) + clf_a.fit(X_train, y_train) + evaluate(y_test, clf_a.predict(X_test), + clf_a.predict_proba(X_test)[:, 1], + "Ablation A: Baseline LightGBM (no clinical metadata)") + + # Ablation B: LightGBM + AHI + X_train_b = np.hstack([ + X_train, + np.tile(ahi_train, (len(X_train)//len(ahi_train)+1))[:len(X_train)].reshape(-1, 1) + ]) + X_test_b = np.hstack([X_test, ahi_test.reshape(-1, 1)]) + clf_b = lgb.LGBMClassifier(n_estimators=200, random_state=42, verbose=-1) + clf_b.fit(X_train_b, y_train) + evaluate(y_test, clf_b.predict(X_test_b), + clf_b.predict_proba(X_test_b)[:, 1], + "Ablation B: LightGBM + AHI (apnea severity)") + + # Ablation C:LightGBM + BMI + X_train_c = np.hstack([ + X_train, + np.tile(bmi_train, (len(X_train)//len(bmi_train)+1))[:len(X_train)].reshape(-1, 1) + ]) + X_test_c = np.hstack([X_test, bmi_test.reshape(-1, 1)]) + clf_c = lgb.LGBMClassifier(n_estimators=200, random_state=42, verbose=-1) + clf_c.fit(X_train_c, y_train) + evaluate(y_test, clf_c.predict(X_test_c), + clf_c.predict_proba(X_test_c)[:, 1], + "Ablation C: LightGBM + BMI (obesity)") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="DREAMT sleep/wake detection — ablation study" + ) + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( + "--demo", + action="store_true", + help="Run with synthetic data (no PhysioNet access needed)", + ) + group.add_argument( + "--root", + type=str, + help="Path to real DREAMT dataset, e.g. /path/to/dreamt/2.1.0", + ) + args = parser.parse_args() + main(root=args.root, demo=args.demo) \ No newline at end of file diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..cf2c07966 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -66,3 +66,4 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .sleep_wake_dreamt import SleepWakeDetectionDREAMT, SleepStagingDREAMT diff --git a/pyhealth/tasks/sleep_wake_dreamt.py b/pyhealth/tasks/sleep_wake_dreamt.py new file mode 100644 index 000000000..c6d934858 --- /dev/null +++ b/pyhealth/tasks/sleep_wake_dreamt.py @@ -0,0 +1,416 @@ +from typing import Any, Dict, List, Optional +import numpy as np +import pandas as pd +from .base_task import BaseTask + +# Sampling frequency of the E4 after upsampling (paper sec 2.3) +SAMPLE_RATE_HZ = 64 +EPOCH_SEC = 30 +EPOCH_SAMPLES = SAMPLE_RATE_HZ * EPOCH_SEC # 1920 samples per 30-sec epoch + +# E4 signal columns used as features (paper sec 2.3) +SIGNAL_COLS = ["BVP", "ACC_X", "ACC_Y", "ACC_Z", "EDA", "TEMP", "HR"] + +# PSG sleep stage label constants +WAKE_LABEL = "W" +SLEEP_LABELS = {"R", "N1", "N2", "N3"} +MISSING_LABEL = "Missing" + +# Fine-grained label map for multi-class staging +FINE_LABEL_MAP = {"W": 0, "N1": 1, "N2": 2, "N3": 3, "R": 4} + +# Coarse label map for binary wake/sleep detection +BINARY_LABEL_MAP = {"W": 1, "N1": 0, "N2": 0, "N3": 0, "R": 0} + + +def _butter_bandpass( + signal: np.ndarray, + low: float, + high: float, + fs: int, + order: int = 5, +) -> np.ndarray: + """Apply a Butterworth bandpass filter to a 1D signal. + + Args: + signal: 1D numpy array of signal values + low: lower cutoff frequency in Hz + high: upper cutoff frequency in Hz + fs: sampling frequency in Hz + order: filter order + + Returns: + Filtered signal as numpy array + """ + try: + from scipy.signal import butter, filtfilt + nyq = fs / 2.0 + b, a = butter(order, [low / nyq, high / nyq], btype="band") + return filtfilt(b, a, signal).astype(np.float32) + except Exception: + return signal.astype(np.float32) + + +def _butter_lowpass( + signal: np.ndarray, + cutoff: float, + fs: int, + order: int = 4, +) -> np.ndarray: + """Apply a Butterworth lowpass filter to a 1D signal. + + Args: + signal: 1D numpy array of signal values + cutoff: cutoff frequency in Hz + fs: sampling frequency in Hz + order: filter order + + Returns: + Filtered signal as numpy array + """ + try: + from scipy.signal import butter, filtfilt + nyq = fs / 2.0 + b, a = butter(order, cutoff / nyq, btype="low") + return filtfilt(b, a, signal).astype(np.float32) + except Exception: + return signal.astype(np.float32) + + +def _segment_detrend( + signal: np.ndarray, + segment_seconds: int = 5, + fs: int = 64, +) -> np.ndarray: + """Detrend signal by subtracting least-squares line from each segment. + + Follows EDA preprocessing from paper section 2.5. + + Args: + signal: 1D numpy array of EDA values + segment_seconds: length of each detrending segment in seconds + fs: sampling frequency in Hz + + Returns: + Detrended signal as numpy array + """ + seg_len = segment_seconds * fs + out = signal.copy().astype(np.float32) + for start in range(0, len(signal), seg_len): + seg = signal[start: start + seg_len] + if len(seg) < 2: + continue + x = np.arange(len(seg), dtype=np.float32) + coeffs = np.polyfit(x, seg, 1) + out[start: start + len(seg)] = seg - np.polyval(coeffs, x) + return out + + +def extract_epoch_features( + epoch_df: pd.DataFrame, + fs: int = SAMPLE_RATE_HZ, +) -> np.ndarray: + """Extract statistical and signal-processing features from one 30-sec epoch. + + Implements the feature engineering described in paper section 2.5: + - ACC: bandpass filtered (3-11 Hz), trimmed mean, max, IQR, MAD per axis + - TEMP: winsorized to 31-40C, mean, min, max, std + - BVP: bandpass filtered (0.5-20 Hz), basic HRV-proxy stats + - EDA: detrended, lowpass filtered, mean and std of phasic component + - HR: mean and std + + Args: + epoch_df: DataFrame slice for one 30-second epoch containing + BVP, ACC_X, ACC_Y, ACC_Z, EDA, TEMP, HR columns + fs: sampling frequency in Hz (default 64) + + Returns: + np.ndarray of shape (n_features,) — float32 feature vector + """ + feats: List[float] = [] + + def safe_col(col: str) -> np.ndarray: + if col in epoch_df.columns: + return epoch_df[col].to_numpy(dtype=np.float32) + return np.zeros(len(epoch_df), dtype=np.float32) + + def summary_stats(x: np.ndarray) -> List[float]: + """Mean, std, min, max of array.""" + return [ + float(np.mean(x)), + float(np.std(x)), + float(np.min(x)), + float(np.max(x)), + ] + + def trimmed_stats(x: np.ndarray) -> List[float]: + """Trimmed mean (10%), max, IQR of absolute values.""" + if len(x) == 0: + return [0.0, 0.0, 0.0] + n = max(1, int(0.1 * len(x))) + sorted_x = np.sort(np.abs(x)) + trimmed = sorted_x[n:-n] if len(sorted_x) > 2 * n else sorted_x + return [ + float(np.mean(trimmed)), + float(np.max(np.abs(x))), + float(np.percentile(np.abs(x), 75) - np.percentile(np.abs(x), 25)), + ] + + # ── ACC features (paper sec 2.5) ───────────────────────────────────────── + for col in ["ACC_X", "ACC_Y", "ACC_Z"]: + raw = safe_col(col) + # Bandpass filter 3-11 Hz (Oura method cited in paper) + filtered = _butter_bandpass(raw, low=3.0, high=11.0, fs=fs, order=5) + feats.extend(trimmed_stats(filtered)) + # MAD from vector magnitude + mag = np.sqrt(np.mean(raw ** 2)) + feats.append(float(np.mean(np.abs(raw - mag)))) + + # ── TEMP features (paper sec 2.5) ──────────────────────────────────────── + temp = np.clip(safe_col("TEMP"), 31.0, 40.0) # winsorize to 31-40C + feats.extend(summary_stats(temp)) + + # ── BVP / HRV features (paper sec 2.5) ─────────────────────────────────── + bvp = safe_col("BVP") + bvp_filt = _butter_bandpass(bvp, low=0.5, high=20.0, fs=fs, order=4) + feats.extend(summary_stats(bvp_filt)) + + # ── EDA features (paper sec 2.5) ───────────────────────────────────────── + eda = safe_col("EDA") + eda_detrended = _segment_detrend(eda, segment_seconds=5, fs=fs) + eda_filtered = _butter_lowpass(eda_detrended, cutoff=1.0, fs=fs, order=4) + feats.extend([float(np.mean(eda_filtered)), float(np.std(eda_filtered))]) + + # ── HR features ────────────────────────────────────────────────────────── + hr = safe_col("HR") + feats.extend([float(np.mean(hr)), float(np.std(hr))]) + + return np.array(feats, dtype=np.float32) + + +class SleepWakeDetectionDREAMT(BaseTask): + """Binary sleep/wake detection task for the DREAMT dataset. + + Based on: Wang et al. (2024). Addressing wearable sleep tracking inequity: + a new dataset and novel methods for a population with sleep disorders. + CHIL 2024, PMLR 248:380-396. + + This task processes overnight Empatica E4 wearable recordings from patients + with sleep disorders. Each night is sliced into non-overlapping 30-second + epochs. Each epoch is labeled as Wake (1) or Sleep (0) based on concurrent + PSG annotations. Clinical metadata (AHI, BMI) is included per epoch to + support mixed-effects modeling as described in paper section 2.6. + + Feature engineering follows paper section 2.5: + - ACC: Butterworth bandpass (3-11 Hz), trimmed mean, max, IQR, MAD + - TEMP: winsorized (31-40C), mean, std, min, max + - BVP: Chebyshev bandpass (0.5-20 Hz), summary stats + - EDA: segment detrended, lowpass filtered, mean and std + - HR: mean and std + + Attributes: + task_name (str): "SleepWakeDetectionDREAMT" + input_schema (Dict[str, str]): Input features per epoch: + - "signal": float feature vector from extract_epoch_features() + - "ahi": float Apnea-Hypopnea Index (random effect) + - "bmi": float Body Mass Index (random effect) + output_schema (Dict[str, str]): Binary label: + - "label": 1 = Wake, 0 = Sleep + + Examples: + >>> from pyhealth.datasets import DREAMTDataset + >>> from pyhealth.tasks import SleepWakeDetectionDREAMT + >>> dataset = DREAMTDataset(root="/path/to/dreamt/2.1.0") + >>> task = SleepWakeDetectionDREAMT() + >>> samples = dataset.set_task(task) + >>> samples[0] + { + 'patient_id': 'S002', + 'epoch_index': 0, + 'signal': array of shape (n_features,), + 'ahi': 22.1, + 'bmi': 33.7, + 'label': 0 + } + """ + + task_name: str = "SleepWakeDetectionDREAMT" + + input_schema: Dict[str, str] = { + # Engineered feature vector per 30-sec epoch + "signal": "float", + # Clinical metadata for mixed-effects modeling (paper sec 2.6) + "ahi": "float", + "bmi": "float", + } + + output_schema: Dict[str, str] = { + # Binary classification: 1 = Wake, 0 = Sleep + "label": "binary", + } + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process one DREAMT patient into a list of 30-second epoch samples. + + Args: + patient: A patient object from DREAMTDataset. Must have a + dreamt_sleep event containing file_64hz, ahi, and bmi. + + Returns: + List of dicts, one per valid epoch. Each dict contains: + - patient_id (str): participant identifier + - epoch_index (int): index of this epoch in the night + - signal (np.ndarray): float32 feature vector + - ahi (float): Apnea-Hypopnea Index + - bmi (float): Body Mass Index + - label (int): 1 = Wake, 0 = Sleep + """ + return _process_patient( + patient=patient, + label_map=BINARY_LABEL_MAP, + ) + + +class SleepStagingDREAMT(BaseTask): + """Multi-class sleep staging task for the DREAMT dataset. + + Based on: Wang et al. (2024). Addressing wearable sleep tracking inequity: + a new dataset and novel methods for a population with sleep disorders. + CHIL 2024, PMLR 248:380-396. + + This task extends SleepWakeDetectionDREAMT to predict all five PSG-derived + sleep stages: Wake (0), N1 (1), N2 (2), N3 (3), REM (4). Feature + engineering follows the same pipeline as SleepWakeDetectionDREAMT. + + Attributes: + task_name (str): "SleepStagingDREAMT" + input_schema (Dict[str, str]): Input features per epoch: + - "signal": float feature vector from extract_epoch_features() + - "ahi": float Apnea-Hypopnea Index + - "bmi": float Body Mass Index + output_schema (Dict[str, str]): Multi-class label: + - "label": 0=Wake, 1=N1, 2=N2, 3=N3, 4=REM + + Examples: + >>> from pyhealth.datasets import DREAMTDataset + >>> from pyhealth.tasks import SleepStagingDREAMT + >>> dataset = DREAMTDataset(root="/path/to/dreamt/2.1.0") + >>> task = SleepStagingDREAMT() + >>> samples = dataset.set_task(task) + >>> samples[0] + { + 'patient_id': 'S002', + 'epoch_index': 0, + 'signal': array of shape (n_features,), + 'ahi': 22.1, + 'bmi': 33.7, + 'label': 2 + } + """ + + task_name: str = "SleepStagingDREAMT" + + input_schema: Dict[str, str] = { + "signal": "float", + "ahi": "float", + "bmi": "float", + } + + output_schema: Dict[str, str] = { + # Multi-class: 0=Wake, 1=N1, 2=N2, 3=N3, 4=REM + "label": "multiclass", + } + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process one DREAMT patient into multi-class sleep stage samples. + + Args: + patient: A patient object from DREAMTDataset. + + Returns: + List of dicts, one per valid epoch, with label in {0,1,2,3,4}. + """ + return _process_patient( + patient=patient, + label_map=FINE_LABEL_MAP, + ) + + +def _process_patient( + patient: Any, + label_map: Dict[str, int], +) -> List[Dict[str, Any]]: + """Shared processing logic for both task classes. + + Reads the patient's overnight E4 CSV, slices into 30-second epochs, + extracts features, and maps PSG labels using the provided label_map. + + Args: + patient: A patient object from DREAMTDataset + label_map: dict mapping PSG stage strings to integer labels. + Use BINARY_LABEL_MAP for wake/sleep or FINE_LABEL_MAP for + 5-class staging. + + Returns: + List of sample dicts, one per valid labeled epoch. + """ + samples = [] + + events = patient.get_events(event_type="dreamt_sleep") + if not events: + return samples + event = events[0] + + file_path = event.file_64hz + if file_path is None: + return samples + + try: + ahi = float(event.ahi) if event.ahi is not None else 0.0 + bmi = float(event.bmi) if event.bmi is not None else 0.0 + except (ValueError, TypeError): + ahi, bmi = 0.0, 0.0 + + try: + df = pd.read_csv(file_path) + except Exception: + return samples + + # Verify required columns + required = SIGNAL_COLS + ["Sleep_Stage"] + if any(c not in df.columns for c in required): + return samples + + n_epochs = len(df) // EPOCH_SAMPLES + + for epoch_idx in range(n_epochs): + start = epoch_idx * EPOCH_SAMPLES + end = start + EPOCH_SAMPLES + epoch_df = df.iloc[start:end] + + # Get PSG label for this epoch + stage = epoch_df["Sleep_Stage"].iloc[-1] + if stage == MISSING_LABEL or pd.isna(stage): + continue + + label = label_map.get(str(stage).strip()) + if label is None: + continue + + # Extract engineered features + signal = extract_epoch_features(epoch_df) + + # Skip epochs with invalid features + if np.isnan(signal).any() or np.isinf(signal).any(): + continue + + samples.append({ + "patient_id": patient.patient_id, + "epoch_index": epoch_idx, + "signal": signal, + "ahi": ahi, + "bmi": bmi, + "label": label, + }) + + return samples \ No newline at end of file diff --git a/tests/test_sleep_wake_dreamt.py b/tests/test_sleep_wake_dreamt.py new file mode 100644 index 000000000..7a8ad1c20 --- /dev/null +++ b/tests/test_sleep_wake_dreamt.py @@ -0,0 +1,246 @@ +import numpy as np +import pytest +import tempfile +import os +from unittest.mock import MagicMock, patch +from pyhealth.tasks.sleep_wake_dreamt import ( + SleepWakeDetectionDREAMT, + EPOCH_SAMPLES, + SIGNAL_COLS, + WAKE_LABEL, +) + +def make_fake_patient( + patient_id: str = "S001", + ahi: float = 22.1, + bmi: float = 33.7, + n_epochs: int = 3, + has_file: bool = True, + sleep_stages: list = None, + tmp_dir: str = None, +): + """Creates a mock DREAMT patient object for testing.""" + if sleep_stages is None: + sleep_stages = ["N2", "W", "N2"] + + # Build a fake overnight DataFrame + n_rows = n_epochs * EPOCH_SAMPLES + data = {col: np.random.randn(n_rows).astype(np.float32) + for col in SIGNAL_COLS} + + # Assign sleep stage label at last row of each epoch + stages = np.full(n_rows, "N2", dtype=object) + for i, stage in enumerate(sleep_stages): + stages[(i + 1) * EPOCH_SAMPLES - 1] = stage + data["Sleep_Stage"] = stages + + import pandas as pd + fake_df = pd.DataFrame(data) + + # Write to a real temp file if tmp_dir provided + if tmp_dir and has_file: + file_path = os.path.join(tmp_dir, f"{patient_id}_whole_df.csv") + fake_df.to_csv(file_path, index=False) + else: + file_path = "/fake/path/S001_whole_df.csv" if has_file else None + + # Mock event + event = MagicMock() + event.file_64hz = file_path + event.ahi = ahi + event.bmi = bmi + + # Mock patient + patient = MagicMock() + patient.patient_id = patient_id + patient.get_events.return_value = [event] + + return patient, fake_df + + +class TestSleepWakeDetectionDREAMT: + + def test_instantiation(self): + """Task can be instantiated.""" + task = SleepWakeDetectionDREAMT() + assert task.task_name == "SleepWakeDetectionDREAMT" + + def test_schema_defined(self): + """Input and output schemas are defined correctly.""" + task = SleepWakeDetectionDREAMT() + assert "signal" in task.input_schema + assert "ahi" in task.input_schema + assert "bmi" in task.input_schema + assert "label" in task.output_schema + + def test_returns_correct_number_of_epochs(self): + """Task returns one sample per valid epoch.""" + task = SleepWakeDetectionDREAMT() + patient, fake_df = make_fake_patient( + n_epochs=3, + sleep_stages=["N2", "W", "N1"] + ) + with patch("pandas.read_csv", return_value=fake_df): + samples = task(patient) + assert len(samples) == 3 + + def test_wake_label_is_1(self): + """Wake epochs are labeled 1.""" + task = SleepWakeDetectionDREAMT() + patient, fake_df = make_fake_patient( + n_epochs=2, + sleep_stages=["W", "N2"] + ) + with patch("pandas.read_csv", return_value=fake_df): + samples = task(patient) + labels = [s["label"] for s in samples] + assert labels[0] == 1 + assert labels[1] == 0 + + def test_sleep_label_is_0(self): + """All NREM and REM epochs are labeled 0.""" + task = SleepWakeDetectionDREAMT() + patient, fake_df = make_fake_patient( + n_epochs=4, + sleep_stages=["N1", "N2", "N3", "R"] + ) + with patch("pandas.read_csv", return_value=fake_df): + samples = task(patient) + assert all(s["label"] == 0 for s in samples) + + def test_missing_epochs_are_skipped(self): + """Epochs labeled Missing are excluded from samples.""" + task = SleepWakeDetectionDREAMT() + patient, fake_df = make_fake_patient( + n_epochs=3, + sleep_stages=["N2", "Missing", "W"] + ) + with patch("pandas.read_csv", return_value=fake_df): + samples = task(patient) + assert len(samples) == 2 + + def test_signal_shape(self): + """Each epoch signal is a 1D engineered feature vector.""" + task = SleepWakeDetectionDREAMT() + patient, fake_df = make_fake_patient(n_epochs=2, + sleep_stages=["N2", "W"]) + with patch("pandas.read_csv", return_value=fake_df): + samples = task(patient) + for s in samples: + # Signal is now an engineered feature vector, not raw signal + assert isinstance(s["signal"], np.ndarray) + assert s["signal"].ndim == 1 + assert len(s["signal"]) > 0 + + def test_clinical_metadata_attached(self): + """AHI and BMI are correctly attached to each sample.""" + task = SleepWakeDetectionDREAMT() + patient, fake_df = make_fake_patient( + ahi=15.5, bmi=28.3, + n_epochs=2, sleep_stages=["N2", "W"] + ) + with patch("pandas.read_csv", return_value=fake_df): + samples = task(patient) + for s in samples: + assert s["ahi"] == pytest.approx(15.5) + assert s["bmi"] == pytest.approx(28.3) + + def test_no_file_returns_empty(self): + """Returns empty list when no file path is available.""" + task = SleepWakeDetectionDREAMT() + patient, fake_df = make_fake_patient(has_file=False) + samples = task(patient) + assert samples == [] + + def test_no_events_returns_empty(self): + """Returns empty list when patient has no events.""" + task = SleepWakeDetectionDREAMT() + patient = MagicMock() + patient.patient_id = "S999" + patient.get_events.return_value = [] + samples = task(patient) + assert samples == [] + + def test_patient_id_in_samples(self): + """Patient ID is correctly propagated to each sample.""" + task = SleepWakeDetectionDREAMT() + patient, fake_df = make_fake_patient( + patient_id="S042", + n_epochs=2, + sleep_stages=["W", "N2"] + ) + with patch("pandas.read_csv", return_value=fake_df): + samples = task(patient) + assert all(s["patient_id"] == "S042" for s in samples) + + def test_uses_temp_directory(self): + """Task correctly reads from a real temporary file.""" + task = SleepWakeDetectionDREAMT() + with tempfile.TemporaryDirectory() as tmp_dir: + patient, _ = make_fake_patient( + patient_id="S001", + n_epochs=3, + sleep_stages=["N2", "W", "N1"], + tmp_dir=tmp_dir, + ) + # No mock needed — reads real temp file + samples = task(patient) + assert len(samples) == 3 + assert all("signal" in s for s in samples) + +class TestSleepStagingDREAMT: + + def test_instantiation(self): + """SleepStagingDREAMT can be instantiated.""" + from pyhealth.tasks.sleep_wake_dreamt import SleepStagingDREAMT + task = SleepStagingDREAMT() + assert task.task_name == "SleepStagingDREAMT" + + def test_schema_defined(self): + """Input and output schemas are defined correctly.""" + from pyhealth.tasks.sleep_wake_dreamt import SleepStagingDREAMT + task = SleepStagingDREAMT() + assert "signal" in task.input_schema + assert "label" in task.output_schema + assert task.output_schema["label"] == "multiclass" + + def test_fine_labels(self): + """Each sleep stage maps to correct integer label.""" + from pyhealth.tasks.sleep_wake_dreamt import SleepStagingDREAMT + task = SleepStagingDREAMT() + patient, fake_df = make_fake_patient( + n_epochs=5, + sleep_stages=["W", "N1", "N2", "N3", "R"] + ) + with patch("pandas.read_csv", return_value=fake_df): + samples = task(patient) + labels = [s["label"] for s in samples] + assert labels == [0, 1, 2, 3, 4] + + def test_missing_skipped(self): + """Missing epochs are skipped in multi-class task too.""" + from pyhealth.tasks.sleep_wake_dreamt import SleepStagingDREAMT + task = SleepStagingDREAMT() + patient, fake_df = make_fake_patient( + n_epochs=3, + sleep_stages=["W", "Missing", "R"] + ) + with patch("pandas.read_csv", return_value=fake_df): + samples = task(patient) + assert len(samples) == 2 + + def test_signal_is_feature_vector(self): + """Signal output is a 1D engineered feature vector.""" + from pyhealth.tasks.sleep_wake_dreamt import SleepStagingDREAMT + task = SleepStagingDREAMT() + patient, fake_df = make_fake_patient( + n_epochs=2, + sleep_stages=["N2", "W"] + ) + with patch("pandas.read_csv", return_value=fake_df): + samples = task(patient) + for s in samples: + assert isinstance(s["signal"], np.ndarray) + assert s["signal"].ndim == 1 + assert len(s["signal"]) > 0 + \ No newline at end of file From 56e62c525be9305cb972986b783a6810b4975b1a Mon Sep 17 00:00:00 2001 From: yyang2002 Date: Wed, 22 Apr 2026 19:50:51 -0500 Subject: [PATCH 2/5] Fix minor issues in task, add test cases, and improve code quality --- .../pyhealth.tasks.sleep_wake_dreamt.rst | 2 +- examples/dreamt_sleep_wake_detection.py | 68 +++++++-------- pyhealth/tasks/sleep_wake_dreamt.py | 86 +++++++++++++++---- tests/test_sleep_wake_dreamt.py | 58 ++++++++++--- 4 files changed, 146 insertions(+), 68 deletions(-) diff --git a/docs/api/tasks/pyhealth.tasks.sleep_wake_dreamt.rst b/docs/api/tasks/pyhealth.tasks.sleep_wake_dreamt.rst index a37be266c..979c5d75e 100644 --- a/docs/api/tasks/pyhealth.tasks.sleep_wake_dreamt.rst +++ b/docs/api/tasks/pyhealth.tasks.sleep_wake_dreamt.rst @@ -1,4 +1,4 @@ -rstpyhealth.tasks.sleep\_wake\_dreamt +pyhealth.tasks.sleep_wake_dreamt =================================== .. autoclass:: pyhealth.tasks.sleep_wake_dreamt.SleepWakeDetectionDREAMT diff --git a/examples/dreamt_sleep_wake_detection.py b/examples/dreamt_sleep_wake_detection.py index 6bc329553..92762ac08 100644 --- a/examples/dreamt_sleep_wake_detection.py +++ b/examples/dreamt_sleep_wake_detection.py @@ -29,13 +29,12 @@ import argparse import numpy as np +import lightgbm as lgb +from imblearn.over_sampling import SMOTE from sklearn.model_selection import GroupShuffleSplit -from sklearn.metrics import ( - f1_score, - roc_auc_score, - cohen_kappa_score, - accuracy_score, -) + +from pyhealth.datasets import DREAMTDataset +from pyhealth.metrics import binary_metrics_fn def make_synthetic_samples(n_patients: int = 5, n_epochs_per_patient: int = 40): @@ -56,14 +55,14 @@ def make_synthetic_samples(n_patients: int = 5, n_epochs_per_patient: int = 40): np.random.seed(42) samples = [] for i in range(n_patients): - pid = f"S{i+1:03d}" + sid = f"S{i+1:03d}" ahi = np.random.uniform(5, 40) bmi = np.random.uniform(25, 45) for epoch_idx in range(n_epochs_per_patient): label = 1 if np.random.random() < 0.25 else 0 signal = np.random.randn(96).astype(np.float32) samples.append({ - "patient_id": pid, + "patient_id": sid, "epoch_index": epoch_idx, "signal": signal.flatten(), "ahi": ahi, @@ -73,8 +72,6 @@ def make_synthetic_samples(n_patients: int = 5, n_epochs_per_patient: int = 40): return samples -# Helpers - def samples_to_arrays(samples): """Convert task samples into numpy arrays for sklearn. @@ -105,39 +102,34 @@ def samples_to_arrays(samples): def evaluate(y_true, y_pred, y_prob, label=""): - """Print evaluation metrics matching paper Table 2.""" - print(f"\n{'─'*55}") - print(f" {label}") - print(f"{'─'*55}") - print(f" F1 Score : {f1_score(y_true, y_pred):.3f}") - print(f" AUROC : {roc_auc_score(y_true, y_prob):.3f}") - print(f" Accuracy : {accuracy_score(y_true, y_pred):.3f}") - print(f" Kappa : {cohen_kappa_score(y_true, y_pred):.3f}") + """Print binary classification metrics using PyHealth metrics.""" + metrics = binary_metrics_fn( + y_true, + y_prob, + metrics=["f1", "roc_auc", "pr_auc", "accuracy", "cohen_kappa"], + ) + print(f"\n{'─' * 55}") + print(label) + print(f"\n{'─' * 55}") + for k, v in metrics.items(): + print(f"{k:20s}: {v:.3f}") def main(root: str = None, demo: bool = False): - try: - import lightgbm as lgb - except ImportError: - print("lightgbm not installed. Run: pip install lightgbm") - return - + if lgb is None: + raise ImportError("lightgbm not installed. Run: pip install lightgbm") + print("\n[1/4] Loading data...") if demo: - print("\n[1/4] Generating synthetic DREAMT-like data (demo mode)...") + print("Using synthetic data") print(" To use real data: --root /path/to/dreamt/2.1.0") all_samples = make_synthetic_samples( n_patients=5, n_epochs_per_patient=40 ) else: - print("\n[1/4] Loading real DREAMT dataset from PhysioNet...") - from pyhealth.datasets import DREAMTDataset - from pyhealth.tasks import SleepWakeDetectionDREAMT + print("\nUsing real DREAMT dataset from PhysioNet...") dataset = DREAMTDataset(root=root) - task = SleepWakeDetectionDREAMT() - all_samples = [] - for pid in dataset.unique_patient_ids: - patient = dataset.get_patient(pid) - all_samples.extend(task(patient)) + task_dataset = dataset.set_task(SleepWakeDetectionDREAMT()) + all_samples = task_dataset.samples print(f" Total epochs : {len(all_samples)}") wake = sum(s["label"] == 1 for s in all_samples) @@ -159,7 +151,6 @@ def main(root: str = None, demo: bool = False): # SMOTE balancing on training set (paper section 3.1) try: - from imblearn.over_sampling import SMOTE X_train, y_train = SMOTE(random_state=42).fit_resample( X_train, y_train ) @@ -173,7 +164,8 @@ def main(root: str = None, demo: bool = False): clf_a.fit(X_train, y_train) evaluate(y_test, clf_a.predict(X_test), clf_a.predict_proba(X_test)[:, 1], - "Ablation A: Baseline LightGBM (no clinical metadata)") + "Ablation A: Baseline LightGBM (no clinical metadata)" + ) # Ablation B: LightGBM + AHI X_train_b = np.hstack([ @@ -185,7 +177,8 @@ def main(root: str = None, demo: bool = False): clf_b.fit(X_train_b, y_train) evaluate(y_test, clf_b.predict(X_test_b), clf_b.predict_proba(X_test_b)[:, 1], - "Ablation B: LightGBM + AHI (apnea severity)") + "Ablation B: LightGBM + AHI (apnea severity)" + ) # Ablation C:LightGBM + BMI X_train_c = np.hstack([ @@ -197,7 +190,8 @@ def main(root: str = None, demo: bool = False): clf_c.fit(X_train_c, y_train) evaluate(y_test, clf_c.predict(X_test_c), clf_c.predict_proba(X_test_c)[:, 1], - "Ablation C: LightGBM + BMI (obesity)") + "Ablation C: LightGBM + BMI (obesity)" + ) if __name__ == "__main__": diff --git a/pyhealth/tasks/sleep_wake_dreamt.py b/pyhealth/tasks/sleep_wake_dreamt.py index c6d934858..82773499e 100644 --- a/pyhealth/tasks/sleep_wake_dreamt.py +++ b/pyhealth/tasks/sleep_wake_dreamt.py @@ -1,8 +1,14 @@ +import logging from typing import Any, Dict, List, Optional + import numpy as np import pandas as pd +from scipy.signal import butter, cheby2, filtfilt + from .base_task import BaseTask +logger = logging.getLogger(__name__) + # Sampling frequency of the E4 after upsampling (paper sec 2.3) SAMPLE_RATE_HZ = 64 EPOCH_SEC = 30 @@ -43,11 +49,41 @@ def _butter_bandpass( Filtered signal as numpy array """ try: - from scipy.signal import butter, filtfilt nyq = fs / 2.0 b, a = butter(order, [low / nyq, high / nyq], btype="band") return filtfilt(b, a, signal).astype(np.float32) - except Exception: + except Exception as e: + logger.debug(f"_butter_bandpass failed: {e}") + return signal.astype(np.float32) + + +def _cheby_bandpass( + signal: np.ndarray, + low: float, + high: float, + fs: int, + order: int = 4, + rs: float = 20.0, +) -> np.ndarray: + """Apply a Chebyshev type II bandpass filter to a 1D signal. + + Args: + signal: 1D numpy array of signal values + low: lower cutoff frequency in Hz + high: upper cutoff frequency in Hz + fs: sampling frequency in Hz + order: filter order + rs: Minimum stopband attenuation in dB. + + Returns: + Filtered signal as numpy array + """ + try: + nyq = fs / 2.0 + b, a = cheby2(order, rs, [low / nyq, high / nyq], btype="bandpass") + return filtfilt(b, a, signal).astype(np.float32) + except Exception as e: + logger.debug(f"_cheby_bandpass failed: {e}") return signal.astype(np.float32) @@ -69,11 +105,11 @@ def _butter_lowpass( Filtered signal as numpy array """ try: - from scipy.signal import butter, filtfilt nyq = fs / 2.0 b, a = butter(order, cutoff / nyq, btype="low") return filtfilt(b, a, signal).astype(np.float32) - except Exception: + except Exception as e: + logger.debug(f"_butter_lowpass failed: {e}") return signal.astype(np.float32) @@ -172,7 +208,13 @@ def trimmed_stats(x: np.ndarray) -> List[float]: # ── BVP / HRV features (paper sec 2.5) ─────────────────────────────────── bvp = safe_col("BVP") - bvp_filt = _butter_bandpass(bvp, low=0.5, high=20.0, fs=fs, order=4) + bvp_filt = _cheby_bandpass( + bvp, + low=0.5, + high=20.0, + fs=fs, + order=4, + rs=20.0) feats.extend(summary_stats(bvp_filt)) # ── EDA features (paper sec 2.5) ───────────────────────────────────────── @@ -343,7 +385,8 @@ def _process_patient( """Shared processing logic for both task classes. Reads the patient's overnight E4 CSV, slices into 30-second epochs, - extracts features, and maps PSG labels using the provided label_map. + performs signal quality filtering, extracts features, and maps PSG + labels using the provided label_map. Args: patient: A patient object from DREAMTDataset @@ -354,7 +397,7 @@ def _process_patient( Returns: List of sample dicts, one per valid labeled epoch. """ - samples = [] + samples: List[Dict[str, Any]] = [] events = patient.get_events(event_type="dreamt_sleep") if not events: @@ -368,17 +411,24 @@ def _process_patient( try: ahi = float(event.ahi) if event.ahi is not None else 0.0 bmi = float(event.bmi) if event.bmi is not None else 0.0 - except (ValueError, TypeError): + except (ValueError, TypeError) as e: + logger.warning( + f"Invalid AHI/BMI for patient {patient.patient_id}: {e}" + ) ahi, bmi = 0.0, 0.0 try: df = pd.read_csv(file_path) - except Exception: + except Exception as e: + logger.warning(f"Failed to read file {file_path}: {e}") return samples # Verify required columns required = SIGNAL_COLS + ["Sleep_Stage"] if any(c not in df.columns for c in required): + logger.warning( + f"Missing required columns in file {file_path}" + ) return samples n_epochs = len(df) // EPOCH_SAMPLES @@ -404,13 +454,15 @@ def _process_patient( if np.isnan(signal).any() or np.isinf(signal).any(): continue - samples.append({ - "patient_id": patient.patient_id, - "epoch_index": epoch_idx, - "signal": signal, - "ahi": ahi, - "bmi": bmi, - "label": label, - }) + samples.append( + { + "patient_id": patient.patient_id, + "epoch_index": epoch_idx, + "signal": signal, + "ahi": ahi, + "bmi": bmi, + "label": label, + } + ) return samples \ No newline at end of file diff --git a/tests/test_sleep_wake_dreamt.py b/tests/test_sleep_wake_dreamt.py index 7a8ad1c20..1fcb1abcc 100644 --- a/tests/test_sleep_wake_dreamt.py +++ b/tests/test_sleep_wake_dreamt.py @@ -1,13 +1,17 @@ -import numpy as np -import pytest -import tempfile import os +import tempfile from unittest.mock import MagicMock, patch + +import numpy as np +import pandas as pd +import pytest + from pyhealth.tasks.sleep_wake_dreamt import ( - SleepWakeDetectionDREAMT, EPOCH_SAMPLES, SIGNAL_COLS, - WAKE_LABEL, + SleepStagingDREAMT, + SleepWakeDetectionDREAMT, + extract_epoch_features, ) def make_fake_patient( @@ -34,7 +38,6 @@ def make_fake_patient( stages[(i + 1) * EPOCH_SAMPLES - 1] = stage data["Sleep_Stage"] = stages - import pandas as pd fake_df = pd.DataFrame(data) # Write to a real temp file if tmp_dir provided @@ -122,8 +125,10 @@ def test_missing_epochs_are_skipped(self): def test_signal_shape(self): """Each epoch signal is a 1D engineered feature vector.""" task = SleepWakeDetectionDREAMT() - patient, fake_df = make_fake_patient(n_epochs=2, - sleep_stages=["N2", "W"]) + patient, fake_df = make_fake_patient( + n_epochs=2, + sleep_stages=["N2", "W"] + ) with patch("pandas.read_csv", return_value=fake_df): samples = task(patient) for s in samples: @@ -187,18 +192,48 @@ def test_uses_temp_directory(self): samples = task(patient) assert len(samples) == 3 assert all("signal" in s for s in samples) + + def test_missing_required_columns_returns_empty(self): + """Returns empty list when required columns are missing.""" + task = SleepWakeDetectionDREAMT() + patient, fake_df = make_fake_patient( + n_epochs=3, + sleep_stages=["R", "N2", "W"] + ) + fake_df = fake_df.drop(columns=["HR"]) + with patch("pandas.read_csv", return_value=fake_df): + samples = task(patient) + assert samples == [] + + def test_extract_epoch_features_output(self): + """Feature extraction returns a valid 1D feature vector.""" + n_rows = EPOCH_SAMPLES + data = { + "BVP": np.random.randn(n_rows).astype(np.float32), + "ACC_X": np.random.randn(n_rows).astype(np.float32), + "ACC_Y": np.random.randn(n_rows).astype(np.float32), + "ACC_Z": np.random.randn(n_rows).astype(np.float32), + "EDA": np.random.randn(n_rows).astype(np.float32), + "TEMP": (33.0 + 0.1 * np.random.randn(n_rows)).astype(np.float32), + "HR": (70.0 + np.random.randn(n_rows)).astype(np.float32), + } + fake_df = pd.DataFrame(data) + features = extract_epoch_features(fake_df) + assert isinstance(features, np.ndarray) + assert features.ndim == 1 + assert len(features) > 0 + # Check numerical validity + assert np.isfinite(features).all() class TestSleepStagingDREAMT: def test_instantiation(self): """SleepStagingDREAMT can be instantiated.""" - from pyhealth.tasks.sleep_wake_dreamt import SleepStagingDREAMT task = SleepStagingDREAMT() assert task.task_name == "SleepStagingDREAMT" def test_schema_defined(self): """Input and output schemas are defined correctly.""" - from pyhealth.tasks.sleep_wake_dreamt import SleepStagingDREAMT task = SleepStagingDREAMT() assert "signal" in task.input_schema assert "label" in task.output_schema @@ -206,7 +241,6 @@ def test_schema_defined(self): def test_fine_labels(self): """Each sleep stage maps to correct integer label.""" - from pyhealth.tasks.sleep_wake_dreamt import SleepStagingDREAMT task = SleepStagingDREAMT() patient, fake_df = make_fake_patient( n_epochs=5, @@ -219,7 +253,6 @@ def test_fine_labels(self): def test_missing_skipped(self): """Missing epochs are skipped in multi-class task too.""" - from pyhealth.tasks.sleep_wake_dreamt import SleepStagingDREAMT task = SleepStagingDREAMT() patient, fake_df = make_fake_patient( n_epochs=3, @@ -231,7 +264,6 @@ def test_missing_skipped(self): def test_signal_is_feature_vector(self): """Signal output is a 1D engineered feature vector.""" - from pyhealth.tasks.sleep_wake_dreamt import SleepStagingDREAMT task = SleepStagingDREAMT() patient, fake_df = make_fake_patient( n_epochs=2, From aaec24665660658bf0cfa0e86d5c408370412ab1 Mon Sep 17 00:00:00 2001 From: prachipradhan Date: Wed, 22 Apr 2026 21:04:52 -0400 Subject: [PATCH 3/5] fix: update input_schema processor types --- examples/dreamt_sleep_wake_detection.py | 8 ++++++-- pyhealth/tasks/sleep_wake_dreamt.py | 14 ++++++-------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/examples/dreamt_sleep_wake_detection.py b/examples/dreamt_sleep_wake_detection.py index 92762ac08..217874db4 100644 --- a/examples/dreamt_sleep_wake_detection.py +++ b/examples/dreamt_sleep_wake_detection.py @@ -34,6 +34,7 @@ from sklearn.model_selection import GroupShuffleSplit from pyhealth.datasets import DREAMTDataset +from pyhealth.tasks import SleepWakeDetectionDREAMT from pyhealth.metrics import binary_metrics_fn @@ -128,8 +129,11 @@ def main(root: str = None, demo: bool = False): else: print("\nUsing real DREAMT dataset from PhysioNet...") dataset = DREAMTDataset(root=root) - task_dataset = dataset.set_task(SleepWakeDetectionDREAMT()) - all_samples = task_dataset.samples + task = SleepWakeDetectionDREAMT() + all_samples = [] + for pid in dataset.unique_patient_ids: + patient = dataset.get_patient(pid) + all_samples.extend(task(patient)) print(f" Total epochs : {len(all_samples)}") wake = sum(s["label"] == 1 for s in all_samples) diff --git a/pyhealth/tasks/sleep_wake_dreamt.py b/pyhealth/tasks/sleep_wake_dreamt.py index 82773499e..81c2d594b 100644 --- a/pyhealth/tasks/sleep_wake_dreamt.py +++ b/pyhealth/tasks/sleep_wake_dreamt.py @@ -279,11 +279,9 @@ class SleepWakeDetectionDREAMT(BaseTask): task_name: str = "SleepWakeDetectionDREAMT" input_schema: Dict[str, str] = { - # Engineered feature vector per 30-sec epoch - "signal": "float", - # Clinical metadata for mixed-effects modeling (paper sec 2.6) - "ahi": "float", - "bmi": "float", + "signal": "tensor", + "ahi": "regression", + "bmi": "regression", } output_schema: Dict[str, str] = { @@ -353,9 +351,9 @@ class SleepStagingDREAMT(BaseTask): task_name: str = "SleepStagingDREAMT" input_schema: Dict[str, str] = { - "signal": "float", - "ahi": "float", - "bmi": "float", + "signal": "tensor", + "ahi": "regression", + "bmi": "regression", } output_schema: Dict[str, str] = { From f2db365cd820569177677c569ae6b8d2c7e67ab8 Mon Sep 17 00:00:00 2001 From: yyang2002 Date: Wed, 22 Apr 2026 20:49:19 -0500 Subject: [PATCH 4/5] Minor fix --- examples/dreamt_sleep_wake_detection.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/examples/dreamt_sleep_wake_detection.py b/examples/dreamt_sleep_wake_detection.py index 217874db4..1486e1ae7 100644 --- a/examples/dreamt_sleep_wake_detection.py +++ b/examples/dreamt_sleep_wake_detection.py @@ -17,7 +17,8 @@ 4. Ablation: comparing LightGBM vs LightGBM + AHI vs LightGBM + BMI Requirements: - pip install lightgbm scikit-learn imbalanced-learn + pip install lightgbm + pip install imbalanced-learn Usage: # Run with synthetic demo data (no PhysioNet access needed): @@ -102,7 +103,7 @@ def samples_to_arrays(samples): ) -def evaluate(y_true, y_pred, y_prob, label=""): +def evaluate(y_true, y_prob, label=""): """Print binary classification metrics using PyHealth metrics.""" metrics = binary_metrics_fn( y_true, @@ -117,8 +118,6 @@ def evaluate(y_true, y_pred, y_prob, label=""): def main(root: str = None, demo: bool = False): - if lgb is None: - raise ImportError("lightgbm not installed. Run: pip install lightgbm") print("\n[1/4] Loading data...") if demo: print("Using synthetic data") @@ -127,13 +126,10 @@ def main(root: str = None, demo: bool = False): n_patients=5, n_epochs_per_patient=40 ) else: - print("\nUsing real DREAMT dataset from PhysioNet...") + print("\nUsing real DREAMT dataset via PyHealth...") dataset = DREAMTDataset(root=root) - task = SleepWakeDetectionDREAMT() - all_samples = [] - for pid in dataset.unique_patient_ids: - patient = dataset.get_patient(pid) - all_samples.extend(task(patient)) + task_dataset = dataset.set_task(SleepWakeDetectionDREAMT()) + all_samples = task_dataset.samples print(f" Total epochs : {len(all_samples)}") wake = sum(s["label"] == 1 for s in all_samples) @@ -166,7 +162,7 @@ def main(root: str = None, demo: bool = False): # Ablation A: Baseline LightGBM clf_a = lgb.LGBMClassifier(n_estimators=200, random_state=42, verbose=-1) clf_a.fit(X_train, y_train) - evaluate(y_test, clf_a.predict(X_test), + evaluate(y_test, clf_a.predict_proba(X_test)[:, 1], "Ablation A: Baseline LightGBM (no clinical metadata)" ) @@ -179,7 +175,7 @@ def main(root: str = None, demo: bool = False): X_test_b = np.hstack([X_test, ahi_test.reshape(-1, 1)]) clf_b = lgb.LGBMClassifier(n_estimators=200, random_state=42, verbose=-1) clf_b.fit(X_train_b, y_train) - evaluate(y_test, clf_b.predict(X_test_b), + evaluate(y_test, clf_b.predict_proba(X_test_b)[:, 1], "Ablation B: LightGBM + AHI (apnea severity)" ) @@ -192,7 +188,7 @@ def main(root: str = None, demo: bool = False): X_test_c = np.hstack([X_test, bmi_test.reshape(-1, 1)]) clf_c = lgb.LGBMClassifier(n_estimators=200, random_state=42, verbose=-1) clf_c.fit(X_train_c, y_train) - evaluate(y_test, clf_c.predict(X_test_c), + evaluate(y_test, clf_c.predict_proba(X_test_c)[:, 1], "Ablation C: LightGBM + BMI (obesity)" ) From 5bc4bd4bf5e51fe2e2ba161718891bf0bc7c96e0 Mon Sep 17 00:00:00 2001 From: prachipradhan Date: Wed, 22 Apr 2026 23:01:32 -0400 Subject: [PATCH 5/5] fix: patient-level iteration --- examples/dreamt_sleep_wake_detection.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/dreamt_sleep_wake_detection.py b/examples/dreamt_sleep_wake_detection.py index 1486e1ae7..f1c5a254e 100644 --- a/examples/dreamt_sleep_wake_detection.py +++ b/examples/dreamt_sleep_wake_detection.py @@ -128,8 +128,11 @@ def main(root: str = None, demo: bool = False): else: print("\nUsing real DREAMT dataset via PyHealth...") dataset = DREAMTDataset(root=root) - task_dataset = dataset.set_task(SleepWakeDetectionDREAMT()) - all_samples = task_dataset.samples + task = SleepWakeDetectionDREAMT() + all_samples = [] + for pid in dataset.unique_patient_ids: + patient = dataset.get_patient(pid) + all_samples.extend(task(patient)) print(f" Total epochs : {len(all_samples)}") wake = sum(s["label"] == 1 for s in all_samples)