Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 206 additions & 0 deletions examples/mistrust_prediction/mistrust_mimic3_logistic_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
"""
Medical Mistrust Prediction on MIMIC-III
=========================================
End-to-end example reproducing the interpersonal-feature mistrust classifiers
from Boag et al. 2018 "Racial Disparities and Mistrust in End-of-Life Care"
using the PyHealth LogisticRegression model with L1 regularisation.

Two tasks are demonstrated:
1. Noncompliance prediction — label from "noncompliant" in NOTEEVENTS
2. Autopsy-consent prediction — label from autopsy consent/decline in NOTEEVENTS

Both use the same interpersonal CHARTEVENTS feature representation, mirroring
the original trust.ipynb pipeline.

Paper: https://arxiv.org/abs/1808.03827
GitHub: https://github.com/wboag/eol-mistrust

Requirements
------------
- MIMIC-III v1.4 access via PhysioNet
- pyhealth installed (pip install pyhealth)

Usage
-----
# With real MIMIC-III data:
python mistrust_mimic3_logistic_regression.py \\
--mimic3_root /path/to/physionet.org/files/mimiciii/1.4

# Smoke-test with synthetic MIMIC-III (no data access needed):
python mistrust_mimic3_logistic_regression.py --synthetic
"""

import argparse
import tempfile

from pyhealth.datasets import MIMIC3Dataset, split_by_patient, get_dataloader
from pyhealth.models import LogisticRegression
from pyhealth.tasks import (
MistrustNoncomplianceMIMIC3,
MistrustAutopsyMIMIC3,
build_interpersonal_itemids,
)
from pyhealth.trainer import Trainer


# ---------------------------------------------------------------------------
# L1 lambda equivalence to sklearn C=0.1:
# l1_lambda = 1 / (C * n_train) ≈ 10 / n_train
# We use a fixed value here; tune based on actual training set size.
# ---------------------------------------------------------------------------
L1_LAMBDA_NONCOMPLIANCE = 2.62e-4 # 10 / 38_157 (paper's 70% of 54,510)
L1_LAMBDA_AUTOPSY = 1.43e-2 # 10 / 697 (paper's 70% of 1,009)
EMBEDDING_DIM = 128
BATCH_SIZE = 256
EPOCHS = 50


def run_task(task_name: str, sample_dataset, l1_lambda: float) -> None:
"""Split, train, and evaluate one mistrust task."""
print(f"\n{'='*60}")
print(f"Task: {task_name} | samples: {len(sample_dataset)}")
print(f" l1_lambda = {l1_lambda:.2e} (equiv. sklearn C = {1/l1_lambda:.1f} / n_train)")
print(f"{'='*60}")

train_ds, val_ds, test_ds = split_by_patient(sample_dataset, [0.7, 0.15, 0.15])

train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

print(f" Train / Val / Test : {len(train_ds)} / {len(val_ds)} / {len(test_ds)}")

model = LogisticRegression(
dataset=sample_dataset,
embedding_dim=EMBEDDING_DIM,
l1_lambda=l1_lambda,
)
print(f" Model parameters : {sum(p.numel() for p in model.parameters()):,}")

trainer = Trainer(model=model)
trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
epochs=EPOCHS,
monitor="roc_auc",
)

metrics = trainer.evaluate(test_loader)
print(f"\n Test metrics ({task_name}):")
for k, v in metrics.items():
print(f" {k}: {v:.4f}")


def main(mimic3_root: str, synthetic: bool) -> None:
# ------------------------------------------------------------------
# STEP 1: Load MIMIC-III dataset
# ------------------------------------------------------------------
if synthetic:
print("Loading synthetic MIMIC-III (no PhysioNet access needed) ...")
root = "https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III"
cache_dir = tempfile.mkdtemp()
dev = True
else:
root = mimic3_root
cache_dir = None
dev = False

base_dataset = MIMIC3Dataset(
root=root,
tables=["CHARTEVENTS", "NOTEEVENTS"],
cache_dir=cache_dir,
dev=dev,
)
base_dataset.stats()

# ------------------------------------------------------------------
# STEP 2: Build interpersonal itemid → label mapping from D_ITEMS
# ------------------------------------------------------------------
if synthetic:
# Synthetic dataset has no D_ITEMS; use an empty dict — features
# will be absent and most samples will be empty (smoke-test only).
print("\nWARNING: Synthetic mode — interpersonal features will be empty.")
print(" This is a pipeline smoke-test only, not a valid experiment.")
itemid_to_label = {}
else:
d_items_path = f"{mimic3_root}/D_ITEMS.csv.gz"
print(f"\nBuilding interpersonal itemid map from {d_items_path} ...")
itemid_to_label = build_interpersonal_itemids(d_items_path)
print(f" Matched {len(itemid_to_label)} interpersonal ITEMIDs")

# ------------------------------------------------------------------
# STEP 3: Noncompliance task
# ------------------------------------------------------------------
nc_task = MistrustNoncomplianceMIMIC3(
itemid_to_label=itemid_to_label,
min_features=1,
)
nc_dataset = base_dataset.set_task(nc_task)

if len(nc_dataset) == 0:
print("\nNoncompliance task: no samples generated (expected in synthetic mode)")
else:
run_task("NoncompliantMistrust", nc_dataset, l1_lambda=L1_LAMBDA_NONCOMPLIANCE)

# ------------------------------------------------------------------
# STEP 4: Autopsy-consent task
# ------------------------------------------------------------------
au_task = MistrustAutopsyMIMIC3(
itemid_to_label=itemid_to_label,
min_features=1,
)
au_dataset = base_dataset.set_task(au_task)

