diff --git a/examples/benchmark_perf/benchmark_workers_1.py b/examples/benchmark_perf/benchmark_workers_1.py index 1106ac007..7549a1805 100644 --- a/examples/benchmark_perf/benchmark_workers_1.py +++ b/examples/benchmark_perf/benchmark_workers_1.py @@ -117,6 +117,7 @@ def main(): print("\n[1/2] Loading MIMIC-IV base dataset...") dataset_start = time.time() + base_cache_dir = f"{cache_root}/base_dataset" base_dataset = MIMIC4Dataset( ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", ehr_tables=[ @@ -127,7 +128,7 @@ def main(): "labevents", ], dev=dev, - cache_dir=f"{cache_root}/base_dataset", + cache_dir=base_cache_dir, ) dataset_time = time.time() - dataset_start @@ -140,7 +141,6 @@ def main(): sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(), num_workers=1, - cache_dir=f"{cache_root}/task_samples", ) task_time = time.time() - task_start @@ -148,12 +148,10 @@ def main(): # Measure cache sizes print("\n[3/3] Measuring cache sizes...") - base_cache_dir = f"{cache_root}/base_dataset" - task_cache_dir = f"{cache_root}/task_samples" - base_cache_size = get_directory_size(base_cache_dir) - task_cache_size = get_directory_size(task_cache_dir) - total_cache_size = base_cache_size + task_cache_size + total_cache_size = get_directory_size(base_cache_dir) + task_cache_size = get_directory_size(f"{base_cache_dir}/tasks") + base_cache_size = total_cache_size - task_cache_size print(f"✓ Base dataset cache: {format_size(base_cache_size)}") print(f"✓ Task samples cache: {format_size(task_cache_size)}") diff --git a/examples/benchmark_perf/benchmark_workers_12.py b/examples/benchmark_perf/benchmark_workers_12.py index 01302fa02..207522d97 100644 --- a/examples/benchmark_perf/benchmark_workers_12.py +++ b/examples/benchmark_perf/benchmark_workers_12.py @@ -117,6 +117,7 @@ def main(): print("\n[1/2] Loading MIMIC-IV base dataset...") dataset_start = time.time() + base_cache_dir = f"{cache_root}/base_dataset" base_dataset = MIMIC4Dataset( ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", ehr_tables=[ @@ -127,7 +128,7 @@ def main(): "labevents", ], dev=dev, - cache_dir=f"{cache_root}/base_dataset", + cache_dir=base_cache_dir, ) dataset_time = time.time() - dataset_start @@ -140,7 +141,6 @@ def main(): sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(), num_workers=12, - cache_dir=f"{cache_root}/task_samples", ) task_time = time.time() - task_start @@ -148,12 +148,10 @@ def main(): # Measure cache sizes print("\n[3/3] Measuring cache sizes...") - base_cache_dir = f"{cache_root}/base_dataset" - task_cache_dir = f"{cache_root}/task_samples" - base_cache_size = get_directory_size(base_cache_dir) - task_cache_size = get_directory_size(task_cache_dir) - total_cache_size = base_cache_size + task_cache_size + total_cache_size = get_directory_size(base_cache_dir) + task_cache_size = get_directory_size(f"{base_cache_dir}/tasks") + base_cache_size = total_cache_size - task_cache_size print(f"✓ Base dataset cache: {format_size(base_cache_size)}") print(f"✓ Task samples cache: {format_size(task_cache_size)}") diff --git a/examples/benchmark_perf/benchmark_workers_4.py b/examples/benchmark_perf/benchmark_workers_4.py index 82b56059a..bcae29835 100644 --- a/examples/benchmark_perf/benchmark_workers_4.py +++ b/examples/benchmark_perf/benchmark_workers_4.py @@ -117,6 +117,7 @@ def main(): print("\n[1/2] Loading MIMIC-IV base dataset...") dataset_start = time.time() + base_cache_dir = f"{cache_root}/base_dataset" base_dataset = MIMIC4Dataset( ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", ehr_tables=[ @@ -127,7 +128,7 @@ def main(): "labevents", ], dev=dev, - # cache_dir=f"{cache_root}/base_dataset", + cache_dir=base_cache_dir, ) dataset_time = time.time() - dataset_start @@ -140,7 +141,6 @@ def main(): sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(), num_workers=4, - cache_dir=f"{cache_root}/task_samples_old", ) task_time = time.time() - task_start @@ -148,12 +148,10 @@ def main(): # Measure cache sizes print("\n[3/3] Measuring cache sizes...") - base_cache_dir = f"{cache_root}/base_dataset" - task_cache_dir = f"{cache_root}/task_samples" - base_cache_size = get_directory_size(base_cache_dir) - task_cache_size = get_directory_size(task_cache_dir) - total_cache_size = base_cache_size + task_cache_size + total_cache_size = get_directory_size(base_cache_dir) + task_cache_size = get_directory_size(f"{base_cache_dir}/tasks") + base_cache_size = total_cache_size - task_cache_size print(f"✓ Base dataset cache: {format_size(base_cache_size)}") print(f"✓ Task samples cache: {format_size(task_cache_size)}") diff --git a/examples/benchmark_perf/benchmark_workers_n.py b/examples/benchmark_perf/benchmark_workers_n.py index 08454e16b..6b53091cb 100644 --- a/examples/benchmark_perf/benchmark_workers_n.py +++ b/examples/benchmark_perf/benchmark_workers_n.py @@ -159,13 +159,6 @@ def parse_workers(value: str) -> list[int]: return workers -def ensure_empty_dir(path: str | Path) -> None: - p = Path(path) - if p.exists(): - shutil.rmtree(p) - p.mkdir(parents=True, exist_ok=True) - - def remove_dir(path: str | Path, retries: int = 3, delay: float = 1.0) -> None: """Remove a directory with retry logic for busy file handles.""" p = Path(path) @@ -282,11 +275,8 @@ def main() -> None: print("\n[1/1] Sweeping num_workers (each run reloads dataset + task)...") for w in args.workers: for r in range(args.repeats): - task_cache_dir = cache_root / f"task_samples_w{w}" - # Ensure no cache artifacts before this run. remove_dir(base_cache_dir) - ensure_empty_dir(task_cache_dir) tracker.reset() run_start = time.time() @@ -311,13 +301,13 @@ def main() -> None: sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(), num_workers=w, - cache_dir=str(task_cache_dir), ) task_process_s = time.time() - task_start total_s = time.time() - run_start peak_rss_bytes = tracker.peak_bytes() - task_cache_bytes = get_directory_size(task_cache_dir) + tasks_dir = base_cache_dir / "tasks" + task_cache_bytes = get_directory_size(tasks_dir) # Capture sample count BEFORE cleaning up the cache (litdata needs it). num_samples = len(sample_dataset) @@ -327,7 +317,6 @@ def main() -> None: del base_dataset # Clean up to avoid disk growth across an overnight sweep. - remove_dir(task_cache_dir) remove_dir(base_cache_dir) results.append( diff --git a/examples/benchmark_perf/benchmark_workers_n_drug_recommendation.py b/examples/benchmark_perf/benchmark_workers_n_drug_recommendation.py index 7cfe46e46..2d12f746e 100644 --- a/examples/benchmark_perf/benchmark_workers_n_drug_recommendation.py +++ b/examples/benchmark_perf/benchmark_workers_n_drug_recommendation.py @@ -159,13 +159,6 @@ def parse_workers(value: str) -> list[int]: return workers -def ensure_empty_dir(path: str | Path) -> None: - p = Path(path) - if p.exists(): - shutil.rmtree(p) - p.mkdir(parents=True, exist_ok=True) - - def remove_dir(path: str | Path, retries: int = 3, delay: float = 1.0) -> None: """Remove a directory with retry logic for busy file handles.""" p = Path(path) @@ -284,11 +277,8 @@ def main() -> None: print("\n[1/1] Sweeping num_workers (each run reloads dataset + task)...") for w in args.workers: for r in range(args.repeats): - task_cache_dir = cache_root / f"task_samples_drug_rec_w{w}" - # Ensure no cache artifacts before this run. remove_dir(base_cache_dir) - ensure_empty_dir(task_cache_dir) tracker.reset() run_start = time.time() @@ -313,13 +303,13 @@ def main() -> None: sample_dataset = base_dataset.set_task( DrugRecommendationMIMIC4(), num_workers=w, - cache_dir=str(task_cache_dir), ) task_process_s = time.time() - task_start total_s = time.time() - run_start peak_rss_bytes = tracker.peak_bytes() - task_cache_bytes = get_directory_size(task_cache_dir) + tasks_dir = base_cache_dir / "tasks" + task_cache_bytes = get_directory_size(tasks_dir) # Capture sample count BEFORE cleaning up the cache (litdata needs it). num_samples = len(sample_dataset) @@ -329,7 +319,6 @@ def main() -> None: del base_dataset # Clean up to avoid disk growth across an overnight sweep. - remove_dir(task_cache_dir) remove_dir(base_cache_dir) results.append( diff --git a/examples/benchmark_perf/benchmark_workers_n_length_of_stay.py b/examples/benchmark_perf/benchmark_workers_n_length_of_stay.py index b8603e01f..79c2a6f23 100644 --- a/examples/benchmark_perf/benchmark_workers_n_length_of_stay.py +++ b/examples/benchmark_perf/benchmark_workers_n_length_of_stay.py @@ -159,13 +159,6 @@ def parse_workers(value: str) -> list[int]: return workers -def ensure_empty_dir(path: str | Path) -> None: - p = Path(path) - if p.exists(): - shutil.rmtree(p) - p.mkdir(parents=True, exist_ok=True) - - def remove_dir(path: str | Path, retries: int = 3, delay: float = 1.0) -> None: """Remove a directory with retry logic for busy file handles.""" p = Path(path) @@ -284,11 +277,8 @@ def main() -> None: print("\n[1/1] Sweeping num_workers (each run reloads dataset + task)...") for w in args.workers: for r in range(args.repeats): - task_cache_dir = cache_root / f"task_samples_los_w{w}" - # Ensure no cache artifacts before this run. remove_dir(base_cache_dir) - ensure_empty_dir(task_cache_dir) tracker.reset() run_start = time.time() @@ -313,13 +303,13 @@ def main() -> None: sample_dataset = base_dataset.set_task( LengthOfStayPredictionMIMIC4(), num_workers=w, - cache_dir=str(task_cache_dir), ) task_process_s = time.time() - task_start total_s = time.time() - run_start peak_rss_bytes = tracker.peak_bytes() - task_cache_bytes = get_directory_size(task_cache_dir) + tasks_dir = base_cache_dir / "tasks" + task_cache_bytes = get_directory_size(tasks_dir) # Capture sample count BEFORE cleaning up the cache (litdata needs it). num_samples = len(sample_dataset) @@ -329,7 +319,6 @@ def main() -> None: del base_dataset # Clean up to avoid disk growth across an overnight sweep. - remove_dir(task_cache_dir) remove_dir(base_cache_dir) results.append( diff --git a/examples/clinical_tasks/dka_mimic4.py b/examples/clinical_tasks/dka_mimic4.py index e83f0dbbb..d33a7c0ac 100644 --- a/examples/clinical_tasks/dka_mimic4.py +++ b/examples/clinical_tasks/dka_mimic4.py @@ -31,18 +31,17 @@ def main(): """Main function to run DKA prediction pipeline on general population.""" - + # Configuration MIMIC4_ROOT = "/srv/local/data/physionet.org/files/mimiciv/2.2/" DATASET_CACHE_DIR = "/shared/rsaas/pyhealth/cache/mimic4_dataset" - TASK_CACHE_DIR = "/shared/rsaas/pyhealth/cache/mimic4_dka_general_stagenet" PROCESSOR_DIR = "/shared/rsaas/pyhealth/processors/stagenet_dka_general_mimic4" DEVICE = "cuda:5" if torch.cuda.is_available() else "cpu" - + print("=" * 60) print("DKA PREDICTION (GENERAL POPULATION) WITH STAGENET ON MIMIC-IV") print("=" * 60) - + # STEP 1: Load MIMIC-IV base dataset print("\n=== Step 1: Loading MIMIC-IV Dataset ===") base_dataset = MIMIC4Dataset( @@ -56,30 +55,29 @@ def main(): cache_dir=DATASET_CACHE_DIR, # dev=True, # Uncomment for faster development iteration ) - + print("Dataset initialized, proceeding to task processing...") - + # STEP 2: Apply DKA prediction task (general population) print("\n=== Step 2: Applying DKA Prediction Task (General Population) ===") - + # Create task with padding for unseen sequences # No T1DM filtering - includes ALL patients dka_task = DKAPredictionMIMIC4(padding=10) - + print(f"Task: {dka_task.task_name}") print(f"Input schema: {list(dka_task.input_schema.keys())}") print(f"Output schema: {list(dka_task.output_schema.keys())}") print("Note: This includes ALL patients (not just diabetics)") - + # Check for pre-fitted processors if os.path.exists(os.path.join(PROCESSOR_DIR, "input_processors.pkl")): print("\nLoading pre-fitted processors...") input_processors, output_processors = load_processors(PROCESSOR_DIR) - + sample_dataset = base_dataset.set_task( dka_task, num_workers=4, - cache_dir=TASK_CACHE_DIR, input_processors=input_processors, output_processors=output_processors, ) @@ -88,25 +86,24 @@ def main(): sample_dataset = base_dataset.set_task( dka_task, num_workers=4, - cache_dir=TASK_CACHE_DIR, ) - + # Save processors for future runs print("Saving processors...") os.makedirs(PROCESSOR_DIR, exist_ok=True) save_processors(sample_dataset, PROCESSOR_DIR) - + print(f"\nTotal samples: {len(sample_dataset)}") - + # Count label distribution label_counts = {0: 0, 1: 0} for sample in sample_dataset: label_counts[int(sample["label"].item())] += 1 - + print(f"Label distribution:") print(f" No DKA (0): {label_counts[0]} ({100*label_counts[0]/len(sample_dataset):.1f}%)") print(f" Has DKA (1): {label_counts[1]} ({100*label_counts[1]/len(sample_dataset):.1f}%)") - + # Inspect a sample sample = sample_dataset[0] print("\nSample structure:") @@ -114,22 +111,22 @@ def main(): print(f" ICD codes (diagnoses + procedures): {sample['icd_codes'][1].shape} (visits x codes)") print(f" Labs: {sample['labs'][0].shape} (timesteps x features)") print(f" Label: {sample['label']}") - + # STEP 3: Split dataset print("\n=== Step 3: Splitting Dataset ===") train_dataset, val_dataset, test_dataset = split_by_patient( sample_dataset, [0.8, 0.1, 0.1] ) - + print(f"Train: {len(train_dataset)} samples") print(f"Validation: {len(val_dataset)} samples") print(f"Test: {len(test_dataset)} samples") - + # Create dataloaders train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True) val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False) test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False) - + # STEP 4: Initialize StageNet model print("\n=== Step 4: Initializing StageNet Model ===") model = StageNet( @@ -139,10 +136,10 @@ def main(): levels=3, dropout=0.3, ) - + num_params = sum(p.numel() for p in model.parameters()) print(f"Model parameters: {num_params:,}") - + # STEP 5: Train the model print("\n=== Step 5: Training Model ===") trainer = Trainer( @@ -150,7 +147,7 @@ def main(): device=DEVICE, metrics=["pr_auc", "roc_auc", "accuracy", "f1"], ) - + trainer.train( train_dataloader=train_loader, val_dataloader=val_loader, @@ -158,27 +155,27 @@ def main(): monitor="roc_auc", optimizer_params={"lr": 1e-5}, ) - + # STEP 6: Evaluate on test set print("\n=== Step 6: Evaluation ===") results = trainer.evaluate(test_loader) print("\nTest Results:") for metric, value in results.items(): print(f" {metric}: {value:.4f}") - + # STEP 7: Inspect model predictions print("\n=== Step 7: Sample Predictions ===") sample_batch = next(iter(test_loader)) with torch.no_grad(): output = model(**sample_batch) - + print(f"Predicted probabilities: {output['y_prob'][:5]}") print(f"True labels: {output['y_true'][:5]}") - + print("\n" + "=" * 60) print("DKA PREDICTION (GENERAL POPULATION) TRAINING COMPLETED!") print("=" * 60) - + return results diff --git a/examples/clinical_tasks/dka_mimic4_stageattn.py b/examples/clinical_tasks/dka_mimic4_stageattn.py index 852b7afce..2f2f5a5a8 100644 --- a/examples/clinical_tasks/dka_mimic4_stageattn.py +++ b/examples/clinical_tasks/dka_mimic4_stageattn.py @@ -43,7 +43,6 @@ # - This ensures consistent encoding and saves computation time # - Processors include vocabulary mappings and sequence length statistics processor_dir = "/home/yongdaf2/dka_sa/processors" - cache_dir = "/home/yongdaf2/dka_sa/cache" if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")): print("\n=== Loading Pre-fitted Processors ===") @@ -52,7 +51,6 @@ sample_dataset = base_dataset.set_task( DKAPredictionMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors, ) @@ -61,7 +59,6 @@ sample_dataset = base_dataset.set_task( DKAPredictionMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, ) # Save processors for future runs diff --git a/examples/clinical_tasks/dka_mimic4_stagenet.py b/examples/clinical_tasks/dka_mimic4_stagenet.py index 7ba68b353..40a30d03f 100644 --- a/examples/clinical_tasks/dka_mimic4_stagenet.py +++ b/examples/clinical_tasks/dka_mimic4_stagenet.py @@ -43,7 +43,6 @@ # - This ensures consistent encoding and saves computation time # - Processors include vocabulary mappings and sequence length statistics processor_dir = "/home/yongdaf2/dka_sn/processors" - cache_dir = "/home/yongdaf2/dka_sn/cache" if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")): print("\n=== Loading Pre-fitted Processors ===") @@ -52,7 +51,6 @@ sample_dataset = base_dataset.set_task( DKAPredictionMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors, ) @@ -61,7 +59,6 @@ sample_dataset = base_dataset.set_task( DKAPredictionMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, ) # Save processors for future runs diff --git a/examples/clinical_tasks/dka_mimic4_transformer.py b/examples/clinical_tasks/dka_mimic4_transformer.py index 65056ef3c..c17af6376 100644 --- a/examples/clinical_tasks/dka_mimic4_transformer.py +++ b/examples/clinical_tasks/dka_mimic4_transformer.py @@ -43,7 +43,6 @@ # - This ensures consistent encoding and saves computation time # - Processors include vocabulary mappings and sequence length statistics processor_dir = "/home/yongdaf2/dka_tf/processors" - cache_dir = "/home/yongdaf2/dka_tf/cache" if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")): print("\n=== Loading Pre-fitted Processors ===") @@ -52,7 +51,6 @@ sample_dataset = base_dataset.set_task( DKAPredictionMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors, ) @@ -61,7 +59,6 @@ sample_dataset = base_dataset.set_task( DKAPredictionMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, ) # Save processors for future runs diff --git a/examples/clinical_tasks/t1d_mimic4_stageattn.py b/examples/clinical_tasks/t1d_mimic4_stageattn.py index 21682d95f..d53b17a85 100644 --- a/examples/clinical_tasks/t1d_mimic4_stageattn.py +++ b/examples/clinical_tasks/t1d_mimic4_stageattn.py @@ -43,7 +43,6 @@ # - This ensures consistent encoding and saves computation time # - Processors include vocabulary mappings and sequence length statistics processor_dir = "/home/yongdaf2/t1d_sa/processors" - cache_dir = "/home/yongdaf2/t1d_sa/cache" if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")): print("\n=== Loading Pre-fitted Processors ===") @@ -52,7 +51,6 @@ sample_dataset = base_dataset.set_task( T1DDKAPredictionMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors, ) @@ -61,7 +59,6 @@ sample_dataset = base_dataset.set_task( T1DDKAPredictionMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, ) # Save processors for future runs diff --git a/examples/clinical_tasks/t1d_mimic4_stagenet.py b/examples/clinical_tasks/t1d_mimic4_stagenet.py index 6e3f3724a..3364b59b7 100644 --- a/examples/clinical_tasks/t1d_mimic4_stagenet.py +++ b/examples/clinical_tasks/t1d_mimic4_stagenet.py @@ -43,7 +43,6 @@ # - This ensures consistent encoding and saves computation time # - Processors include vocabulary mappings and sequence length statistics processor_dir = "/home/yongdaf2/t1d_sn/processors" - cache_dir = "/home/yongdaf2/t1d_sn/cache" if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")): print("\n=== Loading Pre-fitted Processors ===") @@ -52,7 +51,6 @@ sample_dataset = base_dataset.set_task( T1DDKAPredictionMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors, ) @@ -61,7 +59,6 @@ sample_dataset = base_dataset.set_task( T1DDKAPredictionMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, ) # Save processors for future runs diff --git a/examples/clinical_tasks/t1dka_mimic4.py b/examples/clinical_tasks/t1dka_mimic4.py index 9d3705650..8bc56fc27 100644 --- a/examples/clinical_tasks/t1dka_mimic4.py +++ b/examples/clinical_tasks/t1dka_mimic4.py @@ -29,18 +29,17 @@ def main(): """Main function to run T1D DKA prediction pipeline.""" - + # Configuration MIMIC4_ROOT = "/srv/local/data/physionet.org/files/mimiciv/2.2/" DATASET_CACHE_DIR = "/shared/rsaas/pyhealth/cache/mimic4_dataset" - TASK_CACHE_DIR = "/shared/rsaas/pyhealth/cache/mimic4_t1d_dka_stagenet_v2" PROCESSOR_DIR = "/shared/rsaas/pyhealth/processors/stagenet_t1d_dka_mimic4_v2" DEVICE = "cuda:5" if torch.cuda.is_available() else "cpu" - + print("=" * 60) print("T1D DKA PREDICTION WITH STAGENET ON MIMIC-IV") print("=" * 60) - + # STEP 1: Load MIMIC-IV base dataset print("\n=== Step 1: Loading MIMIC-IV Dataset ===") base_dataset = MIMIC4Dataset( @@ -54,29 +53,28 @@ def main(): cache_dir=DATASET_CACHE_DIR, # dev=True, # Uncomment for faster development iteration ) - + print("Dataset initialized, proceeding to task processing...") - + # STEP 2: Apply T1D DKA prediction task print("\n=== Step 2: Applying T1D DKA Prediction Task ===") - + # Create task with 90-day DKA window and padding for unseen sequences dka_task = T1DDKAPredictionMIMIC4(dka_window_days=90, padding=20) - + print(f"Task: {dka_task.task_name}") print(f"DKA window: {dka_task.dka_window_days} days") print(f"Input schema: {list(dka_task.input_schema.keys())}") print(f"Output schema: {list(dka_task.output_schema.keys())}") - + # Check for pre-fitted processors if os.path.exists(os.path.join(PROCESSOR_DIR, "input_processors.pkl")): print("\nLoading pre-fitted processors...") input_processors, output_processors = load_processors(PROCESSOR_DIR) - + sample_dataset = base_dataset.set_task( dka_task, num_workers=4, - cache_dir=TASK_CACHE_DIR, input_processors=input_processors, output_processors=output_processors, ) @@ -85,25 +83,24 @@ def main(): sample_dataset = base_dataset.set_task( dka_task, num_workers=4, - cache_dir=TASK_CACHE_DIR, ) - + # Save processors for future runs print("Saving processors...") os.makedirs(PROCESSOR_DIR, exist_ok=True) save_processors(sample_dataset, PROCESSOR_DIR) - + print(f"\nTotal samples: {len(sample_dataset)}") - + # Count label distribution label_counts = {0: 0, 1: 0} for sample in sample_dataset: label_counts[int(sample["label"].item())] += 1 - + print(f"Label distribution:") print(f" No DKA (0): {label_counts[0]} ({100*label_counts[0]/len(sample_dataset):.1f}%)") print(f" Has DKA (1): {label_counts[1]} ({100*label_counts[1]/len(sample_dataset):.1f}%)") - + # Inspect a sample sample = sample_dataset[0] print("\nSample structure:") @@ -111,22 +108,22 @@ def main(): print(f" ICD codes (diagnoses + procedures): {sample['icd_codes'][1].shape} (visits x codes)") print(f" Labs: {sample['labs'][0].shape} (timesteps x features)") print(f" Label: {sample['label']}") - + # STEP 3: Split dataset print("\n=== Step 3: Splitting Dataset ===") train_dataset, val_dataset, test_dataset = split_by_patient( sample_dataset, [0.8, 0.1, 0.1] ) - + print(f"Train: {len(train_dataset)} samples") print(f"Validation: {len(val_dataset)} samples") print(f"Test: {len(test_dataset)} samples") - + # Create dataloaders train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True) val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False) test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False) - + # STEP 4: Initialize StageNet model print("\n=== Step 4: Initializing StageNet Model ===") model = StageNet( @@ -136,10 +133,10 @@ def main(): levels=3, dropout=0.3, ) - + num_params = sum(p.numel() for p in model.parameters()) print(f"Model parameters: {num_params:,}") - + # STEP 5: Train the model print("\n=== Step 5: Training Model ===") trainer = Trainer( @@ -147,7 +144,7 @@ def main(): device=DEVICE, metrics=["pr_auc", "roc_auc", "accuracy", "f1"], ) - + trainer.train( train_dataloader=train_loader, val_dataloader=val_loader, @@ -155,27 +152,27 @@ def main(): monitor="roc_auc", optimizer_params={"lr": 1e-5}, ) - + # STEP 6: Evaluate on test set print("\n=== Step 6: Evaluation ===") results = trainer.evaluate(test_loader) print("\nTest Results:") for metric, value in results.items(): print(f" {metric}: {value:.4f}") - + # STEP 7: Inspect model predictions print("\n=== Step 7: Sample Predictions ===") sample_batch = next(iter(test_loader)) with torch.no_grad(): output = model(**sample_batch) - + print(f"Predicted probabilities: {output['y_prob'][:5]}") print(f"True labels: {output['y_true'][:5]}") - + print("\n" + "=" * 60) print("T1D DKA PREDICTION TRAINING COMPLETED!") print("=" * 60) - + return results diff --git a/examples/concare_mimic4_example.ipynb b/examples/concare_mimic4_example.ipynb index 188858ecb..5d4afef0a 100644 --- a/examples/concare_mimic4_example.ipynb +++ b/examples/concare_mimic4_example.ipynb @@ -14,12 +14,12 @@ "This example demonstrates how to train the ConCare model for in-hospital mortality\n", "prediction using the MIMIC-IV dataset.\n", "\n", - "ConCare (Concare: Personalized clinical feature embedding via capturing the \n", - "healthcare context) is a model that uses channel-wise GRUs and multi-head \n", + "ConCare (Concare: Personalized clinical feature embedding via capturing the\n", + "healthcare context) is a model that uses channel-wise GRUs and multi-head\n", "self-attention to capture feature correlations and temporal patterns in EHR data.\n", "\n", "Reference:\n", - " Liantao Ma et al. Concare: Personalized clinical feature embedding via \n", + " Liantao Ma et al. Concare: Personalized clinical feature embedding via\n", " capturing the healthcare context. AAAI 2020.\n", "\"\"\"" ] @@ -72,9 +72,8 @@ "\n", "# Apply task to dataset and create samples\n", "samples = dataset.set_task(\n", - " task, \n", - " num_workers=10, \n", - " cache_dir=\"./cache_concare_mortality_m4\"\n", + " task,\n", + " num_workers=10,\n", ")" ] }, @@ -117,7 +116,7 @@ "\n", "# Split dataset into train, validation, and test sets\n", "train_dataset, val_dataset, test_dataset = split_by_sample(\n", - " dataset=samples, \n", + " dataset=samples,\n", " ratios=[0.7, 0.1, 0.2]\n", ")\n", "\n", @@ -197,7 +196,7 @@ "\n", "# Initialize trainer with ROC-AUC metric\n", "trainer = Trainer(\n", - " model=model, \n", + " model=model,\n", " metrics=[\"roc_auc\", \"pr_auc\", \"accuracy\"]\n", ")\n", "\n", @@ -340,7 +339,7 @@ ")\n", "\n", "task = InHospitalMortalityMIMIC3()\n", - "samples = dataset.set_task(task, num_workers=10, cache_dir=\"./cache_concare_mortality_m3\")\n", + "samples = dataset.set_task(task, num_workers=10)\n", "\n", "# The rest of the code remains the same\n", "\"\"\"" diff --git a/examples/conformal_eeg/tuev_conventional_conformal.py b/examples/conformal_eeg/tuev_conventional_conformal.py index 3f067b2e0..b62dc73d1 100644 --- a/examples/conformal_eeg/tuev_conventional_conformal.py +++ b/examples/conformal_eeg/tuev_conventional_conformal.py @@ -37,7 +37,7 @@ def parse_args() -> argparse.Namespace: default="downloads/tuev/v2.0.1/edf", help="Path to TUEV edf/ folder.", ) - parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"]) + parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"]) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--epochs", type=int, default=2) @@ -84,7 +84,7 @@ def main() -> None: print("STEP 1: Load TUEV + build task dataset") print("=" * 80) dataset = TUEVDataset(root=str(root), subset=args.subset) - sample_dataset = dataset.set_task(EEGEventsTUEV(), cache_dir="examples/conformal_eeg/cache") + sample_dataset = dataset.set_task(EEGEventsTUEV()) print(f"Task samples: {len(sample_dataset)}") print(f"Input schema: {sample_dataset.input_schema}") diff --git a/examples/conformal_eeg/tuev_covariate_shift_conformal.py b/examples/conformal_eeg/tuev_covariate_shift_conformal.py index 1afcdf08a..41a356a79 100644 --- a/examples/conformal_eeg/tuev_covariate_shift_conformal.py +++ b/examples/conformal_eeg/tuev_covariate_shift_conformal.py @@ -42,7 +42,7 @@ def parse_args() -> argparse.Namespace: default="downloads/tuev/v2.0.1/edf", help="Path to TUEV edf/ folder.", ) - parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"]) + parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"]) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--epochs", type=int, default=3) @@ -89,7 +89,7 @@ def main() -> None: print("STEP 1: Load TUEV + build task dataset") print("=" * 80) dataset = TUEVDataset(root=str(root), subset=args.subset) - sample_dataset = dataset.set_task(EEGEventsTUEV(), cache_dir="examples/conformal_eeg/cache") + sample_dataset = dataset.set_task(EEGEventsTUEV()) print(f"Task samples: {len(sample_dataset)}") print(f"Input schema: {sample_dataset.input_schema}") diff --git a/examples/conformal_eeg/tuev_kmeans_conformal.py b/examples/conformal_eeg/tuev_kmeans_conformal.py index b5a043dad..906883d72 100644 --- a/examples/conformal_eeg/tuev_kmeans_conformal.py +++ b/examples/conformal_eeg/tuev_kmeans_conformal.py @@ -149,7 +149,7 @@ def _run(args: argparse.Namespace) -> None: print("STEP 1: Load TUEV + build task dataset") print("=" * 80) dataset = TUEVDataset(root=str(root), subset=args.subset, dev=args.quick_test) - sample_dataset = dataset.set_task(EEGEventsTUEV(), cache_dir="examples/conformal_eeg/cache") + sample_dataset = dataset.set_task(EEGEventsTUEV()) if args.quick_test and len(sample_dataset) > quick_test_max_samples: sample_dataset = sample_dataset.subset(range(quick_test_max_samples)) print(f"Capped to {quick_test_max_samples} samples for quick-test.") diff --git a/examples/conformal_eeg/tuev_ncp_conformal.py b/examples/conformal_eeg/tuev_ncp_conformal.py index 9b51a7756..c5e207a6a 100644 --- a/examples/conformal_eeg/tuev_ncp_conformal.py +++ b/examples/conformal_eeg/tuev_ncp_conformal.py @@ -290,7 +290,7 @@ def _run(args: argparse.Namespace) -> None: print("STEP 1: Load TUEV + build task dataset") print("=" * 80) dataset = TUEVDataset(root=str(root), subset=args.subset, dev=args.quick_test) - sample_dataset = dataset.set_task(EEGEventsTUEV(), cache_dir="examples/conformal_eeg/cache") + sample_dataset = dataset.set_task(EEGEventsTUEV()) if args.quick_test and len(sample_dataset) > quick_test_max_samples: sample_dataset = sample_dataset.subset(range(quick_test_max_samples)) print(f"Capped to {quick_test_max_samples} samples for quick-test.") diff --git a/examples/cxr/covid19cxr_conformal.py b/examples/cxr/covid19cxr_conformal.py index 4adb2dd8d..1d1bf8508 100644 --- a/examples/cxr/covid19cxr_conformal.py +++ b/examples/cxr/covid19cxr_conformal.py @@ -35,7 +35,7 @@ root = "/home/johnwu3/projects/PyHealth_Branch_Testing/datasets/COVID-19_Radiography_Dataset" base_dataset = COVID19CXRDataset(root) -sample_dataset = base_dataset.set_task(cache_dir="../../covid19cxr_cache") +sample_dataset = base_dataset.set_task() print(f"Total samples: {len(sample_dataset)}") print(f"Task mode: {sample_dataset.output_schema}") diff --git a/examples/cxr/covid19cxr_tutorial.ipynb b/examples/cxr/covid19cxr_tutorial.ipynb index 4eaf64aeb..2a04844c5 100644 --- a/examples/cxr/covid19cxr_tutorial.ipynb +++ b/examples/cxr/covid19cxr_tutorial.ipynb @@ -142,11 +142,10 @@ "source": [ "root = \"/home/johnwu3/projects/PyHealth_Branch_Testing/datasets/COVID-19_Radiography_Dataset\"\n", "base_cache = \"/home/johnwu3/projects/covid19cxr_base_cache\"\n", - "task_cache = \"/home/johnwu3/projects/covid19cxr_task_cache\"\n", "model_checkpoint = \"/home/johnwu3/projects/covid19cxr_vit_model.ckpt\"\n", "\n", "base_dataset = COVID19CXRDataset(root, cache_dir=base_cache, num_workers=4)\n", - "sample_dataset = base_dataset.set_task(cache_dir=task_cache, num_workers=4)\n", + "sample_dataset = base_dataset.set_task(num_workers=4)\n", "\n", "print(f\"Total samples: {len(sample_dataset)}\")\n", "print(f\"Task mode: {sample_dataset.output_schema}\")\n", @@ -1278,7 +1277,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "adfd29de", "metadata": {}, "outputs": [ @@ -1344,7 +1343,7 @@ " **batch\n", " )\n", " attr_map = result[\"image\"] # Keyed by task schema's feature key\n", - " \n", + "\n", " img_display, vit_attr_display, attention_overlay = visualize_image_attr(\n", " image=image[0],\n", " attribution=attr_map[0, 0],\n", diff --git a/examples/cxr/covid19cxr_tutorial.py b/examples/cxr/covid19cxr_tutorial.py index f0b11457b..0f24f4b58 100644 --- a/examples/cxr/covid19cxr_tutorial.py +++ b/examples/cxr/covid19cxr_tutorial.py @@ -26,7 +26,6 @@ DATA_ROOT = "/home/johnwu3/projects/PyHealth_Branch_Testing/datasets" ROOT = f"{DATA_ROOT}/COVID-19_Radiography_Dataset" CACHE = "/home/johnwu3/projects/covid19cxr_base_cache" -TASK_CACHE = "/home/johnwu3/projects/covid19cxr_task_cache" CKPT = "/home/johnwu3/projects/covid19cxr_vit_model.ckpt" SEED = 42 @@ -41,7 +40,7 @@ # Load dataset and create train/val/calibration/test splits dataset = COVID19CXRDataset(ROOT, cache_dir=CACHE, num_workers=8) - sample_dataset = dataset.set_task(cache_dir=TASK_CACHE, num_workers=8) + sample_dataset = dataset.set_task(num_workers=8) train_data, val_data, cal_data, test_data = split_by_sample_conformal( sample_dataset, ratios=[0.6, 0.1, 0.15, 0.15] ) diff --git a/examples/cxr/covid19cxr_tutorial_display.py b/examples/cxr/covid19cxr_tutorial_display.py index 093ff286f..3f6a33b82 100644 --- a/examples/cxr/covid19cxr_tutorial_display.py +++ b/examples/cxr/covid19cxr_tutorial_display.py @@ -26,7 +26,6 @@ DATA_ROOT = "/home/johnwu3/projects/PyHealth_Branch_Testing/datasets" ROOT = f"{DATA_ROOT}/COVID-19_Radiography_Dataset" CACHE = "/home/johnwu3/projects/covid19cxr_base_cache" -TASK_CACHE = "/home/johnwu3/projects/covid19cxr_task_cache" CKPT = "/home/johnwu3/projects/covid19cxr_vit_model.ckpt" SEED = 42 @@ -41,7 +40,7 @@ # Load dataset and create train/val/calibration/test splits dataset = COVID19CXRDataset(ROOT, cache_dir=CACHE, num_workers=8) - sample_dataset = dataset.set_task(cache_dir=TASK_CACHE, num_workers=8) + sample_dataset = dataset.set_task(num_workers=8) train_data, val_data, cal_data, test_data = split_by_sample_conformal( sample_dataset, ratios=[0.6, 0.1, 0.15, 0.15] ) diff --git a/examples/drug_recommendation/drug_recommendation_mimic4_retain.py b/examples/drug_recommendation/drug_recommendation_mimic4_retain.py index 688178a06..496edc7e3 100644 --- a/examples/drug_recommendation/drug_recommendation_mimic4_retain.py +++ b/examples/drug_recommendation/drug_recommendation_mimic4_retain.py @@ -35,7 +35,6 @@ sample_dataset = base_dataset.set_task( DrugRecommendationMIMIC4(), num_workers=4, - cache_dir="../../mimic4_drug_rec_cache", ) print(f"Total samples: {len(sample_dataset)}") diff --git a/examples/interpretability/gim_stagenet_mimic4.py b/examples/interpretability/gim_stagenet_mimic4.py index 8d45ad9a7..b38eb7a87 100644 --- a/examples/interpretability/gim_stagenet_mimic4.py +++ b/examples/interpretability/gim_stagenet_mimic4.py @@ -31,7 +31,6 @@ sample_dataset = dataset.set_task( MortalityPredictionStageNetMIMIC4(), - cache_dir="~/.cache/pyhealth/mimic4_stagenet_mortality", input_processors=input_processors, output_processors=output_processors, ) diff --git a/examples/interpretability/gim_transformer_mimic4.py b/examples/interpretability/gim_transformer_mimic4.py index 374d00d64..884be99a5 100644 --- a/examples/interpretability/gim_transformer_mimic4.py +++ b/examples/interpretability/gim_transformer_mimic4.py @@ -56,7 +56,6 @@ def maybe_load_processors(resource_dir: str, task): sample_dataset = dataset.set_task( task, - cache_dir="~/.cache/pyhealth/mimic4_transformer_mortality", input_processors=input_processors, output_processors=output_processors, ) diff --git a/examples/interpretability/integrated_gradients_benchmark_stagenet.py b/examples/interpretability/integrated_gradients_benchmark_stagenet.py index 82dbca218..c1cdbf0fd 100644 --- a/examples/interpretability/integrated_gradients_benchmark_stagenet.py +++ b/examples/interpretability/integrated_gradients_benchmark_stagenet.py @@ -37,7 +37,6 @@ "20260131-184735/best.ckpt" ) PROCESSOR_DIR = "../output/processors/stagenet_mortality_mimic4" -CACHE_DIR = "../../mimic4_stagenet_cache" def main(): @@ -85,7 +84,6 @@ def main(): sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(), num_workers=8, - cache_dir=CACHE_DIR, input_processors=input_processors, output_processors=output_processors, ) diff --git a/examples/interpretability/integrated_gradients_mortality_mimic4_stagenet.py b/examples/interpretability/integrated_gradients_mortality_mimic4_stagenet.py index b55a8524d..32bab62d5 100644 --- a/examples/interpretability/integrated_gradients_mortality_mimic4_stagenet.py +++ b/examples/interpretability/integrated_gradients_mortality_mimic4_stagenet.py @@ -319,7 +319,6 @@ def main(): sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(), num_workers=8, - cache_dir="../../mimic4_stagenet_cache", input_processors=input_processors, output_processors=output_processors, ) diff --git a/examples/interpretability/interpretability_metrics.py b/examples/interpretability/interpretability_metrics.py index 7adfb55a6..723f63da1 100644 --- a/examples/interpretability/interpretability_metrics.py +++ b/examples/interpretability/interpretability_metrics.py @@ -44,7 +44,6 @@ def main(): sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(), num_workers=4, - cache_dir="../../mimic4_stagenet_cache", ) print(f"✓ Loaded {len(sample_dataset)} samples") diff --git a/examples/interpretability/shap_stagenet_mimic4.ipynb b/examples/interpretability/shap_stagenet_mimic4.ipynb index a871d7326..f2f4e2efd 100644 --- a/examples/interpretability/shap_stagenet_mimic4.ipynb +++ b/examples/interpretability/shap_stagenet_mimic4.ipynb @@ -13,8 +13,8 @@ }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "PyTorch version: 2.9.0+cu126\n", "CUDA available: True\n", @@ -64,8 +64,8 @@ }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Collecting git+https://github.com/naveenkcb/PyHealth.git\n", " Cloning https://github.com/naveenkcb/PyHealth.git to /tmp/pip-req-build-u5cek8co\n", @@ -273,24 +273,24 @@ ] }, { - "output_type": "display_data", "data": { "application/vnd.colab-display-data+json": { + "id": "a21cf9550c9c4d1a916e50ccaf894bf2", "pip_warning": { "packages": [ "numpy", "torch", "torchgen" ] - }, - "id": "a21cf9550c9c4d1a916e50ccaf894bf2" + } } }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Requirement already satisfied: torch in /usr/local/lib/python3.12/dist-packages (2.7.1)\n", "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.12/dist-packages (1.7.2)\n", @@ -372,8 +372,8 @@ }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n", "Data directory: /content/mimic-iv-demo/2.2\n", @@ -430,8 +430,8 @@ }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Model checkpoint should be at: /content/resources/best.ckpt\n", "Checkpoint exists: True\n" @@ -477,205 +477,205 @@ }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Using default EHR config: /usr/local/lib/python3.12/dist-packages/pyhealth/datasets/configs/mimic4_ehr.yaml\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.mimic4:Using default EHR config: /usr/local/lib/python3.12/dist-packages/pyhealth/datasets/configs/mimic4_ehr.yaml\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Memory usage Before initializing mimic4_ehr: 1574.5 MB\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.mimic4:Memory usage Before initializing mimic4_ehr: 1574.5 MB\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Duplicate table names in tables list. Removing duplicates.\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "WARNING:pyhealth.datasets.base_dataset:Duplicate table names in tables list. Removing duplicates.\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Initializing mimic4_ehr dataset from https://physionet.org/files/mimic-iv-demo/2.2/ (dev mode: False)\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Initializing mimic4_ehr dataset from https://physionet.org/files/mimic-iv-demo/2.2/ (dev mode: False)\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Scanning table: admissions from https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Scanning table: admissions from https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Scanning table: procedures_icd from https://physionet.org/files/mimic-iv-demo/2.2/hosp/procedures_icd.csv.gz\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Scanning table: procedures_icd from https://physionet.org/files/mimic-iv-demo/2.2/hosp/procedures_icd.csv.gz\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Scanning table: diagnoses_icd from https://physionet.org/files/mimic-iv-demo/2.2/hosp/diagnoses_icd.csv.gz\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Scanning table: diagnoses_icd from https://physionet.org/files/mimic-iv-demo/2.2/hosp/diagnoses_icd.csv.gz\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Scanning table: icustays from https://physionet.org/files/mimic-iv-demo/2.2/icu/icustays.csv.gz\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Scanning table: icustays from https://physionet.org/files/mimic-iv-demo/2.2/icu/icustays.csv.gz\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Scanning table: patients from https://physionet.org/files/mimic-iv-demo/2.2/hosp/patients.csv.gz\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Scanning table: patients from https://physionet.org/files/mimic-iv-demo/2.2/hosp/patients.csv.gz\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Scanning table: labevents from https://physionet.org/files/mimic-iv-demo/2.2/hosp/labevents.csv.gz\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Scanning table: labevents from https://physionet.org/files/mimic-iv-demo/2.2/hosp/labevents.csv.gz\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/d_labitems.csv.gz\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/d_labitems.csv.gz\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Memory usage After initializing mimic4_ehr: 1574.9 MB\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.mimic4:Memory usage After initializing mimic4_ehr: 1574.9 MB\n" ] }, { - "output_type": "error", "ename": "TypeError", "evalue": "object of type 'MIMIC4EHRDataset' has no len()", + "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", @@ -722,19 +722,19 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "1a4d785d", "metadata": { - "id": "1a4d785d", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "1a4d785d", "outputId": "25686b72-cfd7-4766-c527-3e5e920da519" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "✓ Loaded input processors from /content/resources/input_processors.pkl\n", "✓ Loaded output processors from /content/resources/output_processors.pkl\n", @@ -742,103 +742,103 @@ ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Setting task MortalityPredictionStageNetMIMIC4 for mimic4_ehr base dataset...\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Generating samples with 1 worker(s)...\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Generating samples with 1 worker(s)...\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Collecting global event dataframe...\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Collecting global event dataframe...\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Collected dataframe with shape: (113470, 39)\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Collected dataframe with shape: (113470, 39)\n", "Generating samples for MortalityPredictionStageNetMIMIC4 with 1 worker: 100%|██████████| 100/100 [00:16<00:00, 6.18it/s]" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Caching samples to /content/.cache/pyhealth/mimic4_stagenet_mortality/MortalityPredictionStageNetMIMIC4.parquet\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "\n", "INFO:pyhealth.datasets.base_dataset:Caching samples to /content/.cache/pyhealth/mimic4_stagenet_mortality/MortalityPredictionStageNetMIMIC4.parquet\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Failed to cache samples: failed to determine supertype of list[f64] and list[list[str]]\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "WARNING:pyhealth.datasets.base_dataset:Failed to cache samples: failed to determine supertype of list[f64] and list[list[str]]\n", "Processing samples: 100%|██████████| 100/100 [00:00<00:00, 1923.08it/s]" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Generated 100 samples for task MortalityPredictionStageNetMIMIC4\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "\n", "INFO:pyhealth.datasets.base_dataset:Generated 100 samples for task MortalityPredictionStageNetMIMIC4\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Total samples: 100\n" ] @@ -850,7 +850,6 @@ "\n", "sample_dataset = dataset.set_task(\n", " MortalityPredictionStageNetMIMIC4(),\n", - " cache_dir=\"/content/.cache/pyhealth/mimic4_stagenet_mortality\",\n", " input_processors=input_processors,\n", " output_processors=output_processors,\n", ")\n", @@ -872,16 +871,16 @@ "execution_count": 7, "id": "4594eea4", "metadata": { - "id": "4594eea4", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "4594eea4", "outputId": "a5f32db8-60fb-4fcb-c024-f1a3e6ce861e" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Loaded 0 ICD code descriptions\n" ] @@ -938,16 +937,16 @@ "execution_count": 8, "id": "22f70a91", "metadata": { - "id": "22f70a91", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "22f70a91", "outputId": "4ce86eee-b4f6-4dd6-bc64-900dfbda1f52" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Using device: cuda:0\n", "\n", @@ -995,16 +994,16 @@ "execution_count": 9, "id": "6cbda428", "metadata": { - "id": "6cbda428", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "6cbda428", "outputId": "54f18aa7-fd73-4acc-bda9-4bc28658d998" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Test samples: 20\n" ] @@ -1147,16 +1146,16 @@ "execution_count": 11, "id": "5ae57044", "metadata": { - "id": "5ae57044", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "5ae57044", "outputId": "0a6bd6e9-8168-467a-c5d1-6134b2be9c3b" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "================================================================================\n", "Initializing SHAP Explainer\n", @@ -1202,16 +1201,16 @@ "execution_count": 12, "id": "3cb63b98", "metadata": { - "id": "3cb63b98", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "3cb63b98", "outputId": "695af029-b984-4957-e6de-4ceb92974ed4" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "icd_codes: device=cuda:0\n", "labs: device=cuda:0\n", @@ -1281,16 +1280,16 @@ "execution_count": 13, "id": "c65de0c3", "metadata": { - "id": "c65de0c3", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "c65de0c3", "outputId": "2225e28a-0653-4441-a8c0-611f36e8bd73" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "\n", "================================================================================\n", @@ -1341,16 +1340,16 @@ "execution_count": 14, "id": "93490ab1", "metadata": { - "id": "93490ab1", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "93490ab1", "outputId": "4bff2e31-e7d5-42ad-d63c-ea7e1c02c20b" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "\n", "================================================================================\n", @@ -1433,16 +1432,16 @@ "execution_count": 15, "id": "c7b02451", "metadata": { - "id": "c7b02451", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "c7b02451", "outputId": "5f2b6f99-5432-4db4-f2fa-4227bd335858" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "\n", "================================================================================\n", @@ -1541,16 +1540,16 @@ "execution_count": 16, "id": "4a5c098c", "metadata": { - "id": "4a5c098c", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "4a5c098c", "outputId": "c1d93796-238e-42c1-ac0c-7ae9013030b4" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "\n", "================================================================================\n", @@ -1597,16 +1596,16 @@ "execution_count": 17, "id": "69867127", "metadata": { - "id": "69867127", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "69867127", "outputId": "d366b86c-e7f4-40d7-881f-cf7710acc812" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "\n", "================================================================================\n", @@ -1675,19 +1674,19 @@ } ], "metadata": { - "language_info": { - "name": "python" - }, + "accelerator": "GPU", "colab": { - "provenance": [], - "gpuType": "T4" + "gpuType": "T4", + "provenance": [] }, - "accelerator": "GPU", "kernelspec": { - "name": "python3", - "display_name": "Python 3" + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/examples/interpretability/shap_stagenet_mimic4.py b/examples/interpretability/shap_stagenet_mimic4.py index f376bc62e..a06a9300d 100644 --- a/examples/interpretability/shap_stagenet_mimic4.py +++ b/examples/interpretability/shap_stagenet_mimic4.py @@ -98,7 +98,7 @@ def print_top_attributions( """Print top-k most important features from SHAP attributions.""" if icd_code_to_desc is None: icd_code_to_desc = {} - + for feature_key, attr in attributions.items(): attr_cpu = attr.detach().cpu() if attr_cpu.dim() == 0 or attr_cpu.size(0) == 0: @@ -117,7 +117,7 @@ def print_top_attributions( print(f" Shape: {attr_cpu[0].shape}") print(f" Total attribution sum: {flattened.sum().item():+.6f}") print(f" Mean attribution: {flattened.mean().item():+.6f}") - + k = min(top_k, flattened.numel()) top_values, top_indices = torch.topk(flattened.abs(), k=k) processor = processors.get(feature_key) if processors else None @@ -169,7 +169,6 @@ def main(): sample_dataset = dataset.set_task( MortalityPredictionStageNetMIMIC4(), - cache_dir="~/.cache/pyhealth/mimic4_stagenet_mortality", input_processors=input_processors, output_processors=output_processors, ) @@ -222,7 +221,7 @@ def main(): probs = output["y_prob"] label_key = model.label_key true_label = sample_batch_device[label_key] - + # Handle binary classification (single probability output) if probs.shape[-1] == 1: prob_death = probs[0].item() diff --git a/examples/length_of_stay/length_of_stay_mimic4_ehrmamba.py b/examples/length_of_stay/length_of_stay_mimic4_ehrmamba.py index a5105a39d..26314db9b 100644 --- a/examples/length_of_stay/length_of_stay_mimic4_ehrmamba.py +++ b/examples/length_of_stay/length_of_stay_mimic4_ehrmamba.py @@ -41,7 +41,6 @@ EPOCHS = 20 DATASET_CACHE = os.path.join(CACHE_BASE, "mimic4_ehr_los") -TASK_CACHE = os.path.join(CACHE_BASE, "mimic4_los_ehrmamba") def main(): @@ -54,9 +53,6 @@ def main(): dataset_cache = os.path.join( CACHE_BASE, "mimic4_ehr_los_quick" if quick_test else "mimic4_ehr_los" ) - task_cache = os.path.join( - CACHE_BASE, "mimic4_los_ehrmamba_quick" if quick_test else "mimic4_los_ehrmamba" - ) num_workers = 1 if quick_test else 4 print("EHRMamba – Length of stay (full MIMIC-IV)") @@ -66,7 +62,7 @@ def main(): print("gpu:", gpu_id, "(CUDA_VISIBLE_DEVICES)") print("device:", DEVICE) print("ehr_root:", EHR_ROOT) - print("cache: dataset", dataset_cache, "| task", task_cache) + print("cache:", dataset_cache) print("seed:", SEED, "| batch_size:", BATCH_SIZE, "| epochs:", epochs) t0 = time.perf_counter() @@ -83,7 +79,6 @@ def main(): sample_dataset = dataset.set_task( task, num_workers=num_workers, - cache_dir=task_cache, ) print(f"Task set in {time.perf_counter() - t1:.1f}s | samples: {len(sample_dataset)}") diff --git a/examples/length_of_stay/length_of_stay_mimic4_stageattn.py b/examples/length_of_stay/length_of_stay_mimic4_stageattn.py index 70d15024a..514f7750c 100644 --- a/examples/length_of_stay/length_of_stay_mimic4_stageattn.py +++ b/examples/length_of_stay/length_of_stay_mimic4_stageattn.py @@ -43,7 +43,6 @@ # - This ensures consistent encoding and saves computation time # - Processors include vocabulary mappings and sequence length statistics processor_dir = "/home/yongdaf2/los_sa/processors" - cache_dir = "/home/yongdaf2/los_sa/cache" if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")): print("\n=== Loading Pre-fitted Processors ===") @@ -52,7 +51,6 @@ sample_dataset = base_dataset.set_task( LengthOfStayStageNetMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors, ) @@ -61,7 +59,6 @@ sample_dataset = base_dataset.set_task( LengthOfStayStageNetMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, ) # Save processors for future runs diff --git a/examples/length_of_stay/length_of_stay_mimic4_stagenet.py b/examples/length_of_stay/length_of_stay_mimic4_stagenet.py index 149db7e0b..3c560fe5b 100644 --- a/examples/length_of_stay/length_of_stay_mimic4_stagenet.py +++ b/examples/length_of_stay/length_of_stay_mimic4_stagenet.py @@ -43,7 +43,6 @@ # - This ensures consistent encoding and saves computation time # - Processors include vocabulary mappings and sequence length statistics processor_dir = "/home/yongdaf2/los_sn/processors" - cache_dir = "/home/yongdaf2/los_sn/cache" if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")): print("\n=== Loading Pre-fitted Processors ===") @@ -52,7 +51,6 @@ sample_dataset = base_dataset.set_task( LengthOfStayStageNetMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors, ) @@ -61,7 +59,6 @@ sample_dataset = base_dataset.set_task( LengthOfStayStageNetMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, ) # Save processors for future runs diff --git a/examples/length_of_stay/length_of_stay_mimic4_transformer.py b/examples/length_of_stay/length_of_stay_mimic4_transformer.py index 094be8965..3d7b5aa16 100644 --- a/examples/length_of_stay/length_of_stay_mimic4_transformer.py +++ b/examples/length_of_stay/length_of_stay_mimic4_transformer.py @@ -43,7 +43,6 @@ # - This ensures consistent encoding and saves computation time # - Processors include vocabulary mappings and sequence length statistics processor_dir = "/home/yongdaf2/los_tf/processors" - cache_dir = "/home/yongdaf2/los_tf/cache" if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")): print("\n=== Loading Pre-fitted Processors ===") @@ -52,7 +51,6 @@ sample_dataset = base_dataset.set_task( LengthOfStayStageNetMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors, ) @@ -61,7 +59,6 @@ sample_dataset = base_dataset.set_task( LengthOfStayStageNetMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, ) # Save processors for future runs diff --git a/examples/lime_stagenet_mimic4.py b/examples/lime_stagenet_mimic4.py index fe1a69922..b3b54085d 100644 --- a/examples/lime_stagenet_mimic4.py +++ b/examples/lime_stagenet_mimic4.py @@ -99,7 +99,7 @@ def print_top_attributions( """Print top-k most important features from LIME attributions.""" if icd_code_to_desc is None: icd_code_to_desc = {} - + for feature_key, attr in attributions.items(): attr_cpu = attr.detach().cpu() if attr_cpu.dim() == 0 or attr_cpu.size(0) == 0: @@ -118,7 +118,7 @@ def print_top_attributions( print(f" Shape: {attr_cpu[0].shape}") print(f" Total attribution sum: {flattened.sum().item():+.6f}") print(f" Mean attribution: {flattened.mean().item():+.6f}") - + k = min(top_k, flattened.numel()) top_values, top_indices = torch.topk(flattened.abs(), k=k) processor = processors.get(feature_key) if processors else None @@ -170,7 +170,6 @@ def main(): sample_dataset = dataset.set_task( MortalityPredictionStageNetMIMIC4(), - cache_dir="~/.cache/pyhealth/mimic4_stagenet_mortality", input_processors=input_processors, output_processors=output_processors, ) @@ -225,7 +224,7 @@ def main(): probs = output["y_prob"] label_key = model.label_key true_label = sample_batch_device[label_key] - + # Handle binary classification (single probability output) if probs.shape[-1] == 1: prob_death = probs[0].item() @@ -262,10 +261,10 @@ def main(): print("Positive values increase the mortality prediction, negative values decrease it.") print_top_attributions( - attributions, - sample_batch_device, - input_processors, - top_k=15, + attributions, + sample_batch_device, + input_processors, + top_k=15, icd_code_to_desc=ICD_CODE_TO_DESC, method_name="LIME" ) @@ -336,11 +335,11 @@ def main(): if key in attributions: flat_lasso = attributions[key][0].flatten().abs() flat_ridge = attr_ridge[key][0].flatten().abs() - + k = min(5, flat_lasso.numel()) top_lasso = torch.topk(flat_lasso, k=k) top_ridge = torch.topk(flat_ridge, k=k) - + print(f"\n {key}:") print(f" Lasso non-zero features: {(flat_lasso > 1e-6).sum().item()}/{flat_lasso.numel()}") print(f" Ridge non-zero features: {(flat_ridge > 1e-6).sum().item()}/{flat_ridge.numel()}") diff --git a/examples/mortality_prediction/ehrmamba_mimic4_full.py b/examples/mortality_prediction/ehrmamba_mimic4_full.py index a77e92c60..aabe7fd59 100644 --- a/examples/mortality_prediction/ehrmamba_mimic4_full.py +++ b/examples/mortality_prediction/ehrmamba_mimic4_full.py @@ -41,7 +41,6 @@ EPOCHS = 20 DATASET_CACHE = os.path.join(CACHE_BASE, "mimic4_ehr") -TASK_CACHE = os.path.join(CACHE_BASE, "mimic4_ihm_ehrmamba") def main(): @@ -52,9 +51,6 @@ def main(): dev = quick_test epochs = 2 if quick_test else EPOCHS dataset_cache = os.path.join(CACHE_BASE, "mimic4_ehr_quick" if quick_test else "mimic4_ehr") - task_cache = os.path.join( - CACHE_BASE, "mimic4_ihm_ehrmamba_quick" if quick_test else "mimic4_ihm_ehrmamba" - ) num_workers = 1 if quick_test else 4 print("EHRMamba – In-hospital mortality (full MIMIC-IV)") @@ -64,7 +60,7 @@ def main(): print("gpu:", gpu_id, "(CUDA_VISIBLE_DEVICES)") print("device:", DEVICE) print("ehr_root:", EHR_ROOT) - print("cache: dataset", dataset_cache, "| task", task_cache) + print("cache:", dataset_cache) print("seed:", SEED, "| batch_size:", BATCH_SIZE, "| epochs:", epochs) t0 = time.perf_counter() @@ -81,7 +77,6 @@ def main(): sample_dataset = dataset.set_task( task, num_workers=num_workers, - cache_dir=task_cache, ) print(f"Task set in {time.perf_counter() - t1:.1f}s | samples: {len(sample_dataset)}") diff --git a/examples/mortality_prediction/mimic3_mortality_prediction_cached.ipynb b/examples/mortality_prediction/mimic3_mortality_prediction_cached.ipynb index 8f908bf56..102a87df1 100644 --- a/examples/mortality_prediction/mimic3_mortality_prediction_cached.ipynb +++ b/examples/mortality_prediction/mimic3_mortality_prediction_cached.ipynb @@ -247,7 +247,7 @@ "from pyhealth.tasks.mortality_prediction import MortalityPredictionMIMIC3\n", "from pyhealth.datasets import split_by_patient, get_dataloader\n", "mimic3_mortality_prediction = MortalityPredictionMIMIC3Heterogeneous()\n", - "samples = dataset.set_task(mimic3_mortality_prediction, num_workers=1, cache_dir=\"cache/\") # use default task" + "samples = dataset.set_task(mimic3_mortality_prediction, num_workers=1) # use default task" ] }, { diff --git a/examples/mortality_prediction/mortality_mimic4_stageattn.py b/examples/mortality_prediction/mortality_mimic4_stageattn.py index 18649832b..f302a8600 100644 --- a/examples/mortality_prediction/mortality_mimic4_stageattn.py +++ b/examples/mortality_prediction/mortality_mimic4_stageattn.py @@ -43,7 +43,6 @@ # - This ensures consistent encoding and saves computation time # - Processors include vocabulary mappings and sequence length statistics processor_dir = "/home/yongdaf2/stageattn/processors" - cache_dir = "/home/yongdaf2/stageattn/cache" if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")): print("\n=== Loading Pre-fitted Processors ===") @@ -52,7 +51,6 @@ sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors, ) @@ -61,7 +59,6 @@ sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, ) # Save processors for future runs diff --git a/examples/mortality_prediction/mortality_mimic4_stagenet_v2.py b/examples/mortality_prediction/mortality_mimic4_stagenet_v2.py index 0e267ae1f..acf9598d0 100644 --- a/examples/mortality_prediction/mortality_mimic4_stagenet_v2.py +++ b/examples/mortality_prediction/mortality_mimic4_stagenet_v2.py @@ -160,7 +160,6 @@ def generate_holdout_set( # - This ensures consistent encoding and saves computation time # - Processors include vocabulary mappings and sequence length statistics processor_dir = "../../output/processors/stagenet_mortality_mimic4" - cache_dir = "../../mimic4_stagenet_cache_v3" if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")): print("\n=== Loading Pre-fitted Processors ===") @@ -169,7 +168,6 @@ def generate_holdout_set( sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(padding=20), num_workers=4, - cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors, ) @@ -178,7 +176,6 @@ def generate_holdout_set( sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(padding=20), num_workers=4, - cache_dir=cache_dir, ) # Save processors for future runs diff --git a/examples/mortality_prediction/mortality_mimic4_transformer.py b/examples/mortality_prediction/mortality_mimic4_transformer.py index 838496849..d2afc585d 100644 --- a/examples/mortality_prediction/mortality_mimic4_transformer.py +++ b/examples/mortality_prediction/mortality_mimic4_transformer.py @@ -43,7 +43,6 @@ # - This ensures consistent encoding and saves computation time # - Processors include vocabulary mappings and sequence length statistics processor_dir = "/home/yongdaf2/mp_tf/processors" - cache_dir = "/home/yongdaf2/mp_tf/cache" if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")): print("\n=== Loading Pre-fitted Processors ===") @@ -52,7 +51,6 @@ sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors, ) @@ -61,7 +59,6 @@ sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, ) # Save processors for future runs diff --git a/examples/mortality_prediction/multimodal_mimic4_demo.py b/examples/mortality_prediction/multimodal_mimic4_demo.py index cec5d610f..bce91e63e 100644 --- a/examples/mortality_prediction/multimodal_mimic4_demo.py +++ b/examples/mortality_prediction/multimodal_mimic4_demo.py @@ -75,14 +75,6 @@ def get_directory_size(path: str | Path) -> int: return total -def ensure_empty_dir(path: str | Path) -> None: - """Ensure directory exists and is empty.""" - p = Path(path) - if p.exists(): - shutil.rmtree(p) - p.mkdir(parents=True, exist_ok=True) - - def remove_dir(path: str | Path, retries: int = 3, delay: float = 1.0) -> None: """Remove a directory with retry logic.""" p = Path(path) @@ -558,7 +550,7 @@ def nested_indices_to_codes(nested_indices, processor): visit_info = lookup_icd_codes(visit_codes, "ICD10CM", max_display=10) if all(info["name"] == "[Unknown code]" for info in visit_info): visit_info = lookup_icd_codes(visit_codes, "ICD9CM", max_display=10) - + codes_preview = ", ".join(info['code'] for info in visit_info) if len(visit_codes) > 10: codes_preview += f" (+{len(visit_codes) - 10} more)" @@ -578,7 +570,7 @@ def nested_indices_to_codes(nested_indices, processor): if all(info["name"] in ["[Unknown code]", "[ICD10PROC unavailable]"] for info in visit_info): visit_info = lookup_icd_codes(visit_codes, "ICD9PROC", max_display=10) - + codes_preview = ", ".join(info['code'] for info in visit_info) if len(visit_codes) > 10: codes_preview += f" (+{len(visit_codes) - 10} more)" @@ -651,7 +643,7 @@ def nested_indices_to_codes(nested_indices, processor): # - lab_times: list of floats from raw processor lab_values_tensor = sample.get("lab_values") lab_times_raw = sample.get("lab_times", []) - + labs_data = None if lab_values_tensor is not None and len(lab_times_raw) > 0: # Convert tensor to list of lists for display functions @@ -870,7 +862,7 @@ def main(): from pyhealth.tasks import MultimodalMortalityPredictionMIMIC4 cache_root = Path(args.cache_dir) - + # Create cache directory name based on configuration cache_suffix = "_dev" if args.dev else "" if args.no_notes and args.no_cxr: @@ -881,9 +873,8 @@ def main(): cache_suffix += "_ehr_notes" else: cache_suffix += "_full" - + base_cache_dir = cache_root / f"base_dataset{cache_suffix}" - task_cache_dir = cache_root / f"task_samples{cache_suffix}" # Initialize memory tracker tracker = PeakMemoryTracker(poll_interval_s=0.1) @@ -946,15 +937,14 @@ def main(): sample_dataset = base_dataset.set_task( task, num_workers=num_workers, - cache_dir=str(task_cache_dir), ) task_process_s = time.time() - task_start total_s = time.time() - run_start peak_rss_bytes = tracker.peak_bytes() - + print(f" ✓ Task completed in {task_process_s:.2f}s", flush=True) - + # Get sample count first (faster, uses cached count if available) print(" Getting sample count...", flush=True) try: @@ -967,12 +957,13 @@ def main(): except Exception as e: print(f" ✗ Error getting sample count: {e}") num_samples = 0 - + # Calculate cache size (can be slow for large directories) if not args.skip_benchmark: print(" Calculating cache size...", flush=True) cache_calc_start = time.time() - task_cache_bytes = get_directory_size(task_cache_dir) + tasks_dir = base_cache_dir / "tasks" + task_cache_bytes = get_directory_size(tasks_dir) cache_calc_time = time.time() - cache_calc_start print(f" ✓ Task cache size: {format_size(task_cache_bytes)} " f"(calculated in {cache_calc_time:.1f}s)") diff --git a/examples/mortality_prediction/multimodal_mimic4_minimal.py b/examples/mortality_prediction/multimodal_mimic4_minimal.py index 7fef80819..093e7764b 100644 --- a/examples/mortality_prediction/multimodal_mimic4_minimal.py +++ b/examples/mortality_prediction/multimodal_mimic4_minimal.py @@ -26,7 +26,7 @@ # Apply multimodal task task = MultimodalMortalityPredictionMIMIC4() - samples = dataset.set_task(task, cache_dir=f"{CACHE_DIR}/task", num_workers=8) + samples = dataset.set_task(task, num_workers=8) # Get and print sample sample = samples[0] diff --git a/examples/mortality_prediction/timeseries_mimic4.ipynb b/examples/mortality_prediction/timeseries_mimic4.ipynb index 3700a33b1..f976bab1b 100644 --- a/examples/mortality_prediction/timeseries_mimic4.ipynb +++ b/examples/mortality_prediction/timeseries_mimic4.ipynb @@ -187,7 +187,7 @@ " from pyhealth.tasks import InHospitalMortalityMIMIC4\n", "\n", " task = InHospitalMortalityMIMIC4()\n", - " samples = dataset.set_task(task, num_workers=4, cache_dir=\"../benchmark_cache/mimic4_ihm_w_pre2/\")\n", + " samples = dataset.set_task(task, num_workers=4)\n", "\n", " from pyhealth.datasets import split_by_sample\n", "\n", diff --git a/examples/mortality_prediction/timeseries_mimic4.py b/examples/mortality_prediction/timeseries_mimic4.py index 85df3c884..f3c675eec 100644 --- a/examples/mortality_prediction/timeseries_mimic4.py +++ b/examples/mortality_prediction/timeseries_mimic4.py @@ -10,7 +10,7 @@ from pyhealth.tasks import InHospitalMortalityMIMIC4 task = InHospitalMortalityMIMIC4() - samples = dataset.set_task(task, num_workers=2, cache_dir="../benchmark_cache/mimic4_ihm/") + samples = dataset.set_task(task, num_workers=2) from pyhealth.datasets import split_by_sample diff --git a/examples/transformer_mimic4.ipynb b/examples/transformer_mimic4.ipynb index f7a435449..37c098521 100644 --- a/examples/transformer_mimic4.ipynb +++ b/examples/transformer_mimic4.ipynb @@ -97,7 +97,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "46bfc544", "metadata": {}, "outputs": [ @@ -161,7 +161,6 @@ "task = InHospitalMortalityMIMIC4()\n", "sample_dataset = dataset.set_task(\n", " task,\n", - " cache_dir=\"../../test_cache_transformer_m4\"\n", ")\n", "train_dataset, val_dataset, test_dataset = split_by_sample(sample_dataset, ratios=[0.7, 0.1, 0.2])" ] diff --git a/examples/tutorial_stagenet_comprehensive.ipynb b/examples/tutorial_stagenet_comprehensive.ipynb index e6ce98399..7e1a9a3fd 100644 --- a/examples/tutorial_stagenet_comprehensive.ipynb +++ b/examples/tutorial_stagenet_comprehensive.ipynb @@ -891,7 +891,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "8e01f7ec", "metadata": {}, "outputs": [ @@ -953,9 +953,8 @@ ], "source": [ "from pyhealth.datasets.utils import save_processors, load_processors\n", - "import os \n", + "import os\n", "processor_dir = \"../../output/processors/stagenet_mortality_mimic4\"\n", - "cache_dir = \"../../mimic4_stagenet_cache_v3\"\n", "\n", "if os.path.exists(os.path.join(processor_dir, \"input_processors.pkl\")):\n", " print(\"\\n=== Loading Pre-fitted Processors ===\")\n", @@ -964,7 +963,6 @@ " sample_dataset = base_dataset.set_task(\n", " MortalityPredictionStageNetMIMIC4(padding=20),\n", " num_workers=1,\n", - " cache_dir=cache_dir,\n", " input_processors=input_processors,\n", " output_processors=output_processors,\n", " )\n", @@ -973,7 +971,6 @@ " sample_dataset = base_dataset.set_task(\n", " MortalityPredictionStageNetMIMIC4(padding=20),\n", " num_workers=1,\n", - " cache_dir=cache_dir,\n", " )\n", "\n", " # Save processors for future runs\n", diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 7faffef60..dcd194bc4 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -125,17 +125,17 @@ def _litdata_merge(cache_dir: Path) -> None: """ from litdata.streaming.writer import _INDEX_FILENAME files = os.listdir(cache_dir) - + # Return if the index already exists if _INDEX_FILENAME in files: return index_files = [f for f in files if f.endswith(_INDEX_FILENAME)] - + # Return if there are no index files to merge if len(index_files) == 0: raise ValueError("There are zero samples in the dataset, please check the task and processors.") - + BinaryWriter(cache_dir=str(cache_dir), chunk_bytes="64MB").merge(num_workers=len(index_files)) @@ -309,6 +309,14 @@ def __init__( tables (List[str]): List of table names to load. dataset_name (Optional[str]): Name of the dataset. Defaults to class name. config_path (Optional[str]): Path to the configuration YAML file. + cache_dir (Optional[str | Path]): Directory for caching processed data. + Behavior depends on the type passed: + + - **None** (default): Auto-generates a cache path under the default + pyhealth cache directory. + - **str** or **Path**: Used as the root cache directory path. A UUID + is appended to the provided path to capture dataset configuration. + num_workers (int): Number of worker processes for parallel operations. dev (bool): Whether to run in dev mode (limits to 1000 patients). """ if len(set(tables)) != len(tables): @@ -326,44 +334,55 @@ def __init__( ) # Cached attributes - self._cache_dir = cache_dir + self.cache_dir = self._init_cache_dir(cache_dir) self._global_event_df = None self._unique_patient_ids = None - @property - def cache_dir(self) -> Path: + def _init_cache_dir(self, cache_dir: str | Path | None) -> Path: """Returns the cache directory path. - The cache structure is as follows:: - tmp/ # Temporary files during processing - global_event_df.parquet/ # Cached global event dataframe - tasks/ # Cached task-specific data, please see set_task method + The cache directory is determined by the type of ``cache_dir`` passed + to ``__init__``: + + - **None**: Auto-generated under default pyhealth cache directory. + - **str** or **Path: Used as the root cache directory path. A UUID + is appended to the provided path to capture dataset configuration. + + The cache structure within the directory is:: + + {dataset_uuid}/ # Cache files for this dataset configuration + tmp/ # Temporary files during processing + global_event_df.parquet/ # Cached global event dataframe + tasks/ # Cached task-specific data + {task_name}_{task_uuid}/ # Cached data for specific task based on task name, schema, and args + task_df.ld/ # Intermediate task dataframe based on schema + samples_{proc_uuid}.ld/ # Final processed samples after applying processors Returns: - Path: The cache directory path. + Path: The resolved cache directory path. """ - if self._cache_dir is None: - id_str = json.dumps( - { - "root": self.root, - "tables": sorted(self.tables), - "dataset_name": self.dataset_name, - "dev": self.dev, - }, - sort_keys=True, - ) - cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / str( - uuid.uuid5(uuid.NAMESPACE_DNS, id_str) - ) + id_str = json.dumps( + { + "root": str(self.root), + "tables": sorted(self.tables), + "dataset_name": self.dataset_name, + "dev": self.dev, + }, + sort_keys=True, + ) + + id = str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str)) + + if cache_dir is None: + cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / id cache_dir.mkdir(parents=True, exist_ok=True) logger.info(f"No cache_dir provided. Using default cache dir: {cache_dir}") - self._cache_dir = cache_dir else: - # Ensure the explicitly provided cache_dir exists - cache_dir = Path(self._cache_dir) + # Ensure separate cache directories for different table configurations by appending a UUID suffix + cache_dir = Path(cache_dir) / id cache_dir.mkdir(parents=True, exist_ok=True) - self._cache_dir = cache_dir - return Path(self._cache_dir) + logger.info(f"Using provided cache_dir: {cache_dir}") + return Path(cache_dir) def create_tmpdir(self) -> Path: """Creates and returns a new temporary directory within the cache. @@ -501,7 +520,10 @@ def global_event_df(self) -> pl.LazyFrame: if self._global_event_df is None: ret_path = self.cache_dir / "global_event_df.parquet" if not ret_path.exists(): + logger.info(f"No cached event dataframe found. Creating: {ret_path}") self._event_transform(ret_path) + else: + logger.info(f"Found cached event dataframe: {ret_path}") self._global_event_df = ret_path return pl.scan_parquet( @@ -814,27 +836,21 @@ def set_task( self, task: Optional[BaseTask] = None, num_workers: Optional[int] = None, - cache_dir: str | Path | None = None, - cache_format: str = "parquet", input_processors: Optional[Dict[str, FeatureProcessor]] = None, output_processors: Optional[Dict[str, FeatureProcessor]] = None, ) -> SampleDataset: """Processes the base dataset to generate the task-specific sample dataset. The cache structure is as follows:: - task_df.ld/ # Intermediate task dataframe after task transformation - samples_{uuid}.ld/ # Final processed samples after applying processors - schema.pkl # Saved SampleBuilder schema - *.bin # Processed sample files - samples_{uuid}.ld/ - ... + {task_name}_{task_uuid}/ # Cached data for specific task based on task name, schema, and args + task_df.ld/ # Intermediate task dataframe based on schema + samples_{proc_uuid}.ld/ # Final processed samples after applying processors + schema.pkl # Saved SampleBuilder schema + *.bin # Processed sample files Args: task (Optional[BaseTask]): The task to set. Uses default task if None. num_workers (int): Number of workers for multi-threading. Default is `self.num_workers`. - cache_dir (Optional[str]): Directory to cache samples after task transformation, - but without applying processors. Default is {self.cache_dir}/tasks/{task_name}_{uuid5(vars(task))}. - cache_format (str): Deprecated. Only "parquet" is supported now. input_processors (Optional[Dict[str, FeatureProcessor]]): Pre-fitted input processors. If provided, these will be used instead of creating new ones from task's input_schema. Defaults to None. @@ -857,31 +873,25 @@ def set_task( if num_workers is None: num_workers = self.num_workers - if cache_format != "parquet": - logger.warning("Only 'parquet' cache_format is supported now. ") - logger.info( f"Setting task {task.task_name} for {self.dataset_name} base dataset..." ) task_params = json.dumps( - vars(task), + { + **vars(task), + "input_schema": task.input_schema, + "output_schema": task.output_schema, + }, sort_keys=True, default=str ) - if cache_dir is None: - cache_dir = self.cache_dir / "tasks" / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}" - cache_dir.mkdir(parents=True, exist_ok=True) - else: - # Ensure the explicitly provided cache_dir exists - cache_dir = Path(cache_dir) - cache_dir.mkdir(parents=True, exist_ok=True) + cache_dir = self.cache_dir / "tasks" / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}" + cache_dir.mkdir(parents=True, exist_ok=True) proc_params = json.dumps( { - "input_schema": task.input_schema, - "output_schema": task.output_schema, "input_processors": ( { f"{k}_{v.__class__.__name__}": vars(v) @@ -906,9 +916,11 @@ def set_task( task_df_path = Path(cache_dir) / "task_df.ld" samples_path = Path(cache_dir) / f"samples_{uuid.uuid5(uuid.NAMESPACE_DNS, proc_params)}.ld" + logger.info(f"Task cache paths: task_df={task_df_path}, samples={samples_path}") + task_df_path.mkdir(parents=True, exist_ok=True) samples_path.mkdir(parents=True, exist_ok=True) - + if not (samples_path / "index.json").exists(): # Check if index.json exists to verify cache integrity, this # is the standard file for litdata.StreamingDataset diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index c7b719f62..134bede35 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -296,6 +296,7 @@ def __init__( """ super().__init__(path, **kwargs) + self.path = path self.dataset_name = "" if dataset_name is None else dataset_name self.task_name = "" if task_name is None else task_name diff --git a/tests/core/test_caching.py b/tests/core/test_caching.py index fa832c31d..4b7f1d512 100644 --- a/tests/core/test_caching.py +++ b/tests/core/test_caching.py @@ -20,16 +20,39 @@ class MockTask(BaseTask): input_schema = {"test_attribute": "raw"} output_schema = {"test_label": "binary"} - def __init__(self, param=None): - self.call_count = 0 - if param: - self.param = param + def __init__(self, param=0): + self.param = param + + def __call__(self, patient): + """Return mock samples based on patient data.""" + # Extract patient's test data from the patient's data source + patient_data = patient.data_source + + samples = [] + for row in patient_data.iter_rows(named=True): + sample = { + "test_attribute": row["test/test_attribute"], + "test_label": row["test/test_label"], + "patient_id": row["patient_id"], + } + samples.append(sample) + + return samples + + +class MockTask2(BaseTask): + """Second mock task with a different output schema than the first""" + task_name = "test_task" + input_schema = {"test_attribute": "raw"} + output_schema = {"test_label": "multiclass"} + + def __init__(self, param=0): + self.param = param def __call__(self, patient): """Return mock samples based on patient data.""" # Extract patient's test data from the patient's data source patient_data = patient.data_source - self.call_count += 1 samples = [] for row in patient_data.iter_rows(named=True): @@ -46,13 +69,13 @@ def __call__(self, patient): class MockDataset(BaseDataset): """Mock dataset for testing purposes.""" - def __init__(self, cache_dir: str | Path | None = None): + def __init__(self, root: str = "", tables = [], dataset_name = "TestDataset", cache_dir: str | Path | None = None, dev = False): super().__init__( - root="", - tables=[], - dataset_name="TestDataset", + root=root, + tables=tables, + dataset_name=dataset_name, cache_dir=cache_dir, - dev=False, + dev=dev, ) def load_data(self) -> dd.DataFrame: @@ -79,19 +102,10 @@ def load_data(self) -> dd.DataFrame: class TestCachingFunctionality(BaseTestCase): """Test cases for caching functionality in BaseDataset.set_task().""" - - @classmethod - def setUpClass(cls): - cls.temp_dir = tempfile.TemporaryDirectory() - cls.dataset = MockDataset(cache_dir=cls.temp_dir.name) - def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.dataset = MockDataset(cache_dir=self.temp_dir.name) self.task = MockTask() - self.cache_dir = Path(self.temp_dir.name) / "task_cache" - self.cache_dir.mkdir() - - def tearDown(self): - shutil.rmtree(self.cache_dir) def test_set_task_signature(self): """Test that set_task has the correct method signature.""" @@ -104,8 +118,6 @@ def test_set_task_signature(self): "self", "task", "num_workers", - "cache_dir", - "cache_format", "input_processors", "output_processors", ] @@ -114,24 +126,25 @@ def test_set_task_signature(self): # Check default values self.assertEqual(sig.parameters["task"].default, None) self.assertEqual(sig.parameters["num_workers"].default, None) - self.assertEqual(sig.parameters["cache_dir"].default, None) - self.assertEqual(sig.parameters["cache_format"].default, "parquet") self.assertEqual(sig.parameters["input_processors"].default, None) self.assertEqual(sig.parameters["output_processors"].default, None) def test_set_task_writes_cache_and_metadata(self): """Ensure set_task materializes cache files and schema metadata.""" - with self.dataset.set_task( - self.task, cache_dir=self.cache_dir, cache_format="parquet" - ) as sample_dataset: + with self.dataset.set_task(self.task) as sample_dataset: self.assertIsInstance(sample_dataset, SampleDataset) self.assertEqual(sample_dataset.dataset_name, "TestDataset") self.assertEqual(sample_dataset.task_name, self.task.task_name) self.assertEqual(len(sample_dataset), 4) - self.assertEqual(self.task.call_count, 2) - # Ensure intermediate cache files are created - self.assertTrue((self.cache_dir / "task_df.ld" / "index.json").exists()) + # Ensure intermediate cache files are created in default location + task_params = json.dumps( + {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "binary"}, "param": 0}, + sort_keys=True, + default=str + ) + task_cache_dir = self.dataset.cache_dir / "tasks" / f"{self.task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}" + self.assertTrue((task_cache_dir / "task_df.ld" / "index.json").exists()) # Cache artifacts should be present for StreamingDataset assert sample_dataset.input_dir.path is not None @@ -156,13 +169,13 @@ def test_set_task_writes_cache_and_metadata(self): self.assertFalse((sample_dir / "index.json").exists()) self.assertFalse((sample_dir / "schema.pkl").exists()) # Ensure intermediate cache files are still present - self.assertTrue((self.cache_dir / "task_df.ld" / "index.json").exists()) + self.assertTrue((task_cache_dir / "task_df.ld" / "index.json").exists()) def test_default_cache_dir_is_used(self): """When cache_dir is omitted, default cache dir should be used.""" task_params = json.dumps( - {"call_count": 0}, + {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "binary"}, "param": 0}, sort_keys=True, default=str ) @@ -179,18 +192,14 @@ def test_default_cache_dir_is_used(self): def test_reuses_existing_cache_without_regeneration(self): """Second call should reuse cached samples instead of recomputing.""" - sample_dataset = self.dataset.set_task(self.task, cache_dir=self.cache_dir) - self.assertEqual(self.task.call_count, 2) + sample_dataset = self.dataset.set_task(self.task) with patch.object( - self.task, "__call__", side_effect=AssertionError("Task should not rerun") + type(self.task), "__call__", side_effect=AssertionError("Task should not rerun") ): - cached_dataset = self.dataset.set_task( - self.task, cache_dir=self.cache_dir, cache_format="parquet" - ) + cached_dataset = self.dataset.set_task(self.task) self.assertEqual(len(cached_dataset), 4) - self.assertEqual(self.task.call_count, 2) sample_dataset.close() cached_dataset.close() @@ -199,14 +208,48 @@ def test_tasks_with_diff_param_values_get_diff_caches(self): sample_dataset1 = self.dataset.set_task(MockTask(param=1)) sample_dataset2 = self.dataset.set_task(MockTask(param=2)) + self.assertNotEqual(sample_dataset1.path, sample_dataset2.path) + + task_params1 = json.dumps( + {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "binary"}, "param": 1}, + sort_keys=True, + default=str + ) + + task_params2 = json.dumps( + {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "binary"}, "param": 2}, + sort_keys=True, + default=str + ) + + task_cache1 = self.dataset.cache_dir / "tasks" / f"{self.task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params1)}" + task_cache2 = self.dataset.cache_dir / "tasks" / f"{self.task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params2)}" + + self.assertTrue(task_cache1.exists()) + self.assertTrue(task_cache2.exists()) + self.assertTrue((task_cache1 / "task_df.ld" / "index.json").exists()) + self.assertTrue((task_cache2 / "task_df.ld" / "index.json").exists()) + self.assertTrue((self.dataset.cache_dir / "global_event_df.parquet").exists()) + self.assertEqual(len(sample_dataset1), 4) + self.assertEqual(len(sample_dataset2), 4) + + sample_dataset1.close() + sample_dataset2.close() + + def test_tasks_with_diff_output_schemas_get_diff_caches(self): + sample_dataset1 = self.dataset.set_task(MockTask()) + sample_dataset2 = self.dataset.set_task(MockTask2()) + + self.assertNotEqual(sample_dataset1.path, sample_dataset2.path) + task_params1 = json.dumps( - {"call_count": 0, "param": 2}, + {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "binary"}, "param": 0}, sort_keys=True, default=str ) task_params2 = json.dumps( - {"call_count": 0, "param": 2}, + {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "multiclass"}, "param": 0}, sort_keys=True, default=str ) @@ -225,6 +268,33 @@ def test_tasks_with_diff_param_values_get_diff_caches(self): sample_dataset1.close() sample_dataset2.close() + def test_datasets_with_diff_roots_get_diff_caches(self): + dataset1 = MockDataset(root=tempfile.TemporaryDirectory().name, cache_dir=self.temp_dir.name) + dataset2 = MockDataset(root=tempfile.TemporaryDirectory().name, cache_dir=self.temp_dir.name) + + self.assertNotEqual(dataset1.cache_dir, dataset2.cache_dir) + + def test_datasets_with_diff_tables_get_diff_caches(self): + dataset1 = MockDataset(tables=["one", "two", ], cache_dir=self.temp_dir.name) + dataset2 = MockDataset(tables=["one", "two", "three"], cache_dir=self.temp_dir.name) + dataset3 = MockDataset(tables=["one", "three"], cache_dir=self.temp_dir.name) + dataset4 = MockDataset(tables=[], cache_dir=self.temp_dir.name) + + caches = [dataset1.cache_dir, dataset2.cache_dir, dataset3.cache_dir, dataset4.cache_dir] + + self.assertEqual(len(caches), len(set(caches))) + + def test_datasets_with_diff_names_get_diff_caches(self): + dataset1 = MockDataset(dataset_name="one", cache_dir=self.temp_dir.name) + dataset2 = MockDataset(dataset_name="two", cache_dir=self.temp_dir.name) + + self.assertNotEqual(dataset1.cache_dir, dataset2.cache_dir) + + def test_datasets_with_diff_dev_values_get_diff_caches(self): + dataset1 = MockDataset(dev=True, cache_dir=self.temp_dir.name) + dataset2 = MockDataset(dev=False, cache_dir=self.temp_dir.name) + + self.assertNotEqual(dataset1.cache_dir, dataset2.cache_dir) if __name__ == "__main__": unittest.main()