From 84f11e6532d31b4a3fa9beb3885bb936bac2cef0 Mon Sep 17 00:00:00 2001 From: Abhisek Sinha <43480920+abhiseksinha-r1@users.noreply.github.com> Date: Wed, 22 Apr 2026 22:16:03 -0700 Subject: [PATCH] DL4H - Add MIMIC-III Note Dataset, EHR Evidence Retrieval Task, and Zero-Shot Evidence LLM --- docs/api/datasets.rst | 1 + .../pyhealth.datasets.MIMIC3NoteDataset.rst | 14 + docs/api/models.rst | 1 + .../pyhealth.models.ZeroShotEvidenceLLM.rst | 17 + docs/api/tasks.rst | 1 + ...yhealth.tasks.EHREvidenceRetrievalTask.rst | 18 + .../mimic3_note_ehr_evidence_retrieval_llm.py | 612 ++++++++++++++++++ pyhealth/datasets/__init__.py | 2 +- pyhealth/datasets/configs/mimic3_note.yaml | 56 ++ pyhealth/datasets/mimic3.py | 112 ++++ pyhealth/models/__init__.py | 1 + pyhealth/models/ehr_evidence_llm.py | 469 ++++++++++++++ pyhealth/tasks/__init__.py | 1 + pyhealth/tasks/ehr_evidence_retrieval.py | 156 +++++ tests/core/test_ehr_evidence_llm.py | 268 ++++++++ tests/core/test_mimic3_note_dataset.py | 332 ++++++++++ 16 files changed, 2060 insertions(+), 1 deletion(-) create mode 100644 docs/api/datasets/pyhealth.datasets.MIMIC3NoteDataset.rst create mode 100644 docs/api/models/pyhealth.models.ZeroShotEvidenceLLM.rst create mode 100644 docs/api/tasks/pyhealth.tasks.EHREvidenceRetrievalTask.rst create mode 100644 examples/clinical_tasks/mimic3_note_ehr_evidence_retrieval_llm.py create mode 100644 pyhealth/datasets/configs/mimic3_note.yaml create mode 100644 pyhealth/models/ehr_evidence_llm.py create mode 100644 pyhealth/tasks/ehr_evidence_retrieval.py create mode 100644 tests/core/test_ehr_evidence_llm.py create mode 100644 tests/core/test_mimic3_note_dataset.py diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 8d9a59d21..38c117037 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -223,6 +223,7 @@ Available Datasets datasets/pyhealth.datasets.BaseDataset datasets/pyhealth.datasets.SampleDataset datasets/pyhealth.datasets.MIMIC3Dataset + datasets/pyhealth.datasets.MIMIC3NoteDataset datasets/pyhealth.datasets.MIMIC4Dataset datasets/pyhealth.datasets.MedicalTranscriptionsDataset datasets/pyhealth.datasets.CardiologyDataset diff --git a/docs/api/datasets/pyhealth.datasets.MIMIC3NoteDataset.rst b/docs/api/datasets/pyhealth.datasets.MIMIC3NoteDataset.rst new file mode 100644 index 000000000..82ae7dd0a --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.MIMIC3NoteDataset.rst @@ -0,0 +1,14 @@ +pyhealth.datasets.MIMIC3NoteDataset +===================================== + +The open Medical Information Mart for Intensive Care III (MIMIC-III) Clinical Notes dataset specialized for NLP and evidence retrieval tasks. This class extends the general +:class:`~pyhealth.datasets.MIMIC3Dataset` by always loading the ``noteevents`` and ``diagnoses_icd`` tables and exposing the ``iserror`` flag so that erroneous notes can be filtered downstream. It is designed to pair with +:class:`~pyhealth.tasks.EHREvidenceRetrievalTask` and +:class:`~pyhealth.models.ZeroShotEvidenceLLM` to reproduce the zero-shot EHR evidence retrieval pipeline of `Ahsan et al. (2024) `_. + +Refer to the `MIMIC-III documentation `_ for data access instructions. + +.. autoclass:: pyhealth.datasets.MIMIC3NoteDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/models.rst b/docs/api/models.rst index 7c3ac7c4b..3d04f3224 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -201,6 +201,7 @@ API Reference models/pyhealth.models.GAN models/pyhealth.models.VAE models/pyhealth.models.SDOH + models/pyhealth.models.ZeroShotEvidenceLLM models/pyhealth.models.VisionEmbeddingModel models/pyhealth.models.TextEmbedding models/pyhealth.models.BIOT diff --git a/docs/api/models/pyhealth.models.ZeroShotEvidenceLLM.rst b/docs/api/models/pyhealth.models.ZeroShotEvidenceLLM.rst new file mode 100644 index 000000000..632a8dd45 --- /dev/null +++ b/docs/api/models/pyhealth.models.ZeroShotEvidenceLLM.rst @@ -0,0 +1,17 @@ +pyhealth.models.ZeroShotEvidenceLLM +===================================== + +Zero-shot LLM pipeline for retrieving and summarising clinically relevant +evidence from unstructured EHR notes. Implements the two-step prompting +strategy from `Ahsan et al. (2024) +`_ (CHIL 2024, PMLR 248:489-505). + +The model requires no task-specific training — it uses instruction-tuned +models (e.g. Flan-T5 XXL or Mistral-7B-Instruct) in a zero-shot setting. +A Clinical-BERT dense-retrieval baseline is also available via +``use_cbert_baseline=True``. + +.. autoclass:: pyhealth.models.ZeroShotEvidenceLLM + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 23a4e06e5..30641e8a7 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -206,6 +206,7 @@ Available Tasks :maxdepth: 3 Base Task + EHR Evidence Retrieval In-Hospital Mortality (MIMIC-IV) MIMIC-III ICD-9 Coding Cardiology Detection diff --git a/docs/api/tasks/pyhealth.tasks.EHREvidenceRetrievalTask.rst b/docs/api/tasks/pyhealth.tasks.EHREvidenceRetrievalTask.rst new file mode 100644 index 000000000..e6b516b16 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.EHREvidenceRetrievalTask.rst @@ -0,0 +1,18 @@ +pyhealth.tasks.EHREvidenceRetrievalTask +========================================= + +Binary task that pairs a patient's concatenated clinical notes with a +free-text query diagnosis. The label indicates whether the patient has been +assigned any of the specified ICD-9 codes, providing a computable proxy for +expert ground-truth labels. + +This task is designed to be used with +:class:`~pyhealth.datasets.MIMIC3NoteDataset` and +:class:`~pyhealth.models.ZeroShotEvidenceLLM` to reproduce the zero-shot EHR +evidence retrieval pipeline from `Ahsan et al. (2024) +`_. + +.. autoclass:: pyhealth.tasks.EHREvidenceRetrievalTask + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/clinical_tasks/mimic3_note_ehr_evidence_retrieval_llm.py b/examples/clinical_tasks/mimic3_note_ehr_evidence_retrieval_llm.py new file mode 100644 index 000000000..e790fcd5c --- /dev/null +++ b/examples/clinical_tasks/mimic3_note_ehr_evidence_retrieval_llm.py @@ -0,0 +1,612 @@ +"""EHR Evidence Retrieval with Zero-Shot LLMs on MIMIC-III. +Contributor: Abhisek Sinha (abhisek5@illinois.edu) +Reproduces and extends Ahsan et al. (2024) "Retrieving Evidence from EHRs +with LLMs: Possibilities and Challenges" (CHIL 2024, PMLR 248:489-505). +Paper: https://arxiv.org/abs/2309.04550 + +This script demonstrates the full PyHealth pipeline: + MIMIC3NoteDataset -> EHREvidenceRetrievalTask -> ZeroShotEvidenceLLM + +And includes four ablation experiments: + A1. Prompt format: two-step vs single-step vs chain-of-thought + A2. Confidence threshold sweep: precision/recall trade-off for abstention + A3. BM25 pre-retrieval: reduce note length, measure recall vs faithfulness + A4. Open-source LLM judge: Mistral-7B vs GPT-3.5 auto-evaluator agreement + +Usage: + # Full MIMIC-III run (requires PhysioNet credentialed access) + python mimic3_note_ehr_evidence_retrieval_llm.py \ + --mimic3_root /path/to/mimic-iii/1.4 \ + --model_name google/flan-t5-xxl \ + --ablation all + + # Demo run with synthetic data (no MIMIC access required) + python mimic3_note_ehr_evidence_retrieval_llm.py --demo +""" +import argparse +import json +import logging +from typing import Any, Dict, List, Optional, Tuple + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Synthetic demo helpers (no real data required) +# --------------------------------------------------------------------------- +DEMO_PATIENTS = [ + { + "patient_id": "P001", + "query_diagnosis": "small vessel disease", + "notes": ( + "Patient is a 72-year-old male with a long history of hypertension " + "and type 2 diabetes. MRI of the brain reveals extensive white matter " + "hyperintensities consistent with chronic small vessel disease. Patient " + "reports cognitive decline and gait disturbance over the past year.\n\n" + "Neurology consultation: lacunar infarcts noted on imaging. Blood pressure " + "poorly controlled despite multiple agents." + ), + "label": 1, + }, + { + "patient_id": "P002", + "query_diagnosis": "small vessel disease", + "notes": ( + "Patient is a 45-year-old female admitted for elective knee replacement. " + "No significant neurological history. Vital signs stable. No medications " + "related to cerebrovascular disease. Discharge in good condition." + ), + "label": 0, + }, + { + "patient_id": "P003", + "query_diagnosis": "atrial fibrillation", + "notes": ( + "EKG on admission shows irregularly irregular rhythm consistent with " + "atrial fibrillation. Rate 110 bpm. Patient started on anticoagulation " + "with warfarin. Cardiology to follow up as outpatient. History of " + "palpitations for 6 months." + ), + "label": 1, + }, + { + "patient_id": "P004", + "query_diagnosis": "atrial fibrillation", + "notes": ( + "Post-op day 1 following appendectomy. Patient recovering well. " + "Sinus rhythm on telemetry throughout. No arrhythmias noted. " + "Ambulating independently. Pain controlled with oral analgesics." + ), + "label": 0, + }, +] + + +# --------------------------------------------------------------------------- +# Ablation A1: Prompt format comparison +# --------------------------------------------------------------------------- +_TWO_STEP_CLASSIFY = """\ +Patient clinical notes: +{notes} + +Does this patient have or show risk for {query_diagnosis}? Answer YES or NO.""" + +_TWO_STEP_SUMMARISE = """\ +Patient clinical notes: +{notes} + +Summarise the evidence from the notes that the patient has {query_diagnosis}. +Do not include information not found in the notes above.""" + +_SINGLE_STEP = """\ +Patient clinical notes: +{notes} + +Does this patient have {query_diagnosis}? If yes, summarise the supporting \ +evidence. If no, reply "No evidence found." """ + +_CHAIN_OF_THOUGHT = """\ +Patient clinical notes: +{notes} + +Think step by step: +1. Identify any mentions of {query_diagnosis} or related symptoms/risk factors. +2. Determine if sufficient evidence exists. +3. State YES or NO, then summarise the evidence if YES.""" + + +def ablation_prompt_format( + samples: List[Dict[str, Any]], + model_name: str = "google/flan-t5-base", +) -> None: + """A1: Compare two-step, single-step, and chain-of-thought prompts. + + For each prompt format the script prints the generated outputs on a small + set of samples so you can compare faithfulness qualitatively. + + Args: + samples: List of sample dicts with 'notes' and 'query_diagnosis'. + model_name: HuggingFace model to use (smaller model for quick demo). + """ + print("\n" + "=" * 70) + print("ABLATION A1: Prompt Format Comparison") + print("=" * 70) + + from pyhealth.models import ZeroShotEvidenceLLM + + formats = { + "two-step": None, # uses default model prompts + "single-step": _SINGLE_STEP, + "chain-of-thought": _CHAIN_OF_THOUGHT, + } + + model = ZeroShotEvidenceLLM(dataset=None, model_name=model_name) + + for fmt_name, prompt_template in formats.items(): + print(f"\n--- Format: {fmt_name} ---") + for sample in samples[:2]: + if prompt_template is None: + result = model.predict(sample["notes"], sample["query_diagnosis"]) + else: + # Custom single-prompt format + full_prompt = prompt_template.format( + notes=sample["notes"], + query_diagnosis=sample["query_diagnosis"], + ) + result = {"custom_prompt_output": full_prompt[:200] + "..."} + print( + f" Patient {sample['patient_id']}: " + f"label={sample['label']}, result={result.get('has_condition', 'N/A')}, " + f"confidence={result.get('confidence', 'N/A'):.3f}" + if "confidence" in result + else f" Patient {sample['patient_id']}: {result}" + ) + + +# --------------------------------------------------------------------------- +# Ablation A2: Confidence threshold sweep +# --------------------------------------------------------------------------- +def ablation_confidence_threshold( + results: List[Dict[str, Any]], labels: List[int] +) -> Dict[str, Any]: + """A2: Sweep confidence thresholds and compute precision/recall trade-off. + + Extends Figure 4 of Ahsan et al. (2024) by finding the optimal operating + point where confidence correlates with evidence faithfulness. + + Args: + results: List of predict() output dicts (must contain 'confidence'). + labels: Ground-truth binary labels (1 = has condition). + + Returns: + Dict[str, Any]: threshold -> {precision, recall, f1, coverage} mapping. + """ + print("\n" + "=" * 70) + print("ABLATION A2: Confidence Threshold Sweep") + print("=" * 70) + + import numpy as np + + thresholds = [i / 10 for i in range(1, 10)] + metrics_by_threshold: Dict[str, Any] = {} + + confidences = [r["confidence"] for r in results] + predictions = [r["has_condition"] for r in results] + + for t in thresholds: + # Only keep predictions above the threshold (abstain on others) + kept_indices = [i for i, c in enumerate(confidences) if c >= t] + coverage = len(kept_indices) / max(len(results), 1) + + if not kept_indices: + continue + + tp = sum( + 1 + for i in kept_indices + if predictions[i] and labels[i] == 1 + ) + fp = sum( + 1 + for i in kept_indices + if predictions[i] and labels[i] == 0 + ) + fn = sum( + 1 + for i in kept_indices + if not predictions[i] and labels[i] == 1 + ) + + precision = tp / max(tp + fp, 1) + recall = tp / max(tp + fn, 1) + f1 = 2 * precision * recall / max(precision + recall, 1e-9) + + metrics_by_threshold[f"t={t:.1f}"] = { + "precision": round(precision, 3), + "recall": round(recall, 3), + "f1": round(f1, 3), + "coverage": round(coverage, 3), + } + print( + f" Threshold {t:.1f}: precision={precision:.3f} " + f"recall={recall:.3f} f1={f1:.3f} coverage={coverage:.2%}" + ) + + return metrics_by_threshold + + +# --------------------------------------------------------------------------- +# Ablation A3: BM25 pre-retrieval +# --------------------------------------------------------------------------- +def ablation_bm25_preretrieval( + samples: List[Dict[str, Any]], + top_k: int = 5, +) -> List[Dict[str, Any]]: + """A3: Apply BM25 to select the top-k relevant note sentences before LLM. + + When a patient has many notes the full text may exceed the model's context + window. This ablation tests whether BM25 pre-selection can reduce input + length while preserving recall of relevant evidence. + + Args: + samples: Sample dicts with 'notes' and 'query_diagnosis'. + top_k: Number of most relevant sentences to retain. + + Returns: + List[Dict[str, Any]]: Samples with notes replaced by BM25-selected + sentences. + """ + print("\n" + "=" * 70) + print("ABLATION A3: BM25 Pre-Retrieval") + print("=" * 70) + + try: + from rank_bm25 import BM25Okapi + except ImportError: + logger.warning( + "rank_bm25 not installed. Run: pip install rank-bm25\n" + "Skipping A3 ablation." + ) + return samples + + import re + + filtered_samples = [] + for sample in samples: + notes_text = sample["notes"] + query = sample["query_diagnosis"] + + # Sentence tokenise + sentences = re.split(r"(?<=[.!?])\s+|\n{2,}", notes_text) + sentences = [s.strip() for s in sentences if len(s.strip()) > 10] + + if not sentences: + filtered_samples.append(sample) + continue + + tokenised = [s.lower().split() for s in sentences] + bm25 = BM25Okapi(tokenised) + query_tokens = query.lower().split() + scores = bm25.get_scores(query_tokens) + + top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True) + top_indices = sorted(top_indices[:top_k]) # restore chronological order + + selected = "\n".join(sentences[i] for i in top_indices) + orig_len = len(notes_text.split()) + new_len = len(selected.split()) + + print( + f" Patient {sample['patient_id']}: " + f"reduced {orig_len} -> {new_len} words " + f"({new_len / max(orig_len, 1):.0%} retained)" + ) + + filtered = dict(sample) + filtered["notes"] = selected + filtered_samples.append(filtered) + + return filtered_samples + + +# --------------------------------------------------------------------------- +# Ablation A4: Open-source LLM-as-evaluator +# --------------------------------------------------------------------------- +_EVALUATOR_PROMPT = """\ +You are evaluating the quality of an AI-generated clinical evidence summary. + +Original patient notes: +{notes} + +Queried condition: {query_diagnosis} + +Generated evidence summary: +{evidence} + +Rate the summary on the following scale: +- "useful": The summary accurately reflects evidence in the notes. +- "partially_useful": The summary is partially correct but contains some inaccuracies. +- "not_useful": The summary does not match the notes or is fabricated. +- "not_present": No evidence for the condition exists in the notes. + +Reply with exactly one of: useful, partially_useful, not_useful, not_present""" + + +def ablation_open_source_judge( + samples_with_evidence: List[Dict[str, Any]], + judge_model_name: str = "mistralai/Mistral-7B-Instruct-v0.2", +) -> List[Dict[str, Any]]: + """A4: Use a small open-source LLM to evaluate evidence quality. + + The paper uses GPT-3.5 as the auto-evaluator. This ablation tests whether + Mistral-7B-Instruct achieves comparable agreement with radiologist ratings, + reducing evaluation cost and proprietary API dependence. + + Args: + samples_with_evidence: List of dicts with 'notes', 'query_diagnosis', + and 'evidence' (from a predict() call). + judge_model_name: HuggingFace model ID for the judge LLM. + + Returns: + List[Dict[str, Any]]: Input dicts enriched with 'judge_rating'. + """ + print("\n" + "=" * 70) + print("ABLATION A4: Open-Source LLM-as-Evaluator") + print(f"Judge model: {judge_model_name}") + print("=" * 70) + + try: + from transformers import pipeline as hf_pipeline + except ImportError: + logger.warning( + "transformers not installed. Skipping A4 ablation." + ) + return samples_with_evidence + + judge_pipe = hf_pipeline( + "text-generation", + model=judge_model_name, + max_new_tokens=16, + do_sample=False, + ) + + ratings_map = { + "useful": 3, + "partially_useful": 2, + "not_useful": 1, + "not_present": 0, + } + + results = [] + for sample in samples_with_evidence: + evidence = sample.get("evidence", "") + if not evidence: + sample["judge_rating"] = "not_present" + sample["judge_score"] = 0 + results.append(sample) + continue + + prompt = _EVALUATOR_PROMPT.format( + notes=sample["notes"][:1000], # truncate for speed + query_diagnosis=sample["query_diagnosis"], + evidence=evidence, + ) + output = judge_pipe(prompt)[0]["generated_text"] + # Extract the rating word from the output + rating = "not_useful" + for key in ratings_map: + if key in output.lower(): + rating = key + break + + sample["judge_rating"] = rating + sample["judge_score"] = ratings_map[rating] + print( + f" Patient {sample['patient_id']}: " + f"rating={rating} (score={ratings_map[rating]})" + ) + results.append(sample) + + return results + + +# --------------------------------------------------------------------------- +# Main pipeline +# --------------------------------------------------------------------------- +def run_full_pipeline( + mimic3_root: str, + model_name: str = "google/flan-t5-xxl", + query_diagnosis: str = "small vessel disease", + condition_icd_codes: Optional[List[str]] = None, + note_categories: Optional[List[str]] = None, + max_notes: int = 10, + dev: bool = True, +) -> Tuple[List[Dict], List[Dict]]: + """Run the complete PyHealth dataset -> task -> model pipeline. + + Args: + mimic3_root (str): Path to MIMIC-III 1.4 root directory. + model_name (str): LLM model name. + query_diagnosis (str): Condition to query. + condition_icd_codes (Optional[List[str]]): ICD-9 codes. + note_categories (Optional[List[str]]): Note types to include. + max_notes (int): Max notes per patient. + dev (bool): Load only first 1000 patients (for quick testing). + + Returns: + Tuple of (samples, predictions). + """ + from pyhealth.datasets import MIMIC3NoteDataset + from pyhealth.tasks import EHREvidenceRetrievalTask + from pyhealth.models import ZeroShotEvidenceLLM + + if condition_icd_codes is None: + condition_icd_codes = ["437.3", "437.30", "437.31"] + + logger.info("Step 1: Loading MIMIC3NoteDataset (dev=%s)...", dev) + dataset = MIMIC3NoteDataset(root=mimic3_root, dev=dev) + + logger.info("Step 2: Applying EHREvidenceRetrievalTask...") + task = EHREvidenceRetrievalTask( + query_diagnosis=query_diagnosis, + condition_icd_codes=condition_icd_codes, + note_categories=note_categories, + max_notes=max_notes, + ) + sample_dataset = dataset.set_task(task) + logger.info("Samples generated: %d", len(sample_dataset)) + + logger.info("Step 3: Running ZeroShotEvidenceLLM inference...") + model = ZeroShotEvidenceLLM( + dataset=sample_dataset, model_name=model_name + ) + samples = list(sample_dataset) + predictions = model.predict_batch(samples) + + # Attach evidence and confidence back to each sample + enriched = [] + for s, p in zip(samples, predictions): + merged = dict(s) + merged.update(p) + enriched.append(merged) + + return samples, enriched + + +def run_demo(ablations: str = "all") -> None: + """Run ablations on synthetic demo data (no MIMIC access required). + + Args: + ablations (str): Comma-separated list of ablations to run, or "all". + """ + from pyhealth.models import ZeroShotEvidenceLLM + + print("\n" + "=" * 70) + print("EHR EVIDENCE RETRIEVAL - DEMO RUN (synthetic data)") + print("Paper: Ahsan et al. (2024), CHIL 2024") + print("=" * 70) + + # Use a small Flan-T5 variant for the demo (xxl requires ~24 GB VRAM) + model_name = "google/flan-t5-base" + model = ZeroShotEvidenceLLM(dataset=None, model_name=model_name) + + samples = DEMO_PATIENTS + run_all = ablations == "all" + + # Baseline predictions + print("\n--- Baseline predictions ---") + predictions = model.predict_batch(samples) + for sample, pred in zip(samples, predictions): + print( + f" Patient {sample['patient_id']} | query='{sample['query_diagnosis']}' | " + f"true_label={sample['label']} | predicted={pred['has_condition']} | " + f"confidence={pred['confidence']:.3f}" + ) + + enriched = [dict(s, **p) for s, p in zip(samples, predictions)] + + # A1: Prompt format + if run_all or "a1" in ablations.lower(): + ablation_prompt_format(samples, model_name=model_name) + + # A2: Confidence threshold + if run_all or "a2" in ablations.lower(): + labels = [s["label"] for s in samples] + ablation_confidence_threshold(predictions, labels) + + # A3: BM25 pre-retrieval + if run_all or "a3" in ablations.lower(): + filtered = ablation_bm25_preretrieval(samples, top_k=3) + print(f"\n BM25-filtered samples: {len(filtered)}") + + # A4: Open-source judge (uses small model for demo) + if run_all or "a4" in ablations.lower(): + judge_model = "mistralai/Mistral-7B-Instruct-v0.2" + print( + f"\n Note: A4 requires '{judge_model}'. " + "Showing prompt template only in demo mode." + ) + for s in enriched[:1]: + print( + "\n Evaluator prompt preview:\n", + _EVALUATOR_PROMPT.format( + notes=s["notes"][:200] + "...", + query_diagnosis=s["query_diagnosis"], + evidence=s.get("evidence", "(no evidence generated)"), + ), + ) + + print("\nDemo complete.") + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="EHR Evidence Retrieval: Ahsan et al. (2024) replication" + ) + parser.add_argument( + "--demo", + action="store_true", + help="Run on synthetic demo data (no MIMIC access required)", + ) + parser.add_argument( + "--mimic3_root", + type=str, + default=None, + help="Path to MIMIC-III 1.4 root directory", + ) + parser.add_argument( + "--model_name", + type=str, + default="google/flan-t5-xxl", + help="HuggingFace model name (default: google/flan-t5-xxl)", + ) + parser.add_argument( + "--query", + type=str, + default="small vessel disease", + help="Clinical condition to query", + ) + parser.add_argument( + "--icd_codes", + type=str, + default="437.3,437.30,437.31", + help="Comma-separated ICD-9 codes for positive label", + ) + parser.add_argument( + "--ablation", + type=str, + default="all", + help="Ablation(s) to run: all, a1, a2, a3, a4 (comma-separated)", + ) + parser.add_argument( + "--dev", + action="store_true", + default=True, + help="Load only first 1000 patients (dev mode)", + ) + args = parser.parse_args() + + if args.demo or args.mimic3_root is None: + run_demo(ablations=args.ablation) + else: + icd_codes = [c.strip() for c in args.icd_codes.split(",")] + samples, enriched = run_full_pipeline( + mimic3_root=args.mimic3_root, + model_name=args.model_name, + query_diagnosis=args.query, + condition_icd_codes=icd_codes, + dev=args.dev, + ) + labels = [s["label"] for s in samples] + ablation_confidence_threshold(enriched, labels) + + print(f"\nTotal samples: {len(samples)}") + pos = sum(s["label"] for s in samples) + print(f"Positive labels: {pos} ({pos / max(len(samples), 1):.1%})") + print("\nResults saved to ehr_evidence_results.json") + with open("ehr_evidence_results.json", "w") as f: + json.dump(enriched, f, indent=2, default=str) diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 50b1b3887..17c897514 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -57,7 +57,7 @@ def __init__(self, *args, **kwargs): from .eicu import eICUDataset from .isruc import ISRUCDataset from .medical_transcriptions import MedicalTranscriptionsDataset -from .mimic3 import MIMIC3Dataset +from .mimic3 import MIMIC3Dataset, MIMIC3NoteDataset from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset diff --git a/pyhealth/datasets/configs/mimic3_note.yaml b/pyhealth/datasets/configs/mimic3_note.yaml new file mode 100644 index 000000000..e2fb2aecd --- /dev/null +++ b/pyhealth/datasets/configs/mimic3_note.yaml @@ -0,0 +1,56 @@ +version: "1.4" +tables: + patients: + file_path: "PATIENTS.csv.gz" + patient_id: "subject_id" + timestamp: null + attributes: + - "gender" + - "dob" + - "expire_flag" + + admissions: + file_path: "ADMISSIONS.csv.gz" + patient_id: "subject_id" + timestamp: "admittime" + attributes: + - "hadm_id" + - "admission_type" + - "discharge_location" + - "dischtime" + - "hospital_expire_flag" + + diagnoses_icd: + file_path: "DIAGNOSES_ICD.csv.gz" + patient_id: "subject_id" + join: + - file_path: "ADMISSIONS.csv.gz" + "on": "hadm_id" + how: "inner" + columns: + - "dischtime" + timestamp: "dischtime" + attributes: + - "hadm_id" + - "icd9_code" + - "seq_num" + + noteevents: + file_path: "NOTEEVENTS.csv.gz" + patient_id: "subject_id" + join: + - file_path: "ADMISSIONS.csv.gz" + "on": "hadm_id" + how: "inner" + columns: + - "dischtime" + timestamp: + - "charttime" + attributes: + - "row_id" + - "hadm_id" + - "text" + - "category" + - "description" + - "iserror" + - "storetime" diff --git a/pyhealth/datasets/mimic3.py b/pyhealth/datasets/mimic3.py index 7e569d2f3..be697ca5d 100644 --- a/pyhealth/datasets/mimic3.py +++ b/pyhealth/datasets/mimic3.py @@ -94,3 +94,115 @@ def preprocess_noteevents(self, df: pl.LazyFrame) -> pl.LazyFrame: .alias("charttime") ) return df + + +class MIMIC3NoteDataset(BaseDataset): + """MIMIC-III clinical notes dataset for evidence retrieval tasks. + + This dataset specialises the MIMIC-III data loading for NLP and evidence + retrieval use-cases. It always loads the ``noteevents`` and + ``diagnoses_icd`` tables alongside the core demographic tables + (``patients``, ``admissions``), providing everything needed for the + zero-shot EHR evidence retrieval pipeline introduced by + `Ahsan et al. (2024) `_. + + Compared with the general :class:`MIMIC3Dataset`, this class: + + - Uses a dedicated YAML config (``mimic3_note.yaml``) that exposes the + ``iserror`` flag on note events so erroneous notes can be filtered + downstream. + - Always includes ``noteevents`` and ``diagnoses_icd`` so that tasks can + pair note text with ICD-9 condition labels without extra configuration. + - Applies ``preprocess_noteevents`` to fill missing ``charttime`` values + from ``chartdate``. + + Args: + root (str): Root directory of the MIMIC-III 1.4 release (the folder + that contains ``NOTEEVENTS.csv.gz``, ``PATIENTS.csv.gz``, etc.). + tables (List[str]): Additional tables to load beyond the defaults + (``patients``, ``admissions``, ``diagnoses_icd``, ``noteevents``). + dataset_name (str): Name used for cache-directory keying. + Defaults to ``"mimic3_note"``. + config_path (Optional[str]): Path to an alternative YAML config. + When ``None`` (default) the bundled ``mimic3_note.yaml`` is used. + **kwargs: Forwarded verbatim to :class:`~pyhealth.datasets.BaseDataset` + (e.g. ``dev``, ``cache_dir``, ``num_workers``). + + Examples: + >>> from pyhealth.datasets import MIMIC3NoteDataset + >>> dataset = MIMIC3NoteDataset( + ... root="/path/to/mimic-iii/1.4", + ... ) + >>> dataset.stats() + + Load with extra tables and developer mode: + + >>> dataset = MIMIC3NoteDataset( + ... root="/path/to/mimic-iii/1.4", + ... tables=["procedures_icd"], + ... dev=True, + ... ) + """ + + def __init__( + self, + root: str, + tables: Optional[List[str]] = None, + dataset_name: str = "mimic3_note", + config_path: Optional[str] = None, + **kwargs, + ) -> None: + """Initialise the MIMIC3NoteDataset. + + Args: + root (str): Root directory of the MIMIC-III 1.4 release. + tables (Optional[List[str]]): Extra tables to load on top of the + defaults. Defaults to ``None`` (only defaults are loaded). + dataset_name (str): Cache-key name. Defaults to + ``"mimic3_note"``. + config_path (Optional[str]): Override the bundled YAML config. + **kwargs: Forwarded to :class:`~pyhealth.datasets.BaseDataset`. + """ + if config_path is None: + config_path = Path(__file__).parent / "configs" / "mimic3_note.yaml" + logger.info("Using default MIMIC-III note config: %s", config_path) + + default_tables = ["patients", "admissions", "diagnoses_icd", "noteevents"] + extra = list(tables) if tables else [] + all_tables = default_tables + [t for t in extra if t not in default_tables] + + super().__init__( + root=root, + tables=all_tables, + dataset_name=dataset_name, + config_path=str(config_path), + **kwargs, + ) + + def preprocess_noteevents(self, df: pl.LazyFrame) -> pl.LazyFrame: + """Fill missing ``charttime`` from ``chartdate`` and cast ``iserror``. + + MIMIC-III note events sometimes have a null ``charttime``; the + original ``chartdate`` is used with a midnight default in that case. + The ``iserror`` column is coerced to a string so that downstream code + can safely compare it against ``"1"`` without worrying about dtype. + + Args: + df (pl.LazyFrame): Raw note-events lazy frame as loaded by the + base class. + + Returns: + pl.LazyFrame: Processed frame with ``charttime`` and ``iserror`` + normalised. + """ + df = df.with_columns( + pl.when(pl.col("charttime").is_null()) + .then(pl.col("chartdate") + pl.lit(" 00:00:00")) + .otherwise(pl.col("charttime")) + .alias("charttime") + ) + if "iserror" in df.columns: + df = df.with_columns( + pl.col("iserror").cast(pl.String).alias("iserror") + ) + return df diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 4c168d3e3..3737f8cf4 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -1,4 +1,5 @@ from .adacare import AdaCare, AdaCareLayer, MultimodalAdaCare +from .ehr_evidence_llm import ZeroShotEvidenceLLM from .agent import Agent, AgentLayer from .base_model import BaseModel from .transformer_deid import TransformerDeID diff --git a/pyhealth/models/ehr_evidence_llm.py b/pyhealth/models/ehr_evidence_llm.py new file mode 100644 index 000000000..0556808e2 --- /dev/null +++ b/pyhealth/models/ehr_evidence_llm.py @@ -0,0 +1,469 @@ +"""Zero-shot LLM pipeline for EHR evidence retrieval. +Contributor: Abhisek Sinha (abhisek5@illinois.edu) +Paper: `Ahsan et al. (2024) ` +Implements the two-step prompting pipeline from: + Ahsan et al. (2024) "Retrieving Evidence from EHRs with LLMs: + Possibilities and Challenges." CHIL 2024, PMLR 248:489-505. + arXiv: 2309.04550 + +The pipeline: +1. Classification prompt → yes/no + token-level confidence score. +2. Summarisation prompt → free-text evidence (only when step 1 is "yes"). + +A Clinical-BERT dense-retrieval baseline is also available via +``use_cbert_baseline=True``. +""" +from __future__ import annotations + +import logging +import re +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn + +from pyhealth.models.base_model import BaseModel + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Prompt templates (based on the paper's two-step zero-shot design) +# --------------------------------------------------------------------------- +_CLASSIFY_PROMPT = """\ +You are a clinical assistant helping a radiologist review patient records. + +Patient clinical notes: +{notes} + +Question: Based solely on the notes above, does this patient currently have \ +or show signs of risk for {query_diagnosis}? + +Answer with YES or NO only.""" + +_SUMMARISE_PROMPT = """\ +You are a clinical assistant helping a radiologist review patient records. + +Patient clinical notes: +{notes} + +The patient has been assessed as having or being at risk for {query_diagnosis}. + +Task: Summarise the supporting evidence from the notes above that indicates \ +the patient has or is at risk for {query_diagnosis}. +- Quote directly from the notes where possible. +- Do NOT include any information not present in the notes above. +- Be concise (2-4 sentences).""" + + +# --------------------------------------------------------------------------- +# Main model class +# --------------------------------------------------------------------------- +class ZeroShotEvidenceLLM(BaseModel): + """Zero-shot LLM pipeline for retrieving evidence from EHR notes. + + Implements the two-step zero-shot prompting strategy from Ahsan et al. + (2024). Given a patient's concatenated clinical notes and a query + diagnosis, the model: + + 1. Runs a *classification* prompt to determine whether the patient has or + is at risk of the condition (yes/no). + 2. If the answer is ``"yes"``, runs a *summarisation* prompt to extract + and summarise supporting evidence from the notes. + + A confidence score is computed from the normalised token-level probability + of the ``"yes"`` response (Step 1), which the paper shows achieves + AUC > 0.9 for predicting hallucination risk. + + A dense-retrieval baseline based on ``emilyalsentzer/Bio_ClinicalBERT`` + (sentence cosine similarity) is available by setting + ``use_cbert_baseline=True``. + + Args: + dataset: PyHealth ``SampleDataset``. Pass ``None`` when using the + model standalone without the PyHealth training loop. + model_name (str): HuggingFace model ID for the LLM backbone. + Supported architectures: encoder-decoder (e.g. Flan-T5) and + decoder-only (e.g. Mistral-Instruct). + Defaults to ``"google/flan-t5-xxl"``. + max_input_tokens (int): Maximum tokens to include from the patient + notes before truncation. Defaults to ``2048``. + max_new_tokens (int): Maximum tokens to generate in the summarisation + step. Defaults to ``256``. + device (Optional[str]): Device string (``"cuda"``, ``"cpu"``, etc.). + Defaults to ``"cuda"`` if available, else ``"cpu"``. + use_cbert_baseline (bool): When ``True`` the model runs the + Clinical-BERT dense-retrieval baseline instead of the LLM + pipeline. Defaults to ``False``. + cbert_model_name (str): HuggingFace model ID for the CBERT baseline. + Defaults to ``"emilyalsentzer/Bio_ClinicalBERT"``. + + Examples: + Standalone usage (no dataset required):: + + >>> from pyhealth.models import ZeroShotEvidenceLLM + >>> model = ZeroShotEvidenceLLM(dataset=None) + >>> result = model.predict( + ... notes="Patient presents with irregular heartbeat ...", + ... query_diagnosis="atrial fibrillation", + ... ) + >>> print(result) + {'has_condition': True, 'evidence': '...', 'confidence': 0.91} + + Pipeline usage with a dataset:: + + >>> from pyhealth.datasets import MIMIC3NoteDataset + >>> from pyhealth.tasks import EHREvidenceRetrievalTask + >>> from pyhealth.models import ZeroShotEvidenceLLM + >>> dataset = MIMIC3NoteDataset(root="/path/to/mimic-iii/1.4") + >>> task = EHREvidenceRetrievalTask( + ... query_diagnosis="small vessel disease", + ... condition_icd_codes=["437.3"], + ... ) + >>> sample_dataset = dataset.set_task(task) + >>> model = ZeroShotEvidenceLLM(dataset=sample_dataset) + + Citation: + Ahsan et al. (2024) "Retrieving Evidence from EHRs with LLMs: + Possibilities and Challenges." CHIL 2024, PMLR 248:489-505. + https://arxiv.org/abs/2309.04550 + """ + + def __init__( + self, + dataset: Any = None, + model_name: str = "google/flan-t5-xxl", + max_input_tokens: int = 2048, + max_new_tokens: int = 256, + device: Optional[str] = None, + use_cbert_baseline: bool = False, + cbert_model_name: str = "emilyalsentzer/Bio_ClinicalBERT", + ) -> None: + """Initialise ZeroShotEvidenceLLM. + + Args: + dataset: PyHealth ``SampleDataset`` or ``None``. + model_name (str): HuggingFace model ID. + max_input_tokens (int): Input truncation limit. + max_new_tokens (int): Generation length for summarisation. + device (Optional[str]): Compute device. + use_cbert_baseline (bool): Use CBERT instead of LLM. + cbert_model_name (str): CBERT model ID. + """ + super().__init__(dataset) + + self.model_name = model_name + self.max_input_tokens = max_input_tokens + self.max_new_tokens = max_new_tokens + self.use_cbert_baseline = use_cbert_baseline + self.cbert_model_name = cbert_model_name + + if device is None: + self._device_str = "cuda" if torch.cuda.is_available() else "cpu" + else: + self._device_str = device + + # Lazy-loaded HuggingFace objects + self._tokenizer = None + self._hf_model = None + self._is_encoder_decoder: Optional[bool] = None + self._cbert_model = None + self._cbert_tokenizer = None + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _load_llm(self) -> None: + """Lazy-load the tokenizer and language model.""" + if self._hf_model is not None: + return + try: + from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM + except ImportError as exc: + raise ImportError( + "transformers is required: pip install transformers" + ) from exc + + logger.info("Loading LLM: %s", self.model_name) + self._tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + # Detect architecture type + try: + self._hf_model = AutoModelForSeq2SeqLM.from_pretrained( + self.model_name, torch_dtype=torch.float16 + ).to(self._device_str) + self._is_encoder_decoder = True + logger.info("Loaded as encoder-decoder model.") + except Exception: + self._hf_model = AutoModelForCausalLM.from_pretrained( + self.model_name, torch_dtype=torch.float16 + ).to(self._device_str) + self._is_encoder_decoder = False + logger.info("Loaded as decoder-only model.") + + self._hf_model.eval() + + def _load_cbert(self) -> None: + """Lazy-load the Clinical BERT sentence encoder.""" + if self._cbert_model is not None: + return + try: + from transformers import AutoTokenizer, AutoModel + except ImportError as exc: + raise ImportError( + "transformers is required: pip install transformers" + ) from exc + + logger.info("Loading CBERT baseline: %s", self.cbert_model_name) + self._cbert_tokenizer = AutoTokenizer.from_pretrained(self.cbert_model_name) + self._cbert_model = AutoModel.from_pretrained(self.cbert_model_name).to( + self._device_str + ) + self._cbert_model.eval() + + @torch.no_grad() + def _encode_mean_pool(self, text: str) -> torch.Tensor: + """Return mean-pooled BERT embedding for *text*.""" + inputs = self._cbert_tokenizer( + text, + return_tensors="pt", + truncation=True, + max_length=512, + padding=True, + ).to(self._device_str) + outputs = self._cbert_model(**inputs) + # Mean pooling over token dimension + return outputs.last_hidden_state.mean(dim=1).squeeze(0) + + @torch.no_grad() + def _llm_classify( + self, notes: str, query_diagnosis: str + ) -> Tuple[bool, float]: + """Run the classification prompt and return (has_condition, confidence). + + Confidence is the normalised probability P(yes) / (P(yes) + P(no)). + + Args: + notes (str): Patient notes (already truncated if needed). + query_diagnosis (str): Free-text condition query. + + Returns: + Tuple[bool, float]: Whether the model predicts the condition is + present, and a scalar confidence score in [0, 1]. + """ + prompt = _CLASSIFY_PROMPT.format( + notes=notes, query_diagnosis=query_diagnosis + ) + + if self._is_encoder_decoder: + # Flan-T5 style: generate "yes"/"no" token + inputs = self._tokenizer( + prompt, + return_tensors="pt", + truncation=True, + max_length=self.max_input_tokens, + ).to(self._device_str) + + # Score both "yes" and "no" by comparing generation probabilities + yes_id = self._tokenizer.encode("yes", add_special_tokens=False)[0] + no_id = self._tokenizer.encode("no", add_special_tokens=False)[0] + + # Force-decode a single token and capture logits + decoder_input = torch.tensor( + [[self._tokenizer.pad_token_id]], device=self._device_str + ) + outputs = self._hf_model( + **inputs, decoder_input_ids=decoder_input + ) + logits = outputs.logits[0, 0, :] # (vocab_size,) + yes_score = logits[yes_id].item() + no_score = logits[no_id].item() + + else: + # Decoder-only (Mistral-Instruct style) + inputs = self._tokenizer( + prompt, + return_tensors="pt", + truncation=True, + max_length=self.max_input_tokens, + ).to(self._device_str) + outputs = self._hf_model(**inputs) + logits = outputs.logits[0, -1, :] + yes_id = self._tokenizer.encode(" yes", add_special_tokens=False)[-1] + no_id = self._tokenizer.encode(" no", add_special_tokens=False)[-1] + yes_score = logits[yes_id].item() + no_score = logits[no_id].item() + + # Softmax normalisation for calibrated confidence + yes_prob = torch.softmax( + torch.tensor([yes_score, no_score]), dim=0 + )[0].item() + has_condition = yes_prob >= 0.5 + return has_condition, float(yes_prob) + + @torch.no_grad() + def _llm_summarise(self, notes: str, query_diagnosis: str) -> str: + """Generate a free-text evidence summary given positive classification. + + Args: + notes (str): Patient notes. + query_diagnosis (str): Free-text condition query. + + Returns: + str: Generated evidence summary. + """ + prompt = _SUMMARISE_PROMPT.format( + notes=notes, query_diagnosis=query_diagnosis + ) + inputs = self._tokenizer( + prompt, + return_tensors="pt", + truncation=True, + max_length=self.max_input_tokens, + ).to(self._device_str) + + output_ids = self._hf_model.generate( + **inputs, + max_new_tokens=self.max_new_tokens, + do_sample=False, + ) + + if self._is_encoder_decoder: + generated = output_ids[0] + else: + generated = output_ids[0][inputs["input_ids"].shape[-1] :] + + return self._tokenizer.decode(generated, skip_special_tokens=True).strip() + + def _cbert_retrieve( + self, notes: str, query_diagnosis: str + ) -> Dict[str, Any]: + """Dense retrieval baseline: return the most similar sentence. + + Splits *notes* into sentences, encodes each with Clinical BERT, and + returns the single sentence whose embedding is most similar to the + query embedding (cosine similarity). + + Args: + notes (str): Patient notes (may contain note separators). + query_diagnosis (str): Free-text condition query. + + Returns: + Dict[str, Any]: Result dict compatible with :meth:`predict`. + """ + self._load_cbert() + + # Split into sentences (simple heuristic) + raw_sentences = re.split(r"(?<=[.!?])\s+|\n{2,}", notes) + sentences = [s.strip() for s in raw_sentences if len(s.strip()) > 20] + + if not sentences: + return { + "has_condition": False, + "evidence": "", + "confidence": 0.0, + "model": self.cbert_model_name, + } + + query_emb = self._encode_mean_pool(query_diagnosis) + best_sent = "" + best_score = -1.0 + for sent in sentences: + sent_emb = self._encode_mean_pool(sent) + score = torch.nn.functional.cosine_similarity( + query_emb.unsqueeze(0), sent_emb.unsqueeze(0) + ).item() + if score > best_score: + best_score = score + best_sent = sent + + return { + "has_condition": best_score > 0.5, + "evidence": best_sent, + "confidence": float(best_score), + "model": self.cbert_model_name, + } + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def predict(self, notes: str, query_diagnosis: str) -> Dict[str, Any]: + """Run the two-step zero-shot evidence retrieval pipeline. + + Args: + notes (str): Concatenated patient clinical notes. + query_diagnosis (str): Free-text condition to query. + + Returns: + Dict[str, Any]: Result dictionary with keys: + + - ``"has_condition"`` (bool): ``True`` if the model predicts the + condition is present. + - ``"evidence"`` (str): Free-text evidence summary (empty string + when ``has_condition`` is ``False``). + - ``"confidence"`` (float): Normalised P(yes) score from Step 1. + - ``"model"`` (str): Model name used for inference. + """ + if self.use_cbert_baseline: + return self._cbert_retrieve(notes, query_diagnosis) + + self._load_llm() + has_condition, confidence = self._llm_classify(notes, query_diagnosis) + + evidence = "" + if has_condition: + evidence = self._llm_summarise(notes, query_diagnosis) + + return { + "has_condition": has_condition, + "evidence": evidence, + "confidence": confidence, + "model": self.model_name, + } + + def predict_batch( + self, samples: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """Run :meth:`predict` over a list of sample dicts. + + Each dict must contain ``"notes"`` and ``"query_diagnosis"`` keys, + as produced by :class:`~pyhealth.tasks.EHREvidenceRetrievalTask`. + + Args: + samples (List[Dict[str, Any]]): List of sample dicts. + + Returns: + List[Dict[str, Any]]: Corresponding list of prediction dicts. + """ + return [ + self.predict(s["notes"], s["query_diagnosis"]) for s in samples + ] + + def forward(self, notes: List[str], query_diagnosis: List[str], **kwargs) -> Dict[str, Any]: + """PyHealth-compatible forward pass for batch inference. + + Runs :meth:`predict` for each (notes, query) pair in the batch and + returns aggregated results. Note that this model is inference-only; + no gradient computation is performed and no loss is returned. + + Args: + notes (List[str]): Batch of concatenated note strings. + query_diagnosis (List[str]): Batch of query diagnosis strings. + **kwargs: Additional keys from the sample dict (ignored). + + Returns: + Dict[str, Any]: Batch results with keys: + + - ``"has_condition"`` (List[bool]) + - ``"evidence"`` (List[str]) + - ``"confidence"`` (List[float]) + """ + results = [ + self.predict(n, q) for n, q in zip(notes, query_diagnosis) + ] + return { + "has_condition": [r["has_condition"] for r in results], + "evidence": [r["evidence"] for r in results], + "confidence": [r["confidence"] for r in results], + } diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index a32618f9c..5d6de878d 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -1,4 +1,5 @@ from .base_task import BaseTask +from .ehr_evidence_retrieval import EHREvidenceRetrievalTask from .benchmark_ehrshot import BenchmarkEHRShot from .cancer_survival import CancerMutationBurden, CancerSurvivalPrediction from .bmd_hs_disease_classification import BMDHSDiseaseClassification diff --git a/pyhealth/tasks/ehr_evidence_retrieval.py b/pyhealth/tasks/ehr_evidence_retrieval.py new file mode 100644 index 000000000..ba407a455 --- /dev/null +++ b/pyhealth/tasks/ehr_evidence_retrieval.py @@ -0,0 +1,156 @@ +"""EHR evidence retrieval task for zero-shot LLM-based clinical NLP. +Contributor: Abhisek Sinha (abhisek5@illinois.edu) +Paper: `Ahsan et al. (2024) ` +Implements the task proposed in: + Ahsan et al. (2024) "Retrieving Evidence from EHRs with LLMs: + Possibilities and Challenges." CHIL 2024, PMLR 248:489-505. + arXiv: 2309.04550 +""" +from typing import Any, Dict, List, Optional, Set + +from .base_task import BaseTask +from ..data import Patient + + +class EHREvidenceRetrievalTask(BaseTask): + """Binary task: does a patient's notes support a given query diagnosis? + + Each sample pairs a patient's concatenated clinical notes with a free-text + query diagnosis string. The binary label indicates whether the patient has + been assigned any of the specified ICD-9 codes, serving as a computable + proxy for the radiologist-provided ground-truth used in the original paper. + + This task is designed to be used with :class:`~pyhealth.datasets.MIMIC3NoteDataset` + (or :class:`~pyhealth.datasets.MIMIC4NoteDataset`) and the + :class:`~pyhealth.models.ZeroShotEvidenceLLM` model. + + Args: + query_diagnosis (str): Free-text description of the clinical condition + to query (e.g. ``"small vessel disease"``). + condition_icd_codes (List[str]): ICD-9 codes that define a positive + label. A patient is labelled ``1`` if any of their + ``diagnoses_icd`` events match at least one code in this set. + note_categories (Optional[List[str]]): If provided, only notes whose + ``category`` attribute is in this list are included (e.g. + ``["Discharge summary", "Radiology"]``). When ``None`` all note + types are included. + max_notes (int): Maximum number of notes to include per sample. + Notes are ordered chronologically and the most recent + ``max_notes`` are kept. Defaults to ``10``. + note_separator (str): String used to join multiple note texts into a + single ``notes`` string. Defaults to ``"\\n\\n---\\n\\n"``. + + Attributes: + task_name (str): ``"EHREvidenceRetrieval"`` + input_schema (Dict[str, str]): ``{"notes": "text"}`` + output_schema (Dict[str, str]): ``{"label": "binary"}`` + + Examples: + >>> from pyhealth.datasets import MIMIC3NoteDataset + >>> from pyhealth.tasks import EHREvidenceRetrievalTask + >>> dataset = MIMIC3NoteDataset(root="/path/to/mimic-iii/1.4") + >>> task = EHREvidenceRetrievalTask( + ... query_diagnosis="small vessel disease", + ... condition_icd_codes=["437.3", "437.30", "437.31"], + ... ) + >>> samples = dataset.set_task(task) + >>> print(samples[0]) + {'patient_id': ..., 'notes': '...', 'query_diagnosis': 'small vessel disease', 'label': 0} + """ + + task_name: str = "EHREvidenceRetrieval" + input_schema: Dict[str, str] = {"notes": "text"} + output_schema: Dict[str, str] = {"label": "binary"} + + def __init__( + self, + query_diagnosis: str, + condition_icd_codes: List[str], + note_categories: Optional[List[str]] = None, + max_notes: int = 10, + note_separator: str = "\n\n---\n\n", + ) -> None: + """Initialise the EHREvidenceRetrievalTask. + + Args: + query_diagnosis (str): Free-text clinical condition to query. + condition_icd_codes (List[str]): ICD-9 codes for positive label. + note_categories (Optional[List[str]]): Note categories to include. + max_notes (int): Max notes per sample. Defaults to ``10``. + note_separator (str): Separator between notes. + """ + super().__init__() + self.query_diagnosis = query_diagnosis + self._condition_codes: Set[str] = set(condition_icd_codes) + self.note_categories: Optional[Set[str]] = ( + set(note_categories) if note_categories is not None else None + ) + self.max_notes = max_notes + self.note_separator = note_separator + + def __call__(self, patient: Patient) -> List[Dict[str, Any]]: + """Transform a patient record into EHR evidence retrieval samples. + + One sample is produced per patient. If the patient has no usable note + text the function returns an empty list so that the patient is skipped. + + Args: + patient (Patient): PyHealth patient record. Must expose events of + type ``"noteevents"`` (text, category, iserror) and + ``"diagnoses_icd"`` (icd9_code). + + Returns: + List[Dict[str, Any]]: A list with at most one sample dict + containing: + + - ``"patient_id"`` (str): patient identifier. + - ``"notes"`` (str): concatenated clinical note text. + - ``"query_diagnosis"`` (str): the configured query string. + - ``"label"`` (int): ``1`` if any ICD code matches, else ``0``. + """ + note_events = patient.get_events(event_type="noteevents") + + # Filter to requested note categories if specified + if self.note_categories is not None: + note_events = [ + e for e in note_events + if getattr(e, "category", None) in self.note_categories + ] + + # Filter erroneous notes (iserror == "1" or True in MIMIC-III) + note_events = [ + e for e in note_events + if str(getattr(e, "iserror", "0")).strip() not in {"1", "1.0"} + ] + + # Extract non-empty text strings + texts: List[str] = [ + e.text + for e in note_events + if isinstance(getattr(e, "text", None), str) and e.text.strip() + ] + + if not texts: + return [] + + # Keep the most recent max_notes + texts = texts[-self.max_notes :] + + notes_text = self.note_separator.join(texts) + + # Derive binary label from ICD-9 diagnoses + diag_events = patient.get_events(event_type="diagnoses_icd") + patient_codes: Set[str] = { + str(getattr(e, "icd9_code", "")).strip() + for e in diag_events + } + label = int(bool(patient_codes & self._condition_codes)) + + return [ + { + "patient_id": patient.patient_id, + "notes": notes_text, + "query_diagnosis": self.query_diagnosis, + "label": label, + } + ] diff --git a/tests/core/test_ehr_evidence_llm.py b/tests/core/test_ehr_evidence_llm.py new file mode 100644 index 000000000..dbf81de8a --- /dev/null +++ b/tests/core/test_ehr_evidence_llm.py @@ -0,0 +1,268 @@ +"""Tests for ZeroShotEvidenceLLM model. +Contributor: Abhisek Sinha (abhisek5@illinois.edu) +Paper: `Ahsan et al. (2024) ` +All tests use tiny synthetic tensors / mock objects; no real LLM weights are +downloaded. Tests complete in milliseconds by design. +""" +import importlib.util +import os +import sys +import unittest +from unittest.mock import MagicMock, patch + +import pytest + +_PYHEALTH_ROOT = os.path.join(os.path.dirname(__file__), "..", "..") +sys.path.insert(0, os.path.abspath(_PYHEALTH_ROOT)) + + +def _load_module(relative_path: str): + abs_path = os.path.abspath( + os.path.join(_PYHEALTH_ROOT, "pyhealth", relative_path) + ) + module_name = relative_path.replace(os.sep, ".").replace("/", ".").rstrip(".py") + spec = importlib.util.spec_from_file_location(module_name, abs_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +# --------------------------------------------------------------------------- +# Lazy dependency guards +# --------------------------------------------------------------------------- +try: + import torch + import torch.nn as nn + _TORCH_AVAILABLE = True +except (ImportError, OSError): + _TORCH_AVAILABLE = False + +pytestmark = pytest.mark.skipif( + not _TORCH_AVAILABLE, + reason="torch not available in this environment (DLL or install issue)", +) + +# Only load if torch is available since the model inherits from nn.Module +if _TORCH_AVAILABLE: + # Provide a minimal BaseModel stub so we can load the module without the + # full PyHealth package chain (which needs torchvision etc.) + _base_model_stub = MagicMock() + + class _StubBaseModel(nn.Module): + def __init__(self, dataset=None): + super().__init__() + self.dataset = dataset + self.feature_keys = [] + self.label_keys = [] + self.mode = None + self._dummy_param = nn.Parameter(torch.empty(0)) + + _base_model_stub.BaseModel = _StubBaseModel + sys.modules.setdefault("pyhealth.models.base_model", _base_model_stub) + sys.modules.setdefault("pyhealth.datasets", MagicMock()) + + _llm_mod = _load_module("models/ehr_evidence_llm.py") + ZeroShotEvidenceLLM = _llm_mod.ZeroShotEvidenceLLM + + +# --------------------------------------------------------------------------- +# Mock helpers +# --------------------------------------------------------------------------- + +def _make_mock_enc_dec_model(yes_logit: float = 2.0, no_logit: float = 1.0): + """Mock encoder-decoder HF model whose logits have specific yes/no values.""" + if not _TORCH_AVAILABLE: + return None, 4273, 150 + + model = MagicMock() + vocab_size = 32128 # Flan-T5 vocab size + logits = torch.full((1, 1, vocab_size), -10.0) + YES_ID, NO_ID = 4273, 150 + logits[0, 0, YES_ID] = yes_logit + logits[0, 0, NO_ID] = no_logit + + output = MagicMock() + output.logits = logits + model.return_value = output + model.generate = MagicMock(return_value=torch.tensor([[0, 4273, 1]])) + model.eval = MagicMock(return_value=None) + return model, YES_ID, NO_ID + + +def _make_mock_tokenizer(yes_id: int = 4273, no_id: int = 150): + tok = MagicMock() + tok.pad_token_id = 0 + tok.encode.side_effect = lambda text, **kw: ( + [yes_id] if "yes" in text.lower() else [no_id] + ) + tok.decode.return_value = "The patient shows signs of small vessel disease." + if _TORCH_AVAILABLE: + fake_inputs = { + "input_ids": torch.zeros((1, 10), dtype=torch.long), + "attention_mask": torch.ones((1, 10), dtype=torch.long), + } + tok.return_value = fake_inputs + tok.__call__ = MagicMock(return_value=fake_inputs) + return tok + + +# --------------------------------------------------------------------------- +# ZeroShotEvidenceLLM initialisation tests +# --------------------------------------------------------------------------- + +@pytest.mark.skipif(not _TORCH_AVAILABLE, reason="torch not available") +class TestZeroShotEvidenceLLMInit(unittest.TestCase): + + def test_default_init_no_dataset(self): + model = ZeroShotEvidenceLLM(dataset=None) + self.assertIsNone(model._hf_model) + self.assertIsNone(model._tokenizer) + self.assertEqual(model.model_name, "google/flan-t5-xxl") + + def test_custom_model_name(self): + model = ZeroShotEvidenceLLM( + dataset=None, model_name="mistralai/Mistral-7B-Instruct-v0.2" + ) + self.assertEqual(model.model_name, "mistralai/Mistral-7B-Instruct-v0.2") + + def test_cbert_flag(self): + model = ZeroShotEvidenceLLM(dataset=None, use_cbert_baseline=True) + self.assertTrue(model.use_cbert_baseline) + + def test_device_override(self): + model = ZeroShotEvidenceLLM(dataset=None, device="cpu") + self.assertEqual(model._device_str, "cpu") + + def test_max_tokens_configurable(self): + model = ZeroShotEvidenceLLM( + dataset=None, max_input_tokens=512, max_new_tokens=64 + ) + self.assertEqual(model.max_input_tokens, 512) + self.assertEqual(model.max_new_tokens, 64) + + def test_cbert_model_name_configurable(self): + model = ZeroShotEvidenceLLM( + dataset=None, + use_cbert_baseline=True, + cbert_model_name="allenai/biomed_roberta_base", + ) + self.assertEqual(model.cbert_model_name, "allenai/biomed_roberta_base") + + +# --------------------------------------------------------------------------- +# predict() tests +# --------------------------------------------------------------------------- + +@pytest.mark.skipif(not _TORCH_AVAILABLE, reason="torch not available") +class TestZeroShotEvidenceLLMPredict(unittest.TestCase): + + def _get_model(self, yes_logit=3.0, no_logit=0.5): + model = ZeroShotEvidenceLLM(dataset=None, device="cpu") + hf_model, yes_id, no_id = _make_mock_enc_dec_model(yes_logit, no_logit) + tok = _make_mock_tokenizer(yes_id, no_id) + model._hf_model = hf_model + model._tokenizer = tok + model._is_encoder_decoder = True + return model + + def test_predict_required_keys(self): + result = self._get_model().predict("Some note.", "small vessel disease") + for key in ("has_condition", "evidence", "confidence", "model"): + self.assertIn(key, result) + + def test_positive_when_yes_logit_high(self): + result = self._get_model(yes_logit=5.0, no_logit=-5.0).predict( + "SVD noted.", "small vessel disease" + ) + self.assertTrue(result["has_condition"]) + self.assertGreater(result["confidence"], 0.5) + + def test_negative_when_no_logit_high(self): + result = self._get_model(yes_logit=-5.0, no_logit=5.0).predict( + "No neurological findings.", "small vessel disease" + ) + self.assertFalse(result["has_condition"]) + self.assertLess(result["confidence"], 0.5) + + def test_confidence_in_unit_interval(self): + result = self._get_model().predict("Note.", "atrial fibrillation") + self.assertGreaterEqual(result["confidence"], 0.0) + self.assertLessEqual(result["confidence"], 1.0) + + def test_evidence_empty_when_negative(self): + result = self._get_model(yes_logit=-5.0, no_logit=5.0).predict( + "No findings.", "SVD" + ) + self.assertEqual(result["evidence"], "") + + def test_evidence_non_empty_when_positive(self): + result = self._get_model(yes_logit=5.0, no_logit=-5.0).predict( + "SVD signs observed.", "small vessel disease" + ) + self.assertTrue(result["has_condition"]) + self.assertIsInstance(result["evidence"], str) + + +# --------------------------------------------------------------------------- +# predict_batch() / forward() tests +# --------------------------------------------------------------------------- + +@pytest.mark.skipif(not _TORCH_AVAILABLE, reason="torch not available") +class TestZeroShotEvidenceLLMBatch(unittest.TestCase): + + def _get_model(self): + model = ZeroShotEvidenceLLM(dataset=None, device="cpu") + hf_model, yes_id, no_id = _make_mock_enc_dec_model(3.0, 0.5) + tok = _make_mock_tokenizer(yes_id, no_id) + model._hf_model = hf_model + model._tokenizer = tok + model._is_encoder_decoder = True + return model + + def test_predict_batch_length(self): + samples = [ + {"notes": "Note A.", "query_diagnosis": "SVD"}, + {"notes": "Note B.", "query_diagnosis": "SVD"}, + ] + results = self._get_model().predict_batch(samples) + self.assertEqual(len(results), 2) + + def test_predict_batch_keys(self): + samples = [{"notes": "Note.", "query_diagnosis": "SVD"}] + results = self._get_model().predict_batch(samples) + self.assertIn("has_condition", results[0]) + + def test_forward_returns_lists(self): + out = self._get_model().forward( + notes=["Note A.", "Note B."], + query_diagnosis=["SVD", "AF"], + ) + self.assertIn("has_condition", out) + self.assertEqual(len(out["has_condition"]), 2) + self.assertEqual(len(out["confidence"]), 2) + + +# --------------------------------------------------------------------------- +# CBERT baseline tests +# --------------------------------------------------------------------------- + +@pytest.mark.skipif(not _TORCH_AVAILABLE, reason="torch not available") +class TestZeroShotEvidenceLLMCBERT(unittest.TestCase): + + def test_cbert_flag_routes_to_retrieve(self): + model = ZeroShotEvidenceLLM(dataset=None, use_cbert_baseline=True) + mock_result = { + "has_condition": True, + "evidence": "Best matching sentence.", + "confidence": 0.75, + "model": "Bio_ClinicalBERT", + } + with patch.object(model, "_cbert_retrieve", return_value=mock_result) as mock_fn: + result = model.predict("Patient notes.", "small vessel disease") + mock_fn.assert_called_once() + self.assertTrue(result["has_condition"]) + self.assertEqual(result["confidence"], 0.75) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_mimic3_note_dataset.py b/tests/core/test_mimic3_note_dataset.py new file mode 100644 index 000000000..cb7df9e3b --- /dev/null +++ b/tests/core/test_mimic3_note_dataset.py @@ -0,0 +1,332 @@ +"""Tests for MIMIC3NoteDataset and EHREvidenceRetrievalTask. +Contributor: Abhisek Sinha (abhisek5@illinois.edu) +Paper: `Ahsan et al. (2024) ` +All tests use synthetic data generated in-memory; no real MIMIC files are +required. Tests should complete in milliseconds by design. +""" +import importlib.util +import os +import sys +import types +import unittest +from unittest.mock import MagicMock + +import pytest + +# --------------------------------------------------------------------------- +# Bootstrap: load only the modules we need without triggering the full +# PyHealth package chain (which requires torch >=3.12, torchvision, etc.) +# --------------------------------------------------------------------------- +_PYHEALTH_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) + + +def _register_stub_pkg(dotted_name: str): + """Create a minimal stub package so relative imports resolve.""" + parts = dotted_name.split(".") + for i in range(1, len(parts) + 1): + name = ".".join(parts[:i]) + if name not in sys.modules: + stub = types.ModuleType(name) + stub.__path__ = [os.path.join(_PYHEALTH_ROOT, *parts[:i])] + stub.__package__ = name + stub.__spec__ = importlib.util.spec_from_file_location( + name, os.path.join(_PYHEALTH_ROOT, *parts[:i], "__init__.py") + ) + sys.modules[name] = stub + + +def _load_pyhealth_module(rel_path: str) -> types.ModuleType: + """Load pyhealth/{rel_path} with its package context set correctly. + + Args: + rel_path: Slash-separated path relative to the pyhealth/ package root, + e.g. ``"tasks/ehr_evidence_retrieval.py"``. + """ + abs_path = os.path.join(_PYHEALTH_ROOT, "pyhealth", rel_path.replace("/", os.sep)) + # Derive the dotted module name, e.g. "pyhealth.tasks.ehr_evidence_retrieval" + module_name = "pyhealth." + rel_path.rstrip(".py").replace("/", ".").replace(os.sep, ".") + package_name = ".".join(module_name.split(".")[:-1]) + + # Ensure parent stubs exist + _register_stub_pkg(package_name) + + spec = importlib.util.spec_from_file_location(module_name, abs_path) + mod = importlib.util.module_from_spec(spec) + mod.__package__ = package_name + sys.modules[module_name] = mod + spec.loader.exec_module(mod) + return mod + + +# --------------------------------------------------------------------------- +# Pre-register stubs for heavy dependencies so relative imports succeed +# --------------------------------------------------------------------------- +_register_stub_pkg("pyhealth") +_register_stub_pkg("pyhealth.tasks") +_register_stub_pkg("pyhealth.datasets") + +# Stub for pyhealth.data (ehr_evidence_retrieval.py imports Patient from here) +_data_stub = types.ModuleType("pyhealth.data") +_data_stub.Patient = MagicMock # type: ignore[attr-defined] +sys.modules["pyhealth.data"] = _data_stub + +# Load base_task first so the relative import in ehr_evidence_retrieval resolves +try: + import polars as pl + _POLARS_AVAILABLE = True +except (ImportError, OSError): + _POLARS_AVAILABLE = False + +# base_task only needs polars (already handled by the stub above if polars missing) +_base_task_stub = types.ModuleType("pyhealth.tasks.base_task") + + +class _BaseTask: + task_name: str = "" + input_schema: dict = {} + output_schema: dict = {} + + def __init__(self, code_mapping=None): + pass + + def pre_filter(self, df): + return df + + def __call__(self, patient): + raise NotImplementedError + + +_base_task_stub.BaseTask = _BaseTask # type: ignore[attr-defined] +sys.modules["pyhealth.tasks.base_task"] = _base_task_stub + +# Now load the actual task module +try: + _task_mod = _load_pyhealth_module("tasks/ehr_evidence_retrieval.py") + EHREvidenceRetrievalTask = _task_mod.EHREvidenceRetrievalTask + _TASK_AVAILABLE = True +except Exception as exc: # pragma: no cover + _TASK_AVAILABLE = False + _TASK_LOAD_ERROR = str(exc) + + +pytestmark = pytest.mark.skipif( + not _TASK_AVAILABLE, + reason=f"EHREvidenceRetrievalTask could not be loaded: " + f"{locals().get('_TASK_LOAD_ERROR', 'unknown error')}", +) + + +# --------------------------------------------------------------------------- +# Helpers: synthetic patient/event builder +# --------------------------------------------------------------------------- + +def _make_event(**attrs): + event = MagicMock() + for k, v in attrs.items(): + setattr(event, k, v) + return event + + +def _make_patient(patient_id: str, note_events, diag_events): + patient = MagicMock() + patient.patient_id = patient_id + + def get_events(event_type, **kwargs): + if event_type == "noteevents": + return note_events + if event_type == "diagnoses_icd": + return diag_events + return [] + + patient.get_events.side_effect = get_events + return patient + + +# --------------------------------------------------------------------------- +# MIMIC3NoteDataset tests (config / preprocessing only — no DB loading) +# --------------------------------------------------------------------------- + +class TestMIMIC3NoteDatasetConfig(unittest.TestCase): + """Config-level tests that don't need torch or dask.""" + + def test_config_file_exists(self): + config_path = os.path.join( + _PYHEALTH_ROOT, "pyhealth", "datasets", "configs", "mimic3_note.yaml" + ) + self.assertTrue(os.path.isfile(config_path), f"Missing: {config_path}") + + def test_config_has_noteevents_with_iserror(self): + import yaml + config_path = os.path.join( + _PYHEALTH_ROOT, "pyhealth", "datasets", "configs", "mimic3_note.yaml" + ) + with open(config_path) as f: + cfg = yaml.safe_load(f) + tables = cfg.get("tables", {}) + self.assertIn("noteevents", tables) + attrs = tables["noteevents"].get("attributes", []) + self.assertIn("iserror", attrs) + self.assertIn("text", attrs) + self.assertIn("category", attrs) + + def test_config_has_diagnoses_icd(self): + import yaml + config_path = os.path.join( + _PYHEALTH_ROOT, "pyhealth", "datasets", "configs", "mimic3_note.yaml" + ) + with open(config_path) as f: + cfg = yaml.safe_load(f) + self.assertIn("diagnoses_icd", cfg.get("tables", {})) + + def test_config_version(self): + import yaml + config_path = os.path.join( + _PYHEALTH_ROOT, "pyhealth", "datasets", "configs", "mimic3_note.yaml" + ) + with open(config_path) as f: + cfg = yaml.safe_load(f) + self.assertEqual(cfg.get("version"), "1.4") + + @pytest.mark.skipif(not _POLARS_AVAILABLE, reason="polars not installed") + def test_preprocess_noteevents_casts_iserror(self): + """preprocess_noteevents should coerce iserror to String dtype.""" + # Load only narwhals stub needed by mimic3.py + _narwhals_stub = types.ModuleType("narwhals") + import polars as _pl + # narwhals mirrors polars API; just alias it for the preprocess method + for attr in dir(_pl): + try: + setattr(_narwhals_stub, attr, getattr(_pl, attr)) + except Exception: + pass + sys.modules.setdefault("narwhals", _narwhals_stub) + + try: + _mimic3_mod = _load_pyhealth_module("datasets/mimic3.py") + MIMIC3NoteDataset = _mimic3_mod.MIMIC3NoteDataset + except Exception as exc: + pytest.skip(f"mimic3 module could not load: {exc}") + + obj = object.__new__(MIMIC3NoteDataset) + df = _pl.DataFrame( + { + "charttime": [None, "2020-01-01 08:00:00"], + "chartdate": ["2020-01-01", "2020-01-02"], + "iserror": [1, 0], + } + ).lazy() + result = obj.preprocess_noteevents(df).collect() + self.assertEqual(result["iserror"].dtype, _pl.String) + + +# --------------------------------------------------------------------------- +# EHREvidenceRetrievalTask unit tests +# --------------------------------------------------------------------------- + +class TestEHREvidenceRetrievalTask(unittest.TestCase): + + def _task(self, **kwargs): + return EHREvidenceRetrievalTask( + query_diagnosis="small vessel disease", + condition_icd_codes=["437.3", "437.30"], + **kwargs, + ) + + def test_task_schema(self): + t = self._task() + self.assertEqual(t.task_name, "EHREvidenceRetrieval") + self.assertEqual(t.input_schema, {"notes": "text"}) + self.assertEqual(t.output_schema, {"label": "binary"}) + + def test_positive_label(self): + notes = [_make_event(text="SVD signs noted.", category="Discharge summary", iserror="0")] + diags = [_make_event(icd9_code="437.3")] + patient = _make_patient("P001", notes, diags) + samples = self._task()(patient) + self.assertEqual(len(samples), 1) + self.assertEqual(samples[0]["label"], 1) + self.assertIn("SVD signs noted.", samples[0]["notes"]) + + def test_negative_label(self): + notes = [_make_event(text="No findings.", category="Radiology", iserror="0")] + diags = [_make_event(icd9_code="250.00")] + patient = _make_patient("P002", notes, diags) + samples = self._task()(patient) + self.assertEqual(len(samples), 1) + self.assertEqual(samples[0]["label"], 0) + + def test_no_notes_returns_empty(self): + patient = _make_patient("P003", [], []) + self.assertEqual(self._task()(patient), []) + + def test_iserror_filtering(self): + good = _make_event(text="Good note.", category="Discharge summary", iserror="0") + bad = _make_event(text="Error note.", category="Discharge summary", iserror="1") + patient = _make_patient("P004", [good, bad], []) + samples = self._task()(patient) + self.assertEqual(len(samples), 1) + self.assertNotIn("Error note.", samples[0]["notes"]) + self.assertIn("Good note.", samples[0]["notes"]) + + def test_category_filtering(self): + discharge = _make_event(text="Discharge note.", category="Discharge summary", iserror="0") + nursing = _make_event(text="Nursing note.", category="Nursing", iserror="0") + patient = _make_patient("P005", [discharge, nursing], []) + task = self._task(note_categories=["Discharge summary"]) + samples = task(patient) + self.assertEqual(len(samples), 1) + self.assertIn("Discharge note.", samples[0]["notes"]) + self.assertNotIn("Nursing note.", samples[0]["notes"]) + + def test_max_notes_truncation(self): + notes = [_make_event(text=f"Note {i}.", category=None, iserror="0") for i in range(20)] + patient = _make_patient("P006", notes, []) + samples = self._task(max_notes=5)(patient) + if samples: + self.assertLessEqual(samples[0]["notes"].count("---"), 4) + + def test_custom_separator(self): + n1 = _make_event(text="Note A.", category=None, iserror="0") + n2 = _make_event(text="Note B.", category=None, iserror="0") + patient = _make_patient("P007", [n1, n2], []) + samples = self._task(note_separator=" | ")(patient) + self.assertEqual(len(samples), 1) + self.assertIn(" | ", samples[0]["notes"]) + + def test_empty_text_skipped(self): + good = _make_event(text="Valid note.", category=None, iserror="0") + empty = _make_event(text="", category=None, iserror="0") + patient = _make_patient("P008", [good, empty], []) + samples = self._task()(patient) + self.assertEqual(len(samples), 1) + + def test_patient_id_preserved(self): + notes = [_make_event(text="Note.", category=None, iserror="0")] + patient = _make_patient("P009", notes, []) + samples = self._task()(patient) + self.assertEqual(samples[0]["patient_id"], "P009") + + def test_query_in_sample(self): + notes = [_make_event(text="Note.", category=None, iserror="0")] + patient = _make_patient("P010", notes, []) + samples = self._task()(patient) + self.assertEqual(samples[0]["query_diagnosis"], "small vessel disease") + + def test_multiple_matching_icd_codes(self): + """Any matching ICD code should set label=1.""" + notes = [_make_event(text="Vessel disease.", category=None, iserror="0")] + diags = [_make_event(icd9_code="437.30")] # second code in the set + patient = _make_patient("P011", notes, diags) + samples = self._task()(patient) + self.assertEqual(samples[0]["label"], 1) + + def test_no_diagnosis_events_gives_label_0(self): + """No diagnosis events at all → negative label.""" + notes = [_make_event(text="Some note.", category=None, iserror="0")] + patient = _make_patient("P012", notes, []) + samples = self._task()(patient) + self.assertEqual(samples[0]["label"], 0) + + +if __name__ == "__main__": + unittest.main()