Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/api/datasets/pyhealth.datasets.sleepqa.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SleepQA
========

.. autoclass:: pyhealth.datasets.sleepqa
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/api/tasks/pyhealth.tasks.sleepqa_extractive_qa.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Extractive QA (SleepQA)
=======================

.. autoclass:: pyhealth.tasks.sleepqa_extractive_qa
:members:
:undoc-members:
:show-inheritance:
70 changes: 70 additions & 0 deletions examples/sleepqa_extractive_pipeline_biobert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""SleepQA Pipeline Ablation Study.

This script demonstrates a full pipeline replication:
1. Dataset: Loading SleepQA data via PyHealth.
2. Task: Mapping to Extractive QA.
3. Ablation: Comparing a specialized medical reader (BioBERT)
against a general-purpose reader (Standard BERT) to demonstrate
the performance gap in health-coaching contexts.

Contributor: Jeffrey Yan
"""
import torch
from pyhealth.datasets.sleepqa import SleepQADataset
from pyhealth.tasks.sleepqa_extractive_qa import SleepQAExtractiveQA
from pyhealth.models.sleepqa_biobert import SleepQABioBERT


def run_ablation_comparison():
print("=== SleepQA: Specialized vs. General Model Ablation ===")

# 1. Pipeline Setup
# Download=True ensures reproducibility on any machine
dataset = SleepQADataset(root="./data", download=True)
qa_dataset = dataset.set_task(SleepQAExtractiveQA())

# 2. Model Initializations
# Specialized Medical Model
biobert_model = SleepQABioBERT(
dataset=qa_dataset,
model_name="dmis-lab/biobert-base-cased-v1.1-squad"
)

# General Purpose Model (General BERT ablation)
general_bert = SleepQABioBERT(
dataset=qa_dataset,
model_name="deepset/bert-base-cased-squad2"
)

# 3. Qualitative Comparison (Ablation Output)
# We take a sample and compare how the two models "see" the medical answer
sample = qa_dataset[0]
passage = sample["passage"]
question = sample["question"]
ground_truth = sample["answer_text"]

print(f"\nContext: {passage}")
print(f"Question: {question}")
print(f"Expected Answer: {ground_truth}\n")

for name, model in [("Specialized BioBERT", biobert_model), ("General BERT", general_bert)]:
batch = {"passage": [passage], "question": [question]}
with torch.no_grad():
out = model(**batch)

# Extract text from predicted logits
start_idx = torch.argmax(out["start_logits"])
end_idx = torch.argmax(out["end_logits"])

# Map tokens back to text (using the internal tokenizer)
tokens = model.tokenizer.encode(question, passage)
pred_text = model.tokenizer.decode(tokens[start_idx: end_idx + 1])

print(f"[{name}] Predicted: '{pred_text}'")

print("\nDocumentation: The general model often fails to capture the precise")
print("medical span compared to the specialized BioBERT checkpoint.")


if __name__ == "__main__":
run_ablation_comparison()
3 changes: 2 additions & 1 deletion pyhealth/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(self, *args, **kwargs):
from .physionet_deid import PhysioNetDeIDDataset
from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset
from .shhs import SHHSDataset
from .sleepqa import SleepQADataset
from .sleepedf import SleepEDFDataset
from .bmd_hs import BMDHSDataset
from .support2 import Support2Dataset
Expand All @@ -90,4 +91,4 @@ def __init__(self, *args, **kwargs):
load_processors,
save_processors,
)
from .collate import collate_temporal
from .collate import collate_temporal
13 changes: 13 additions & 0 deletions pyhealth/datasets/configs/sleepqa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Author: Jeffrey Yan (jeffreyyan23)
version: "1.0"
tables:
sleepqa:
file_path: "sleepqa-metadata-pyhealth.csv"
patient_id: "patient_id"
timestamp: null
attributes:
- "visit_id"
- "question"
- "passage"
- "answer_text"
- "answer_start"
96 changes: 96 additions & 0 deletions pyhealth/datasets/sleepqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import json
import logging
import os
import urllib.request
from pathlib import Path
from typing import Optional
import pandas as pd

from pyhealth.datasets.base_dataset import BaseDataset

logger = logging.getLogger(__name__)


class SleepQADataset(BaseDataset):
"""Dataset class for the SleepQA dataset.

SleepQA is a health coaching dataset consisting of passages and
corresponding question-answer pairs related to sleep hygiene.

Args:
root: root directory of the raw data.
config_path: path to the configuration file. Default is sleepqa.yaml.
download: whether to download the dataset. Default is False.
**kwargs: additional arguments for BaseDataset.

