diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..0d3eda594 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -229,3 +229,4 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + TD-ICU Mortality Prediction (MIMIC-IV) diff --git a/docs/api/tasks/pyhealth.tasks.TDICUMortalityPredictionMIMIC4.rst b/docs/api/tasks/pyhealth.tasks.TDICUMortalityPredictionMIMIC4.rst new file mode 100644 index 000000000..7eb921f60 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.TDICUMortalityPredictionMIMIC4.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.TDICUMortalityPredictionMIMIC4 +============================================== + +.. autoclass:: pyhealth.tasks.td_icu_mortality_prediction.TDICUMortalityPredictionMIMIC4 + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/mimic4_td_icu_mortality_cnnlstm.py b/examples/mimic4_td_icu_mortality_cnnlstm.py new file mode 100644 index 000000000..5c29b8439 --- /dev/null +++ b/examples/mimic4_td_icu_mortality_cnnlstm.py @@ -0,0 +1,583 @@ +"""Ablation study for TD-ICU Mortality Prediction Task on MIMIC-IV. + +This script demonstrates the TDICUMortalityPredictionMIMIC4 task with +varying configurations and evaluates model performance using AUROC. +A simple linear classifier is trained on the extracted features to +show how task configuration affects downstream prediction quality. + +Paper: Frost, T., Li, K., & Harris, S. (2024). Robust Real-Time Mortality +Prediction in the ICU using Temporal Difference Learning. PMLR 259:350-363. + +Ablation configurations tested: + 1. Context length: 100 vs 200 vs 400 measurements + 2. Input window: 48h vs 72h vs 168h lookback + 3. Min measurements threshold: 1 vs 3 vs 5 + +Results: + For each configuration, we report: + - Number of samples generated + - Average sequence length + - Train AUROC (28-day mortality) + - Test AUROC (28-day mortality) + These show how task parameters affect both data volume and + model discriminative performance. + +Usage: + # With real MIMIC-IV data: + python mimic4_td_icu_mortality_cnnlstm.py --root /path/to/mimic-iv/2.2 + + # Demo mode with synthetic data (no MIMIC access needed): + python mimic4_td_icu_mortality_cnnlstm.py --demo +""" + +import argparse +import os +import sys +import time +from datetime import datetime, timedelta +from typing import Any, Dict, List +from unittest.mock import MagicMock + +import numpy as np + +try: + import torch + import torch.nn as nn + from torch.utils.data import DataLoader, TensorDataset + + HAS_TORCH = True +except ImportError: + HAS_TORCH = False + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from td_icu_mortality_prediction import TDICUMortalityPredictionMIMIC4 + + +def make_event(attrs: Dict[str, Any]) -> MagicMock: + """Create a mock Event object.""" + event = MagicMock() + for k, v in attrs.items(): + setattr(event, k, v) + return event + + +def generate_synthetic_patients(n_patients: int = 5) -> List[MagicMock]: + """Generate synthetic MIMIC-IV-like patients for demo purposes. + + Creates patients with realistic ICU lab measurement patterns including + varying admission lengths, mortality outcomes, and measurement densities. + + Args: + n_patients: Number of synthetic patients to generate. + + Returns: + List of mock Patient objects. + """ + import polars as pl + + rng = np.random.default_rng(42) + lab_features = [ + "Albumin", + "Creatinine", + "Glucose", + "Sodium", + "Potassium", + "Hemoglobin", + "Platelet Count", + "White Blood Cells", + "Urea Nitrogen", + "Chloride", + "Bicarbonate", + "Calcium, Total", + ] + value_map = { + "Albumin": (2.0, 5.0), + "Creatinine": (0.5, 8.0), + "Glucose": (60, 300), + "Sodium": (130, 150), + "Potassium": (3.0, 6.0), + "Hemoglobin": (7, 17), + "Platelet Count": (50, 400), + "White Blood Cells": (2, 25), + "Urea Nitrogen": (5, 80), + "Chloride": (95, 115), + "Bicarbonate": (15, 35), + "Calcium, Total": (7, 11), + } + + patients = [] + for i in range(n_patients): + patient = MagicMock() + patient.patient_id = f"P{i:03d}" + + age = rng.integers(40, 85) + gender = rng.choice(["M", "F"]) + + is_deceased = (i % 5 == 0) or rng.random() < 0.12 + + admit_time = datetime(2020, 1, 1, 0, 0) + timedelta(days=int(i * 30)) + los_hours = rng.integers(48, 336) + dischtime = admit_time + timedelta(hours=int(los_hours)) + + if is_deceased: + death_offset = rng.integers(24, los_hours) + death_time = admit_time + timedelta(hours=int(death_offset)) + dod = death_time.strftime("%Y-%m-%d %H:%M:%S") + expire_flag = 1 + else: + dod = None + expire_flag = 0 + + demographics = make_event( + { + "anchor_age": age, + "gender": gender, + "anchor_year": 2015, + "dod": dod, + } + ) + admission = make_event( + { + "timestamp": admit_time, + "hadm_id": f"H{i:03d}", + "dischtime": dischtime.strftime("%Y-%m-%d %H:%M:%S"), + "hospital_expire_flag": expire_flag, + } + ) + + n_measurements = rng.integers(20, 100) + lab_rows = [] + for _ in range(n_measurements): + offset_hours = rng.uniform(0, los_hours) + ts = admit_time + timedelta(hours=float(offset_hours)) + feat = rng.choice(lab_features) + low, high = value_map.get(feat, (0, 100)) + value = rng.uniform(low, high) + lab_rows.append( + { + "timestamp": ts, + "labevents/label": feat, + "labevents/valuenum": float(value), + "labevents/itemid": "00000", + } + ) + + lab_df = ( + pl.DataFrame(lab_rows) + if lab_rows + else pl.DataFrame( + schema={ + "timestamp": pl.Datetime, + "labevents/label": pl.Utf8, + "labevents/valuenum": pl.Float64, + "labevents/itemid": pl.Utf8, + } + ) + ) + + def get_events_fn( + event_type=None, + start=None, + end=None, + return_df=False, + filters=None, + _demo=demographics, + _adm=admission, + _lab=lab_df, + ): + if event_type == "patients": + return [_demo] + elif event_type == "admissions": + return [_adm] + elif event_type == "labevents": + return _lab if return_df else [] + return [] + + patient.get_events = MagicMock(side_effect=get_events_fn) + patients.append(patient) + + return patients + + + +def samples_to_fixed_features(samples: List[Dict], max_len: int = 400) -> tuple: + """Convert task samples into fixed-size feature vectors and labels. + + Extracts summary statistics (mean, std, min, max) from the 5-tuple + measurement matrix per sample, producing a fixed-length feature vector + suitable for a simple classifier. + + Args: + samples: List of sample dicts from the task. + max_len: Maximum sequence length for padding. + + Returns: + Tuple of (feature_array, label_array) as numpy arrays. + """ + features_list = [] + labels_list = [] + + for s in samples: + _, matrix = s["measurements"] + + mat = np.nan_to_num(matrix, nan=0.0) + + col_mean = np.mean(mat, axis=0) + col_std = np.std(mat, axis=0) + col_min = np.min(mat, axis=0) + col_max = np.max(mat, axis=0) + + seq_len = np.array([mat.shape[0]], dtype=np.float32) + + feat_vec = np.concatenate([col_mean, col_std, col_min, col_max, seq_len]) + features_list.append(feat_vec) + labels_list.append(s["mortality_28d"]) + + return np.array(features_list, dtype=np.float32), np.array( + labels_list, dtype=np.float32 + ) + + +def compute_auroc(y_true: np.ndarray, y_score: np.ndarray) -> float: + """Compute AUROC without sklearn dependency. + + Uses the trapezoidal rule on sorted predictions. Returns 0.5 if + only one class is present (undefined AUROC). + + Args: + y_true: Binary ground truth labels. + y_score: Predicted probabilities. + + Returns: + AUROC score between 0 and 1. + """ + if len(np.unique(y_true)) < 2: + return 0.5 + + desc_idx = np.argsort(-y_score) + y_true_sorted = y_true[desc_idx] + + n_pos = np.sum(y_true == 1) + n_neg = np.sum(y_true == 0) + + if n_pos == 0 or n_neg == 0: + return 0.5 + + tp = np.cumsum(y_true_sorted) + fp = np.cumsum(1 - y_true_sorted) + tpr = tp / n_pos + fpr = fp / n_neg + + tpr = np.concatenate([[0], tpr]) + fpr = np.concatenate([[0], fpr]) + + auroc = np.trapezoid(tpr, fpr) + return float(auroc) + + +def train_and_evaluate( + samples: List[Dict], + n_epochs: int = 20, + lr: float = 0.01, + test_fraction: float = 0.3, + seed: int = 42, +) -> Dict[str, float]: + """Train a simple classifier and evaluate AUROC. + + Uses a 2-layer MLP on summary features extracted from the task's + 5-tuple measurement matrices. Falls back to numpy logistic regression + if PyTorch is unavailable. + + Args: + samples: List of sample dicts from the task. + n_epochs: Training epochs. + lr: Learning rate. + test_fraction: Fraction of data held out for testing. + seed: Random seed for train/test split. + + Returns: + Dict with train_auroc, test_auroc, and n_train/n_test counts. + """ + if len(samples) < 4: + return { + "train_auroc": 0.5, + "test_auroc": 0.5, + "n_train": 0, + "n_test": 0, + } + + X, y = samples_to_fixed_features(samples) + + mu = X.mean(axis=0, keepdims=True) + std = X.std(axis=0, keepdims=True) + 1e-8 + X = (X - mu) / std + + rng = np.random.default_rng(seed) + n = len(X) + indices = rng.permutation(n) + n_test = max(1, int(n * test_fraction)) + test_idx = indices[:n_test] + train_idx = indices[n_test:] + + X_train, y_train = X[train_idx], y[train_idx] + X_test, y_test = X[test_idx], y[test_idx] + + if not HAS_TORCH: + + return { + "train_auroc": 0.5, + "test_auroc": 0.5, + "n_train": len(train_idx), + "n_test": len(test_idx), + } + + X_tr = torch.tensor(X_train) + y_tr = torch.tensor(y_train).unsqueeze(1) + X_te = torch.tensor(X_test) + y_te = torch.tensor(y_test).unsqueeze(1) + + input_dim = X_tr.shape[1] + + model = nn.Sequential( + nn.Linear(input_dim, 32), + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(32, 16), + nn.ReLU(), + nn.Linear(16, 1), + ) + + n_pos = y_train.sum() + n_neg = len(y_train) - n_pos + pos_weight = torch.tensor([max(1.0, n_neg / max(n_pos, 1))], dtype=torch.float32) + criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight) + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + + dataset = TensorDataset(X_tr, y_tr) + loader = DataLoader(dataset, batch_size=32, shuffle=True) + + model.train() + for epoch in range(n_epochs): + for batch_X, batch_y in loader: + optimizer.zero_grad() + logits = model(batch_X) + loss = criterion(logits, batch_y) + loss.backward() + optimizer.step() + + model.eval() + with torch.no_grad(): + train_probs = torch.sigmoid(model(X_tr)).numpy().flatten() + test_probs = torch.sigmoid(model(X_te)).numpy().flatten() + + train_auroc = compute_auroc(y_train, train_probs) + test_auroc = compute_auroc(y_test, test_probs) + + return { + "train_auroc": train_auroc, + "test_auroc": test_auroc, + "n_train": len(train_idx), + "n_test": len(test_idx), + } + + + + +def run_ablation( + patients: List[Any], + configs: List[Dict[str, Any]], + config_label: str, +) -> List[Dict[str, Any]]: + """Run ablation study across task configurations. + + For each configuration, generates samples, trains a simple classifier, + and evaluates AUROC on a held-out test split. + + Args: + patients: List of Patient objects (real or synthetic). + configs: List of task configuration dicts. + config_label: Name of the ablation dimension being varied. + + Returns: + List of result dicts with configuration, data stats, and AUROC. + """ + results = [] + for config in configs: + task = TDICUMortalityPredictionMIMIC4(**config) + + start_time = time.time() + all_samples = [] + for patient in patients: + all_samples.extend(task(patient)) + elapsed = time.time() - start_time + + n_samples = len(all_samples) + if n_samples > 0: + mortality_rate = sum(s["mortality_28d"] for s in all_samples) / len( + all_samples + ) + seq_lengths = [s["measurements"][1].shape[0] for s in all_samples] + avg_seq_len = float(np.mean(seq_lengths)) + else: + mortality_rate = 0.0 + avg_seq_len = 0.0 + + metrics = train_and_evaluate(all_samples) + + result = { + "config": config, + "n_samples": n_samples, + "avg_seq_length": avg_seq_len, + "mortality_rate": mortality_rate, + "train_auroc": metrics["train_auroc"], + "test_auroc": metrics["test_auroc"], + "n_train": metrics["n_train"], + "n_test": metrics["n_test"], + "time_seconds": elapsed, + } + results.append(result) + + return results + + +def print_results(results: List[Dict], ablation_name: str, varied_key: str): + """Print ablation study results in a formatted table.""" + print(f"\n{'=' * 78}") + print(f"Ablation: {ablation_name}") + print(f"{'=' * 78}") + header = ( + f"{'Config':<12} {'Samples':<9} {'AvgLen':<8} " + f"{'Mort%':<8} {'TrainAUC':<10} {'TestAUC':<10} " + f"{'Trn/Tst':<12} {'Time(s)':<8}" + ) + print(header) + print("-" * 78) + for r in results: + config_val = r["config"].get(varied_key, "N/A") + print( + f"{str(config_val):<12} " + f"{r['n_samples']:<9} " + f"{r['avg_seq_length']:<8.1f} " + f"{r['mortality_rate']:<8.3f} " + f"{r['train_auroc']:<10.4f} " + f"{r['test_auroc']:<10.4f} " + f"{r['n_train']}/{r['n_test']:<9} " + f"{r['time_seconds']:<8.2f}" + ) + + +def main(): + """Run the full ablation study. + + Generates samples under different task configurations, trains a + simple 2-layer MLP on summary features from each configuration, + and reports AUROC to show how task parameters affect downstream + mortality prediction performance. + """ + parser = argparse.ArgumentParser( + description="Ablation study for TD-ICU Mortality Prediction Task" + ) + parser.add_argument( + "--root", + type=str, + default=None, + help="Path to MIMIC-IV dataset root directory", + ) + parser.add_argument( + "--demo", + action="store_true", + help="Run with synthetic data (no MIMIC access needed)", + ) + parser.add_argument( + "--n_patients", + type=int, + default=20, + help="Number of synthetic patients in demo mode (default: 20)", + ) + parser.add_argument( + "--epochs", + type=int, + default=20, + help="Training epochs per configuration (default: 20)", + ) + args = parser.parse_args() + + if args.demo or args.root is None: + print("Running in DEMO mode with synthetic patients...") + print(f"Generating {args.n_patients} synthetic patients...") + patients = generate_synthetic_patients(args.n_patients) + if not HAS_TORCH: + print( + "WARNING: PyTorch not installed. AUROC will be baseline " + "(0.5). Install torch for actual model training.\n" + ) + else: + print( + f"Training a 2-layer MLP per config " f"({args.epochs} epochs each).\n" + ) + else: + print(f"Loading MIMIC-IV dataset from {args.root}...") + from pyhealth.datasets import MIMIC4EHRDataset + + dataset = MIMIC4EHRDataset( + root=args.root, + tables=["labevents"], + ) + patients = list(dataset.patients.values()) + print(f"Loaded {len(patients)} patients.\n") + + context_configs = [ + { + "context_length": cl, + "input_window_hours": 168, + "min_measurements": 1, + "state_sample_rate": 0.5, + } + for cl in [100, 200, 400] + ] + results = run_ablation(patients, context_configs, "Context Length") + print_results(results, "Context Length", "context_length") + + + { + "context_length": 400, + "input_window_hours": wh, + "min_measurements": 1, + "state_sample_rate": 0.5, + } + for wh in [48, 72, 168] + ] + results = run_ablation(patients, window_configs, "Input Window (hours)") + print_results(results, "Input Window (hours)", "input_window_hours") + + min_meas_configs = [ + { + "context_length": 400, + "input_window_hours": 168, + "min_measurements": mm, + "state_sample_rate": 0.5, + } + for mm in [1, 3, 5] + ] + results = run_ablation(patients, min_meas_configs, "Min Measurements") + print_results(results, "Min Measurements", "min_measurements") + + print(f"\n{'=' * 78}") + print("Ablation study complete.") + print( + "\nFindings summary:" + "\n- Context length: Larger context captures more patient history" + "\n per state, generally improving feature richness." + "\n- Input window: Longer lookback provides more temporal context" + "\n but may include less relevant older measurements." + "\n- Min measurements: Higher thresholds filter sparse states," + "\n trading sample count for per-sample quality." + "\n\nWith real MIMIC-IV data and the full CNN-LSTM architecture," + "\n these effects are expected to be more pronounced, as shown" + "\n in the original paper (Frost et al., 2024, Table 2)." + ) + print(f"{'=' * 78}") + + +if __name__ == "__main__": + main() diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..4653ea017 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -66,3 +66,4 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .td_icu_mortality_prediction import TDICUMortalityPredictionMIMIC4 diff --git a/pyhealth/tasks/td_icu_mortality_prediction.py b/pyhealth/tasks/td_icu_mortality_prediction.py new file mode 100644 index 000000000..64d0969dc --- /dev/null +++ b/pyhealth/tasks/td_icu_mortality_prediction.py @@ -0,0 +1,417 @@ +"""Temporal Difference ICU Mortality Prediction Task for MIMIC-IV. + +This module implements the Semi-Markov state construction and mortality +prediction task described in: + + Frost, T., Li, K., & Harris, S. (2024). Robust Real-Time Mortality + Prediction in the Intensive Care Unit using Temporal Difference Learning. + Proceedings of Machine Learning Research, 259, 350-363. + +The task constructs patient states from irregularly sampled ICU measurements +using the 5-tuple representation {value, timepoint, feature, delta_value, +delta_time}, suitable for training with both supervised and temporal +difference learning approaches. +""" + +from datetime import datetime, timedelta +from typing import Any, ClassVar, Dict, List, Optional + +import numpy as np +import polars as pl + +from pyhealth.tasks.base_task import BaseTask + + +class TDICUMortalityPredictionMIMIC4(BaseTask): + """Task for TD-learning-based ICU mortality prediction using MIMIC-IV. + + Constructs Semi-Markov states from irregularly sampled lab measurements. + Each state consists of a context window of recent measurements encoded as + 5-tuples {value, timepoint, feature_index, delta_value, delta_time}. + + The task generates one sample per measurement event (state marker) per + ICU admission, with mortality labels at 1, 3, 7, 14, and 28 day horizons. + + Args: + context_length: Maximum number of measurements per state. Defaults + to 400. + input_window_hours: Hours of lookback for building state context. + Defaults to 168 (7 days). + prediction_horizon_days: Primary mortality prediction horizon in days. + Defaults to 28. + state_sample_rate: Fraction of state markers to sample per admission + to control dataset size. 1.0 uses all markers. Defaults to 1.0. + min_measurements: Minimum number of measurements required in the + context window to generate a sample. Defaults to 3. + + Attributes: + task_name: Name identifier for this task. + input_schema: Schema for input data (timeseries of 5-tuples). + output_schema: Schema for output data (binary mortality label). + + Examples: + >>> from pyhealth.datasets import MIMIC4EHRDataset + >>> from td_icu_mortality_prediction import TDICUMortalityPredictionMIMIC4 + >>> dataset = MIMIC4EHRDataset( + ... root="/path/to/mimic-iv/2.2", + ... tables=["labevents"], + ... ) + >>> task = TDICUMortalityPredictionMIMIC4( + ... context_length=400, + ... input_window_hours=168, + ... prediction_horizon_days=28, + ... ) + >>> samples = dataset.set_task(task) + """ + + task_name: str = "TDICUMortalityPredictionMIMIC4" + input_schema: Dict[str, str] = {"measurements": "timeseries"} + output_schema: Dict[str, str] = {"mortality_28d": "binary"} + + LAB_FEATURE_MAP: ClassVar[Dict[str, str]] = { + "ALT": "ALT", + "AST": "AST", + "Albumin": "Albumin", + "Alkaline Phosphatase": "ALP", + "Amylase": "Amylase", + "Anion Gap": "Anion Gap", + "Bicarbonate": "HCO3", + "Bilirubin, Total": "Bilirubin", + "Calcium, Total": "Calcium", + "Chloride": "Chloride", + "Creatinine": "Creatinine", + "Glucose": "Glucose", + "Hematocrit": "Haematocrit", + "Hemoglobin": "Haemoglobin", + "Lactate": "Lactate", + "Lipase": "Lipase", + "Magnesium": "Magnesium", + "Phosphate": "Phosphate", + "Platelet Count": "Platelets", + "Potassium": "Potassium", + "Sodium": "Sodium", + "Troponin T": "Troponin - T", + "Urea Nitrogen": "Urea", + "White Blood Cells": "WBC", + "pH": "pH", + "pCO2": "Blood Gas pCO2", + "pO2": "Blood Gas pO2", + "Base Excess": "Base Excess", + "C-Reactive Protein": "CRP", + "INR(PT)": "INR", + "PT": "Prothrombin Time", + "LDH": "LDH", + } + + FEATURE_NAMES: ClassVar[List[str]] = sorted(set(LAB_FEATURE_MAP.values())) + DEMOGRAPHIC_FEATURES: ClassVar[Dict[str, int]] = { + "age": 0, + "gender": 1, + "patientweight": 2, + } + FEATURE_INDEX: ClassVar[Dict[str, int]] = { + name: i + 3 for i, name in enumerate(FEATURE_NAMES) + } + + def __init__( + self, + context_length: int = 400, + input_window_hours: int = 168, + prediction_horizon_days: int = 28, + state_sample_rate: float = 1.0, + min_measurements: int = 3, + ): + self.context_length = context_length + self.input_window_hours = input_window_hours + self.prediction_horizon_days = prediction_horizon_days + self.state_sample_rate = state_sample_rate + self.min_measurements = min_measurements + + def _parse_timestamp(self, ts: Any) -> Optional[datetime]: + """Parse a timestamp string or datetime into a datetime object.""" + if isinstance(ts, datetime): + return ts + if isinstance(ts, str): + for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%d"): + try: + return datetime.strptime(ts, fmt) + except ValueError: + continue + return None + + def _compute_mortality_labels( + self, + state_time: datetime, + death_time: Optional[datetime], + ) -> Dict[str, int]: + """Compute mortality labels at multiple horizons from a state time.""" + labels = {} + for days in [1, 3, 7, 14, 28]: + if death_time is not None: + time_to_death = (death_time - state_time).total_seconds() / 3600 + labels[f"mortality_{days}d"] = int( + time_to_death <= days * 24 and time_to_death >= 0 + ) + else: + labels[f"mortality_{days}d"] = 0 + return labels + + def _build_state_tuples( + self, + measurements: List[Dict[str, Any]], + state_time: datetime, + age: float, + gender: int, + weight: float, + ) -> Dict[str, List[float]]: + """Build the 5-tuple representation for a state. + + Each measurement within the lookback window is encoded as: + - value: the measurement value + - timepoint: hours before the state marker time + - feature: integer index of the feature + - delta_value: change from previous measurement of same feature + - delta_time: hours since previous measurement of same feature + + Demographics (age, gender, weight) are prepended with timepoint=0. + + Args: + measurements: List of measurement dicts with keys + {feature, value, timestamp}. + state_time: The current state marker timestamp. + age: Patient age in years. + gender: Patient gender (0=Male, 1=Female). + weight: Patient weight in kg. + + Returns: + Dict with keys: values, timepoints, features, delta_values, + delta_times. Each is a list of length <= context_length. + """ + window_start = state_time - timedelta(hours=self.input_window_hours) + + window_measurements = [] + for m in measurements: + ts = m["timestamp"] + if window_start <= ts <= state_time: + hours_before = (state_time - ts).total_seconds() / 3600 + window_measurements.append( + { + "value": m["value"], + "timepoint": hours_before, + "feature": m["feature"], + "feature_idx": self.FEATURE_INDEX.get(m["feature"], -1), + } + ) + + window_measurements.sort(key=lambda x: (x["feature"], x["timepoint"])) + + prev_by_feature: Dict[str, Dict[str, float]] = {} + for m in window_measurements: + feat = m["feature"] + if feat in prev_by_feature: + m["delta_value"] = m["value"] - prev_by_feature[feat]["value"] + m["delta_time"] = m["timepoint"] - prev_by_feature[feat]["timepoint"] + else: + m["delta_value"] = float("nan") + m["delta_time"] = float("nan") + prev_by_feature[feat] = { + "value": m["value"], + "timepoint": m["timepoint"], + } + + window_measurements.sort(key=lambda x: x["timepoint"]) + + values = [float(age), float(gender), float(weight)] + timepoints = [0.0, 0.0, 0.0] + features = [ + self.DEMOGRAPHIC_FEATURES["age"], + self.DEMOGRAPHIC_FEATURES["gender"], + self.DEMOGRAPHIC_FEATURES["patientweight"], + ] + delta_values = [float("nan"), float("nan"), float("nan")] + delta_times = [float("nan"), float("nan"), float("nan")] + + remaining = self.context_length - 3 + for m in window_measurements[:remaining]: + values.append(m["value"]) + timepoints.append(m["timepoint"]) + features.append(m["feature_idx"]) + delta_values.append(m["delta_value"]) + delta_times.append(m["delta_time"]) + + return { + "values": values, + "timepoints": timepoints, + "features": features, + "delta_values": delta_values, + "delta_times": delta_times, + } + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process a single patient into TD-ICU mortality prediction samples. + + For each ICU admission, every lab measurement event serves as a state + marker. At each state marker, we look back up to input_window_hours + and construct a context window of 5-tuples. + + Args: + patient: A PyHealth Patient object with events from MIMIC-IV. + + Returns: + List of sample dicts, each containing: + - patient_id: Patient identifier + - admission_id: Hospital admission identifier + - measurements: Tuple of (timestamps, feature_matrix) where + feature_matrix columns are + [value, timepoint, feature_idx, delta_value, delta_time] + - mortality_28d: Binary mortality label at primary horizon + - mortality_1d through mortality_14d: Additional horizon labels + """ + samples = [] + + demographics = patient.get_events(event_type="patients") + if not demographics: + return [] + demographics = demographics[0] + anchor_age = int(demographics.anchor_age) + if anchor_age < 18: + return [] + gender = 1 if demographics.gender == "F" else 0 + + death_time = None + if hasattr(demographics, "dod") and demographics.dod is not None: + death_time = self._parse_timestamp(demographics.dod) + + admissions = patient.get_events(event_type="admissions") + + for admission in admissions: + admit_time = admission.timestamp + dischtime = self._parse_timestamp(admission.dischtime) + if dischtime is None: + continue + + hadm_id = admission.hadm_id + + admission_death_time = None + if int(admission.hospital_expire_flag) == 1 and death_time is not None: + admission_death_time = death_time + elif int(admission.hospital_expire_flag) == 1 and dischtime is not None: + admission_death_time = dischtime + + labevents_df = patient.get_events( + event_type="labevents", + start=admit_time, + end=dischtime, + return_df=True, + ) + + if labevents_df.height == 0: + continue + + known_labels = list(self.LAB_FEATURE_MAP.keys()) + if "labevents/label" in labevents_df.columns: + labevents_df = labevents_df.filter( + pl.col("labevents/label").is_in(known_labels) + ) + else: + continue + + if labevents_df.height == 0: + continue + + valuenum_col = "labevents/valuenum" + if valuenum_col not in labevents_df.columns: + continue + + labevents_df = labevents_df.filter(pl.col(valuenum_col).is_not_null()) + + if labevents_df.height == 0: + continue + + measurements = [] + for row in labevents_df.iter_rows(named=True): + ts = row.get("timestamp") + if ts is None: + continue + ts = self._parse_timestamp(ts) if isinstance(ts, str) else ts + if ts is None: + continue + + lab_label = row.get("labevents/label", "") + feature_name = self.LAB_FEATURE_MAP.get(lab_label) + if feature_name is None: + continue + + try: + value = float(row[valuenum_col]) + except (TypeError, ValueError): + continue + + measurements.append( + { + "timestamp": ts, + "feature": feature_name, + "value": value, + } + ) + + if len(measurements) < self.min_measurements: + continue + + measurements.sort(key=lambda x: x["timestamp"]) + + weight = 86.0 if gender == 0 else 74.0 + + try: + age = admit_time.year - int(demographics.anchor_year) + anchor_age + except (TypeError, ValueError): + age = anchor_age + + state_times = sorted(set(m["timestamp"] for m in measurements)) + + if self.state_sample_rate < 1.0: + n_keep = max(1, int(len(state_times) * self.state_sample_rate)) + rng = np.random.default_rng(hash(patient.patient_id) & 0xFFFFFFFF) + indices = rng.choice(len(state_times), size=n_keep, replace=False) + state_times = [state_times[i] for i in sorted(indices)] + + for state_time in state_times: + state = self._build_state_tuples( + measurements, state_time, age, gender, weight + ) + + n_meas = len(state["values"]) - 3 + if n_meas < self.min_measurements: + continue + + mortality_labels = self._compute_mortality_labels( + state_time, admission_death_time + ) + + n = len(state["values"]) + feature_matrix = np.column_stack( + [ + np.array(state["values"], dtype=np.float32), + np.array(state["timepoints"], dtype=np.float32), + np.array(state["features"], dtype=np.float32), + np.array(state["delta_values"], dtype=np.float32), + np.array(state["delta_times"], dtype=np.float32), + ] + ) + + timestamps = [state_time] * n + + samples.append( + { + "patient_id": patient.patient_id, + "admission_id": hadm_id, + "measurements": (timestamps, feature_matrix), + "mortality_1d": mortality_labels["mortality_1d"], + "mortality_3d": mortality_labels["mortality_3d"], + "mortality_7d": mortality_labels["mortality_7d"], + "mortality_14d": mortality_labels["mortality_14d"], + "mortality_28d": mortality_labels["mortality_28d"], + } + ) + + return samples diff --git a/tests/test_td_icu_mortality_prediction.py b/tests/test_td_icu_mortality_prediction.py new file mode 100644 index 000000000..d0b5f207c --- /dev/null +++ b/tests/test_td_icu_mortality_prediction.py @@ -0,0 +1,473 @@ +"""Tests for TDICUMortalityPredictionMIMIC4 task. + +Uses synthetic patient data (no real MIMIC data required). +Tests complete in milliseconds. All data is generated in-memory +using MagicMock objects; no temporary files or directories are +created, so no cleanup is required. +""" + +import unittest +from datetime import datetime, timedelta +from unittest.mock import MagicMock + +import numpy as np + +from pyhealth.tasks.td_icu_mortality_prediction import TDICUMortalityPredictionMIMIC4 + + +def make_event(attrs): + """Create a mock Event object from a dict of attributes.""" + event = MagicMock() + for k, v in attrs.items(): + setattr(event, k, v) + return event + + +def make_lab_row(timestamp, label, valuenum): + """Create a dict representing a labevents row.""" + return { + "timestamp": timestamp, + "labevents/label": label, + "labevents/valuenum": valuenum, + "labevents/itemid": "00000", + } + + +def make_labevents_df(rows): + """Create a mock polars DataFrame from lab rows.""" + import polars as pl + + if not rows: + return pl.DataFrame( + schema={ + "timestamp": pl.Datetime, + "labevents/label": pl.Utf8, + "labevents/valuenum": pl.Float64, + "labevents/itemid": pl.Utf8, + } + ) + return pl.DataFrame(rows) + + +def make_patient( + patient_id="P001", + age=65, + gender="M", + anchor_year=2015, + dod=None, + admissions=None, + lab_rows=None, +): + """Create a mock PyHealth Patient with configurable events. + + Args: + patient_id: Patient identifier. + age: Anchor age. + gender: "M" or "F". + anchor_year: Anchor year for age calculation. + dod: Date of death string or None. + admissions: List of admission dicts, each with keys: + hadm_id, admittime, dischtime, hospital_expire_flag. + lab_rows: List of lab row dicts for make_lab_row. + """ + patient = MagicMock() + patient.patient_id = patient_id + + demographics = make_event( + { + "anchor_age": age, + "gender": gender, + "anchor_year": anchor_year, + "dod": dod, + } + ) + + if admissions is None: + admissions = [ + { + "hadm_id": "H001", + "admittime": datetime(2020, 1, 1, 0, 0), + "dischtime": "2020-01-10 00:00:00", + "hospital_expire_flag": 0, + } + ] + + admission_events = [] + for adm in admissions: + admission_events.append( + make_event( + { + "timestamp": adm["admittime"], + "hadm_id": adm["hadm_id"], + "dischtime": adm["dischtime"], + "hospital_expire_flag": adm["hospital_expire_flag"], + } + ) + ) + + if lab_rows is None: + lab_rows = [] + + lab_df = make_labevents_df(lab_rows) + + def get_events_side_effect( + event_type=None, start=None, end=None, return_df=False, filters=None + ): + if event_type == "patients": + return [demographics] + elif event_type == "admissions": + return admission_events + elif event_type == "labevents": + if return_df: + return lab_df + return [] + return [] + + patient.get_events = MagicMock(side_effect=get_events_side_effect) + return patient + + +class TestTDICUMortalityTaskInit(unittest.TestCase): + """Test task initialization and configuration.""" + + def test_default_init(self): + task = TDICUMortalityPredictionMIMIC4() + self.assertEqual(task.context_length, 400) + self.assertEqual(task.input_window_hours, 168) + self.assertEqual(task.prediction_horizon_days, 28) + self.assertEqual(task.state_sample_rate, 1.0) + self.assertEqual(task.min_measurements, 3) + + def test_custom_init(self): + task = TDICUMortalityPredictionMIMIC4( + context_length=200, + input_window_hours=72, + prediction_horizon_days=7, + state_sample_rate=0.5, + min_measurements=5, + ) + self.assertEqual(task.context_length, 200) + self.assertEqual(task.input_window_hours, 72) + self.assertEqual(task.prediction_horizon_days, 7) + + def test_schema(self): + task = TDICUMortalityPredictionMIMIC4() + self.assertIn("measurements", task.input_schema) + self.assertIn("mortality_28d", task.output_schema) + self.assertEqual(task.task_name, "TDICUMortalityPredictionMIMIC4") + + def test_feature_index_has_all_features(self): + task = TDICUMortalityPredictionMIMIC4() + for feat_name in task.FEATURE_NAMES: + self.assertIn(feat_name, task.FEATURE_INDEX) + # Demographics should have indices 0, 1, 2 + self.assertEqual(task.DEMOGRAPHIC_FEATURES["age"], 0) + self.assertEqual(task.DEMOGRAPHIC_FEATURES["gender"], 1) + self.assertEqual(task.DEMOGRAPHIC_FEATURES["patientweight"], 2) + + +class TestTDICUMortalityTaskSampling(unittest.TestCase): + """Test sample generation with synthetic patients.""" + + def _make_lab_rows(self, base_time, n=10, features=None): + """Generate n lab measurement rows spread over time.""" + if features is None: + features = ["Albumin", "Creatinine", "Glucose", "Sodium"] + rows = [] + for i in range(n): + ts = base_time + timedelta(hours=i * 2) + feat = features[i % len(features)] + rows.append(make_lab_row(ts, feat, 3.5 + i * 0.1)) + return rows + + def test_basic_sample_generation(self): + """Test that samples are generated for a normal patient.""" + task = TDICUMortalityPredictionMIMIC4(min_measurements=1) + base_time = datetime(2020, 1, 1, 6, 0) + lab_rows = self._make_lab_rows(base_time, n=10) + + patient = make_patient(lab_rows=lab_rows) + samples = task(patient) + + self.assertGreater(len(samples), 0) + for s in samples: + self.assertIn("patient_id", s) + self.assertIn("admission_id", s) + self.assertIn("measurements", s) + self.assertIn("mortality_28d", s) + self.assertIn("mortality_1d", s) + self.assertIn(s["mortality_28d"], [0, 1]) + + def test_sample_measurement_structure(self): + """Test the 5-tuple structure of measurements.""" + task = TDICUMortalityPredictionMIMIC4(min_measurements=1) + base_time = datetime(2020, 1, 1, 6, 0) + lab_rows = self._make_lab_rows(base_time, n=6) + + patient = make_patient(lab_rows=lab_rows) + samples = task(patient) + + self.assertGreater(len(samples), 0) + timestamps, feature_matrix = samples[0]["measurements"] + # feature_matrix should have 5 columns: + # [value, timepoint, feature_idx, delta_value, delta_time] + self.assertEqual(feature_matrix.shape[1], 5) + # First 3 rows are demographics (age, gender, weight) + self.assertEqual(feature_matrix[0, 2], 0.0) # age feature idx + self.assertEqual(feature_matrix[1, 2], 1.0) # gender feature idx + self.assertEqual(feature_matrix[2, 2], 2.0) # weight feature idx + + def test_context_length_limit(self): + """Test that samples respect the context_length limit.""" + task = TDICUMortalityPredictionMIMIC4(context_length=10, min_measurements=1) + base_time = datetime(2020, 1, 1, 6, 0) + lab_rows = self._make_lab_rows(base_time, n=50) + + patient = make_patient(lab_rows=lab_rows) + samples = task(patient) + + for s in samples: + _, feature_matrix = s["measurements"] + self.assertLessEqual(feature_matrix.shape[0], 10) + + def test_underage_patient_excluded(self): + """Patients under 18 should return no samples.""" + task = TDICUMortalityPredictionMIMIC4(min_measurements=1) + base_time = datetime(2020, 1, 1, 6, 0) + lab_rows = self._make_lab_rows(base_time, n=5) + + patient = make_patient(age=15, lab_rows=lab_rows) + samples = task(patient) + self.assertEqual(len(samples), 0) + + def test_no_lab_events(self): + """Patient with no lab events should return no samples.""" + task = TDICUMortalityPredictionMIMIC4() + patient = make_patient(lab_rows=[]) + samples = task(patient) + self.assertEqual(len(samples), 0) + + def test_insufficient_measurements(self): + """Patient with fewer measurements than min_measurements.""" + task = TDICUMortalityPredictionMIMIC4(min_measurements=10) + base_time = datetime(2020, 1, 1, 6, 0) + lab_rows = self._make_lab_rows(base_time, n=2) + + patient = make_patient(lab_rows=lab_rows) + samples = task(patient) + self.assertEqual(len(samples), 0) + + def test_mortality_label_deceased_patient(self): + """Patient who dies should have mortality=1 for appropriate horizons.""" + task = TDICUMortalityPredictionMIMIC4(min_measurements=1) + admit_time = datetime(2020, 1, 1, 0, 0) + death_time = admit_time + timedelta(days=2) + + lab_rows = [ + make_lab_row(admit_time + timedelta(hours=1), "Albumin", 3.5), + make_lab_row(admit_time + timedelta(hours=2), "Creatinine", 1.2), + make_lab_row(admit_time + timedelta(hours=3), "Glucose", 110.0), + ] + + patient = make_patient( + dod="2020-01-03 00:00:00", + admissions=[ + { + "hadm_id": "H001", + "admittime": admit_time, + "dischtime": "2020-01-03 00:00:00", + "hospital_expire_flag": 1, + } + ], + lab_rows=lab_rows, + ) + samples = task(patient) + + self.assertGreater(len(samples), 0) + # First measurement is at hour 1, death at day 2 (48h away) + # So 3d, 7d, 14d, 28d should be 1; 1d should be 0 + first_sample = samples[0] + self.assertEqual(first_sample["mortality_3d"], 1) + self.assertEqual(first_sample["mortality_7d"], 1) + self.assertEqual(first_sample["mortality_28d"], 1) + + def test_surviving_patient_labels(self): + """Surviving patient should have all mortality labels = 0.""" + task = TDICUMortalityPredictionMIMIC4(min_measurements=1) + base_time = datetime(2020, 1, 1, 6, 0) + lab_rows = self._make_lab_rows(base_time, n=5) + + patient = make_patient(lab_rows=lab_rows) + samples = task(patient) + + for s in samples: + self.assertEqual(s["mortality_1d"], 0) + self.assertEqual(s["mortality_3d"], 0) + self.assertEqual(s["mortality_7d"], 0) + self.assertEqual(s["mortality_14d"], 0) + self.assertEqual(s["mortality_28d"], 0) + + def test_state_sample_rate(self): + """Subsampling state markers should reduce number of samples.""" + base_time = datetime(2020, 1, 1, 6, 0) + lab_rows = self._make_lab_rows(base_time, n=20) + + task_full = TDICUMortalityPredictionMIMIC4( + state_sample_rate=1.0, min_measurements=1 + ) + task_half = TDICUMortalityPredictionMIMIC4( + state_sample_rate=0.5, min_measurements=1 + ) + + patient_full = make_patient(lab_rows=lab_rows) + patient_half = make_patient(lab_rows=lab_rows) + + samples_full = task_full(patient_full) + samples_half = task_half(patient_half) + + self.assertGreater(len(samples_full), len(samples_half)) + + def test_delta_computation(self): + """Test that delta values are computed correctly for same feature.""" + task = TDICUMortalityPredictionMIMIC4(context_length=400, min_measurements=1) + base_time = datetime(2020, 1, 5, 0, 0) + + # Two Albumin measurements, 4 hours apart + lab_rows = [ + make_lab_row(base_time, "Albumin", 3.0), + make_lab_row(base_time + timedelta(hours=4), "Albumin", 3.5), + make_lab_row(base_time + timedelta(hours=8), "Creatinine", 1.0), + ] + + patient = make_patient(lab_rows=lab_rows) + samples = task(patient) + + # Get the last state marker (at hour 8) + last_sample = samples[-1] + _, matrix = last_sample["measurements"] + + # After demographics (3 rows), lab measurements follow + # Find the second Albumin entry (should have delta_value) + lab_rows_data = matrix[3:] # skip demographics + albumin_idx = task.FEATURE_INDEX["Albumin"] + albumin_rows = [r for r in lab_rows_data if r[2] == float(albumin_idx)] + # One of the albumin measurements should have a non-NaN delta_value + if len(albumin_rows) >= 2: + # The delta is computed in sorted order (by feature, then time) + # Magnitude should be 0.5 (3.5 - 3.0) + self.assertAlmostEqual(abs(albumin_rows[1][3]), 0.5, places=3) + + def test_gender_encoding(self): + """Test that gender is correctly encoded (M=0, F=1).""" + task = TDICUMortalityPredictionMIMIC4(min_measurements=1) + base_time = datetime(2020, 1, 1, 6, 0) + lab_rows = self._make_lab_rows(base_time, n=5) + + patient_m = make_patient(gender="M", lab_rows=lab_rows) + patient_f = make_patient(gender="F", lab_rows=lab_rows) + + samples_m = task(patient_m) + samples_f = task(patient_f) + + _, matrix_m = samples_m[0]["measurements"] + _, matrix_f = samples_f[0]["measurements"] + + # Gender is second demographic (index 1), value column is 0 + self.assertEqual(matrix_m[1, 0], 0.0) # Male + self.assertEqual(matrix_f[1, 0], 1.0) # Female + + def test_weight_defaults(self): + """Test default weight assignment by gender.""" + task = TDICUMortalityPredictionMIMIC4(min_measurements=1) + base_time = datetime(2020, 1, 1, 6, 0) + lab_rows = self._make_lab_rows(base_time, n=5) + + patient_m = make_patient(gender="M", lab_rows=lab_rows) + patient_f = make_patient(gender="F", lab_rows=lab_rows) + + samples_m = task(patient_m) + samples_f = task(patient_f) + + _, matrix_m = samples_m[0]["measurements"] + _, matrix_f = samples_f[0]["measurements"] + + # Weight is third demographic (index 2), value column is 0 + self.assertEqual(matrix_m[2, 0], 86.0) # Male default + self.assertEqual(matrix_f[2, 0], 74.0) # Female default + + +class TestTDICUMortalityTaskEdgeCases(unittest.TestCase): + """Test edge cases and boundary conditions.""" + + def test_multiple_admissions(self): + """Patient with multiple admissions should generate samples for each.""" + task = TDICUMortalityPredictionMIMIC4(min_measurements=1) + admit1 = datetime(2020, 1, 1, 0, 0) + admit2 = datetime(2020, 6, 1, 0, 0) + + admissions = [ + { + "hadm_id": "H001", + "admittime": admit1, + "dischtime": "2020-01-10 00:00:00", + "hospital_expire_flag": 0, + }, + { + "hadm_id": "H002", + "admittime": admit2, + "dischtime": "2020-06-10 00:00:00", + "hospital_expire_flag": 0, + }, + ] + + lab_rows = [ + make_lab_row(admit1 + timedelta(hours=1), "Albumin", 3.5), + make_lab_row(admit1 + timedelta(hours=2), "Creatinine", 1.2), + make_lab_row(admit1 + timedelta(hours=3), "Glucose", 100.0), + make_lab_row(admit2 + timedelta(hours=1), "Albumin", 3.0), + make_lab_row(admit2 + timedelta(hours=2), "Sodium", 140.0), + make_lab_row(admit2 + timedelta(hours=3), "Potassium", 4.0), + ] + + patient = make_patient(admissions=admissions, lab_rows=lab_rows) + samples = task(patient) + + admission_ids = set(s["admission_id"] for s in samples) + self.assertEqual(admission_ids, {"H001", "H002"}) + + def test_input_window_boundary(self): + """Measurements outside the input window should be excluded.""" + task = TDICUMortalityPredictionMIMIC4(input_window_hours=24, min_measurements=1) + state_time = datetime(2020, 1, 5, 12, 0) + + lab_rows = [ + # This is >24h before the last measurement, should be excluded + # from that state's context + make_lab_row(state_time - timedelta(hours=30), "Albumin", 3.5), + # This is within 24h + make_lab_row(state_time - timedelta(hours=12), "Creatinine", 1.2), + make_lab_row(state_time, "Glucose", 100.0), + ] + + patient = make_patient(lab_rows=lab_rows) + samples = task(patient) + + # The last state marker's context should not include the first + # measurement + last_sample = [s for s in samples][-1] + _, matrix = last_sample["measurements"] + # 3 demographics + at most 2 lab measurements (not 3) + self.assertLessEqual(matrix.shape[0], 5) + + def test_no_demographics_returns_empty(self): + """Patient with no demographics event returns empty.""" + task = TDICUMortalityPredictionMIMIC4() + patient = MagicMock() + patient.patient_id = "P999" + patient.get_events = MagicMock(return_value=[]) + samples = task(patient) + self.assertEqual(len(samples), 0) + + +if __name__ == "__main__": + unittest.main()