Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/tasks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,5 @@ Available Tasks
Mutation Pathogenicity (COSMIC) <tasks/pyhealth.tasks.MutationPathogenicityPrediction>
Cancer Survival Prediction (TCGA) <tasks/pyhealth.tasks.CancerSurvivalPrediction>
Cancer Mutation Burden (TCGA) <tasks/pyhealth.tasks.CancerMutationBurden>
Temporal Evaluation <tasks/pyhealth.tasks.temporal_evaluation>
Temporal Mortality (MIMIC-IV) <tasks/pyhealth.tasks.temporal_risk_prediction>
9 changes: 9 additions & 0 deletions docs/api/tasks/pyhealth.tasks.temporal_evaluation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
pyhealth.tasks.temporal_evaluation
=================================

.. automodule:: pyhealth.tasks.temporal_evaluation
:members:
:undoc-members:
:show-inheritance:


7 changes: 7 additions & 0 deletions docs/api/tasks/pyhealth.tasks.temporal_risk_prediction.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pyhealth.tasks.temporal_risk_prediction
======================================

.. automodule:: pyhealth.tasks.temporal_risk_prediction
:members:
:undoc-members:
:show-inheritance:
126 changes: 126 additions & 0 deletions docs/api/tasks/temporal_risk_prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Temporal risk prediction task for PyHealth datasets."""

from __future__ import annotations

from typing import Any, Dict, List, Optional

from pyhealth.data import Patient
from pyhealth.tasks import BaseTask


class TemporalMortalityMIMIC4(BaseTask):
"""Temporal mortality prediction task for MIMIC-IV style EHR data.

This task creates one sample per admission. Each sample contains:
1. a lightweight numeric feature vector summarizing patient history
2. the admission year for temporal splitting
3. a binary mortality label

The task is designed to pair with temporal evaluation utilities that
compare deployment-like temporal splits against random train/test splits.

Attributes:
task_name: Name of the task.
input_schema: Schema for model inputs.
output_schema: Schema for model outputs.
"""

task_name: str = "TemporalMortalityMIMIC4"

input_schema: Dict[str, str] = {
"features": "tensor",
"year": "tensor",
}

output_schema: Dict[str, str] = {
"label": "binary",
}

def __init__(self, min_history_events: int = 1) -> None:
"""Initializes the task.

Args:
min_history_events: Minimum number of historical events required
to emit a sample.
"""
self.min_history_events = min_history_events

def _safe_year(self, event: Any) -> Optional[int]:
"""Extracts the year from an event timestamp.

Args:
event: Event object with a timestamp attribute.

Returns:
The year if available, otherwise None.
"""
timestamp = getattr(event, "timestamp", None)
if timestamp is None:
return None
return int(timestamp.year)

def __call__(self, patient: Patient) -> List[Dict[str, Any]]:
"""Processes one patient into temporal prediction samples.

Args:
patient: A PyHealth Patient object.

Returns:
A list of sample dictionaries.
"""
samples: List[Dict[str, Any]] = []

admissions = patient.get_events("admissions")
for admission in admissions:
year = self._safe_year(admission)
if year is None:
continue

end_time = admission.timestamp

diagnoses = patient.get_events("diagnoses_icd", end=end_time)
procedures = patient.get_events("procedures_icd", end=end_time)
prescriptions = patient.get_events("prescriptions", end=end_time)

diagnosis_codes = [
getattr(event, "icd_code", None)
for event in diagnoses
if getattr(event, "icd_code", None) is not None
]
procedure_codes = [
getattr(event, "icd_code", None)
for event in procedures
if getattr(event, "icd_code", None) is not None
]
drug_names = [
getattr(event, "drug", None)
for event in prescriptions
if getattr(event, "drug", None) is not None
]

history_count = (
len(diagnosis_codes) + len(procedure_codes) + len(drug_names)
)
if history_count < self.min_history_events:
continue

features = [
float(len(set(diagnosis_codes))),
float(len(set(procedure_codes))),
float(len(set(drug_names))),
float(history_count),
]

label = int(
getattr(admission, "hospital_expire_flag", "0") == "1"
)

samples.append(
{
"features": features,
"year": [float(year)],
"label": label,
}
)

return samples
55 changes: 55 additions & 0 deletions examples/mimic4_temporal_mortality_logistic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Example: temporal vs random evaluation on MIMIC-IV task samples."""