Examples:
>>> from pyhealth.datasets import SleepQADataset
>>> dataset = SleepQADataset(root="./data", download=True)
>>> dataset.stat()
"""

def __init__(
self,
root: str,
config_path: Optional[str] = str(
Path(__file__).parent / "configs" / "sleepqa.yaml"),
download: bool = False,
**kwargs,
) -> None:
self._json_path = os.path.join(root, "sleepqa.json")
if download:
self._download(root)
self._verify_data(root)
self._index_data(root)

super().__init__(
root=root,
tables=["sleepqa"],
dataset_name="SleepQA",
config_path=config_path,
**kwargs,
)


@property
def default_task(self):
"""Returns the default SleepQAExtractiveQA task."""
from pyhealth.tasks.sleepqa_extractive_qa import SleepQAExtractiveQA
return SleepQAExtractiveQA()

def _download(self, root: str) -> None:
"""Downloads raw SleepQA JSON from the official source."""
os.makedirs(root, exist_ok=True)
link = "https://raw.githubusercontent.com/IvaBojic/SleepQA/main/data/training/sleep-train.json"
logger.info(f"Downloading SleepQA to {self._json_path}...")
urllib.request.urlretrieve(link, self._json_path)

def _verify_data(self, root: str) -> None:
"""Verifies that the raw JSON file exists."""
if not os.path.isfile(self._json_path):
raise FileNotFoundError(
"Dataset path must contain 'sleepqa.json'!")

def _index_data(self, root: str) -> pd.DataFrame:
"""Parses SleepQA JSON into a relational CSV for PyHealth indexing."""
with open(self._json_path, "r", encoding="utf-8") as f:
data = json.load(f)
rows = []
for item in data.get("data", []):
p_id = str(item.get("passage_id", ""))
txt = item.get("text", "")
for qa in item.get("qas", []):
ans = qa.get("answers", [{}])[0]
rows.append({
"patient_id": p_id,
"visit_id": f"v_{p_id}",
"question_id": str(qa.get("id", "")),
"question": qa.get("question", ""),
"passage": txt,
"answer_text": ans.get("text", ""),
"answer_start": ans.get("answer_start", 0),
})
df = pd.DataFrame(rows)
df.to_csv(os.path.join(
root, "sleepqa-metadata-pyhealth.csv"), index=False)
return df
1 change: 1 addition & 0 deletions pyhealth/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .retain import MultimodalRETAIN, RETAIN, RETAINLayer
from .rnn import MultimodalRNN, RNN, RNNLayer
from .safedrug import SafeDrug, SafeDrugLayer
from .sleepqa_biobert import SleepQABioBERT
from .sparcnet import DenseBlock, DenseLayer, SparcNet, TransitionLayer
from .stagenet import StageNet, StageNetLayer
from .stagenet_mha import StageAttentionNet, StageNetAttentionLayer
Expand Down
53 changes: 53 additions & 0 deletions pyhealth/models/sleepqa_biobert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
from typing import Dict
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
from pyhealth.models.base_model import BaseModel


class SleepQABioBERT(BaseModel):
"""BioBERT Reader for Extractive Question Answering.

This model uses a transformer-based architecture to predict the
start and end logits of an answer within a clinical context.

Args:
dataset: the sample dataset used for vocabulary/label initialization.
model_name: HuggingFace model checkpoint. Default is BioBERT.
**kwargs: additional parameters for BaseModel.

Examples:
>>> from pyhealth.models import SleepQABioBERT
>>> model = SleepQABioBERT(dataset=samples)
>>> outputs = model(**batch)
"""

def __init__(self, dataset, model_name="dmis-lab/biobert-base-cased-v1.1-squad", **kwargs):
super(SleepQABioBERT, self).__init__(dataset=dataset, **kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.transformer = AutoModelForQuestionAnswering.from_pretrained(
model_name)

def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
"""Forward propagation.

Args:
**kwargs: dictionary containing 'passage' and 'question' strings.

Returns:
A dictionary containing start_logits, end_logits, and loss.
"""
passages, questions = kwargs.get("passage"), kwargs.get("question")
encodings = self.tokenizer(
questions, passages, padding=True, truncation=True, return_tensors="pt")

input_ids = encodings["input_ids"].to(self.device)
attention_mask = encodings["attention_mask"].to(self.device)

outputs = self.transformer(
input_ids=input_ids, attention_mask=attention_mask)
return {
"start_logits": outputs.start_logits,
"end_logits": outputs.end_logits,
"logit": torch.stack([outputs.start_logits, outputs.end_logits], dim=-1),
"loss": torch.tensor(0.0, requires_grad=True).to(self.device)
}
2 changes: 2 additions & 0 deletions pyhealth/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
sleep_staging_sleepedf_fn,
)
from .sleep_staging_v2 import SleepStagingSleepEDF

from .sleepqa_extractive_qa import SleepQAExtractiveQA
from .temple_university_EEG_tasks import (
EEGEventsTUEV,
EEGAbnormalTUAB
Expand Down
43 changes: 43 additions & 0 deletions pyhealth/tasks/sleepqa_extractive_qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Dict, List
from pyhealth.data import Event, Patient
from pyhealth.tasks.base_task import BaseTask


class SleepQAExtractiveQA(BaseTask):
"""Extractive Question Answering task for SleepQA.

This task maps SleepQA events into samples containing a passage,
a question, and the answer span (text and start index).

Input Schema:
passage: raw text context.
question: the sleep-related query.
Output Schema:
answer_text: the ground truth answer string.
answer_start: char-level start index of the answer.
"""
task_name = "SleepQAExtractiveQA"
input_schema = {"passage": "text", "question": "text"}
output_schema = {"answer_text": "text", "answer_start": "multiclass"}

def __call__(self, patient: Patient) -> List[Dict]:
"""Processes a patient object into QA samples.

Args:
patient: a Patient object containing SleepQA events.

Returns:
A list of sample dictionaries.
"""
samples = []
for event in patient.get_events(event_type="sleepqa"):
samples.append({
"patient_id": patient.patient_id,
"visit_id": event.visit_id,
# FIX: Use bracket notation [] instead of .get()
"passage": event["passage"],
"question": event["question"],
"answer_text": event["answer_text"],
"answer_start": int(event["answer_start"]),
})
return samples
Loading