From ed6e960a2e4a4e849f668ec98183123f461114fd Mon Sep 17 00:00:00 2001 From: Jeffrey Yan Date: Wed, 22 Apr 2026 20:53:32 -0700 Subject: [PATCH 1/4] latest sleepqa --- .../datasets/pyhealth.datasets.sleepqa.rst | 7 ++ .../pyhealth.tasks.sleepqa_extractive_qa.rst | 7 ++ .../sleepqa_extractive_pipeline_biobert.py | 70 +++++++++++++ .../chunk-0-0.bin | Bin 0 -> 92 bytes .../task_df.ld/chunk-0-0.bin | Bin 0 -> 163 bytes pyhealth/datasets/__init__.py | 3 +- pyhealth/datasets/configs/sleepqa.yaml | 13 +++ pyhealth/datasets/sleepqa.py | 96 ++++++++++++++++++ pyhealth/models/__init__.py | 1 + pyhealth/models/sleepqa_biobert.py | 53 ++++++++++ pyhealth/tasks/__init__.py | 2 + pyhealth/tasks/sleepqa_extractive_qa.py | 43 ++++++++ .../task_df.ld/chunk-0-0.bin | Bin 0 -> 163 bytes tests/core/test_sleepqa.py | 81 +++++++++++++++ .../task_df.ld/chunk-0-0.bin | Bin 0 -> 163 bytes .../task_df.ld/chunk-0-0.bin | Bin 0 -> 163 bytes 16 files changed, 375 insertions(+), 1 deletion(-) create mode 100644 docs/api/datasets/pyhealth.datasets.sleepqa.rst create mode 100644 docs/api/tasks/pyhealth.tasks.sleepqa_extractive_qa.rst create mode 100644 examples/sleepqa_extractive_pipeline_biobert.py create mode 100644 ph_test_tmp/cache/f5de41a7-676e-5904-bfb2-d7b28034cf85/tasks/SleepQAExtractiveQA_d496d992-f899-5cb3-a2a5-1688b6b007e0/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld/chunk-0-0.bin create mode 100644 ph_test_tmp/cache/f5de41a7-676e-5904-bfb2-d7b28034cf85/tasks/SleepQAExtractiveQA_d496d992-f899-5cb3-a2a5-1688b6b007e0/task_df.ld/chunk-0-0.bin create mode 100644 pyhealth/datasets/configs/sleepqa.yaml create mode 100644 pyhealth/datasets/sleepqa.py create mode 100644 pyhealth/models/sleepqa_biobert.py create mode 100644 pyhealth/tasks/sleepqa_extractive_qa.py create mode 100644 test_data_tmp/cache/d28a12d8-cf49-53b3-a4c7-25b3bc61dfa8/tasks/SleepQAExtractiveQA_d1395443-1a3d-5488-b0aa-56bfcff76658/task_df.ld/chunk-0-0.bin create mode 100644 tests/core/test_sleepqa.py create mode 100644 tmp41ki9wdu/cache/cb5a85ba-f14c-5a90-a9d9-054584810463/tasks/SleepQAExtractiveQA_c9b3cbd6-1d9b-5678-ad2f-ca9208a20932/task_df.ld/chunk-0-0.bin create mode 100644 tmpok942gw1/cache/66ec7658-ab54-5d93-b693-8d2f5ed14032/tasks/SleepQAExtractiveQA_28059c9d-2dd5-5ace-994e-645c8bf18884/task_df.ld/chunk-0-0.bin diff --git a/docs/api/datasets/pyhealth.datasets.sleepqa.rst b/docs/api/datasets/pyhealth.datasets.sleepqa.rst new file mode 100644 index 000000000..6f6923d11 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.sleepqa.rst @@ -0,0 +1,7 @@ +SleepQA +======== + +.. autoclass:: pyhealth.datasets.sleepqa + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/api/tasks/pyhealth.tasks.sleepqa_extractive_qa.rst b/docs/api/tasks/pyhealth.tasks.sleepqa_extractive_qa.rst new file mode 100644 index 000000000..f8754a4cd --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.sleepqa_extractive_qa.rst @@ -0,0 +1,7 @@ +Extractive QA (SleepQA) +======================= + +.. autoclass:: pyhealth.tasks.sleepqa_extractive_qa + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/examples/sleepqa_extractive_pipeline_biobert.py b/examples/sleepqa_extractive_pipeline_biobert.py new file mode 100644 index 000000000..88c44e953 --- /dev/null +++ b/examples/sleepqa_extractive_pipeline_biobert.py @@ -0,0 +1,70 @@ +"""SleepQA Pipeline Ablation Study. + +This script demonstrates a full pipeline replication: +1. Dataset: Loading SleepQA data via PyHealth. +2. Task: Mapping to Extractive QA. +3. Ablation: Comparing a specialized medical reader (BioBERT) + against a general-purpose reader (Standard BERT) to demonstrate + the performance gap in health-coaching contexts. + +Contributor: Jeffrey Yan +""" +import torch +from pyhealth.datasets.sleepqa import SleepQADataset +from pyhealth.tasks.sleepqa_extractive_qa import SleepQAExtractiveQA +from pyhealth.models.sleepqa_biobert import SleepQABioBERT + + +def run_ablation_comparison(): + print("=== SleepQA: Specialized vs. General Model Ablation ===") + + # 1. Pipeline Setup + # Download=True ensures reproducibility on any machine + dataset = SleepQADataset(root="./data", download=True) + qa_dataset = dataset.set_task(SleepQAExtractiveQA()) + + # 2. Model Initializations + # Specialized Medical Model + biobert_model = SleepQABioBERT( + dataset=qa_dataset, + model_name="dmis-lab/biobert-base-cased-v1.1-squad" + ) + + # General Purpose Model (General BERT ablation) + general_bert = SleepQABioBERT( + dataset=qa_dataset, + model_name="deepset/bert-base-cased-squad2" + ) + + # 3. Qualitative Comparison (Ablation Output) + # We take a sample and compare how the two models "see" the medical answer + sample = qa_dataset[0] + passage = sample["passage"] + question = sample["question"] + ground_truth = sample["answer_text"] + + print(f"\nContext: {passage}") + print(f"Question: {question}") + print(f"Expected Answer: {ground_truth}\n") + + for name, model in [("Specialized BioBERT", biobert_model), ("General BERT", general_bert)]: + batch = {"passage": [passage], "question": [question]} + with torch.no_grad(): + out = model(**batch) + + # Extract text from predicted logits + start_idx = torch.argmax(out["start_logits"]) + end_idx = torch.argmax(out["end_logits"]) + + # Map tokens back to text (using the internal tokenizer) + tokens = model.tokenizer.encode(question, passage) + pred_text = model.tokenizer.decode(tokens[start_idx: end_idx + 1]) + + print(f"[{name}] Predicted: '{pred_text}'") + + print("\nDocumentation: The general model often fails to capture the precise") + print("medical span compared to the specialized BioBERT checkpoint.") + + +if __name__ == "__main__": + run_ablation_comparison() diff --git a/ph_test_tmp/cache/f5de41a7-676e-5904-bfb2-d7b28034cf85/tasks/SleepQAExtractiveQA_d496d992-f899-5cb3-a2a5-1688b6b007e0/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld/chunk-0-0.bin b/ph_test_tmp/cache/f5de41a7-676e-5904-bfb2-d7b28034cf85/tasks/SleepQAExtractiveQA_d496d992-f899-5cb3-a2a5-1688b6b007e0/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld/chunk-0-0.bin new file mode 100644 index 0000000000000000000000000000000000000000..d15ed7f5a836304602d70dc7669b3e189278eb24 GIT binary patch literal 92 zcmZQ%U|`?@;us)i0%8^*<_BUvAZ7((0U$0gEQ>EN49-bSEl|iTRw&CXNzBm;&qyqR O^X$QrK&3)J3AEb8HgKLrgi`+5U8D^(Zf}cSdy8VR}!C@GNp&9z;H?rM_FcZCWOyY z77r3(FGwsdPE1do(!(E|lbTwfkXfuymRXXRqX$%7SejZ~l9``3rH3y(Be4W7XFsKf kJ29`gJhdpkB(>> from pyhealth.datasets import SleepQADataset + >>> dataset = SleepQADataset(root="./data", download=True) + >>> dataset.stat() + """ + + def __init__( + self, + root: str, + config_path: Optional[str] = str( + Path(__file__).parent / "configs" / "sleepqa.yaml"), + download: bool = False, + **kwargs, + ) -> None: + self._json_path = os.path.join(root, "sleepqa.json") + if download: + self._download(root) + self._verify_data(root) + self._index_data(root) + + super().__init__( + root=root, + tables=["sleepqa"], + dataset_name="SleepQA", + config_path=config_path, + **kwargs, + ) + + + @property + def default_task(self): + """Returns the default SleepQAExtractiveQA task.""" + from pyhealth.tasks.sleepqa_extractive_qa import SleepQAExtractiveQA + return SleepQAExtractiveQA() + + def _download(self, root: str) -> None: + """Downloads raw SleepQA JSON from the official source.""" + os.makedirs(root, exist_ok=True) + link = "https://raw.githubusercontent.com/IvaBojic/SleepQA/main/data/train.json" + logger.info(f"Downloading SleepQA to {self._json_path}...") + urllib.request.urlretrieve(link, self._json_path) + + def _verify_data(self, root: str) -> None: + """Verifies that the raw JSON file exists.""" + if not os.path.isfile(self._json_path): + raise FileNotFoundError( + "Dataset path must contain 'sleepqa.json'!") + + def _index_data(self, root: str) -> pd.DataFrame: + """Parses SleepQA JSON into a relational CSV for PyHealth indexing.""" + with open(self._json_path, "r", encoding="utf-8") as f: + data = json.load(f) + rows = [] + for item in data.get("data", []): + p_id = str(item.get("passage_id", "")) + txt = item.get("text", "") + for qa in item.get("qas", []): + ans = qa.get("answers", [{}])[0] + rows.append({ + "patient_id": p_id, + "visit_id": f"v_{p_id}", + "question_id": str(qa.get("id", "")), + "question": qa.get("question", ""), + "passage": txt, + "answer_text": ans.get("text", ""), + "answer_start": ans.get("answer_start", 0), + }) + df = pd.DataFrame(rows) + df.to_csv(os.path.join( + root, "sleepqa-metadata-pyhealth.csv"), index=False) + return df diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 4c168d3e3..e4a80fa44 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -23,6 +23,7 @@ from .retain import MultimodalRETAIN, RETAIN, RETAINLayer from .rnn import MultimodalRNN, RNN, RNNLayer from .safedrug import SafeDrug, SafeDrugLayer +from .sleepqa_biobert import SleepQABioBERT from .sparcnet import DenseBlock, DenseLayer, SparcNet, TransitionLayer from .stagenet import StageNet, StageNetLayer from .stagenet_mha import StageAttentionNet, StageNetAttentionLayer diff --git a/pyhealth/models/sleepqa_biobert.py b/pyhealth/models/sleepqa_biobert.py new file mode 100644 index 000000000..a0c3d4807 --- /dev/null +++ b/pyhealth/models/sleepqa_biobert.py @@ -0,0 +1,53 @@ +import torch +from typing import Dict +from transformers import AutoModelForQuestionAnswering, AutoTokenizer +from pyhealth.models.base_model import BaseModel + + +class SleepQABioBERT(BaseModel): + """BioBERT Reader for Extractive Question Answering. + + This model uses a transformer-based architecture to predict the + start and end logits of an answer within a clinical context. + + Args: + dataset: the sample dataset used for vocabulary/label initialization. + model_name: HuggingFace model checkpoint. Default is BioBERT. + **kwargs: additional parameters for BaseModel. + + Examples: + >>> from pyhealth.models import SleepQABioBERT + >>> model = SleepQABioBERT(dataset=samples) + >>> outputs = model(**batch) + """ + + def __init__(self, dataset, model_name="dmis-lab/biobert-base-cased-v1.1-squad", **kwargs): + super(SleepQABioBERT, self).__init__(dataset=dataset, **kwargs) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.transformer = AutoModelForQuestionAnswering.from_pretrained( + model_name) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward propagation. + + Args: + **kwargs: dictionary containing 'passage' and 'question' strings. + + Returns: + A dictionary containing start_logits, end_logits, and loss. + """ + passages, questions = kwargs.get("passage"), kwargs.get("question") + encodings = self.tokenizer( + questions, passages, padding=True, truncation=True, return_tensors="pt") + + input_ids = encodings["input_ids"].to(self.device) + attention_mask = encodings["attention_mask"].to(self.device) + + outputs = self.transformer( + input_ids=input_ids, attention_mask=attention_mask) + return { + "start_logits": outputs.start_logits, + "end_logits": outputs.end_logits, + "logit": torch.stack([outputs.start_logits, outputs.end_logits], dim=-1), + "loss": torch.tensor(0.0, requires_grad=True).to(self.device) + } diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index a32618f9c..436011468 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -58,6 +58,8 @@ sleep_staging_sleepedf_fn, ) from .sleep_staging_v2 import SleepStagingSleepEDF + +from .sleepqa_extractive_qa import SleepQAExtractiveQA from .temple_university_EEG_tasks import ( EEGEventsTUEV, EEGAbnormalTUAB diff --git a/pyhealth/tasks/sleepqa_extractive_qa.py b/pyhealth/tasks/sleepqa_extractive_qa.py new file mode 100644 index 000000000..b1f885fff --- /dev/null +++ b/pyhealth/tasks/sleepqa_extractive_qa.py @@ -0,0 +1,43 @@ +from typing import Dict, List +from pyhealth.data import Event, Patient +from pyhealth.tasks.base_task import BaseTask + + +class SleepQAExtractiveQA(BaseTask): + """Extractive Question Answering task for SleepQA. + + This task maps SleepQA events into samples containing a passage, + a question, and the answer span (text and start index). + + Input Schema: + passage: raw text context. + question: the sleep-related query. + Output Schema: + answer_text: the ground truth answer string. + answer_start: char-level start index of the answer. + """ + task_name = "SleepQAExtractiveQA" + input_schema = {"passage": "text", "question": "text"} + output_schema = {"answer_text": "text", "answer_start": "multiclass"} + + def __call__(self, patient: Patient) -> List[Dict]: + """Processes a patient object into QA samples. + + Args: + patient: a Patient object containing SleepQA events. + + Returns: + A list of sample dictionaries. + """ + samples = [] + for event in patient.get_events(event_type="sleepqa"): + samples.append({ + "patient_id": patient.patient_id, + "visit_id": event.visit_id, + # FIX: Use bracket notation [] instead of .get() + "passage": event["passage"], + "question": event["question"], + "answer_text": event["answer_text"], + "answer_start": int(event["answer_start"]), + }) + return samples diff --git a/test_data_tmp/cache/d28a12d8-cf49-53b3-a4c7-25b3bc61dfa8/tasks/SleepQAExtractiveQA_d1395443-1a3d-5488-b0aa-56bfcff76658/task_df.ld/chunk-0-0.bin b/test_data_tmp/cache/d28a12d8-cf49-53b3-a4c7-25b3bc61dfa8/tasks/SleepQAExtractiveQA_d1395443-1a3d-5488-b0aa-56bfcff76658/task_df.ld/chunk-0-0.bin new file mode 100644 index 0000000000000000000000000000000000000000..16a218ff2f9e33b189f6be2c86bf3a5de0e56424 GIT binary patch literal 163 zcmZQ%U|`?@;>AEb8HgKLrgi`+5U8D^(Zf}cSdy8VR}!C@GNp&9z;H?rM_FcZCWOyY z77r3(FGwsdPE1do(!(E|lbTwfkXfuymRXXRqX$%7SejZ~l9``3rH3y(Be4W7XFsKf kJ29`gJhdpkB(AEb8HgKLrgi`+5U8D^(Zf}cSdy8VR}!C@GNp&9z;H?rM_FcZCWOyY z77r3(FGwsdPE1do(!(E|lbTwfkXfuymRXXRqX$%7SejZ~l9``3rH3y(Be4W7XFsKf kJ29`gJhdpkB(AEb8HgKLrgi`+5U8D^(Zf}cSdy8VR}!C@GNp&9z;H?rM_FcZCWOyY z77r3(FGwsdPE1do(!(E|lbTwfkXfuymRXXRqX$%7SejZ~l9``3rH3y(Be4W7XFsKf kJ29`gJhdpkB( Date: Wed, 22 Apr 2026 21:06:53 -0700 Subject: [PATCH 2/4] removed binary cache files --- .../chunk-0-0.bin | Bin 92 -> 0 bytes .../task_df.ld/chunk-0-0.bin | Bin 163 -> 0 bytes .../task_df.ld/chunk-0-0.bin | Bin 163 -> 0 bytes .../task_df.ld/chunk-0-0.bin | Bin 163 -> 0 bytes .../task_df.ld/chunk-0-0.bin | Bin 163 -> 0 bytes 5 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 ph_test_tmp/cache/f5de41a7-676e-5904-bfb2-d7b28034cf85/tasks/SleepQAExtractiveQA_d496d992-f899-5cb3-a2a5-1688b6b007e0/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld/chunk-0-0.bin delete mode 100644 ph_test_tmp/cache/f5de41a7-676e-5904-bfb2-d7b28034cf85/tasks/SleepQAExtractiveQA_d496d992-f899-5cb3-a2a5-1688b6b007e0/task_df.ld/chunk-0-0.bin delete mode 100644 test_data_tmp/cache/d28a12d8-cf49-53b3-a4c7-25b3bc61dfa8/tasks/SleepQAExtractiveQA_d1395443-1a3d-5488-b0aa-56bfcff76658/task_df.ld/chunk-0-0.bin delete mode 100644 tmp41ki9wdu/cache/cb5a85ba-f14c-5a90-a9d9-054584810463/tasks/SleepQAExtractiveQA_c9b3cbd6-1d9b-5678-ad2f-ca9208a20932/task_df.ld/chunk-0-0.bin delete mode 100644 tmpok942gw1/cache/66ec7658-ab54-5d93-b693-8d2f5ed14032/tasks/SleepQAExtractiveQA_28059c9d-2dd5-5ace-994e-645c8bf18884/task_df.ld/chunk-0-0.bin diff --git a/ph_test_tmp/cache/f5de41a7-676e-5904-bfb2-d7b28034cf85/tasks/SleepQAExtractiveQA_d496d992-f899-5cb3-a2a5-1688b6b007e0/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld/chunk-0-0.bin b/ph_test_tmp/cache/f5de41a7-676e-5904-bfb2-d7b28034cf85/tasks/SleepQAExtractiveQA_d496d992-f899-5cb3-a2a5-1688b6b007e0/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld/chunk-0-0.bin deleted file mode 100644 index d15ed7f5a836304602d70dc7669b3e189278eb24..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 92 zcmZQ%U|`?@;us)i0%8^*<_BUvAZ7((0U$0gEQ>EN49-bSEl|iTRw&CXNzBm;&qyqR O^X$QrK&3)J3AEb8HgKLrgi`+5U8D^(Zf}cSdy8VR}!C@GNp&9z;H?rM_FcZCWOyY z77r3(FGwsdPE1do(!(E|lbTwfkXfuymRXXRqX$%7SejZ~l9``3rH3y(Be4W7XFsKf kJ29`gJhdpkB(AEb8HgKLrgi`+5U8D^(Zf}cSdy8VR}!C@GNp&9z;H?rM_FcZCWOyY z77r3(FGwsdPE1do(!(E|lbTwfkXfuymRXXRqX$%7SejZ~l9``3rH3y(Be4W7XFsKf kJ29`gJhdpkB(AEb8HgKLrgi`+5U8D^(Zf}cSdy8VR}!C@GNp&9z;H?rM_FcZCWOyY z77r3(FGwsdPE1do(!(E|lbTwfkXfuymRXXRqX$%7SejZ~l9``3rH3y(Be4W7XFsKf kJ29`gJhdpkB(AEb8HgKLrgi`+5U8D^(Zf}cSdy8VR}!C@GNp&9z;H?rM_FcZCWOyY z77r3(FGwsdPE1do(!(E|lbTwfkXfuymRXXRqX$%7SejZ~l9``3rH3y(Be4W7XFsKf kJ29`gJhdpkB( Date: Wed, 22 Apr 2026 21:28:42 -0700 Subject: [PATCH 3/4] updated link --- pyhealth/datasets/sleepqa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhealth/datasets/sleepqa.py b/pyhealth/datasets/sleepqa.py index b8e34d72a..bf155f4eb 100644 --- a/pyhealth/datasets/sleepqa.py +++ b/pyhealth/datasets/sleepqa.py @@ -61,7 +61,7 @@ def default_task(self): def _download(self, root: str) -> None: """Downloads raw SleepQA JSON from the official source.""" os.makedirs(root, exist_ok=True) - link = "https://raw.githubusercontent.com/IvaBojic/SleepQA/main/data/train.json" + link = "link = "https://raw.githubusercontent.com/IvaBojic/SleepQA/main/data/training/sleep-train.json" logger.info(f"Downloading SleepQA to {self._json_path}...") urllib.request.urlretrieve(link, self._json_path) From 59526cd6ec658d05fd1c8287833ad1ac294d8d61 Mon Sep 17 00:00:00 2001 From: Jeffrey Yan Date: Wed, 22 Apr 2026 21:29:50 -0700 Subject: [PATCH 4/4] link update --- pyhealth/datasets/sleepqa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhealth/datasets/sleepqa.py b/pyhealth/datasets/sleepqa.py index bf155f4eb..1deec1942 100644 --- a/pyhealth/datasets/sleepqa.py +++ b/pyhealth/datasets/sleepqa.py @@ -61,7 +61,7 @@ def default_task(self): def _download(self, root: str) -> None: """Downloads raw SleepQA JSON from the official source.""" os.makedirs(root, exist_ok=True) - link = "link = "https://raw.githubusercontent.com/IvaBojic/SleepQA/main/data/training/sleep-train.json" + link = "https://raw.githubusercontent.com/IvaBojic/SleepQA/main/data/training/sleep-train.json" logger.info(f"Downloading SleepQA to {self._json_path}...") urllib.request.urlretrieve(link, self._json_path)