from pyhealth.datasets import MIMIC4EHRDataset
from pyhealth.tasks.temporal_evaluation import (
run_random_experiment,
run_temporal_experiment,
sample_dataset_to_temporal_records,
)
from pyhealth.tasks.temporal_risk_prediction import TemporalMortalityMIMIC4


def safe_metric(value):
if value is None:
return "None"
return f"{value:.3f}"


def main() -> None:
dataset = MIMIC4EHRDataset(
root="YOUR_DATA_ROOT",
tables=["admissions", "diagnoses_icd", "procedures_icd", "prescriptions"],
dev=True,
)

sample_dataset = dataset.set_task(TemporalMortalityMIMIC4())
records = sample_dataset_to_temporal_records(sample_dataset)

temporal_result = run_temporal_experiment(records, split_year=2017)
random_result = run_random_experiment(records, random_state=42)

print("=== Temporal split ===")
print(
f"train={temporal_result.train_size} | "
f"test={temporal_result.test_size} | "
f"accuracy={temporal_result.accuracy:.3f} | "
f"auroc={safe_metric(temporal_result.auroc)} | "
f"auprc={safe_metric(temporal_result.auprc)} | "
f"brier={safe_metric(temporal_result.brier)} | "
f"f1={temporal_result.f1:.3f}"
)

print("\n=== Random split ===")
print(
f"train={random_result.train_size} | "
f"test={random_result.test_size} | "
f"accuracy={random_result.accuracy:.3f} | "
f"auroc={safe_metric(random_result.auroc)} | "
f"auprc={safe_metric(random_result.auprc)} | "
f"brier={safe_metric(random_result.brier)} | "
f"f1={random_result.f1:.3f}"
)


if __name__ == "__main__":
main()
66 changes: 66 additions & 0 deletions examples/synthetic_temporal_mortality_logistic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Example: temporal vs random evaluation on synthetic clinical data.

This example demonstrates the core EMDOT-style idea:
train on earlier years and test on later years, then compare against
a random split baseline.