if len(au_dataset) == 0:
print("\nAutopsy task: no samples generated (expected in synthetic mode)")
else:
run_task("AutopsyConsentMistrust", au_dataset, l1_lambda=L1_LAMBDA_AUTOPSY)

# ------------------------------------------------------------------
# STEP 5: Paper-equivalent evaluation notes
# ------------------------------------------------------------------
print("\n" + "="*60)
print("Paper-equivalent evaluation notes")
print("="*60)
print("""
Boag et al. 2018 used sklearn LogisticRegression(C=0.1, penalty='l1')
trained on 54,510 patients (all with interpersonal chartevents).
Equivalent PyHealth setup:

model = LogisticRegression(
dataset=sample_dataset,
embedding_dim=128,
l1_lambda=10 / len(train_dataset), # = 1/(C * n_train), C=0.1
)

Expected test AUC-ROC (paper Table 4 / PROGRESS.md):
Noncompliance : 0.667
Autopsy : 0.531

Higher AUC than sklearn is possible because PyHealth uses learned
embeddings (128-dim) rather than 1-hot DictVectorizer features,
giving the model richer representations of the feature vocabulary.
""")


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Mistrust prediction with PyHealth LogisticRegression + L1"
)
parser.add_argument(
"--mimic3_root",
type=str,
default=None,
help="Path to MIMIC-III v1.4 directory (required unless --synthetic)",
)
parser.add_argument(
"--synthetic",
action="store_true",
help="Use synthetic MIMIC-III for pipeline smoke-test (no PhysioNet access needed)",
)
args = parser.parse_args()

if not args.synthetic and args.mimic3_root is None:
parser.error("Provide --mimic3_root or pass --synthetic for smoke-test mode")

main(mimic3_root=args.mimic3_root, synthetic=args.synthetic)
27 changes: 22 additions & 5 deletions pyhealth/models/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,33 @@


class LogisticRegression(BaseModel):
"""Logistic/Linear regression baseline model.
"""Logistic/Linear regression baseline model with optional L1 regularization.

This model uses embeddings from different input features and applies a single
linear transformation (no hidden layers or non-linearity) to produce predictions.

- For classification tasks: acts as logistic regression
- For regression tasks: acts as linear regression

The model automatically handles different input types through the EmbeddingModel,
pools sequence dimensions, concatenates all feature embeddings, and applies a
final linear layer.

L1 regularization (``l1_lambda > 0``) adds a sparsity-inducing penalty to the
weight vector during training, equivalent to scikit-learn's
``LogisticRegression(penalty='l1', C=C)`` with ``l1_lambda = 1 / (C * n_train)``.
This is the formulation used in Boag et al. (2018) "Racial Disparities and
Mistrust in End-of-Life Care" (MLHC 2018) to train interpersonal-feature
mistrust classifiers on MIMIC-III.

Args:
dataset: the dataset to train the model. It is used to query certain
information such as the set of all tokens.
embedding_dim: the embedding dimension. Default is 128.
l1_lambda: coefficient for the L1 weight penalty added to the loss.
``loss = BCE + l1_lambda * ||W||_1``. Set to 0.0 (default) to
disable regularization (backward-compatible). Equivalent to
``1 / (C * n_train)`` for sklearn's C-parameterised formulation.
**kwargs: other parameters (for compatibility).

Examples:
Expand Down Expand Up @@ -55,7 +66,7 @@ class LogisticRegression(BaseModel):
... dataset_name="test")
>>>
>>> from pyhealth.models import LogisticRegression
>>> model = LogisticRegression(dataset=dataset)
>>> model = LogisticRegression(dataset=dataset, l1_lambda=1e-4)
>>>
>>> from pyhealth.datasets import get_dataloader
>>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)
Expand All @@ -64,7 +75,7 @@ class LogisticRegression(BaseModel):
>>> ret = model(**data_batch)
>>> print(ret)
{
'loss': tensor(0.6931, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
'loss': tensor(0.6931, grad_fn=<AddBackward0>),
'y_prob': tensor([[0.5123],
[0.4987]], grad_fn=<SigmoidBackward0>),
'y_true': tensor([[1.],
Expand All @@ -80,10 +91,12 @@ def __init__(
self,
dataset: SampleDataset,
embedding_dim: int = 128,
l1_lambda: float = 0.0,
**kwargs,
):
super(LogisticRegression, self).__init__(dataset)
self.embedding_dim = embedding_dim
self.l1_lambda = l1_lambda

assert len(self.label_keys) == 1, "Only one label key is supported"
self.label_key = self.label_keys[0]
Expand Down Expand Up @@ -197,6 +210,10 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
# Obtain y_true, loss, y_prob
y_true = kwargs[self.label_key].to(self.device)
loss = self.get_loss_function()(logits, y_true)
# L1 regularization on the final linear layer's weights (bias excluded),
# equivalent to sklearn's penalty='l1' with C = 1 / (l1_lambda * n_train).
if self.l1_lambda > 0.0:
loss = loss + self.l1_lambda * self.fc.weight.abs().sum()
y_prob = self.prepare_y_prob(logits)

results = {
Expand Down
5 changes: 5 additions & 0 deletions pyhealth/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,8 @@
VariantClassificationClinVar,
)
from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task
from .mistrust_mimic3 import (
MistrustNoncomplianceMIMIC3,
MistrustAutopsyMIMIC3,
build_interpersonal_itemids,
)
Loading