Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ Available Datasets
datasets/pyhealth.datasets.MIMIC3Dataset
datasets/pyhealth.datasets.MIMIC4Dataset
datasets/pyhealth.datasets.MedicalTranscriptionsDataset
datasets/pyhealth.datasets.MedLingoDataset
datasets/pyhealth.datasets.CardiologyDataset
datasets/pyhealth.datasets.eICUDataset
datasets/pyhealth.datasets.ISRUCDataset
Expand Down
7 changes: 7 additions & 0 deletions docs/api/datasets/pyhealth.datasets.medlingo.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pyhealth.datasets.MedLingoDataset
===================================

.. autoclass:: pyhealth.datasets.medlingo.MedLingoDataset
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/api/tasks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ Available Tasks
Drug Recommendation <tasks/pyhealth.tasks.drug_recommendation>
Length of Stay Prediction <tasks/pyhealth.tasks.length_of_stay_prediction>
Medical Transcriptions Classification <tasks/pyhealth.tasks.MedicalTranscriptionsClassification>
MedLingo Jargon Expansion <tasks/pyhealth.tasks.medlingo_jargon_expansion>
Mortality Prediction (Next Visit) <tasks/pyhealth.tasks.mortality_prediction>
Mortality Prediction (StageNet MIMIC-IV) <tasks/pyhealth.tasks.mortality_prediction_stagenet_mimic4>
Patient Linkage (MIMIC-III) <tasks/pyhealth.tasks.patient_linkage_mimic3_fn>
Expand Down
7 changes: 7 additions & 0 deletions docs/api/tasks/pyhealth.tasks.medlingo_jargon_expansion.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pyhealth.tasks.medlingo_jargon_expansion
========================================

.. autoclass:: pyhealth.tasks.medlingo_jargon_expansion.MedLingoJargonExpansionTask
:members:
:undoc-members:
:show-inheritance:
119 changes: 119 additions & 0 deletions examples/medlingo_medlingo_jargon_expansion_transformersmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
MedLingo jargon expansion with :class:`~pyhealth.models.TransformersModel`.

**Paper:** Jia, Sontag & Agrawal — *Diagnosing our datasets* (CHIL 2025),
https://arxiv.org/abs/2505.15024. Public CSV: ``questions.csv`` (columns
``word1``, ``word2``, ``question``, ``answer``) from the MedLingo export in
Flora-jia-jfr/diagnosing_our_datasets — place that file under the directory you
pass as ``root`` below.

**Ablation (two task configs):**

- ``MedLingoJargonExpansionTask(shot_mode="one_shot")`` — ``prompt`` is the
released ``question`` string (matches the distributed MedLingo item).
- ``MedLingoJargonExpansionTask(shot_mode="zero_shot")`` — ``prompt`` is rebuilt
from ``word1`` and ``word2`` only; the CSV ``question`` field is not used, so
any one-shot / ICL demo in that column is stripped by construction.

**Limitation vs the paper:** this PyHealth task uses **multiclass classification
on the string ``answer``** (via ``TransformersModel`` + Hugging Face encoders).
The paper evaluates **open-ended** generations with an LLM judge; this script
does not reproduce that protocol.

**Smoke run (no Hugging Face download):** by default this script only builds the
dataset, runs ``set_task`` for both shot modes, and prints sample counts. To
also run one forward pass with a **tiny** BERT (small one-time download unless
cached), set environment variable ``PYHEALTH_MEDLINGO_RUN_MODEL=1``::

PYHEALTH_MEDLINGO_RUN_MODEL=1 python examples/medlingo_medlingo_jargon_expansion_transformersmodel.py

Optional: ``PYHEALTH_MEDLINGO_MODEL=<hf_model_id>`` overrides the tiny default
(``hf-internal-testing/tiny-random-bert``).