It is intentionally synthetic and lightweight so it runs anywhere.
"""

from pyhealth.tasks.temporal_evaluation import (
run_ablation,
run_random_experiment,
)


SYNTHETIC_DATA = [
{"patient_id": 1, "year": 2010, "features": [0.20, 0.10], "label": 0},
{"patient_id": 2, "year": 2011, "features": [0.25, 0.15], "label": 0},
{"patient_id": 3, "year": 2012, "features": [0.70, 0.60], "label": 1},
{"patient_id": 4, "year": 2013, "features": [0.40, 0.35], "label": 0},
{"patient_id": 5, "year": 2014, "features": [0.55, 0.45], "label": 1},
{"patient_id": 6, "year": 2015, "features": [0.60, 0.50], "label": 1},
{"patient_id": 7, "year": 2016, "features": [0.65, 0.55], "label": 1},
{"patient_id": 8, "year": 2017, "features": [0.35, 0.25], "label": 0},
{"patient_id": 9, "year": 2018, "features": [0.75, 0.60], "label": 1},
{"patient_id": 10, "year": 2019, "features": [0.15, 0.10], "label": 0},
{"patient_id": 11, "year": 2020, "features": [0.82, 0.70], "label": 1},
{"patient_id": 12, "year": 2021, "features": [0.18, 0.12], "label": 0},
]


def safe_metric(value):
if value is None:
return "None"
return f"{value:.3f}"


def main() -> None:
print("=== Temporal ablation ===")
for result in run_ablation(SYNTHETIC_DATA, split_years=[2013, 2015, 2017]):
print(
f"split_year={result.split_year} | "
f"train={result.train_size} | "
f"test={result.test_size} | "
f"accuracy={result.accuracy:.3f} | "
f"auroc={safe_metric(result.auroc)} | "
f"auprc={safe_metric(result.auprc)} | "
f"brier={safe_metric(result.brier)} | "
f"f1={result.f1:.3f}"
)

print("\n=== Random baseline ===")
random_result = run_random_experiment(SYNTHETIC_DATA, random_state=42)
print(
f"train={random_result.train_size} | "
f"test={random_result.test_size} | "
f"accuracy={random_result.accuracy:.3f} | "
f"auroc={safe_metric(random_result.auroc)} | "
f"auprc={safe_metric(random_result.auprc)} | "
f"brier={safe_metric(random_result.brier)} | "
f"f1={random_result.f1:.3f}"
)


if __name__ == "__main__":
main()
86 changes: 86 additions & 0 deletions examples/synthetic_temporal_shift_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Demo: temporal evaluation under realistic synthetic clinical shift.

This example generates synthetic clinical data with temporal drift and compares:
1. multiple temporal cutoffs
2. a random split baseline

The goal is to simulate the EMDOT idea more realistically than a tiny hand-made
toy dataset, while still remaining lightweight and reproducible.
"""

from pyhealth.tasks.temporal_evaluation import (
generate_synthetic_temporal_shift_data,
run_ablation,
run_random_experiment,
)


def format_result(prefix: str, result) -> str:
return (
f"{prefix} | "
f"train={result.train_size} | "
f"test={result.test_size} | "
f"accuracy={result.accuracy:.3f} | "
f"auroc={result.auroc:.3f if result.auroc is not None else 'None'} | "
f"auprc={result.auprc:.3f if result.auprc is not None else 'None'} | "
f"brier={result.brier:.3f if result.brier is not None else 'None'} | "
f"f1={result.f1:.3f}"
)


def safe_metric(value):
if value is None:
return "None"
return f"{value:.3f}"


def main() -> None:
dataset = generate_synthetic_temporal_shift_data(
n_patients_per_year=40,
start_year=2010,
end_year=2021,
seed=42,
)

print("=== Synthetic temporal shift dataset ===")
print(f"Total records: {len(dataset)}")
print(f"Years: {dataset[0]['year']} to {dataset[-1]['year']}")
print()

print("=== Temporal ablation ===")
temporal_results = run_ablation(dataset, split_years=[2013, 2015, 2017, 2019])
for result in temporal_results:
print(
f"split_year={result.split_year} | "
f"train={result.train_size} | "
f"test={result.test_size} | "
f"accuracy={result.accuracy:.3f} | "
f"auroc={safe_metric(result.auroc)} | "
f"auprc={safe_metric(result.auprc)} | "
f"brier={safe_metric(result.brier)} | "
f"f1={result.f1:.3f}"
)

print()
print("=== Random baseline ===")
random_result = run_random_experiment(dataset, random_state=42)
print(
f"train={random_result.train_size} | "
f"test={random_result.test_size} | "
f"accuracy={random_result.accuracy:.3f} | "
f"auroc={safe_metric(random_result.auroc)} | "
f"auprc={safe_metric(random_result.auprc)} | "
f"brier={safe_metric(random_result.brier)} | "
f"f1={random_result.f1:.3f}"
)

print()
print("Interpretation:")
print(
"This synthetic dataset introduces temporal drift, so temporal splits "
"better reflect future deployment conditions than random splits."
)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions pyhealth/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .temporal_evaluation import *
from .base_task import BaseTask
from .benchmark_ehrshot import BenchmarkEHRShot
from .cancer_survival import CancerMutationBurden, CancerSurvivalPrediction
Expand Down
Loading