diff --git a/test-resources/core/mimic3demo/NOTEEVENTS.csv.gz b/test-resources/core/mimic3demo/NOTEEVENTS.csv.gz new file mode 100644 index 000000000..3d87d0e7a Binary files /dev/null and b/test-resources/core/mimic3demo/NOTEEVENTS.csv.gz differ diff --git a/tests/core/test_mimic3_mortality_prediction.py b/tests/core/test_mimic3_mortality_prediction.py index d7105782b..0fd6fc169 100644 --- a/tests/core/test_mimic3_mortality_prediction.py +++ b/tests/core/test_mimic3_mortality_prediction.py @@ -131,7 +131,6 @@ def test_mortality_prediction_mimic3_set_task(self): traceback.print_exc() self.fail(f"Failed to use set_task with MortalityPredictionMIMIC3: {e}") - @unittest.skip("Skipping multimodal test - noteevents not included in test resources") def test_multimodal_mortality_prediction_mimic3_set_task(self): """Test MultimodalMortalityPredictionMIMIC3 task with set_task() method.""" task = MultimodalMortalityPredictionMIMIC3() @@ -144,44 +143,55 @@ def test_multimodal_mortality_prediction_mimic3_set_task(self): self.assertIn("clinical_notes", task.input_schema) self.assertIn("mortality", task.output_schema) + # Load dataset with noteevents for multimodal testing + multimodal_dataset = MIMIC3Dataset( + root=self.demo_dataset_path, + tables=["diagnoses_icd", "procedures_icd", "prescriptions", "noteevents"], + ) + # Test using set_task method try: - sample_dataset = self.dataset.set_task(task) + sample_dataset = multimodal_dataset.set_task(task) self.assertIsNotNone(sample_dataset, "set_task should return a dataset") - self.assertTrue( - hasattr(sample_dataset, "samples"), "Sample dataset should have samples" - ) # Verify we got some samples - self.assertGreater( - len(sample_dataset.samples), 0, "Should generate at least one sample" - ) + num_samples = len(sample_dataset) + self.assertGreater(num_samples, 0, "Should generate at least one sample") # Test sample structure - if len(sample_dataset.samples) > 0: - sample = sample_dataset.samples[0] - required_keys = [ - "hadm_id", - "patient_id", - "conditions", - "procedures", - "drugs", - "clinical_notes", - "mortality", - ] - for key in required_keys: - self.assertIn(key, sample, f"Sample should contain key: {key}") + sample = sample_dataset[0] + required_keys = [ + "hadm_id", + "patient_id", + "conditions", + "procedures", + "drugs", + "clinical_notes", + "mortality", + ] + for key in required_keys: + self.assertIn(key, sample, f"Sample should contain key: {key}") + + # Verify data types + self.assertIsInstance( + sample["clinical_notes"], str, "clinical_notes should be a string" + ) + self.assertIn( + int(sample["mortality"]), [0, 1], "Mortality label should be 0 or 1" + ) - # Verify data types - self.assertIsInstance( - sample["clinical_notes"], str, "clinical_notes should be a string" - ) - self.assertIn( - sample["mortality"], [0, 1], "Mortality label should be 0 or 1" - ) + # Verify that at least one sample has non-empty clinical notes, + # proving that NOTEEVENTS data was actually loaded + has_notes = any( + len(sample_dataset[i]["clinical_notes"]) > 0 + for i in range(num_samples) + ) + self.assertTrue( + has_notes, "At least one sample should have non-empty clinical notes" + ) - print(f"Generated {len(sample_dataset.samples)} multimodal samples") - print(f"Clinical notes length: {len(sample['clinical_notes'])}") + print(f"Generated {num_samples} multimodal samples") + print(f"Clinical notes length: {len(sample['clinical_notes'])}") except Exception as e: self.fail(