Run from the repository root after ``pip install -e .``, or set
``PYTHONPATH`` to the repo root so ``import pyhealth`` resolves.
"""

from __future__ import annotations

import logging
import os
import tempfile
from pathlib import Path

import pandas as pd

logging.basicConfig(level=logging.WARNING)
for _name in ("pyhealth", "pyhealth.datasets", "pyhealth.datasets.base_dataset"):
logging.getLogger(_name).setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def _write_synthetic_questions_csv(path: Path) -> None:
"""Tiny stand-in for ``datasets/MedLingo/questions.csv`` (no secrets)."""
rows = [
{
"word1": "MI",
"word2": "STEMI",
"question": "ICL_STUB What is MI vs STEMI in one sentence?",
"answer": "types of heart attack",
},
{
"word1": "HTN",
"word2": "BP",
"question": "ICL_STUB Define HTN.",
"answer": "high blood pressure",
},
]
pd.DataFrame(rows).to_csv(path, index=False)


def main() -> None:
from pyhealth.datasets import MedLingoDataset, get_dataloader
from pyhealth.tasks import MedLingoJargonExpansionTask

tmp = Path(tempfile.mkdtemp(prefix="pyhealth_medlingo_"))
root = tmp / "root"
root.mkdir()
cache = tmp / "cache"
_write_synthetic_questions_csv(root / "questions.csv")

base = MedLingoDataset(root=str(root), cache_dir=cache, num_workers=1)
logger.info("Patients: %s", len(base.unique_patient_ids))

for shot in ("one_shot", "zero_shot"):
task = MedLingoJargonExpansionTask(shot_mode=shot)
samples = base.set_task(task=task, num_workers=1)
logger.info("shot_mode=%s -> %s samples", shot, len(samples))
if len(samples):
s0 = samples[0]
logger.info("First keys: %s", sorted(s0.keys()))

if os.environ.get("PYHEALTH_MEDLINGO_RUN_MODEL") != "1":
logger.info(
"Skipping TransformersModel forward (no download). "
"Set PYHEALTH_MEDLINGO_RUN_MODEL=1 to run a tiny HF model on one batch."
)
return

from pyhealth.models import TransformersModel

model_name = os.environ.get(
"PYHEALTH_MEDLINGO_MODEL", "hf-internal-testing/tiny-random-bert"
)
task = MedLingoJargonExpansionTask(shot_mode="one_shot")
samples = base.set_task(task=task, num_workers=1)
loader = get_dataloader(samples, batch_size=2, shuffle=False)
model = TransformersModel(dataset=samples, model_name=model_name)
model.eval()
batch = next(iter(loader))
import torch

with torch.no_grad():
out = model(**batch)
logger.info("Forward ok; loss=%s", out.get("loss"))


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions pyhealth/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self, *args, **kwargs):
from .eicu import eICUDataset
from .isruc import ISRUCDataset
from .medical_transcriptions import MedicalTranscriptionsDataset
from .medlingo import MedLingoDataset
from .mimic3 import MIMIC3Dataset
from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset
from .mimicextract import MIMICExtractDataset
Expand Down
11 changes: 11 additions & 0 deletions pyhealth/datasets/configs/medlingo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
version: "1.0"
tables:
questions:
file_path: "questions.csv"
patient_id: null
timestamp: null
attributes:
- "word1"
- "word2"
- "question"
- "answer"
82 changes: 82 additions & 0 deletions pyhealth/datasets/medlingo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import logging
from pathlib import Path
from typing import Any

import narwhals as nw

from ..tasks.medlingo_jargon_expansion import MedLingoJargonExpansionTask
from .base_dataset import BaseDataset

logger = logging.getLogger(__name__)

# Expected public export from Flora-jia-jfr/diagnosing_our_datasets:
# datasets/MedLingo/questions.csv with columns word1, word2, question, answer.
_REQUIRED_QUESTION_COLUMNS = frozenset({"word1", "word2", "question", "answer"})


class MedLingoDataset(BaseDataset):
"""MedLingo jargon QA rows from the *Diagnosing our datasets* line of work.

