diff --git a/examples/mistrust_prediction/mistrust_mimic3_logistic_regression.py b/examples/mistrust_prediction/mistrust_mimic3_logistic_regression.py new file mode 100644 index 000000000..b06acf13f --- /dev/null +++ b/examples/mistrust_prediction/mistrust_mimic3_logistic_regression.py @@ -0,0 +1,206 @@ +""" +Medical Mistrust Prediction on MIMIC-III +========================================= +End-to-end example reproducing the interpersonal-feature mistrust classifiers +from Boag et al. 2018 "Racial Disparities and Mistrust in End-of-Life Care" +using the PyHealth LogisticRegression model with L1 regularisation. + +Two tasks are demonstrated: + 1. Noncompliance prediction — label from "noncompliant" in NOTEEVENTS + 2. Autopsy-consent prediction — label from autopsy consent/decline in NOTEEVENTS + +Both use the same interpersonal CHARTEVENTS feature representation, mirroring +the original trust.ipynb pipeline. + +Paper: https://arxiv.org/abs/1808.03827 +GitHub: https://github.com/wboag/eol-mistrust + +Requirements +------------ + - MIMIC-III v1.4 access via PhysioNet + - pyhealth installed (pip install pyhealth) + +Usage +----- + # With real MIMIC-III data: + python mistrust_mimic3_logistic_regression.py \\ + --mimic3_root /path/to/physionet.org/files/mimiciii/1.4 + + # Smoke-test with synthetic MIMIC-III (no data access needed): + python mistrust_mimic3_logistic_regression.py --synthetic +""" + +import argparse +import tempfile + +from pyhealth.datasets import MIMIC3Dataset, split_by_patient, get_dataloader +from pyhealth.models import LogisticRegression +from pyhealth.tasks import ( + MistrustNoncomplianceMIMIC3, + MistrustAutopsyMIMIC3, + build_interpersonal_itemids, +) +from pyhealth.trainer import Trainer + + +# --------------------------------------------------------------------------- +# L1 lambda equivalence to sklearn C=0.1: +# l1_lambda = 1 / (C * n_train) ≈ 10 / n_train +# We use a fixed value here; tune based on actual training set size. +# --------------------------------------------------------------------------- +L1_LAMBDA_NONCOMPLIANCE = 2.62e-4 # 10 / 38_157 (paper's 70% of 54,510) +L1_LAMBDA_AUTOPSY = 1.43e-2 # 10 / 697 (paper's 70% of 1,009) +EMBEDDING_DIM = 128 +BATCH_SIZE = 256 +EPOCHS = 50 + + +def run_task(task_name: str, sample_dataset, l1_lambda: float) -> None: + """Split, train, and evaluate one mistrust task.""" + print(f"\n{'='*60}") + print(f"Task: {task_name} | samples: {len(sample_dataset)}") + print(f" l1_lambda = {l1_lambda:.2e} (equiv. sklearn C = {1/l1_lambda:.1f} / n_train)") + print(f"{'='*60}") + + train_ds, val_ds, test_ds = split_by_patient(sample_dataset, [0.7, 0.15, 0.15]) + + train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True) + val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE, shuffle=False) + test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE, shuffle=False) + + print(f" Train / Val / Test : {len(train_ds)} / {len(val_ds)} / {len(test_ds)}") + + model = LogisticRegression( + dataset=sample_dataset, + embedding_dim=EMBEDDING_DIM, + l1_lambda=l1_lambda, + ) + print(f" Model parameters : {sum(p.numel() for p in model.parameters()):,}") + + trainer = Trainer(model=model) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=EPOCHS, + monitor="roc_auc", + ) + + metrics = trainer.evaluate(test_loader) + print(f"\n Test metrics ({task_name}):") + for k, v in metrics.items(): + print(f" {k}: {v:.4f}") + + +def main(mimic3_root: str, synthetic: bool) -> None: + # ------------------------------------------------------------------ + # STEP 1: Load MIMIC-III dataset + # ------------------------------------------------------------------ + if synthetic: + print("Loading synthetic MIMIC-III (no PhysioNet access needed) ...") + root = "https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III" + cache_dir = tempfile.mkdtemp() + dev = True + else: + root = mimic3_root + cache_dir = None + dev = False + + base_dataset = MIMIC3Dataset( + root=root, + tables=["CHARTEVENTS", "NOTEEVENTS"], + cache_dir=cache_dir, + dev=dev, + ) + base_dataset.stats() + + # ------------------------------------------------------------------ + # STEP 2: Build interpersonal itemid → label mapping from D_ITEMS + # ------------------------------------------------------------------ + if synthetic: + # Synthetic dataset has no D_ITEMS; use an empty dict — features + # will be absent and most samples will be empty (smoke-test only). + print("\nWARNING: Synthetic mode — interpersonal features will be empty.") + print(" This is a pipeline smoke-test only, not a valid experiment.") + itemid_to_label = {} + else: + d_items_path = f"{mimic3_root}/D_ITEMS.csv.gz" + print(f"\nBuilding interpersonal itemid map from {d_items_path} ...") + itemid_to_label = build_interpersonal_itemids(d_items_path) + print(f" Matched {len(itemid_to_label)} interpersonal ITEMIDs") + + # ------------------------------------------------------------------ + # STEP 3: Noncompliance task + # ------------------------------------------------------------------ + nc_task = MistrustNoncomplianceMIMIC3( + itemid_to_label=itemid_to_label, + min_features=1, + ) + nc_dataset = base_dataset.set_task(nc_task) + + if len(nc_dataset) == 0: + print("\nNoncompliance task: no samples generated (expected in synthetic mode)") + else: + run_task("NoncompliantMistrust", nc_dataset, l1_lambda=L1_LAMBDA_NONCOMPLIANCE) + + # ------------------------------------------------------------------ + # STEP 4: Autopsy-consent task + # ------------------------------------------------------------------ + au_task = MistrustAutopsyMIMIC3( + itemid_to_label=itemid_to_label, + min_features=1, + ) + au_dataset = base_dataset.set_task(au_task) + + if len(au_dataset) == 0: + print("\nAutopsy task: no samples generated (expected in synthetic mode)") + else: + run_task("AutopsyConsentMistrust", au_dataset, l1_lambda=L1_LAMBDA_AUTOPSY) + + # ------------------------------------------------------------------ + # STEP 5: Paper-equivalent evaluation notes + # ------------------------------------------------------------------ + print("\n" + "="*60) + print("Paper-equivalent evaluation notes") + print("="*60) + print(""" + Boag et al. 2018 used sklearn LogisticRegression(C=0.1, penalty='l1') + trained on 54,510 patients (all with interpersonal chartevents). + Equivalent PyHealth setup: + + model = LogisticRegression( + dataset=sample_dataset, + embedding_dim=128, + l1_lambda=10 / len(train_dataset), # = 1/(C * n_train), C=0.1 + ) + + Expected test AUC-ROC (paper Table 4 / PROGRESS.md): + Noncompliance : 0.667 + Autopsy : 0.531 + + Higher AUC than sklearn is possible because PyHealth uses learned + embeddings (128-dim) rather than 1-hot DictVectorizer features, + giving the model richer representations of the feature vocabulary. + """) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Mistrust prediction with PyHealth LogisticRegression + L1" + ) + parser.add_argument( + "--mimic3_root", + type=str, + default=None, + help="Path to MIMIC-III v1.4 directory (required unless --synthetic)", + ) + parser.add_argument( + "--synthetic", + action="store_true", + help="Use synthetic MIMIC-III for pipeline smoke-test (no PhysioNet access needed)", + ) + args = parser.parse_args() + + if not args.synthetic and args.mimic3_root is None: + parser.error("Provide --mimic3_root or pass --synthetic for smoke-test mode") + + main(mimic3_root=args.mimic3_root, synthetic=args.synthetic) diff --git a/pyhealth/models/logistic_regression.py b/pyhealth/models/logistic_regression.py index 8155d101f..d421427e5 100644 --- a/pyhealth/models/logistic_regression.py +++ b/pyhealth/models/logistic_regression.py @@ -10,22 +10,33 @@ class LogisticRegression(BaseModel): - """Logistic/Linear regression baseline model. + """Logistic/Linear regression baseline model with optional L1 regularization. This model uses embeddings from different input features and applies a single linear transformation (no hidden layers or non-linearity) to produce predictions. - + - For classification tasks: acts as logistic regression - For regression tasks: acts as linear regression - + The model automatically handles different input types through the EmbeddingModel, pools sequence dimensions, concatenates all feature embeddings, and applies a final linear layer. + L1 regularization (``l1_lambda > 0``) adds a sparsity-inducing penalty to the + weight vector during training, equivalent to scikit-learn's + ``LogisticRegression(penalty='l1', C=C)`` with ``l1_lambda = 1 / (C * n_train)``. + This is the formulation used in Boag et al. (2018) "Racial Disparities and + Mistrust in End-of-Life Care" (MLHC 2018) to train interpersonal-feature + mistrust classifiers on MIMIC-III. + Args: dataset: the dataset to train the model. It is used to query certain information such as the set of all tokens. embedding_dim: the embedding dimension. Default is 128. + l1_lambda: coefficient for the L1 weight penalty added to the loss. + ``loss = BCE + l1_lambda * ||W||_1``. Set to 0.0 (default) to + disable regularization (backward-compatible). Equivalent to + ``1 / (C * n_train)`` for sklearn's C-parameterised formulation. **kwargs: other parameters (for compatibility). Examples: @@ -55,7 +66,7 @@ class LogisticRegression(BaseModel): ... dataset_name="test") >>> >>> from pyhealth.models import LogisticRegression - >>> model = LogisticRegression(dataset=dataset) + >>> model = LogisticRegression(dataset=dataset, l1_lambda=1e-4) >>> >>> from pyhealth.datasets import get_dataloader >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) @@ -64,7 +75,7 @@ class LogisticRegression(BaseModel): >>> ret = model(**data_batch) >>> print(ret) { - 'loss': tensor(0.6931, grad_fn=), + 'loss': tensor(0.6931, grad_fn=), 'y_prob': tensor([[0.5123], [0.4987]], grad_fn=), 'y_true': tensor([[1.], @@ -80,10 +91,12 @@ def __init__( self, dataset: SampleDataset, embedding_dim: int = 128, + l1_lambda: float = 0.0, **kwargs, ): super(LogisticRegression, self).__init__(dataset) self.embedding_dim = embedding_dim + self.l1_lambda = l1_lambda assert len(self.label_keys) == 1, "Only one label key is supported" self.label_key = self.label_keys[0] @@ -197,6 +210,10 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: # Obtain y_true, loss, y_prob y_true = kwargs[self.label_key].to(self.device) loss = self.get_loss_function()(logits, y_true) + # L1 regularization on the final linear layer's weights (bias excluded), + # equivalent to sklearn's penalty='l1' with C = 1 / (l1_lambda * n_train). + if self.l1_lambda > 0.0: + loss = loss + self.l1_lambda * self.fc.weight.abs().sum() y_prob = self.prepare_y_prob(logits) results = { diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..60d7717a3 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -66,3 +66,8 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .mistrust_mimic3 import ( + MistrustNoncomplianceMIMIC3, + MistrustAutopsyMIMIC3, + build_interpersonal_itemids, +) diff --git a/pyhealth/tasks/mistrust_mimic3.py b/pyhealth/tasks/mistrust_mimic3.py new file mode 100644 index 000000000..c9bebf797 --- /dev/null +++ b/pyhealth/tasks/mistrust_mimic3.py @@ -0,0 +1,482 @@ +""" +Medical Mistrust Tasks for MIMIC-III +===================================== +Implements two binary classification tasks that serve as computational proxies +for medical mistrust, as described in: + + Boag et al. "Racial Disparities and Mistrust in End-of-Life Care." + MLHC 2018. https://arxiv.org/abs/1808.03827 + +Both tasks extract interpersonal interaction features from CHARTEVENTS and +derive binary labels from free-text NOTEEVENTS. The resulting samples are +intended for use with ``pyhealth.models.LogisticRegression`` (with +``l1_lambda > 0`` for the paper-equivalent L1 regularisation). + +Tasks +----- +MistrustNoncomplianceMIMIC3 + Predicts whether a hospital admission contains documented patient + noncompliance (search string: "noncompliant"). + Label 1 = noncompliant (mistrustful), 0 = compliant (trusting). + +MistrustAutopsyMIMIC3 + Predicts whether the family consented to a post-mortem autopsy. + Autopsy consent is treated as a signal of distrust in the quality of + care received. + Label 1 = consented (mistrustful), 0 = declined (trusting). + Admissions with ambiguous signals (both consent and decline) are excluded. + +Input features +-------------- +Both tasks produce ``interpersonal_features``: a *list of feature-key strings* +extracted from CHARTEVENTS (schema type ``"sequence"``). Each key has the +form ``"||"``, mirroring the normalisation rules +in trust.ipynb / script 02_chartevents_features.py. + +The vocabulary is learned automatically by the PyHealth tokeniser during +``dataset.set_task()``, so no external DictVectorizer is required. + +Usage +----- + >>> from pyhealth.datasets import MIMIC3Dataset + >>> from pyhealth.tasks import MistrustNoncomplianceMIMIC3 + >>> from pyhealth.models import LogisticRegression + >>> from pyhealth.trainer import Trainer + >>> + >>> base_dataset = MIMIC3Dataset( + ... root="/path/to/mimic-iii/1.4", + ... tables=["CHARTEVENTS", "NOTEEVENTS"], + ... ) + >>> task = MistrustNoncomplianceMIMIC3( + ... itemid_to_label={720: "ventilator mode", ...} # from D_ITEMS + ... ) + >>> sample_dataset = base_dataset.set_task(task) + >>> model = LogisticRegression(dataset=sample_dataset, l1_lambda=1e-4) + >>> trainer = Trainer(model=model) + >>> trainer.train(train_dataloader=..., val_dataloader=..., epochs=50) + +Helper +------ + ``build_interpersonal_itemids(d_items_path)`` — reads D_ITEMS.csv.gz and + returns a ``{itemid: label}`` dict filtered to interpersonal keywords, + ready to pass to either task. +""" + +from pathlib import Path +from typing import Any, Dict, List, Optional + +from pyhealth.tasks.base_task import BaseTask + + +# --------------------------------------------------------------------------- +# Keywords that define "interpersonal" CHARTEVENTS items (trust.ipynb cell 4) +# --------------------------------------------------------------------------- +_INTERPERSONAL_KEYWORDS = [ + "family communication", "follows commands", "education barrier", + "education learner", "education method", "education readiness", + "education topic", "pain", "pain level", "pain level (rest)", + "pain assess method", "restraint", "spiritual support", "support systems", + "state", "behavior", "behavioral state", "reason for restraint", + "stress", "safety", "safety measures", "family", "patient/family informed", + "pt./family informed", "health care proxy", "bath", "bed bath", "bedbath", + "chg bath", "skin care", "judgement", "family meeting", + "emotional / physical / sexual harm", "verbal response", "side rails", + "orientation", "rsbi deferred", "richmond-ras scale", "riker-sas scale", + "status and comfort", "teaching directed toward", "consults", + "social work consult", "sitter", "security", "observer", "informed", +] + +# Autopsy keyword sets +_AUTOPSY_CONSENT_WORDS = ("consent", "agree", "request") +_AUTOPSY_DECLINE_WORDS = ("decline", "not consent", "refuse", "denied") + + +# --------------------------------------------------------------------------- +# Public helper: build itemid→label dict from D_ITEMS.csv.gz +# --------------------------------------------------------------------------- + +def build_interpersonal_itemids(d_items_path: str) -> Dict[int, str]: + """Build an ``{itemid: label}`` dict for interpersonal CHARTEVENTS items. + + Reads ``D_ITEMS.csv.gz`` (or uncompressed ``D_ITEMS.csv``) and filters rows + whose ``LABEL`` contains any of the interpersonal keywords used in + Boag et al. 2018. Pass the result to ``MistrustNoncomplianceMIMIC3`` or + ``MistrustAutopsyMIMIC3`` as ``itemid_to_label``. + + Args: + d_items_path: Path to ``D_ITEMS.csv.gz`` (or ``.csv``). + + Returns: + Dict mapping ``itemid (int)`` to ``label (str)`` for all matched rows + where ``LINKSTO == 'chartevents'``. + + Example: + >>> from pyhealth.tasks import build_interpersonal_itemids + >>> itemid_to_label = build_interpersonal_itemids( + ... "/path/to/mimic-iii/1.4/D_ITEMS.csv.gz" + ... ) + >>> len(itemid_to_label) + 168 + """ + import pandas as pd + + df = pd.read_csv(d_items_path, usecols=["ITEMID", "LABEL", "LINKSTO"]) + df = df[df["LINKSTO"] == "chartevents"].copy() + + def _matches(label: str) -> bool: + lo = str(label).lower() + return any(k in lo for k in _INTERPERSONAL_KEYWORDS) + + df = df[df["LABEL"].apply(_matches)] + return dict(zip(df["ITEMID"].astype(int), df["LABEL"].astype(str))) + + +# --------------------------------------------------------------------------- +# Feature normalisation — mirrors trust.ipynb cell 7 +# --------------------------------------------------------------------------- + +def _restraint_reason(v: str) -> str: + if v in ("not applicable", "none", ""): + return "none" + if "threat" in v or "acute risk of" in v: + return "threat of harm" + if "confusion" in v or "delirium" in v or v == "impaired judgment" or v == "sundowning": + return "confusion/delirium" + if "occurence" in v or v == "severe physical agitation" or v == "violent/self des": + return "presence of violence" + if v in ("ext/txinterfere", "protection of lines and tubes", "treatment interference"): + return "treatment interference" + if "risk for fall" in v or "risk for falling" in v: + return "risk for falls" + return v + + +def _restraint_location(v: str) -> str: + if v in ("none", ""): + return "none" + if "4 point" in v or "4point" in v: + return "4 point restraint" + return "some restraint" + + +def _restraint_device(v: str) -> str: + if "sitter" in v: + return "sitter" + if "limb" in v: + return "limb" + return v + + +def _bath(label: str, v: str) -> str: + if "part" in label: + return "partial" + if "self" in v: + return "self" + if "refused" in v: + return "refused" + if "shave" in v: + return "shave" + if "hair" in v: + return "hair" + if "none" in v: + return "none" + return "done" + + +def _normalise_feature(label: str, value: str) -> Optional[str]: + """Normalise a (label, value) pair into a feature key string. + + Returns ``"||"`` or ``None`` to skip the row. + Mirrors the normalisation in ``02_chartevents_features.py`` / trust.ipynb. + """ + lo = label.lower() + v = (value or "none").lower().strip() + + if "reason for restraint" in lo: + return f"reason for restraint||{_restraint_reason(v)}" + if "restraint location" in lo: + return f"restraint location||{_restraint_location(v)}" + if "restraint device" in lo: + return f"restraint device||{_restraint_device(v)}" + if "bath" in lo: + return f"bath||{_bath(lo, v)}" + + # Skipped categories + if lo in ("behavior", "behavioral state"): + return None + if lo.startswith("pain management") or lo.startswith("pain type") \ + or lo.startswith("pain cause") or lo.startswith("pain location"): + return None + + # Categories kept as-is + for prefix in ("pain level", "education topic", "safety measures", + "side rails", "status and comfort"): + if lo.startswith(prefix): + return f"{prefix}||{v}" + + if "informed" in lo: + return f"informed||{v}" + + return f"{lo}||{v}" + + +# --------------------------------------------------------------------------- +# Shared extraction helpers +# --------------------------------------------------------------------------- + +def _extract_interpersonal_features( + chartevents: List[Any], + itemid_to_label: Dict[int, str], +) -> List[str]: + """Return a deduplicated list of interpersonal feature-key strings. + + Args: + chartevents: list of chartevents Event objects for one admission. + itemid_to_label: ``{itemid: label}`` dict (from ``build_interpersonal_itemids``). + + Returns: + Sorted list of unique ``"category||value"`` strings. + """ + seen = set() + for ev in chartevents: + itemid = getattr(ev, "itemid", None) + if itemid is None: + continue + try: + itemid = int(itemid) + except (ValueError, TypeError): + continue + if itemid not in itemid_to_label: + continue + label = itemid_to_label[itemid] + value = str(getattr(ev, "value", "") or "") + fkey = _normalise_feature(label, value) + if fkey is not None: + seen.add(fkey) + return sorted(seen) + + +def _extract_noncompliance_label(noteevents: List[Any]) -> int: + """Return 1 if any note contains 'noncompliant', else 0.""" + for ev in noteevents: + text = str(getattr(ev, "text", "") or "").lower() + if "noncompliant" in text: + return 1 + return 0 + + +def _extract_autopsy_label(noteevents: List[Any]) -> Optional[int]: + """Return 1 (consent/mistrust), 0 (decline/trust), or None (ambiguous/absent).""" + consented = False + declined = False + for ev in noteevents: + text = str(getattr(ev, "text", "") or "").lower() + if "autopsy" not in text: + continue + for line in text.split("\n"): + if "autopsy" not in line: + continue + if any(w in line for w in _AUTOPSY_DECLINE_WORDS): + declined = True + if any(w in line for w in _AUTOPSY_CONSENT_WORDS): + consented = True + if consented and declined: + return None # ambiguous — exclude + if consented: + return 1 + if declined: + return 0 + return None # no autopsy mention + + +# --------------------------------------------------------------------------- +# Task 1: Noncompliance mistrust +# --------------------------------------------------------------------------- + +class MistrustNoncomplianceMIMIC3(BaseTask): + """Predict documented noncompliance as a proxy for medical mistrust. + + For each hospital admission the task produces one sample: + + - ``interpersonal_features``: deduplicated list of normalised CHARTEVENTS + feature-key strings (schema: ``"sequence"``). + - ``noncompliance``: ``1`` if any note for this admission contains the + string ``"noncompliant"``, else ``0`` (schema: ``"binary"``). + + All admissions that appear in the chartevents interpersonal-feature set + receive a label (default 0 / trust). Base rate ≈ 0.88 % in MIMIC-III v1.4. + + Args: + itemid_to_label: ``{itemid (int): label (str)}`` mapping from + ``build_interpersonal_itemids()``. Required to identify which + CHARTEVENTS rows correspond to interpersonal interaction features. + min_features: minimum number of interpersonal feature keys required + for a sample to be included. Defaults to 1. + + Examples: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> from pyhealth.tasks import MistrustNoncomplianceMIMIC3, build_interpersonal_itemids + >>> itemid_to_label = build_interpersonal_itemids( + ... "/path/to/mimic-iii/1.4/D_ITEMS.csv.gz" + ... ) + >>> base_dataset = MIMIC3Dataset( + ... root="/path/to/mimic-iii/1.4", + ... tables=["CHARTEVENTS", "NOTEEVENTS"], + ... ) + >>> task = MistrustNoncomplianceMIMIC3(itemid_to_label=itemid_to_label) + >>> sample_dataset = base_dataset.set_task(task) + """ + + task_name: str = "MistrustNoncomplianceMIMIC3" + input_schema: Dict[str, str] = {"interpersonal_features": "sequence"} + output_schema: Dict[str, str] = {"noncompliance": "binary"} + + def __init__( + self, + itemid_to_label: Dict[int, str], + min_features: int = 1, + ) -> None: + self.itemid_to_label = itemid_to_label + self.min_features = min_features + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process a single patient into noncompliance classification samples. + + Args: + patient: a PyHealth Patient object with ``chartevents`` and + ``noteevents`` event types loaded. + + Returns: + List of dicts, one per admission, each containing: + - ``patient_id`` + - ``visit_id`` (hadm_id) + - ``interpersonal_features`` (list of str) + - ``noncompliance`` (int 0/1) + """ + samples = [] + admissions = patient.get_events(event_type="admissions") + + for admission in admissions: + hadm_id = admission.hadm_id + + chartevents = patient.get_events( + event_type="chartevents", + filters=[("hadm_id", "==", hadm_id)], + ) + features = _extract_interpersonal_features(chartevents, self.itemid_to_label) + + if len(features) < self.min_features: + continue + + noteevents = patient.get_events( + event_type="noteevents", + filters=[("hadm_id", "==", hadm_id)], + ) + label = _extract_noncompliance_label(noteevents) + + samples.append( + { + "patient_id": patient.patient_id, + "visit_id": hadm_id, + "interpersonal_features": features, + "noncompliance": label, + } + ) + + return samples + + +# --------------------------------------------------------------------------- +# Task 2: Autopsy-consent mistrust +# --------------------------------------------------------------------------- + +class MistrustAutopsyMIMIC3(BaseTask): + """Predict autopsy consent as a proxy for medical mistrust. + + Autopsy consent signals post-mortem distrust of the care received. + Black patients in MIMIC-III v1.4 consent to autopsies at ~39 % vs ~26 % + for white patients (Boag et al. 2018). + + Only admissions with an explicit, unambiguous autopsy mention in + NOTEEVENTS receive a label (consent=1 / decline=0). Admissions where + both signals appear are excluded. + + Args: + itemid_to_label: ``{itemid (int): label (str)}`` mapping from + ``build_interpersonal_itemids()``. + min_features: minimum interpersonal features required per sample. + Defaults to 1. + + Examples: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> from pyhealth.tasks import MistrustAutopsyMIMIC3, build_interpersonal_itemids + >>> itemid_to_label = build_interpersonal_itemids( + ... "/path/to/mimic-iii/1.4/D_ITEMS.csv.gz" + ... ) + >>> base_dataset = MIMIC3Dataset( + ... root="/path/to/mimic-iii/1.4", + ... tables=["CHARTEVENTS", "NOTEEVENTS"], + ... ) + >>> task = MistrustAutopsyMIMIC3(itemid_to_label=itemid_to_label) + >>> sample_dataset = base_dataset.set_task(task) + """ + + task_name: str = "MistrustAutopsyMIMIC3" + input_schema: Dict[str, str] = {"interpersonal_features": "sequence"} + output_schema: Dict[str, str] = {"autopsy_consent": "binary"} + + def __init__( + self, + itemid_to_label: Dict[int, str], + min_features: int = 1, + ) -> None: + self.itemid_to_label = itemid_to_label + self.min_features = min_features + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process a single patient into autopsy-consent classification samples. + + Args: + patient: a PyHealth Patient object with ``chartevents`` and + ``noteevents`` event types loaded. + + Returns: + List of dicts, one per admission with an explicit autopsy signal: + - ``patient_id`` + - ``visit_id`` (hadm_id) + - ``interpersonal_features`` (list of str) + - ``autopsy_consent`` (int 0/1) + """ + samples = [] + admissions = patient.get_events(event_type="admissions") + + for admission in admissions: + hadm_id = admission.hadm_id + + noteevents = patient.get_events( + event_type="noteevents", + filters=[("hadm_id", "==", hadm_id)], + ) + label = _extract_autopsy_label(noteevents) + if label is None: + continue # no explicit or ambiguous signal — skip + + chartevents = patient.get_events( + event_type="chartevents", + filters=[("hadm_id", "==", hadm_id)], + ) + features = _extract_interpersonal_features(chartevents, self.itemid_to_label) + + if len(features) < self.min_features: + continue + + samples.append( + { + "patient_id": patient.patient_id, + "visit_id": hadm_id, + "interpersonal_features": features, + "autopsy_consent": label, + } + ) + + return samples