From ebe15f0cecb00ab67b32978cfe17404093c450d0 Mon Sep 17 00:00:00 2001 From: leeethinh Date: Wed, 22 Apr 2026 20:09:02 -0500 Subject: [PATCH] Final submission: temporal evaluation + risk prediction + tests + examples --- docs/api/tasks.rst | 2 + .../pyhealth.tasks.temporal_evaluation.rst | 9 + ...yhealth.tasks.temporal_risk_prediction.rst | 7 + docs/api/tasks/temporal_risk_prediction.py | 126 +++++ .../mimic4_temporal_mortality_logistic.py | 55 +++ .../synthetic_temporal_mortality_logistic.py | 66 +++ examples/synthetic_temporal_shift_demo.py | 86 ++++ pyhealth/tasks/__init__.py | 1 + pyhealth/tasks/temporal_evaluation.py | 453 ++++++++++++++++++ pyhealth/tasks/temporal_risk_prediction.py | 126 +++++ tests/test_temporal_evaluation.py | 115 +++++ tests/test_temporal_risk_prediction.py | 98 ++++ 12 files changed, 1144 insertions(+) create mode 100644 docs/api/tasks/pyhealth.tasks.temporal_evaluation.rst create mode 100644 docs/api/tasks/pyhealth.tasks.temporal_risk_prediction.rst create mode 100644 docs/api/tasks/temporal_risk_prediction.py create mode 100644 examples/mimic4_temporal_mortality_logistic.py create mode 100644 examples/synthetic_temporal_mortality_logistic.py create mode 100644 examples/synthetic_temporal_shift_demo.py create mode 100644 pyhealth/tasks/temporal_evaluation.py create mode 100644 pyhealth/tasks/temporal_risk_prediction.py create mode 100644 tests/test_temporal_evaluation.py create mode 100644 tests/test_temporal_risk_prediction.py diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 23a4e06e5..1d9395e43 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -230,3 +230,5 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + Temporal Evaluation + Temporal Mortality (MIMIC-IV) \ No newline at end of file diff --git a/docs/api/tasks/pyhealth.tasks.temporal_evaluation.rst b/docs/api/tasks/pyhealth.tasks.temporal_evaluation.rst new file mode 100644 index 000000000..090063785 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.temporal_evaluation.rst @@ -0,0 +1,9 @@ +pyhealth.tasks.temporal_evaluation +================================= + +.. automodule:: pyhealth.tasks.temporal_evaluation + :members: + :undoc-members: + :show-inheritance: + + \ No newline at end of file diff --git a/docs/api/tasks/pyhealth.tasks.temporal_risk_prediction.rst b/docs/api/tasks/pyhealth.tasks.temporal_risk_prediction.rst new file mode 100644 index 000000000..beb064b3e --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.temporal_risk_prediction.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.temporal_risk_prediction +====================================== + +.. automodule:: pyhealth.tasks.temporal_risk_prediction + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/api/tasks/temporal_risk_prediction.py b/docs/api/tasks/temporal_risk_prediction.py new file mode 100644 index 000000000..a107d4a75 --- /dev/null +++ b/docs/api/tasks/temporal_risk_prediction.py @@ -0,0 +1,126 @@ +"""Temporal risk prediction task for PyHealth datasets.""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from pyhealth.data import Patient +from pyhealth.tasks import BaseTask + + +class TemporalMortalityMIMIC4(BaseTask): + """Temporal mortality prediction task for MIMIC-IV style EHR data. + + This task creates one sample per admission. Each sample contains: + 1. a lightweight numeric feature vector summarizing patient history + 2. the admission year for temporal splitting + 3. a binary mortality label + + The task is designed to pair with temporal evaluation utilities that + compare deployment-like temporal splits against random train/test splits. + + Attributes: + task_name: Name of the task. + input_schema: Schema for model inputs. + output_schema: Schema for model outputs. + """ + + task_name: str = "TemporalMortalityMIMIC4" + + input_schema: Dict[str, str] = { + "features": "tensor", + "year": "tensor", + } + + output_schema: Dict[str, str] = { + "label": "binary", + } + + def __init__(self, min_history_events: int = 1) -> None: + """Initializes the task. + + Args: + min_history_events: Minimum number of historical events required + to emit a sample. + """ + self.min_history_events = min_history_events + + def _safe_year(self, event: Any) -> Optional[int]: + """Extracts the year from an event timestamp. + + Args: + event: Event object with a timestamp attribute. + + Returns: + The year if available, otherwise None. + """ + timestamp = getattr(event, "timestamp", None) + if timestamp is None: + return None + return int(timestamp.year) + + def __call__(self, patient: Patient) -> List[Dict[str, Any]]: + """Processes one patient into temporal prediction samples. + + Args: + patient: A PyHealth Patient object. + + Returns: + A list of sample dictionaries. + """ + samples: List[Dict[str, Any]] = [] + + admissions = patient.get_events("admissions") + for admission in admissions: + year = self._safe_year(admission) + if year is None: + continue + + end_time = admission.timestamp + + diagnoses = patient.get_events("diagnoses_icd", end=end_time) + procedures = patient.get_events("procedures_icd", end=end_time) + prescriptions = patient.get_events("prescriptions", end=end_time) + + diagnosis_codes = [ + getattr(event, "icd_code", None) + for event in diagnoses + if getattr(event, "icd_code", None) is not None + ] + procedure_codes = [ + getattr(event, "icd_code", None) + for event in procedures + if getattr(event, "icd_code", None) is not None + ] + drug_names = [ + getattr(event, "drug", None) + for event in prescriptions + if getattr(event, "drug", None) is not None + ] + + history_count = ( + len(diagnosis_codes) + len(procedure_codes) + len(drug_names) + ) + if history_count < self.min_history_events: + continue + + features = [ + float(len(set(diagnosis_codes))), + float(len(set(procedure_codes))), + float(len(set(drug_names))), + float(history_count), + ] + + label = int( + getattr(admission, "hospital_expire_flag", "0") == "1" + ) + + samples.append( + { + "features": features, + "year": [float(year)], + "label": label, + } + ) + + return samples \ No newline at end of file diff --git a/examples/mimic4_temporal_mortality_logistic.py b/examples/mimic4_temporal_mortality_logistic.py new file mode 100644 index 000000000..5bada16dc --- /dev/null +++ b/examples/mimic4_temporal_mortality_logistic.py @@ -0,0 +1,55 @@ +"""Example: temporal vs random evaluation on MIMIC-IV task samples.""" + +from pyhealth.datasets import MIMIC4EHRDataset +from pyhealth.tasks.temporal_evaluation import ( + run_random_experiment, + run_temporal_experiment, + sample_dataset_to_temporal_records, +) +from pyhealth.tasks.temporal_risk_prediction import TemporalMortalityMIMIC4 + + +def safe_metric(value): + if value is None: + return "None" + return f"{value:.3f}" + + +def main() -> None: + dataset = MIMIC4EHRDataset( + root="YOUR_DATA_ROOT", + tables=["admissions", "diagnoses_icd", "procedures_icd", "prescriptions"], + dev=True, + ) + + sample_dataset = dataset.set_task(TemporalMortalityMIMIC4()) + records = sample_dataset_to_temporal_records(sample_dataset) + + temporal_result = run_temporal_experiment(records, split_year=2017) + random_result = run_random_experiment(records, random_state=42) + + print("=== Temporal split ===") + print( + f"train={temporal_result.train_size} | " + f"test={temporal_result.test_size} | " + f"accuracy={temporal_result.accuracy:.3f} | " + f"auroc={safe_metric(temporal_result.auroc)} | " + f"auprc={safe_metric(temporal_result.auprc)} | " + f"brier={safe_metric(temporal_result.brier)} | " + f"f1={temporal_result.f1:.3f}" + ) + + print("\n=== Random split ===") + print( + f"train={random_result.train_size} | " + f"test={random_result.test_size} | " + f"accuracy={random_result.accuracy:.3f} | " + f"auroc={safe_metric(random_result.auroc)} | " + f"auprc={safe_metric(random_result.auprc)} | " + f"brier={safe_metric(random_result.brier)} | " + f"f1={random_result.f1:.3f}" + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/synthetic_temporal_mortality_logistic.py b/examples/synthetic_temporal_mortality_logistic.py new file mode 100644 index 000000000..864f40ae2 --- /dev/null +++ b/examples/synthetic_temporal_mortality_logistic.py @@ -0,0 +1,66 @@ +"""Example: temporal vs random evaluation on synthetic clinical data. + +This example demonstrates the core EMDOT-style idea: +train on earlier years and test on later years, then compare against +a random split baseline. + +It is intentionally synthetic and lightweight so it runs anywhere. +""" + +from pyhealth.tasks.temporal_evaluation import ( + run_ablation, + run_random_experiment, +) + + +SYNTHETIC_DATA = [ + {"patient_id": 1, "year": 2010, "features": [0.20, 0.10], "label": 0}, + {"patient_id": 2, "year": 2011, "features": [0.25, 0.15], "label": 0}, + {"patient_id": 3, "year": 2012, "features": [0.70, 0.60], "label": 1}, + {"patient_id": 4, "year": 2013, "features": [0.40, 0.35], "label": 0}, + {"patient_id": 5, "year": 2014, "features": [0.55, 0.45], "label": 1}, + {"patient_id": 6, "year": 2015, "features": [0.60, 0.50], "label": 1}, + {"patient_id": 7, "year": 2016, "features": [0.65, 0.55], "label": 1}, + {"patient_id": 8, "year": 2017, "features": [0.35, 0.25], "label": 0}, + {"patient_id": 9, "year": 2018, "features": [0.75, 0.60], "label": 1}, + {"patient_id": 10, "year": 2019, "features": [0.15, 0.10], "label": 0}, + {"patient_id": 11, "year": 2020, "features": [0.82, 0.70], "label": 1}, + {"patient_id": 12, "year": 2021, "features": [0.18, 0.12], "label": 0}, +] + + +def safe_metric(value): + if value is None: + return "None" + return f"{value:.3f}" + + +def main() -> None: + print("=== Temporal ablation ===") + for result in run_ablation(SYNTHETIC_DATA, split_years=[2013, 2015, 2017]): + print( + f"split_year={result.split_year} | " + f"train={result.train_size} | " + f"test={result.test_size} | " + f"accuracy={result.accuracy:.3f} | " + f"auroc={safe_metric(result.auroc)} | " + f"auprc={safe_metric(result.auprc)} | " + f"brier={safe_metric(result.brier)} | " + f"f1={result.f1:.3f}" + ) + + print("\n=== Random baseline ===") + random_result = run_random_experiment(SYNTHETIC_DATA, random_state=42) + print( + f"train={random_result.train_size} | " + f"test={random_result.test_size} | " + f"accuracy={random_result.accuracy:.3f} | " + f"auroc={safe_metric(random_result.auroc)} | " + f"auprc={safe_metric(random_result.auprc)} | " + f"brier={safe_metric(random_result.brier)} | " + f"f1={random_result.f1:.3f}" + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/synthetic_temporal_shift_demo.py b/examples/synthetic_temporal_shift_demo.py new file mode 100644 index 000000000..4d252e32f --- /dev/null +++ b/examples/synthetic_temporal_shift_demo.py @@ -0,0 +1,86 @@ +"""Demo: temporal evaluation under realistic synthetic clinical shift. + +This example generates synthetic clinical data with temporal drift and compares: +1. multiple temporal cutoffs +2. a random split baseline + +The goal is to simulate the EMDOT idea more realistically than a tiny hand-made +toy dataset, while still remaining lightweight and reproducible. +""" + +from pyhealth.tasks.temporal_evaluation import ( + generate_synthetic_temporal_shift_data, + run_ablation, + run_random_experiment, +) + + +def format_result(prefix: str, result) -> str: + return ( + f"{prefix} | " + f"train={result.train_size} | " + f"test={result.test_size} | " + f"accuracy={result.accuracy:.3f} | " + f"auroc={result.auroc:.3f if result.auroc is not None else 'None'} | " + f"auprc={result.auprc:.3f if result.auprc is not None else 'None'} | " + f"brier={result.brier:.3f if result.brier is not None else 'None'} | " + f"f1={result.f1:.3f}" + ) + + +def safe_metric(value): + if value is None: + return "None" + return f"{value:.3f}" + + +def main() -> None: + dataset = generate_synthetic_temporal_shift_data( + n_patients_per_year=40, + start_year=2010, + end_year=2021, + seed=42, + ) + + print("=== Synthetic temporal shift dataset ===") + print(f"Total records: {len(dataset)}") + print(f"Years: {dataset[0]['year']} to {dataset[-1]['year']}") + print() + + print("=== Temporal ablation ===") + temporal_results = run_ablation(dataset, split_years=[2013, 2015, 2017, 2019]) + for result in temporal_results: + print( + f"split_year={result.split_year} | " + f"train={result.train_size} | " + f"test={result.test_size} | " + f"accuracy={result.accuracy:.3f} | " + f"auroc={safe_metric(result.auroc)} | " + f"auprc={safe_metric(result.auprc)} | " + f"brier={safe_metric(result.brier)} | " + f"f1={result.f1:.3f}" + ) + + print() + print("=== Random baseline ===") + random_result = run_random_experiment(dataset, random_state=42) + print( + f"train={random_result.train_size} | " + f"test={random_result.test_size} | " + f"accuracy={random_result.accuracy:.3f} | " + f"auroc={safe_metric(random_result.auroc)} | " + f"auprc={safe_metric(random_result.auprc)} | " + f"brier={safe_metric(random_result.brier)} | " + f"f1={random_result.f1:.3f}" + ) + + print() + print("Interpretation:") + print( + "This synthetic dataset introduces temporal drift, so temporal splits " + "better reflect future deployment conditions than random splits." + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index a32618f9c..e226ff786 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -1,3 +1,4 @@ +from .temporal_evaluation import * from .base_task import BaseTask from .benchmark_ehrshot import BenchmarkEHRShot from .cancer_survival import CancerMutationBurden, CancerSurvivalPrediction diff --git a/pyhealth/tasks/temporal_evaluation.py b/pyhealth/tasks/temporal_evaluation.py new file mode 100644 index 000000000..9c3c79afb --- /dev/null +++ b/pyhealth/tasks/temporal_evaluation.py @@ -0,0 +1,453 @@ +"""Temporal evaluation utilities for clinical prediction. + +This module implements a lightweight temporal evaluation pipeline inspired by: + +Zhou, H., Chen, Y., and Lipton, Z. C. +Evaluating Model Performance in Medical Datasets Over Time. + +The main idea is to evaluate models in a deployment-like setting: +train on records up to a time cutoff and test on records after that cutoff. + +This file is intentionally lightweight so it can be tested entirely with +synthetic data in milliseconds. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Iterable, List, Sequence, Tuple + +import numpy as np +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import ( + accuracy_score, + average_precision_score, + brier_score_loss, + f1_score, + roc_auc_score, +) +from sklearn.model_selection import train_test_split + + +Record = Dict[str, Any] + + +@dataclass +class TemporalExperimentResult: + """Container for experiment outputs. + + Attributes: + experiment_type: Either "temporal" or "random". + split_year: Temporal cutoff year, or None for random split. + train_size: Number of training examples. + test_size: Number of testing examples. + accuracy: Accuracy on the test split. + auroc: AUROC on the test split, or None if unavailable. + auprc: AUPRC on the test split, or None if unavailable. + brier: Brier score on the test split, or None if unavailable. + f1: F1 score on the test split. + """ + + experiment_type: str + split_year: int | None + train_size: int + test_size: int + accuracy: float + auroc: float | None + auprc: float | None + brier: float | None + f1: float + + +def validate_dataset(dataset: Sequence[Record]) -> None: + """Validates the dataset format. + + Expected keys for each record: + - "year" + - "label" + - "features" + + Args: + dataset: Sequence of record dictionaries. + + Raises: + ValueError: If the dataset is empty or malformed. + """ + if not dataset: + raise ValueError("Dataset must not be empty.") + + required_keys = {"year", "label", "features"} + for i, row in enumerate(dataset): + missing = required_keys - set(row.keys()) + if missing: + raise ValueError( + f"Record at index {i} is missing required keys: {sorted(missing)}" + ) + + if not isinstance(row["features"], (list, tuple)): + raise ValueError( + f"Record at index {i} must have a list or tuple in 'features'." + ) + + if len(row["features"]) == 0: + raise ValueError( + f"Record at index {i} must contain at least one feature." + ) + + try: + int(row["year"]) + int(row["label"]) + [float(value) for value in row["features"]] + except (TypeError, ValueError) as exc: + raise ValueError( + f"Record at index {i} contains non-numeric year, label, or features." + ) from exc + + +def prepare_data(dataset: Sequence[Record]) -> Tuple[List[List[float]], List[int]]: + """Converts record dictionaries into feature matrix and label vector. + + Args: + dataset: Sequence of records with "features" and "label". + + Returns: + Tuple of (X, y). + """ + validate_dataset(dataset) + x = [[float(value) for value in row["features"]] for row in dataset] + y = [int(row["label"]) for row in dataset] + return x, y + + +def temporal_split( + dataset: Sequence[Record], split_year: int +) -> Tuple[List[Record], List[Record]]: + """Splits records into past-vs-future subsets. + + Train split contains years <= split_year. + Test split contains years > split_year. + + Args: + dataset: Sequence of records containing a "year" key. + split_year: Temporal cutoff year. + + Returns: + Tuple of (train_records, test_records). + + Raises: + ValueError: If either split is empty. + """ + validate_dataset(dataset) + train = [row for row in dataset if int(row["year"]) <= split_year] + test = [row for row in dataset if int(row["year"]) > split_year] + + if not train: + raise ValueError("Temporal split produced an empty training set.") + if not test: + raise ValueError("Temporal split produced an empty testing set.") + + return train, test + + +def _check_binary_labels(y_train: Sequence[int], y_test: Sequence[int]) -> None: + """Ensures both train and test splits are valid for binary classification.""" + if len(set(y_train)) < 2: + raise ValueError("Training set must contain at least two classes.") + if len(set(y_test)) < 2: + raise ValueError("Testing set must contain at least two classes.") + + +def train_logistic_regression( + x_train: Sequence[Sequence[float]], + y_train: Sequence[int], + max_iter: int = 1000, +) -> LogisticRegression: + """Trains a logistic regression model. + + Args: + x_train: Training features. + y_train: Training labels. + max_iter: Maximum number of optimizer iterations. + + Returns: + Trained LogisticRegression model. + + Raises: + ValueError: If training labels contain fewer than two classes. + """ + if len(set(y_train)) < 2: + raise ValueError("Training labels must contain at least two classes.") + + model = LogisticRegression(max_iter=max_iter) + model.fit(x_train, y_train) + return model + + +def evaluate_model( + model: LogisticRegression, + x_test: Sequence[Sequence[float]], + y_test: Sequence[int], +) -> Tuple[float, float | None, float | None, float | None, float]: + """Evaluates a binary classifier on a test set. + + Args: + model: Trained logistic regression model. + x_test: Testing features. + y_test: Testing labels. + + Returns: + Tuple of (accuracy, auroc, auprc, brier, f1). + AUROC/AUPRC/Brier may be None if unavailable. + """ + predictions = model.predict(x_test) + accuracy = float(accuracy_score(y_test, predictions)) + f1 = float(f1_score(y_test, predictions)) + + auroc: float | None + auprc: float | None + brier: float | None + + try: + probabilities = model.predict_proba(x_test)[:, 1] + auroc = float(roc_auc_score(y_test, probabilities)) + auprc = float(average_precision_score(y_test, probabilities)) + brier = float(brier_score_loss(y_test, probabilities)) + except Exception: + auroc = None + auprc = None + brier = None + + return accuracy, auroc, auprc, brier, f1 + + +def run_temporal_experiment( + dataset: Sequence[Record], split_year: int +) -> TemporalExperimentResult: + """Runs one temporal train/test experiment. + + Args: + dataset: Sequence of patient records. + split_year: Train on years <= split_year, test on years > split_year. + + Returns: + TemporalExperimentResult with summary metrics. + """ + train_records, test_records = temporal_split(dataset, split_year) + x_train, y_train = prepare_data(train_records) + x_test, y_test = prepare_data(test_records) + + _check_binary_labels(y_train, y_test) + + model = train_logistic_regression(x_train, y_train) + accuracy, auroc, auprc, brier, f1 = evaluate_model(model, x_test, y_test) + + return TemporalExperimentResult( + experiment_type="temporal", + split_year=split_year, + train_size=len(train_records), + test_size=len(test_records), + accuracy=accuracy, + auroc=auroc, + auprc=auprc, + brier=brier, + f1=f1, + ) + + +def run_random_experiment( + dataset: Sequence[Record], + test_size: float = 0.4, + random_state: int = 42, +) -> TemporalExperimentResult: + """Runs a random train/test baseline experiment. + + Args: + dataset: Sequence of patient records. + test_size: Fraction used for testing. + random_state: Random seed. + + Returns: + TemporalExperimentResult with summary metrics. + + Raises: + ValueError: If the dataset does not contain both classes. + """ + x, y = prepare_data(dataset) + + if len(set(y)) < 2: + raise ValueError("Dataset must contain at least two classes.") + + x_train, x_test, y_train, y_test = train_test_split( + x, + y, + test_size=test_size, + random_state=random_state, + stratify=y, + ) + + _check_binary_labels(y_train, y_test) + + model = train_logistic_regression(x_train, y_train) + accuracy, auroc, auprc, brier, f1 = evaluate_model(model, x_test, y_test) + + return TemporalExperimentResult( + experiment_type="random", + split_year=None, + train_size=len(x_train), + test_size=len(x_test), + accuracy=accuracy, + auroc=auroc, + auprc=auprc, + brier=brier, + f1=f1, + ) + + +def run_ablation( + dataset: Sequence[Record], + split_years: Iterable[int], +) -> List[TemporalExperimentResult]: + """Runs multiple temporal cutoffs for an ablation study. + + Args: + dataset: Sequence of patient records. + split_years: Iterable of temporal cutoff years. + + Returns: + List of TemporalExperimentResult objects. + """ + results: List[TemporalExperimentResult] = [] + for year in split_years: + results.append(run_temporal_experiment(dataset, split_year=year)) + return results + + +def generate_synthetic_temporal_shift_data( + n_patients_per_year: int = 40, + start_year: int = 2010, + end_year: int = 2021, + seed: int = 42, +) -> List[Record]: + """Generates realistic synthetic clinical data with temporal drift. + + Features: + - age_like + - comorbidity_like + - lab_like + - utilization_like + + The data distribution shifts over time, which makes temporal evaluation + meaningfully different from random splits. + + Args: + n_patients_per_year: Number of synthetic patients generated per year. + start_year: First year. + end_year: Last year. + seed: Random seed. + + Returns: + List of synthetic patient records. + """ + rng = np.random.default_rng(seed) + records: List[Record] = [] + patient_id = 1 + + for year in range(start_year, end_year + 1): + drift = (year - start_year) / max(1, (end_year - start_year)) + + for _ in range(n_patients_per_year): + age_like = np.clip(rng.normal(loc=55 + 8 * drift, scale=10), 18, 90) + comorbidity_like = np.clip( + rng.normal(loc=1.5 + 1.2 * drift, scale=0.9), 0, 6 + ) + lab_like = np.clip( + rng.normal(loc=0.45 + 0.20 * drift, scale=0.15), 0.05, 1.20 + ) + utilization_like = np.clip( + rng.normal(loc=1.2 + 0.8 * drift, scale=0.7), 0, 5 + ) + + # The label relationship changes slightly over time. + score = ( + -6.0 + + 0.035 * age_like + + 0.55 * comorbidity_like + + 3.2 * lab_like + + 0.35 * utilization_like + - 0.9 * drift + + rng.normal(0, 0.35) + ) + + probability = 1.0 / (1.0 + np.exp(-score)) + label = int(rng.random() < probability) + + records.append( + { + "patient_id": patient_id, + "year": year, + "features": [ + round(float(age_like), 3), + round(float(comorbidity_like), 3), + round(float(lab_like), 3), + round(float(utilization_like), 3), + ], + "label": label, + } + ) + patient_id += 1 + + return records + +def _to_python_scalar(value: Any) -> Any: + """Converts tensor-like objects to Python values.""" + if hasattr(value, "detach"): + value = value.detach().cpu() + if hasattr(value, "numpy"): + value = value.numpy() + if hasattr(value, "tolist"): + value = value.tolist() + return value + + +def sample_dataset_to_temporal_records(sample_dataset: Sequence[Any]) -> List[Record]: + """Converts a PyHealth SampleDataset into temporal evaluation records. + + Args: + sample_dataset: Sequence of PyHealth samples. + + Returns: + List of records compatible with temporal evaluation utilities. + """ + records: List[Record] = [] + + for i, sample in enumerate(sample_dataset): + features = _to_python_scalar(sample["features"]) + year = _to_python_scalar(sample["year"]) + label = _to_python_scalar(sample["label"]) + + if isinstance(year, list): + year = year[0] + if isinstance(label, list): + label = label[0] + + records.append( + { + "patient_id": i, + "year": int(year), + "features": [float(x) for x in features], + "label": int(label), + } + ) + + return records + +if __name__ == "__main__": + dataset = generate_synthetic_temporal_shift_data() + results = run_ablation(dataset, [2013, 2015, 2017]) + + print("=== Temporal ablation ===") + for r in results: + print(r) + + print("\n=== Random baseline ===") + print(run_random_experiment(dataset)) \ No newline at end of file diff --git a/pyhealth/tasks/temporal_risk_prediction.py b/pyhealth/tasks/temporal_risk_prediction.py new file mode 100644 index 000000000..a107d4a75 --- /dev/null +++ b/pyhealth/tasks/temporal_risk_prediction.py @@ -0,0 +1,126 @@ +"""Temporal risk prediction task for PyHealth datasets.""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from pyhealth.data import Patient +from pyhealth.tasks import BaseTask + + +class TemporalMortalityMIMIC4(BaseTask): + """Temporal mortality prediction task for MIMIC-IV style EHR data. + + This task creates one sample per admission. Each sample contains: + 1. a lightweight numeric feature vector summarizing patient history + 2. the admission year for temporal splitting + 3. a binary mortality label + + The task is designed to pair with temporal evaluation utilities that + compare deployment-like temporal splits against random train/test splits. + + Attributes: + task_name: Name of the task. + input_schema: Schema for model inputs. + output_schema: Schema for model outputs. + """ + + task_name: str = "TemporalMortalityMIMIC4" + + input_schema: Dict[str, str] = { + "features": "tensor", + "year": "tensor", + } + + output_schema: Dict[str, str] = { + "label": "binary", + } + + def __init__(self, min_history_events: int = 1) -> None: + """Initializes the task. + + Args: + min_history_events: Minimum number of historical events required + to emit a sample. + """ + self.min_history_events = min_history_events + + def _safe_year(self, event: Any) -> Optional[int]: + """Extracts the year from an event timestamp. + + Args: + event: Event object with a timestamp attribute. + + Returns: + The year if available, otherwise None. + """ + timestamp = getattr(event, "timestamp", None) + if timestamp is None: + return None + return int(timestamp.year) + + def __call__(self, patient: Patient) -> List[Dict[str, Any]]: + """Processes one patient into temporal prediction samples. + + Args: + patient: A PyHealth Patient object. + + Returns: + A list of sample dictionaries. + """ + samples: List[Dict[str, Any]] = [] + + admissions = patient.get_events("admissions") + for admission in admissions: + year = self._safe_year(admission) + if year is None: + continue + + end_time = admission.timestamp + + diagnoses = patient.get_events("diagnoses_icd", end=end_time) + procedures = patient.get_events("procedures_icd", end=end_time) + prescriptions = patient.get_events("prescriptions", end=end_time) + + diagnosis_codes = [ + getattr(event, "icd_code", None) + for event in diagnoses + if getattr(event, "icd_code", None) is not None + ] + procedure_codes = [ + getattr(event, "icd_code", None) + for event in procedures + if getattr(event, "icd_code", None) is not None + ] + drug_names = [ + getattr(event, "drug", None) + for event in prescriptions + if getattr(event, "drug", None) is not None + ] + + history_count = ( + len(diagnosis_codes) + len(procedure_codes) + len(drug_names) + ) + if history_count < self.min_history_events: + continue + + features = [ + float(len(set(diagnosis_codes))), + float(len(set(procedure_codes))), + float(len(set(drug_names))), + float(history_count), + ] + + label = int( + getattr(admission, "hospital_expire_flag", "0") == "1" + ) + + samples.append( + { + "features": features, + "year": [float(year)], + "label": label, + } + ) + + return samples \ No newline at end of file diff --git a/tests/test_temporal_evaluation.py b/tests/test_temporal_evaluation.py new file mode 100644 index 000000000..f8d5c81cb --- /dev/null +++ b/tests/test_temporal_evaluation.py @@ -0,0 +1,115 @@ +import pytest + +from pyhealth.tasks.temporal_evaluation import ( + generate_synthetic_temporal_shift_data, + prepare_data, + run_ablation, + run_random_experiment, + run_temporal_experiment, + temporal_split, + validate_dataset, +) + + +@pytest.fixture +def synthetic_data(): + return [ + {"patient_id": 1, "year": 2010, "features": [0.20, 0.10], "label": 0}, + {"patient_id": 2, "year": 2011, "features": [0.25, 0.15], "label": 0}, + {"patient_id": 3, "year": 2012, "features": [0.70, 0.60], "label": 1}, + {"patient_id": 4, "year": 2013, "features": [0.40, 0.35], "label": 0}, + {"patient_id": 5, "year": 2014, "features": [0.55, 0.45], "label": 1}, + {"patient_id": 6, "year": 2015, "features": [0.60, 0.50], "label": 1}, + {"patient_id": 7, "year": 2016, "features": [0.65, 0.55], "label": 1}, + {"patient_id": 8, "year": 2017, "features": [0.35, 0.25], "label": 0}, + {"patient_id": 9, "year": 2018, "features": [0.75, 0.60], "label": 1}, + {"patient_id": 10, "year": 2019, "features": [0.15, 0.10], "label": 0}, + {"patient_id": 11, "year": 2020, "features": [0.82, 0.70], "label": 1}, + {"patient_id": 12, "year": 2021, "features": [0.18, 0.12], "label": 0}, + ] + + +def test_validate_dataset_passes(synthetic_data): + validate_dataset(synthetic_data) + + +def test_validate_dataset_empty(): + with pytest.raises(ValueError, match="Dataset must not be empty"): + validate_dataset([]) + + +def test_validate_dataset_missing_key(): + bad_data = [ + {"patient_id": 1, "year": 2010, "label": 0}, + ] + with pytest.raises(ValueError, match="missing required keys"): + validate_dataset(bad_data) + + +def test_prepare_data(synthetic_data): + x, y = prepare_data(synthetic_data[:2]) + assert x == [[0.20, 0.10], [0.25, 0.15]] + assert y == [0, 0] + + +def test_temporal_split(synthetic_data): + train, test = temporal_split(synthetic_data, 2015) + assert len(train) == 6 + assert len(test) == 6 + assert all(row["year"] <= 2015 for row in train) + assert all(row["year"] > 2015 for row in test) + + +def test_temporal_split_empty_train(synthetic_data): + with pytest.raises(ValueError, match="empty training set"): + temporal_split(synthetic_data, 2000) + + +def test_temporal_split_empty_test(synthetic_data): + with pytest.raises(ValueError, match="empty testing set"): + temporal_split(synthetic_data, 2025) + + +def test_run_temporal_experiment(synthetic_data): + result = run_temporal_experiment(synthetic_data, 2015) + assert result.experiment_type == "temporal" + assert result.split_year == 2015 + assert result.train_size == 6 + assert result.test_size == 6 + assert 0.0 <= result.accuracy <= 1.0 + assert result.auroc is None or 0.0 <= result.auroc <= 1.0 + assert result.auprc is None or 0.0 <= result.auprc <= 1.0 + assert result.brier is None or 0.0 <= result.brier <= 1.0 + assert 0.0 <= result.f1 <= 1.0 + + +def test_run_random_experiment(synthetic_data): + result = run_random_experiment(synthetic_data, random_state=42) + assert result.experiment_type == "random" + assert result.split_year is None + assert result.train_size + result.test_size == len(synthetic_data) + assert 0.0 <= result.accuracy <= 1.0 + assert result.auroc is None or 0.0 <= result.auroc <= 1.0 + assert result.auprc is None or 0.0 <= result.auprc <= 1.0 + assert result.brier is None or 0.0 <= result.brier <= 1.0 + assert 0.0 <= result.f1 <= 1.0 + + +def test_run_ablation(synthetic_data): + results = run_ablation(synthetic_data, split_years=[2013, 2015, 2017]) + assert len(results) == 3 + assert [result.split_year for result in results] == [2013, 2015, 2017] + + +def test_generate_synthetic_temporal_shift_data(): + data = generate_synthetic_temporal_shift_data( + n_patients_per_year=5, + start_year=2010, + end_year=2012, + seed=123, + ) + assert len(data) == 15 + assert all("year" in row and "features" in row and "label" in row for row in data) + assert all(len(row["features"]) == 4 for row in data) + years = sorted(set(row["year"] for row in data)) + assert years == [2010, 2011, 2012] \ No newline at end of file diff --git a/tests/test_temporal_risk_prediction.py b/tests/test_temporal_risk_prediction.py new file mode 100644 index 000000000..f2b498b6e --- /dev/null +++ b/tests/test_temporal_risk_prediction.py @@ -0,0 +1,98 @@ +from pyhealth.tasks.temporal_risk_prediction import TemporalMortalityMIMIC4 + + +class MockTimestamp: + def __init__(self, year): + self.year = year + + +class MockEvent: + def __init__( + self, + timestamp=None, + icd_code=None, + drug=None, + hospital_expire_flag="0", + ): + self.timestamp = timestamp + self.icd_code = icd_code + self.drug = drug + self.hospital_expire_flag = hospital_expire_flag + + +class MockPatient: + """Mock patient object for testing temporal task behavior.""" + def __init__(self, event_map): + self.event_map = event_map + + def get_events(self, table, end=None): + return self.event_map.get(table, []) + + +def test_temporal_task_builds_samples(): + """Tests that the task generates valid samples with features, year, and label.""" + patient = MockPatient( + { + "admissions": [ + MockEvent( + timestamp=MockTimestamp(2018), + hospital_expire_flag="1", + ), + MockEvent( + timestamp=MockTimestamp(2020), + hospital_expire_flag="0", + ), + ], + "diagnoses_icd": [ + MockEvent(timestamp=MockTimestamp(2017), icd_code="A"), + MockEvent(timestamp=MockTimestamp(2018), icd_code="B"), + ], + "procedures_icd": [ + MockEvent(timestamp=MockTimestamp(2018), icd_code="P1"), + ], + "prescriptions": [ + MockEvent(timestamp=MockTimestamp(2018), drug="drug_a"), + ], + } + ) + + task = TemporalMortalityMIMIC4() + samples = task(patient) + + assert len(samples) == 2 + assert all("features" in s and "year" in s and "label" in s for s in samples) + assert all(len(s["features"]) == 4 for s in samples) + assert samples[0]["year"] == [2018.0] + assert samples[0]["label"] in [0, 1] + + +def test_temporal_task_skips_missing_timestamp(): + patient = MockPatient( + { + "admissions": [ + MockEvent(timestamp=None, hospital_expire_flag="0"), + ], + "diagnoses_icd": [], + "procedures_icd": [], + "prescriptions": [], + } + ) + + task = TemporalMortalityMIMIC4() + assert task(patient) == [] + + +def test_temporal_task_respects_min_history_events(): + patient = MockPatient( + { + "admissions": [ + MockEvent(timestamp=MockTimestamp(2018), hospital_expire_flag="0"), + ], + "diagnoses_icd": [], + "procedures_icd": [], + "prescriptions": [], + } + ) + + task = TemporalMortalityMIMIC4(min_history_events=1) + assert task(patient) == [] \ No newline at end of file