diff --git a/tests/core/test_mimic3_drug_recommendation.py b/tests/core/test_mimic3_drug_recommendation.py new file mode 100644 index 000000000..bad1075af --- /dev/null +++ b/tests/core/test_mimic3_drug_recommendation.py @@ -0,0 +1,105 @@ +from pathlib import Path +import tempfile +import unittest + +from pyhealth.datasets import MIMIC3Dataset +from pyhealth.tasks import DrugRecommendationMIMIC3 + + +class TestDrugRecommendationMIMIC3(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.cache_dir = tempfile.TemporaryDirectory() + + dataset = MIMIC3Dataset( + root=str( + Path(__file__).parent.parent.parent + / "test-resources" + / "core" + / "mimic3demo" + ), + tables=["diagnoses_icd", "procedures_icd", "prescriptions"], + cache_dir=cls.cache_dir.name, + ) + + cls.samples = dataset.set_task(DrugRecommendationMIMIC3()) + + @classmethod + def tearDownClass(cls): + cls.samples.close() + + def test_task_schema(self): + self.assertIn("task_name", vars(DrugRecommendationMIMIC3)) + self.assertIn("input_schema", vars(DrugRecommendationMIMIC3)) + self.assertIn("output_schema", vars(DrugRecommendationMIMIC3)) + + self.assertEqual( + "DrugRecommendationMIMIC3", + DrugRecommendationMIMIC3.task_name, + ) + self.assertIn("conditions", DrugRecommendationMIMIC3.input_schema) + self.assertIn("procedures", DrugRecommendationMIMIC3.input_schema) + self.assertIn("drugs_hist", DrugRecommendationMIMIC3.input_schema) + + for key in ("conditions", "procedures", "drugs_hist"): + self.assertEqual( + DrugRecommendationMIMIC3.input_schema[key], + "nested_sequence", + ) + + self.assertIn("drugs", DrugRecommendationMIMIC3.output_schema) + self.assertEqual( + DrugRecommendationMIMIC3.output_schema["drugs"], + "multilabel", + ) + + def test_sample_schema(self): + for sample in self.samples: + self.assertIn("patient_id", sample) + self.assertIn("visit_id", sample) + self.assertIn("conditions", sample) + self.assertIn("procedures", sample) + self.assertIn("drugs_hist", sample) + self.assertIn("drugs", sample) + + def test_conditions_are_nested(self): + """Conditions should be a 2-D tensor (visits x codes).""" + for sample in self.samples: + cond = sample["conditions"] + self.assertEqual( + cond.dim(), + 2, + "conditions should be a 2-D tensor (nested_sequence)", + ) + + def test_single_visit_patients_excluded(self): + """Patient 10006 has only 1 visit (142345). + + Drug recommendation requires at least 2 visits. + """ + patients = [s["patient_id"] for s in self.samples] + visits = [s["visit_id"] for s in self.samples] + + self.assertNotIn("10006", patients) + self.assertNotIn("142345", visits) + + def test_visit_without_procedures_excluded(self): + """Patient 41795: visit 118192 has no procedures. + + Visits missing any of conditions, procedures, or drugs + are excluded by the task. + """ + visits = [s["visit_id"] for s in self.samples] + self.assertNotIn("118192", visits) + + def test_multi_visit_patient_produces_samples(self): + """Patient 10088 has 3 visits, all with diag+proc+rx. + + Should produce samples for this patient. + """ + patients = [s["patient_id"] for s in self.samples] + self.assertIn("10088", patients) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_mimic3_mortality_prediction.py b/tests/core/test_mimic3_mortality_prediction.py index d7105782b..8112ba4b4 100644 --- a/tests/core/test_mimic3_mortality_prediction.py +++ b/tests/core/test_mimic3_mortality_prediction.py @@ -1,192 +1,120 @@ -import unittest -import os from pathlib import Path +import tempfile +import unittest from pyhealth.datasets import MIMIC3Dataset -from pyhealth.tasks.mortality_prediction import ( - MortalityPredictionMIMIC3, - MultimodalMortalityPredictionMIMIC3, -) - - -class TestMIMIC3MortalityPrediction(unittest.TestCase): - """Test MIMIC-3 mortality prediction tasks with demo data from local test resources.""" - - def setUp(self): - """Set up demo dataset path for each test.""" - self._setup_dataset_path() - self._load_dataset() - - def _setup_dataset_path(self): - """Get path to local MIMIC-III demo dataset in test resources.""" - # Get the path to the test-resources/core/mimic3demo directory - test_dir = Path(__file__).parent.parent.parent - self.demo_dataset_path = str(test_dir / "test-resources" / "core" / "mimic3demo") - - print(f"\n{'='*60}") - print(f"Setting up MIMIC-III demo dataset for mortality prediction") - print(f"Dataset path: {self.demo_dataset_path}") - - # List files in the dataset directory - files = os.listdir(self.demo_dataset_path) - print(f"Found {len(files)} files in dataset directory:") - for f in sorted(files): - file_path = os.path.join(self.demo_dataset_path, f) - size = os.path.getsize(file_path) / 1024 # KB - print(f" - {f} ({size:.1f} KB)") - print(f"{'='*60}\n") - - def _load_dataset(self): - """Load the dataset for testing.""" - tables = ["diagnoses_icd", "procedures_icd", "prescriptions"] - print(f"Loading MIMIC3Dataset with tables: {tables}") - self.dataset = MIMIC3Dataset(root=self.demo_dataset_path, tables=tables) - print(f"✓ Dataset loaded successfully") - print() - - def test_dataset_stats(self): - """Test that the dataset loads correctly and stats() works.""" - print(f"\n{'='*60}") - print("TEST: test_dataset_stats()") - print(f"{'='*60}") - try: - print("Calling dataset.stats()...") - self.dataset.stats() - print("✓ dataset.stats() executed successfully") - except Exception as e: - print(f"✗ dataset.stats() failed with error: {e}") - self.fail(f"dataset.stats() failed: {e}") - - def test_mortality_prediction_mimic3_set_task(self): - """Test MortalityPredictionMIMIC3 task with set_task() method.""" - print(f"\n{'='*60}") - print("TEST: test_mortality_prediction_mimic3_set_task()") - print(f"{'='*60}") - - print("Initializing MortalityPredictionMIMIC3 task...") - task = MortalityPredictionMIMIC3() - - # Test that task is properly initialized - print(f"✓ Task initialized: {task.task_name}") - self.assertEqual(task.task_name, "MortalityPredictionMIMIC3") - self.assertIn("conditions", task.input_schema) - self.assertIn("procedures", task.input_schema) - self.assertIn("drugs", task.input_schema) - self.assertIn("mortality", task.output_schema) - print(f" Input schema: {list(task.input_schema.keys())}") - print(f" Output schema: {list(task.output_schema.keys())}") - - # Test using set_task method - try: - print("\nCalling dataset.set_task()...") - sample_dataset = self.dataset.set_task(task) - self.assertIsNotNone(sample_dataset, "set_task should return a dataset") - print(f"✓ set_task() completed") - - # Verify we got some samples - num_samples = len(sample_dataset) - self.assertGreater(num_samples, 0, "Should generate at least one sample") - print(f"✓ Generated {num_samples} mortality prediction samples") - - # Test sample structure - if num_samples > 0: - sample = sample_dataset[0] - required_keys = [ - "hadm_id", - "patient_id", - "conditions", - "procedures", - "drugs", - "mortality", - ] - - print(f"\nFirst sample structure:") - print(f" Sample keys: {list(sample.keys())}") - - for key in required_keys: - self.assertIn(key, sample, f"Sample should contain key: {key}") - if key in ["conditions", "procedures", "drugs"]: - print(f" - {key}: {len(sample[key])} items") - else: - print(f" - {key}: {sample[key]}") - - # Verify mortality label is binary (0 or 1) - self.assertIn( - sample["mortality"], [0, 1], "Mortality label should be 0 or 1" - ) - - # Count mortality distribution - mortality_counts = {0: 0, 1: 0} - for s in sample_dataset: - mortality_counts[int(s["mortality"].item())] += 1 - print(f"\nMortality label distribution:") - print(f" Survived (0): {mortality_counts[0]} ({mortality_counts[0]/num_samples*100:.1f}%)") - print(f" Died (1): {mortality_counts[1]} ({mortality_counts[1]/num_samples*100:.1f}%)") - - print(f"\n✓ test_mortality_prediction_mimic3_set_task() passed successfully") - - except Exception as e: - print(f"✗ Failed with error: {e}") - import traceback - traceback.print_exc() - self.fail(f"Failed to use set_task with MortalityPredictionMIMIC3: {e}") - - @unittest.skip("Skipping multimodal test - noteevents not included in test resources") - def test_multimodal_mortality_prediction_mimic3_set_task(self): - """Test MultimodalMortalityPredictionMIMIC3 task with set_task() method.""" - task = MultimodalMortalityPredictionMIMIC3() - - # Test that task is properly initialized - self.assertEqual(task.task_name, "MultimodalMortalityPredictionMIMIC3") - self.assertIn("conditions", task.input_schema) - self.assertIn("procedures", task.input_schema) - self.assertIn("drugs", task.input_schema) - self.assertIn("clinical_notes", task.input_schema) - self.assertIn("mortality", task.output_schema) - - # Test using set_task method - try: - sample_dataset = self.dataset.set_task(task) - self.assertIsNotNone(sample_dataset, "set_task should return a dataset") - self.assertTrue( - hasattr(sample_dataset, "samples"), "Sample dataset should have samples" - ) - - # Verify we got some samples - self.assertGreater( - len(sample_dataset.samples), 0, "Should generate at least one sample" - ) - - # Test sample structure - if len(sample_dataset.samples) > 0: - sample = sample_dataset.samples[0] - required_keys = [ - "hadm_id", - "patient_id", - "conditions", - "procedures", - "drugs", - "clinical_notes", - "mortality", - ] - for key in required_keys: - self.assertIn(key, sample, f"Sample should contain key: {key}") - - # Verify data types - self.assertIsInstance( - sample["clinical_notes"], str, "clinical_notes should be a string" - ) - self.assertIn( - sample["mortality"], [0, 1], "Mortality label should be 0 or 1" - ) - - print(f"Generated {len(sample_dataset.samples)} multimodal samples") - print(f"Clinical notes length: {len(sample['clinical_notes'])}") - - except Exception as e: - self.fail( - f"Failed to use set_task with MultimodalMortalityPredictionMIMIC3: {e}" - ) +from pyhealth.tasks import MortalityPredictionMIMIC3 + + +class TestMortalityPredictionMIMIC3(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.cache_dir = tempfile.TemporaryDirectory() + + dataset = MIMIC3Dataset( + root=str( + Path(__file__).parent.parent.parent + / "test-resources" + / "core" + / "mimic3demo" + ), + tables=["diagnoses_icd", "procedures_icd", "prescriptions"], + cache_dir=cls.cache_dir.name, + ) + + cls.samples = dataset.set_task(MortalityPredictionMIMIC3()) + + @classmethod + def tearDownClass(cls): + cls.samples.close() + + def test_task_schema(self): + self.assertIn("task_name", vars(MortalityPredictionMIMIC3)) + self.assertIn("input_schema", vars(MortalityPredictionMIMIC3)) + self.assertIn("output_schema", vars(MortalityPredictionMIMIC3)) + + self.assertEqual( + "MortalityPredictionMIMIC3", + MortalityPredictionMIMIC3.task_name, + ) + self.assertIn("conditions", MortalityPredictionMIMIC3.input_schema) + self.assertIn("procedures", MortalityPredictionMIMIC3.input_schema) + self.assertIn("drugs", MortalityPredictionMIMIC3.input_schema) + self.assertIn("mortality", MortalityPredictionMIMIC3.output_schema) + + def test_sample_schema(self): + for sample in self.samples: + self.assertIn("patient_id", sample) + self.assertIn("hadm_id", sample) + self.assertIn("conditions", sample) + self.assertIn("procedures", sample) + self.assertIn("drugs", sample) + self.assertIn("mortality", sample) + + def test_mortality_label_is_binary(self): + for sample in self.samples: + label = int(sample["mortality"].item()) + self.assertIn(label, [0, 1]) + + def test_mortality_label_from_next_visit(self): + """Patient 10059: visit 142582 (expire=0) then 122098 (expire=1). + + Mortality label is derived from the NEXT visit's expire flag, + so visit 142582 should have mortality=1. + """ + labels = [ + int(s["mortality"].item()) + for s in self.samples + if s["hadm_id"] == "142582" + ] + + self.assertEqual(len(labels), 1) + self.assertEqual(labels[0], 1) + + def test_surviving_next_visit(self): + """Patient 10119: visit 157466 (expire=0) then 165436 (expire=0). + + Next visit also survived, so mortality=0. + """ + labels = [ + int(s["mortality"].item()) + for s in self.samples + if s["hadm_id"] == "157466" + ] + + self.assertEqual(len(labels), 1) + self.assertEqual(labels[0], 0) + + def test_last_visit_excluded(self): + """Patient 10059: last visit 122098 should not appear. + + The task drops the last visit because there is no next visit + to derive the mortality label from. + """ + visits = [s["hadm_id"] for s in self.samples] + + self.assertIn("142582", visits) + self.assertNotIn("122098", visits) + + def test_single_visit_patients_excluded(self): + """Patient 10006 has only 1 visit (142345). + + Patients with a single visit cannot produce mortality samples. + """ + patients = [s["patient_id"] for s in self.samples] + visits = [s["hadm_id"] for s in self.samples] + + self.assertNotIn("10006", patients) + self.assertNotIn("142345", visits) + + def test_visit_without_procedures_excluded(self): + """Patient 10117: visit 187023 has no procedures or prescriptions. + + Visits missing any of conditions, procedures, or drugs are + excluded by the task. + """ + visits = [s["hadm_id"] for s in self.samples] + self.assertNotIn("187023", visits) if __name__ == "__main__":