Public MedLingo data (e.g. ``questions.csv``) is released with the paper
*Diagnosing our datasets* (Jia, Sontag & Agrawal, CHIL 2025,
https://arxiv.org/abs/2505.15024). Place ``questions.csv`` under ``root``
(same layout as ``datasets/MedLingo/questions.csv`` in the paper's data
repo). Each CSV row becomes one synthetic patient with a single
``questions`` event; attributes are ``word1``, ``word2``, ``question``,
and ``answer`` (column names are matched case-insensitively after load).

Args:
root: Directory containing ``questions.csv``.
dataset_name: Optional override for the dataset name.
config_path: YAML config path; defaults to ``configs/medlingo.yaml``.
cache_dir: Optional cache root (see :class:`BaseDataset`).
num_workers: Workers for task/sample transforms.
dev: If True, limits to the first 1000 patients (see ``BaseDataset``).

Note:
:meth:`default_task` uses ``MedLingoJargonExpansionTask(shot_mode=
\"one_shot\")`` so ``set_task()`` matches the released CSV prompts.
Pass ``MedLingoJargonExpansionTask(shot_mode=\"zero_shot\")`` for the
ablation that rebuilds the prompt from ``word1``/``word2`` only.
"""

def __init__(
self,
root: str,
dataset_name: str | None = None,
config_path: str | Path | None = None,
cache_dir=None,
num_workers: int = 1,
dev: bool = False,
) -> None:
if config_path is None:
logger.info("No config path provided, using default MedLingo config")
config_path = Path(__file__).parent / "configs" / "medlingo.yaml"
default_tables = ["questions"]
super().__init__(
root=root,
tables=default_tables,
dataset_name=dataset_name or "medlingo",
config_path=str(config_path),
cache_dir=cache_dir,
num_workers=num_workers,
dev=dev,
)

@property
def default_task(self) -> MedLingoJargonExpansionTask:
"""Default MedLingo task using the released one-shot ``question`` text."""
return MedLingoJargonExpansionTask(shot_mode="one_shot")

def preprocess_questions(self, df: Any) -> Any:
"""Ensure required MedLingo columns exist after lowercasing names."""
lf = nw.from_native(df)
names = set(lf.columns)
missing = _REQUIRED_QUESTION_COLUMNS - names
if missing:
raise ValueError(
"questions.csv is missing required column(s): "
f"{sorted(missing)}. Expected columns: "
f"{sorted(_REQUIRED_QUESTION_COLUMNS)} (case-insensitive)."
)
return lf
1 change: 1 addition & 0 deletions pyhealth/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .length_of_stay_stagenet_mimic4 import LengthOfStayStageNetMIMIC4
from .medical_coding import MIMIC3ICD9Coding
from .medical_transcriptions_classification import MedicalTranscriptionsClassification
from .medlingo_jargon_expansion import MedLingoJargonExpansionTask
from .mortality_prediction import (
MortalityPredictionEICU,
MortalityPredictionEICU2,
Expand Down
108 changes: 108 additions & 0 deletions pyhealth/tasks/medlingo_jargon_expansion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""MedLingo jargon expansion task (plain-language answer from a prompt).

Tied to *Diagnosing our datasets* (Jia, Sontag & Agrawal, CHIL 2025;
https://arxiv.org/abs/2505.15024). This task is a **multiclass shortcut** over
the string ``answer`` column; it does not reproduce the paper's open-ended
generation plus LLM-as-judge setup.
"""

from __future__ import annotations

from typing import Any, Dict, List, Literal, Optional, Tuple

from ..data import Event, Patient
from .base_task import BaseTask

ShotMode = Literal["zero_shot", "one_shot"]


def _as_str(value: Any) -> Optional[str]:
"""Return a clean string or None if the value is unusable."""
if value is None:
return None
text = str(value).strip()
if not text or text.lower() == "nan":
return None
return text


class MedLingoJargonExpansionTask(BaseTask):
"""Map each MedLingo row to a text prompt and a plain-language ``answer``.

Ablation (``shot_mode``), aligned with the course rubric:

- **one_shot**: Use the ``question`` field verbatim as ``prompt``. This
matches the **released** MedLingo item (including any in-context demo
baked into that string).
- **zero_shot**: Do **not** use ``question``. Rebuild a minimal instruction
from ``word1`` and ``word2`` only so the model never sees the released
one-shot prompt (ICL demonstration stripped by construction).

Attributes:
task_name: Includes ``shot_mode`` so caches differ per configuration.
shot_mode: Either ``\"zero_shot\"`` or ``\"one_shot\"``.
input_schema: Single ``\"text\"`` field ``prompt`` for encoder models.
output_schema: ``answer`` as ``\"multiclass\"`` over distinct strings.
"""

input_schema: Dict[str, str] = {"prompt": "text"}
output_schema: Dict[str, str] = {"answer": "multiclass"}

def __init__(
self,
shot_mode: ShotMode = "one_shot",
code_mapping: Optional[Dict[str, Tuple[str, str]]] = None,
) -> None:
if shot_mode not in ("zero_shot", "one_shot"):
raise ValueError(
f"shot_mode must be 'zero_shot' or 'one_shot', got {shot_mode!r}"
)
super().__init__(code_mapping=code_mapping)
self.shot_mode: ShotMode = shot_mode
self.task_name = f"MedLingoJargonExpansionTask/{shot_mode}"

def _build_prompt(self, event: Event) -> Optional[str]:
"""Build model input text for the current ``shot_mode``."""
word1 = _as_str(event.word1)
word2 = _as_str(event.word2)
question = _as_str(event.question)

if self.shot_mode == "one_shot":
# Released conditioning: full CSV ``question`` (demo + query as
# distributed).
return question

# zero_shot: ignore ``question`` entirely; ICL is not present by design.
if word1 is None or word2 is None:
return None
return (
"In plain language, define the medical jargon that connects "
f'"{word1}" and "{word2}". Respond with the plain-language '
"definition only."
)

def __call__(self, patient: Patient) -> List[Dict[str, Any]]:
"""Emit one sample per patient when fields are valid.

Args:
patient: Synthetic patient with a single ``questions`` event.

Returns:
A one-element list with ``id``, ``prompt``, and ``answer``, or
empty if required fields are missing.
"""
events = patient.get_events(event_type="questions")
if len(events) != 1:
return []
event = events[0]
answer = _as_str(event.answer)
prompt = self._build_prompt(event)
if prompt is None or answer is None:
return []
return [
{
"id": patient.patient_id,
"prompt": prompt,
"answer": answer,
}
]
Loading