Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added test-resources/core/mimic3demo/NOTEEVENTS.csv.gz
Binary file not shown.
70 changes: 40 additions & 30 deletions tests/core/test_mimic3_mortality_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def test_mortality_prediction_mimic3_set_task(self):
traceback.print_exc()
self.fail(f"Failed to use set_task with MortalityPredictionMIMIC3: {e}")

@unittest.skip("Skipping multimodal test - noteevents not included in test resources")
def test_multimodal_mortality_prediction_mimic3_set_task(self):
"""Test MultimodalMortalityPredictionMIMIC3 task with set_task() method."""
task = MultimodalMortalityPredictionMIMIC3()
Expand All @@ -144,44 +143,55 @@ def test_multimodal_mortality_prediction_mimic3_set_task(self):
self.assertIn("clinical_notes", task.input_schema)
self.assertIn("mortality", task.output_schema)

# Load dataset with noteevents for multimodal testing
multimodal_dataset = MIMIC3Dataset(
root=self.demo_dataset_path,
tables=["diagnoses_icd", "procedures_icd", "prescriptions", "noteevents"],
)

# Test using set_task method
try:
sample_dataset = self.dataset.set_task(task)
sample_dataset = multimodal_dataset.set_task(task)
self.assertIsNotNone(sample_dataset, "set_task should return a dataset")
self.assertTrue(
hasattr(sample_dataset, "samples"), "Sample dataset should have samples"
)

# Verify we got some samples
self.assertGreater(
len(sample_dataset.samples), 0, "Should generate at least one sample"
)
num_samples = len(sample_dataset)
self.assertGreater(num_samples, 0, "Should generate at least one sample")

# Test sample structure
if len(sample_dataset.samples) > 0:
sample = sample_dataset.samples[0]
required_keys = [
"hadm_id",
"patient_id",
"conditions",
"procedures",
"drugs",
"clinical_notes",
"mortality",
]
for key in required_keys:
self.assertIn(key, sample, f"Sample should contain key: {key}")
sample = sample_dataset[0]
required_keys = [
"hadm_id",
"patient_id",
"conditions",
"procedures",
"drugs",
"clinical_notes",
"mortality",
]
for key in required_keys:
self.assertIn(key, sample, f"Sample should contain key: {key}")

# Verify data types
self.assertIsInstance(
sample["clinical_notes"], str, "clinical_notes should be a string"
)
self.assertIn(
int(sample["mortality"]), [0, 1], "Mortality label should be 0 or 1"
)

# Verify data types
self.assertIsInstance(
sample["clinical_notes"], str, "clinical_notes should be a string"
)
self.assertIn(
sample["mortality"], [0, 1], "Mortality label should be 0 or 1"
)
# Verify that at least one sample has non-empty clinical notes,
# proving that NOTEEVENTS data was actually loaded
has_notes = any(
len(sample_dataset[i]["clinical_notes"]) > 0
for i in range(num_samples)
)
self.assertTrue(
has_notes, "At least one sample should have non-empty clinical notes"
)

print(f"Generated {len(sample_dataset.samples)} multimodal samples")
print(f"Clinical notes length: {len(sample['clinical_notes'])}")
print(f"Generated {num_samples} multimodal samples")
print(f"Clinical notes length: {len(sample['clinical_notes'])}")

except Exception as e:
self.fail(
Expand Down
Loading