From d4f8beb8e8a87ff75d7bdd007b7dfac5d6d96d82 Mon Sep 17 00:00:00 2001 From: John Wu Date: Mon, 9 Feb 2026 16:56:27 -0600 Subject: [PATCH 01/16] init commit, tldr the cache_dir or files should probably indicate table names --- pyhealth/datasets/base_dataset.py | 78 ++++++++++++++++++++++++------- 1 file changed, 60 insertions(+), 18 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 7faffef60..44748e05f 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -309,6 +309,21 @@ def __init__( tables (List[str]): List of table names to load. dataset_name (Optional[str]): Name of the dataset. Defaults to class name. config_path (Optional[str]): Path to the configuration YAML file. + cache_dir (Optional[str | Path]): Directory for caching processed data. + Behavior depends on the type passed: + + - **None** (default): Auto-generates a cache path under the default + pyhealth cache directory with a UUID subdirectory derived from the + dataset configuration (tables, root, dev mode). Different table sets + automatically get separate caches. + - **str**: Treated as a base name. A UUID suffix will be appended to + the directory name to prevent cache collisions between different table + configurations. For example, ``"/my/cache"`` becomes + ``"/my/cache_"``. A warning is logged showing the transformation. + - **Path**: Used as-is with NO modification. Use this when you want + full control over the exact cache directory path. You are responsible + for ensuring different table configurations don't share the same path. + num_workers (int): Number of worker processes for parallel operations. dev (bool): Whether to run in dev mode (limits to 1000 patients). """ if len(set(tables)) != len(tables): @@ -333,37 +348,64 @@ def __init__( @property def cache_dir(self) -> Path: """Returns the cache directory path. - The cache structure is as follows:: + + The cache directory is determined by the type of ``cache_dir`` passed + to ``__init__``: + + - **None**: Auto-generated under default pyhealth cache with UUID subdir. + - **str**: UUID suffix appended to directory name (e.g., ``cache_``). + - **Path**: Used exactly as-is (no UUID appended). Pass ``Path(...)`` to + opt out of automatic UUID suffixing and use an exact directory. + + The cache structure within the directory is:: tmp/ # Temporary files during processing global_event_df.parquet/ # Cached global event dataframe - tasks/ # Cached task-specific data, please see set_task method + tasks/ # Cached task-specific data Returns: - Path: The cache directory path. + Path: The resolved cache directory path. """ + # If already computed (Path object), return it directly. + # This also handles the case where the user passed Path() explicitly + # at init time -- it's used as-is with no modification. + if isinstance(self._cache_dir, Path): + return self._cache_dir + + # Generate UUID based on dataset configuration (tables, root, etc.) + # to ensure different table sets get isolated cache directories. + id_str = json.dumps( + { + "root": self.root, + "tables": sorted(self.tables), + "dataset_name": self.dataset_name, + "dev": self.dev, + }, + sort_keys=True, + ) + cache_uuid = str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str)) + if self._cache_dir is None: - id_str = json.dumps( - { - "root": self.root, - "tables": sorted(self.tables), - "dataset_name": self.dataset_name, - "dev": self.dev, - }, - sort_keys=True, - ) - cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / str( - uuid.uuid5(uuid.NAMESPACE_DNS, id_str) - ) + # No cache_dir provided: use default pyhealth cache with UUID subdir + cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / cache_uuid cache_dir.mkdir(parents=True, exist_ok=True) logger.info(f"No cache_dir provided. Using default cache dir: {cache_dir}") self._cache_dir = cache_dir else: - # Ensure the explicitly provided cache_dir exists - cache_dir = Path(self._cache_dir) + # String provided: append UUID to directory name for table isolation + base_path = Path(self._cache_dir) + cache_dir = base_path.parent / f"{base_path.name}_{cache_uuid}" cache_dir.mkdir(parents=True, exist_ok=True) + logger.warning( + f"cache_dir was provided as a string: '{self._cache_dir}'. " + f"A UUID suffix has been appended for table-specific isolation: " + f"'{cache_dir}'. Different table configurations will use separate " + f"cache directories. To use an exact path with no modification, " + f"pass cache_dir=Path('{self._cache_dir}') instead." + ) self._cache_dir = cache_dir - return Path(self._cache_dir) + + return self._cache_dir def create_tmpdir(self) -> Path: """Creates and returns a new temporary directory within the cache. From 840efbc8c237621a61875be63d0e9b45a171a010 Mon Sep 17 00:00:00 2001 From: John Wu Date: Tue, 10 Feb 2026 14:08:39 -0600 Subject: [PATCH 02/16] better approach is to hash the filename rather than the directory name --- pyhealth/datasets/base_dataset.py | 120 ++++++++++++++++++------------ 1 file changed, 72 insertions(+), 48 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 44748e05f..6c663a77e 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -313,16 +313,14 @@ def __init__( Behavior depends on the type passed: - **None** (default): Auto-generates a cache path under the default - pyhealth cache directory with a UUID subdirectory derived from the - dataset configuration (tables, root, dev mode). Different table sets - automatically get separate caches. - - **str**: Treated as a base name. A UUID suffix will be appended to - the directory name to prevent cache collisions between different table - configurations. For example, ``"/my/cache"`` becomes - ``"/my/cache_"``. A warning is logged showing the transformation. - - **Path**: Used as-is with NO modification. Use this when you want - full control over the exact cache directory path. You are responsible - for ensuring different table configurations don't share the same path. + pyhealth cache directory. Cache files include a UUID in their + filenames (e.g., ``global_event_df_{uuid}.parquet``) derived from + the dataset configuration, so different table sets don't collide. + - **str**: Used as the cache directory path. Cache files include a + UUID in their filenames to prevent collisions between different + table configurations sharing the same directory. + - **Path**: Used as-is with NO modification. Cache files still include + UUID in their filenames for isolation. num_workers (int): Number of worker processes for parallel operations. dev (bool): Whether to run in dev mode (limits to 1000 patients). """ @@ -352,16 +350,19 @@ def cache_dir(self) -> Path: The cache directory is determined by the type of ``cache_dir`` passed to ``__init__``: - - **None**: Auto-generated under default pyhealth cache with UUID subdir. - - **str**: UUID suffix appended to directory name (e.g., ``cache_``). - - **Path**: Used exactly as-is (no UUID appended). Pass ``Path(...)`` to - opt out of automatic UUID suffixing and use an exact directory. + - **None**: Auto-generated under default pyhealth cache directory. + - **str**: Used as-is as the cache directory path. + - **Path**: Used exactly as-is (no modification). + + Cache files within the directory include UUID suffixes in their + filenames (e.g., ``global_event_df_{uuid}.parquet``) to prevent + collisions between different table configurations. The cache structure within the directory is:: - tmp/ # Temporary files during processing - global_event_df.parquet/ # Cached global event dataframe - tasks/ # Cached task-specific data + tmp/ # Temporary files during processing + global_event_df_{uuid}.parquet/ # Cached global event dataframe + tasks/ # Cached task-specific data Returns: Path: The resolved cache directory path. @@ -372,41 +373,44 @@ def cache_dir(self) -> Path: if isinstance(self._cache_dir, Path): return self._cache_dir - # Generate UUID based on dataset configuration (tables, root, etc.) - # to ensure different table sets get isolated cache directories. - id_str = json.dumps( - { - "root": self.root, - "tables": sorted(self.tables), - "dataset_name": self.dataset_name, - "dev": self.dev, - }, - sort_keys=True, - ) - cache_uuid = str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str)) - if self._cache_dir is None: - # No cache_dir provided: use default pyhealth cache with UUID subdir - cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / cache_uuid + # No cache_dir provided: use default pyhealth cache directory + cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / "datasets" cache_dir.mkdir(parents=True, exist_ok=True) logger.info(f"No cache_dir provided. Using default cache dir: {cache_dir}") self._cache_dir = cache_dir else: - # String provided: append UUID to directory name for table isolation - base_path = Path(self._cache_dir) - cache_dir = base_path.parent / f"{base_path.name}_{cache_uuid}" + # String provided: use as-is (file-based isolation via UUID in filenames) + cache_dir = Path(self._cache_dir) cache_dir.mkdir(parents=True, exist_ok=True) - logger.warning( - f"cache_dir was provided as a string: '{self._cache_dir}'. " - f"A UUID suffix has been appended for table-specific isolation: " - f"'{cache_dir}'. Different table configurations will use separate " - f"cache directories. To use an exact path with no modification, " - f"pass cache_dir=Path('{self._cache_dir}') instead." + logger.info( + f"Using cache dir: {cache_dir} " + f"(cache files will include UUID suffix for table isolation)" ) self._cache_dir = cache_dir return self._cache_dir + def _get_cache_uuid(self) -> str: + """Get the cache UUID for this dataset configuration. + + Returns a deterministic UUID computed from tables, root, dataset_name, + and dev mode. This is used to create unique filenames within the cache + directory so that different table configurations don't collide. + """ + if not hasattr(self, '_cache_uuid') or self._cache_uuid is None: + id_str = json.dumps( + { + "root": self.root, + "tables": sorted(self.tables), + "dataset_name": self.dataset_name, + "dev": self.dev, + }, + sort_keys=True, + ) + self._cache_uuid = str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str)) + return self._cache_uuid + def create_tmpdir(self) -> Path: """Creates and returns a new temporary directory within the cache. @@ -541,9 +545,12 @@ def global_event_df(self) -> pl.LazyFrame: self._main_guard(type(self).global_event_df.fget.__name__) # type: ignore if self._global_event_df is None: - ret_path = self.cache_dir / "global_event_df.parquet" + ret_path = self.cache_dir / f"global_event_df_{self._get_cache_uuid()}.parquet" if not ret_path.exists(): + logger.info(f"No cached event dataframe found. Creating: {ret_path}") self._event_transform(ret_path) + else: + logger.info(f"Found cached event dataframe: {ret_path}") self._global_event_df = ret_path return pl.scan_parquet( @@ -864,13 +871,16 @@ def set_task( """Processes the base dataset to generate the task-specific sample dataset. The cache structure is as follows:: - task_df.ld/ # Intermediate task dataframe after task transformation - samples_{uuid}.ld/ # Final processed samples after applying processors - schema.pkl # Saved SampleBuilder schema - *.bin # Processed sample files - samples_{uuid}.ld/ + task_df_{schema_uuid}.ld/ # Intermediate task dataframe (schema-aware) + samples_{proc_uuid}.ld/ # Final processed samples after applying processors + schema.pkl # Saved SampleBuilder schema + *.bin # Processed sample files + samples_{proc_uuid}.ld/ ... + The task_df path includes a hash of the task's input/output schemas, + so changing schemas automatically invalidates the cached task dataframe. + Args: task (Optional[BaseTask]): The task to set. Uses default task if None. num_workers (int): Number of workers for multi-threading. Default is `self.num_workers`. @@ -945,9 +955,23 @@ def set_task( default=str ) - task_df_path = Path(cache_dir) / "task_df.ld" + # Hash based ONLY on task schemas (not the task instance) to avoid + # recursion issues. This ensures task_df is invalidated when schemas change. + task_schema_params = json.dumps( + { + "input_schema": task.input_schema, + "output_schema": task.output_schema, + }, + sort_keys=True, + default=str + ) + task_schema_hash = uuid.uuid5(uuid.NAMESPACE_DNS, task_schema_params) + + task_df_path = Path(cache_dir) / f"task_df_{task_schema_hash}.ld" samples_path = Path(cache_dir) / f"samples_{uuid.uuid5(uuid.NAMESPACE_DNS, proc_params)}.ld" + logger.info(f"Task cache paths: task_df={task_df_path}, samples={samples_path}") + task_df_path.mkdir(parents=True, exist_ok=True) samples_path.mkdir(parents=True, exist_ok=True) From 5e108ba359980c3fbb06c0ddf858799bd338aa59 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 11 Feb 2026 18:49:33 -0600 Subject: [PATCH 03/16] Fix cache dir --- pyhealth/datasets/base_dataset.py | 48 +++++++++++++++++-------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 6c663a77e..3930795f6 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -339,12 +339,11 @@ def __init__( ) # Cached attributes - self._cache_dir = cache_dir + self._cache_dir = self._init_cache_dir(cache_dir) self._global_event_df = None self._unique_patient_ids = None - @property - def cache_dir(self) -> Path: + def _init_cache_dir(self, cache_dir: str | Path | None) -> Path: """Returns the cache directory path. The cache directory is determined by the type of ``cache_dir`` passed @@ -360,36 +359,41 @@ def cache_dir(self) -> Path: The cache structure within the directory is:: - tmp/ # Temporary files during processing - global_event_df_{uuid}.parquet/ # Cached global event dataframe - tasks/ # Cached task-specific data + tmp/ # Temporary files during processing + {uuid}/ # Cache files for this dataset configuration + global_event_df.parquet/ # Cached global event dataframe + tasks/ # Cached task-specific data + {task_name}_{uuid}/ # Cached data for specific task based on task name and its args + task_df_{uuid}.ld/ # Intermediate task dataframe based on schema + samples_{uuid}.ld/ # Final processed samples after applying processors Returns: Path: The resolved cache directory path. """ - # If already computed (Path object), return it directly. - # This also handles the case where the user passed Path() explicitly - # at init time -- it's used as-is with no modification. - if isinstance(self._cache_dir, Path): - return self._cache_dir - - if self._cache_dir is None: - # No cache_dir provided: use default pyhealth cache directory - cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / "datasets" + id_str = json.dumps( + { + "root": self.root, + "tables": sorted(self.tables), + "dataset_name": self.dataset_name, + "dev": self.dev, + }, + sort_keys=True, + ) + + if cache_dir is None: + cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / str( + uuid.uuid5(uuid.NAMESPACE_DNS, id_str) + ) cache_dir.mkdir(parents=True, exist_ok=True) logger.info(f"No cache_dir provided. Using default cache dir: {cache_dir}") self._cache_dir = cache_dir else: - # String provided: use as-is (file-based isolation via UUID in filenames) - cache_dir = Path(self._cache_dir) + # Ensure separate cache directories for different table configurations by appending a UUID suffix + cache_dir = Path(self._cache_dir) / str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str)) cache_dir.mkdir(parents=True, exist_ok=True) - logger.info( - f"Using cache dir: {cache_dir} " - f"(cache files will include UUID suffix for table isolation)" - ) self._cache_dir = cache_dir + return Path(self._cache_dir) - return self._cache_dir def _get_cache_uuid(self) -> str: """Get the cache UUID for this dataset configuration. From fa63c430a26b1b3ad0af27a994568edb21796c6f Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Sun, 15 Feb 2026 19:41:38 +0000 Subject: [PATCH 04/16] Simplify new caching behavior code and add unit tests --- pyhealth/datasets/base_dataset.py | 119 +++++++++------------------- pyhealth/datasets/sample_dataset.py | 1 + tests/core/test_caching.py | 106 +++++++++++++++++++++++-- 3 files changed, 137 insertions(+), 89 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 3930795f6..e7066e264 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -125,17 +125,17 @@ def _litdata_merge(cache_dir: Path) -> None: """ from litdata.streaming.writer import _INDEX_FILENAME files = os.listdir(cache_dir) - + # Return if the index already exists if _INDEX_FILENAME in files: return index_files = [f for f in files if f.endswith(_INDEX_FILENAME)] - + # Return if there are no index files to merge if len(index_files) == 0: raise ValueError("There are zero samples in the dataset, please check the task and processors.") - + BinaryWriter(cache_dir=str(cache_dir), chunk_bytes="64MB").merge(num_workers=len(index_files)) @@ -311,16 +311,11 @@ def __init__( config_path (Optional[str]): Path to the configuration YAML file. cache_dir (Optional[str | Path]): Directory for caching processed data. Behavior depends on the type passed: - - - **None** (default): Auto-generates a cache path under the default - pyhealth cache directory. Cache files include a UUID in their - filenames (e.g., ``global_event_df_{uuid}.parquet``) derived from - the dataset configuration, so different table sets don't collide. - - **str**: Used as the cache directory path. Cache files include a - UUID in their filenames to prevent collisions between different - table configurations sharing the same directory. - - **Path**: Used as-is with NO modification. Cache files still include - UUID in their filenames for isolation. + + - **None** (default): Auto-generates a cache path under the default + pyhealth cache directory. + - **str** or **Path**: Used as the root cache directory path. A UUID + is appended to the provided path to capture dataset configuration. num_workers (int): Number of worker processes for parallel operations. dev (bool): Whether to run in dev mode (limits to 1000 patients). """ @@ -339,7 +334,7 @@ def __init__( ) # Cached attributes - self._cache_dir = self._init_cache_dir(cache_dir) + self.cache_dir = self._init_cache_dir(cache_dir) self._global_event_df = None self._unique_patient_ids = None @@ -350,70 +345,44 @@ def _init_cache_dir(self, cache_dir: str | Path | None) -> Path: to ``__init__``: - **None**: Auto-generated under default pyhealth cache directory. - - **str**: Used as-is as the cache directory path. - - **Path**: Used exactly as-is (no modification). - - Cache files within the directory include UUID suffixes in their - filenames (e.g., ``global_event_df_{uuid}.parquet``) to prevent - collisions between different table configurations. + - **str** or **Path: Used as the root cache directory path. A UUID + is appended to the provided path to capture dataset configuration. The cache structure within the directory is:: - tmp/ # Temporary files during processing - {uuid}/ # Cache files for this dataset configuration + {dataset_uuid}/ # Cache files for this dataset configuration + tmp/ # Temporary files during processing global_event_df.parquet/ # Cached global event dataframe tasks/ # Cached task-specific data - {task_name}_{uuid}/ # Cached data for specific task based on task name and its args - task_df_{uuid}.ld/ # Intermediate task dataframe based on schema - samples_{uuid}.ld/ # Final processed samples after applying processors + {task_name}_{task_uuid}/ # Cached data for specific task based on task name, schema, and args + task_df.ld/ # Intermediate task dataframe based on schema + samples.ld/ # Final processed samples after applying processors Returns: Path: The resolved cache directory path. """ id_str = json.dumps( { - "root": self.root, + "root": str(self.root), "tables": sorted(self.tables), "dataset_name": self.dataset_name, "dev": self.dev, }, sort_keys=True, ) - + + id = str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str)) + if cache_dir is None: - cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / str( - uuid.uuid5(uuid.NAMESPACE_DNS, id_str) - ) + cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / id cache_dir.mkdir(parents=True, exist_ok=True) logger.info(f"No cache_dir provided. Using default cache dir: {cache_dir}") - self._cache_dir = cache_dir else: # Ensure separate cache directories for different table configurations by appending a UUID suffix - cache_dir = Path(self._cache_dir) / str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str)) + cache_dir = Path(cache_dir) / id cache_dir.mkdir(parents=True, exist_ok=True) - self._cache_dir = cache_dir - return Path(self._cache_dir) - - - def _get_cache_uuid(self) -> str: - """Get the cache UUID for this dataset configuration. - - Returns a deterministic UUID computed from tables, root, dataset_name, - and dev mode. This is used to create unique filenames within the cache - directory so that different table configurations don't collide. - """ - if not hasattr(self, '_cache_uuid') or self._cache_uuid is None: - id_str = json.dumps( - { - "root": self.root, - "tables": sorted(self.tables), - "dataset_name": self.dataset_name, - "dev": self.dev, - }, - sort_keys=True, - ) - self._cache_uuid = str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str)) - return self._cache_uuid + logger.info(f"Using provided cache_dir: {cache_dir}") + return Path(cache_dir) def create_tmpdir(self) -> Path: """Creates and returns a new temporary directory within the cache. @@ -549,7 +518,7 @@ def global_event_df(self) -> pl.LazyFrame: self._main_guard(type(self).global_event_df.fget.__name__) # type: ignore if self._global_event_df is None: - ret_path = self.cache_dir / f"global_event_df_{self._get_cache_uuid()}.parquet" + ret_path = self.cache_dir / f"global_event_df.parquet" if not ret_path.exists(): logger.info(f"No cached event dataframe found. Creating: {ret_path}") self._event_transform(ret_path) @@ -875,21 +844,17 @@ def set_task( """Processes the base dataset to generate the task-specific sample dataset. The cache structure is as follows:: - task_df_{schema_uuid}.ld/ # Intermediate task dataframe (schema-aware) - samples_{proc_uuid}.ld/ # Final processed samples after applying processors - schema.pkl # Saved SampleBuilder schema - *.bin # Processed sample files - samples_{proc_uuid}.ld/ - ... - - The task_df path includes a hash of the task's input/output schemas, - so changing schemas automatically invalidates the cached task dataframe. + {task_name}_{task_uuid}/ # Cached data for specific task based on task name, schema, and args + task_df.ld/ # Intermediate task dataframe based on schema + samples.ld/ # Final processed samples after applying processors + schema.pkl # Saved SampleBuilder schema + *.bin # Processed sample files Args: task (Optional[BaseTask]): The task to set. Uses default task if None. num_workers (int): Number of workers for multi-threading. Default is `self.num_workers`. cache_dir (Optional[str]): Directory to cache samples after task transformation, - but without applying processors. Default is {self.cache_dir}/tasks/{task_name}_{uuid5(vars(task))}. + but without applying processors. Default is {self.cache_dir}/tasks. cache_format (str): Deprecated. Only "parquet" is supported now. input_processors (Optional[Dict[str, FeatureProcessor]]): Pre-fitted input processors. If provided, these will be used @@ -921,7 +886,11 @@ def set_task( ) task_params = json.dumps( - vars(task), + { + **vars(task), + "input_schema": task.input_schema, + "output_schema": task.output_schema, + }, sort_keys=True, default=str ) @@ -959,26 +928,14 @@ def set_task( default=str ) - # Hash based ONLY on task schemas (not the task instance) to avoid - # recursion issues. This ensures task_df is invalidated when schemas change. - task_schema_params = json.dumps( - { - "input_schema": task.input_schema, - "output_schema": task.output_schema, - }, - sort_keys=True, - default=str - ) - task_schema_hash = uuid.uuid5(uuid.NAMESPACE_DNS, task_schema_params) - - task_df_path = Path(cache_dir) / f"task_df_{task_schema_hash}.ld" + task_df_path = Path(cache_dir) / "task_df.ld" samples_path = Path(cache_dir) / f"samples_{uuid.uuid5(uuid.NAMESPACE_DNS, proc_params)}.ld" logger.info(f"Task cache paths: task_df={task_df_path}, samples={samples_path}") task_df_path.mkdir(parents=True, exist_ok=True) samples_path.mkdir(parents=True, exist_ok=True) - + if not (samples_path / "index.json").exists(): # Check if index.json exists to verify cache integrity, this # is the standard file for litdata.StreamingDataset diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index c7b719f62..134bede35 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -296,6 +296,7 @@ def __init__( """ super().__init__(path, **kwargs) + self.path = path self.dataset_name = "" if dataset_name is None else dataset_name self.task_name = "" if task_name is None else task_name diff --git a/tests/core/test_caching.py b/tests/core/test_caching.py index fa832c31d..0a6ac7dfe 100644 --- a/tests/core/test_caching.py +++ b/tests/core/test_caching.py @@ -43,16 +43,45 @@ def __call__(self, patient): return samples +class MockTask2(BaseTask): + """Second mock task with a different output schema than the first""" + task_name = "test_task" + input_schema = {"test_attribute": "raw"} + output_schema = {"test_label": "multiclass"} + + def __init__(self, param=None): + self.call_count = 0 + if param: + self.param = param + + def __call__(self, patient): + """Return mock samples based on patient data.""" + # Extract patient's test data from the patient's data source + patient_data = patient.data_source + self.call_count += 1 + + samples = [] + for row in patient_data.iter_rows(named=True): + sample = { + "test_attribute": row["test/test_attribute"], + "test_label": row["test/test_label"], + "patient_id": row["patient_id"], + } + samples.append(sample) + + return samples + + class MockDataset(BaseDataset): """Mock dataset for testing purposes.""" - def __init__(self, cache_dir: str | Path | None = None): + def __init__(self, root: str = "", tables = [], dataset_name = "TestDataset", cache_dir: str | Path | None = None, dev = False): super().__init__( - root="", - tables=[], - dataset_name="TestDataset", + root=root, + tables=tables, + dataset_name=dataset_name, cache_dir=cache_dir, - dev=False, + dev=dev, ) def load_data(self) -> dd.DataFrame: @@ -162,7 +191,7 @@ def test_set_task_writes_cache_and_metadata(self): def test_default_cache_dir_is_used(self): """When cache_dir is omitted, default cache dir should be used.""" task_params = json.dumps( - {"call_count": 0}, + {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "binary"}, "call_count": 0}, sort_keys=True, default=str ) @@ -199,14 +228,16 @@ def test_tasks_with_diff_param_values_get_diff_caches(self): sample_dataset1 = self.dataset.set_task(MockTask(param=1)) sample_dataset2 = self.dataset.set_task(MockTask(param=2)) + self.assertNotEqual(sample_dataset1.path, sample_dataset2.path) + task_params1 = json.dumps( - {"call_count": 0, "param": 2}, + {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "binary"}, "call_count": 0, "param": 1}, sort_keys=True, default=str ) task_params2 = json.dumps( - {"call_count": 0, "param": 2}, + {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "binary"}, "call_count": 0, "param": 2}, sort_keys=True, default=str ) @@ -225,6 +256,65 @@ def test_tasks_with_diff_param_values_get_diff_caches(self): sample_dataset1.close() sample_dataset2.close() + def test_tasks_with_diff_output_schemas_get_diff_caches(self): + sample_dataset1 = self.dataset.set_task(MockTask()) + sample_dataset2 = self.dataset.set_task(MockTask2()) + + self.assertNotEqual(sample_dataset1.path, sample_dataset2.path) + + task_params1 = json.dumps( + {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "binary"}, "call_count": 0}, + sort_keys=True, + default=str + ) + + task_params2 = json.dumps( + {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "multiclass"}, "call_count": 0}, + sort_keys=True, + default=str + ) + + task_cache1 = self.dataset.cache_dir / "tasks" / f"{self.task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params1)}" + task_cache2 = self.dataset.cache_dir / "tasks" / f"{self.task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params2)}" + + self.assertTrue(task_cache1.exists()) + self.assertTrue(task_cache2.exists()) + self.assertTrue((task_cache1 / "task_df.ld" / "index.json").exists()) + self.assertTrue((task_cache2 / "task_df.ld" / "index.json").exists()) + self.assertTrue((self.dataset.cache_dir / "global_event_df.parquet").exists()) + self.assertEqual(len(sample_dataset1), 4) + self.assertEqual(len(sample_dataset2), 4) + + sample_dataset1.close() + sample_dataset2.close() + + def test_datasets_with_diff_roots_get_diff_caches(self): + dataset1 = MockDataset(root=tempfile.TemporaryDirectory().name, cache_dir=self.temp_dir.name) + dataset2 = MockDataset(root=tempfile.TemporaryDirectory().name, cache_dir=self.temp_dir.name) + + self.assertNotEqual(dataset1.cache_dir, dataset2.cache_dir) + + def test_datasets_with_diff_tables_get_diff_caches(self): + dataset1 = MockDataset(tables=["one", "two", ], cache_dir=self.temp_dir.name) + dataset2 = MockDataset(tables=["one", "two", "three"], cache_dir=self.temp_dir.name) + dataset3 = MockDataset(tables=["one", "three"], cache_dir=self.temp_dir.name) + dataset4 = MockDataset(tables=[], cache_dir=self.temp_dir.name) + + caches = [dataset1.cache_dir, dataset2.cache_dir, dataset3.cache_dir, dataset4.cache_dir] + + self.assertEqual(len(caches), len(set(caches))) + + def test_datasets_with_diff_names_get_diff_caches(self): + dataset1 = MockDataset(dataset_name="one", cache_dir=self.temp_dir.name) + dataset2 = MockDataset(dataset_name="two", cache_dir=self.temp_dir.name) + + self.assertNotEqual(dataset1.cache_dir, dataset2.cache_dir) + + def test_datasets_with_diff_dev_values_get_diff_caches(self): + dataset1 = MockDataset(dev=True, cache_dir=self.temp_dir.name) + dataset2 = MockDataset(dev=False, cache_dir=self.temp_dir.name) + + self.assertNotEqual(dataset1.cache_dir, dataset2.cache_dir) if __name__ == "__main__": unittest.main() From a70e74b958f3eb286ce24d88cab1147502b8e336 Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Sun, 15 Feb 2026 19:55:48 +0000 Subject: [PATCH 05/16] Remove unnecessary f-string --- pyhealth/datasets/base_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index e7066e264..aa260f303 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -518,7 +518,7 @@ def global_event_df(self) -> pl.LazyFrame: self._main_guard(type(self).global_event_df.fget.__name__) # type: ignore if self._global_event_df is None: - ret_path = self.cache_dir / f"global_event_df.parquet" + ret_path = self.cache_dir / "global_event_df.parquet" if not ret_path.exists(): logger.info(f"No cached event dataframe found. Creating: {ret_path}") self._event_transform(ret_path) From 7b4d113198b74591b1de981d233855c56eb609e9 Mon Sep 17 00:00:00 2001 From: John Wu Date: Sun, 15 Feb 2026 20:44:41 -0600 Subject: [PATCH 06/16] commit to remove cache_dir from task, will update after seeing if test cases pass --- examples/concare_mimic4_example.ipynb | 2 +- examples/cxr/covid19cxr_conformal.py | 2 +- examples/cxr/covid19cxr_tutorial.ipynb | 2 +- examples/cxr/covid19cxr_tutorial.py | 2 +- examples/cxr/covid19cxr_tutorial_display.py | 2 +- .../mimic3_mortality_prediction_cached.ipynb | 2 +- .../multimodal_mimic4_demo.py | 1 - .../multimodal_mimic4_minimal.py | 2 +- .../timeseries_mimic4.ipynb | 2 +- .../mortality_prediction/timeseries_mimic4.py | 2 +- pyhealth/datasets/base_dataset.py | 12 ++------- tests/core/test_caching.py | 26 ++++++++++++------- 12 files changed, 28 insertions(+), 29 deletions(-) diff --git a/examples/concare_mimic4_example.ipynb b/examples/concare_mimic4_example.ipynb index 188858ecb..23e8c4976 100644 --- a/examples/concare_mimic4_example.ipynb +++ b/examples/concare_mimic4_example.ipynb @@ -340,7 +340,7 @@ ")\n", "\n", "task = InHospitalMortalityMIMIC3()\n", - "samples = dataset.set_task(task, num_workers=10, cache_dir=\"./cache_concare_mortality_m3\")\n", + "samples = dataset.set_task(task, num_workers=10)\n", "\n", "# The rest of the code remains the same\n", "\"\"\"" diff --git a/examples/cxr/covid19cxr_conformal.py b/examples/cxr/covid19cxr_conformal.py index 4adb2dd8d..1d1bf8508 100644 --- a/examples/cxr/covid19cxr_conformal.py +++ b/examples/cxr/covid19cxr_conformal.py @@ -35,7 +35,7 @@ root = "/home/johnwu3/projects/PyHealth_Branch_Testing/datasets/COVID-19_Radiography_Dataset" base_dataset = COVID19CXRDataset(root) -sample_dataset = base_dataset.set_task(cache_dir="../../covid19cxr_cache") +sample_dataset = base_dataset.set_task() print(f"Total samples: {len(sample_dataset)}") print(f"Task mode: {sample_dataset.output_schema}") diff --git a/examples/cxr/covid19cxr_tutorial.ipynb b/examples/cxr/covid19cxr_tutorial.ipynb index 4eaf64aeb..2a5e4624e 100644 --- a/examples/cxr/covid19cxr_tutorial.ipynb +++ b/examples/cxr/covid19cxr_tutorial.ipynb @@ -146,7 +146,7 @@ "model_checkpoint = \"/home/johnwu3/projects/covid19cxr_vit_model.ckpt\"\n", "\n", "base_dataset = COVID19CXRDataset(root, cache_dir=base_cache, num_workers=4)\n", - "sample_dataset = base_dataset.set_task(cache_dir=task_cache, num_workers=4)\n", + "sample_dataset = base_dataset.set_task(num_workers=4)\n", "\n", "print(f\"Total samples: {len(sample_dataset)}\")\n", "print(f\"Task mode: {sample_dataset.output_schema}\")\n", diff --git a/examples/cxr/covid19cxr_tutorial.py b/examples/cxr/covid19cxr_tutorial.py index f0b11457b..39b4cae40 100644 --- a/examples/cxr/covid19cxr_tutorial.py +++ b/examples/cxr/covid19cxr_tutorial.py @@ -41,7 +41,7 @@ # Load dataset and create train/val/calibration/test splits dataset = COVID19CXRDataset(ROOT, cache_dir=CACHE, num_workers=8) - sample_dataset = dataset.set_task(cache_dir=TASK_CACHE, num_workers=8) + sample_dataset = dataset.set_task(num_workers=8) train_data, val_data, cal_data, test_data = split_by_sample_conformal( sample_dataset, ratios=[0.6, 0.1, 0.15, 0.15] ) diff --git a/examples/cxr/covid19cxr_tutorial_display.py b/examples/cxr/covid19cxr_tutorial_display.py index 093ff286f..cb81678de 100644 --- a/examples/cxr/covid19cxr_tutorial_display.py +++ b/examples/cxr/covid19cxr_tutorial_display.py @@ -41,7 +41,7 @@ # Load dataset and create train/val/calibration/test splits dataset = COVID19CXRDataset(ROOT, cache_dir=CACHE, num_workers=8) - sample_dataset = dataset.set_task(cache_dir=TASK_CACHE, num_workers=8) + sample_dataset = dataset.set_task(num_workers=8) train_data, val_data, cal_data, test_data = split_by_sample_conformal( sample_dataset, ratios=[0.6, 0.1, 0.15, 0.15] ) diff --git a/examples/mortality_prediction/mimic3_mortality_prediction_cached.ipynb b/examples/mortality_prediction/mimic3_mortality_prediction_cached.ipynb index 8f908bf56..102a87df1 100644 --- a/examples/mortality_prediction/mimic3_mortality_prediction_cached.ipynb +++ b/examples/mortality_prediction/mimic3_mortality_prediction_cached.ipynb @@ -247,7 +247,7 @@ "from pyhealth.tasks.mortality_prediction import MortalityPredictionMIMIC3\n", "from pyhealth.datasets import split_by_patient, get_dataloader\n", "mimic3_mortality_prediction = MortalityPredictionMIMIC3Heterogeneous()\n", - "samples = dataset.set_task(mimic3_mortality_prediction, num_workers=1, cache_dir=\"cache/\") # use default task" + "samples = dataset.set_task(mimic3_mortality_prediction, num_workers=1) # use default task" ] }, { diff --git a/examples/mortality_prediction/multimodal_mimic4_demo.py b/examples/mortality_prediction/multimodal_mimic4_demo.py index cec5d610f..c84a71c37 100644 --- a/examples/mortality_prediction/multimodal_mimic4_demo.py +++ b/examples/mortality_prediction/multimodal_mimic4_demo.py @@ -946,7 +946,6 @@ def main(): sample_dataset = base_dataset.set_task( task, num_workers=num_workers, - cache_dir=str(task_cache_dir), ) task_process_s = time.time() - task_start diff --git a/examples/mortality_prediction/multimodal_mimic4_minimal.py b/examples/mortality_prediction/multimodal_mimic4_minimal.py index 7fef80819..093e7764b 100644 --- a/examples/mortality_prediction/multimodal_mimic4_minimal.py +++ b/examples/mortality_prediction/multimodal_mimic4_minimal.py @@ -26,7 +26,7 @@ # Apply multimodal task task = MultimodalMortalityPredictionMIMIC4() - samples = dataset.set_task(task, cache_dir=f"{CACHE_DIR}/task", num_workers=8) + samples = dataset.set_task(task, num_workers=8) # Get and print sample sample = samples[0] diff --git a/examples/mortality_prediction/timeseries_mimic4.ipynb b/examples/mortality_prediction/timeseries_mimic4.ipynb index 3700a33b1..f976bab1b 100644 --- a/examples/mortality_prediction/timeseries_mimic4.ipynb +++ b/examples/mortality_prediction/timeseries_mimic4.ipynb @@ -187,7 +187,7 @@ " from pyhealth.tasks import InHospitalMortalityMIMIC4\n", "\n", " task = InHospitalMortalityMIMIC4()\n", - " samples = dataset.set_task(task, num_workers=4, cache_dir=\"../benchmark_cache/mimic4_ihm_w_pre2/\")\n", + " samples = dataset.set_task(task, num_workers=4)\n", "\n", " from pyhealth.datasets import split_by_sample\n", "\n", diff --git a/examples/mortality_prediction/timeseries_mimic4.py b/examples/mortality_prediction/timeseries_mimic4.py index 85df3c884..f3c675eec 100644 --- a/examples/mortality_prediction/timeseries_mimic4.py +++ b/examples/mortality_prediction/timeseries_mimic4.py @@ -10,7 +10,7 @@ from pyhealth.tasks import InHospitalMortalityMIMIC4 task = InHospitalMortalityMIMIC4() - samples = dataset.set_task(task, num_workers=2, cache_dir="../benchmark_cache/mimic4_ihm/") + samples = dataset.set_task(task, num_workers=2) from pyhealth.datasets import split_by_sample diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index aa260f303..8056e0c40 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -836,7 +836,6 @@ def set_task( self, task: Optional[BaseTask] = None, num_workers: Optional[int] = None, - cache_dir: str | Path | None = None, cache_format: str = "parquet", input_processors: Optional[Dict[str, FeatureProcessor]] = None, output_processors: Optional[Dict[str, FeatureProcessor]] = None, @@ -853,8 +852,6 @@ def set_task( Args: task (Optional[BaseTask]): The task to set. Uses default task if None. num_workers (int): Number of workers for multi-threading. Default is `self.num_workers`. - cache_dir (Optional[str]): Directory to cache samples after task transformation, - but without applying processors. Default is {self.cache_dir}/tasks. cache_format (str): Deprecated. Only "parquet" is supported now. input_processors (Optional[Dict[str, FeatureProcessor]]): Pre-fitted input processors. If provided, these will be used @@ -895,13 +892,8 @@ def set_task( default=str ) - if cache_dir is None: - cache_dir = self.cache_dir / "tasks" / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}" - cache_dir.mkdir(parents=True, exist_ok=True) - else: - # Ensure the explicitly provided cache_dir exists - cache_dir = Path(cache_dir) - cache_dir.mkdir(parents=True, exist_ok=True) + cache_dir = self.cache_dir / "tasks" / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}" + cache_dir.mkdir(parents=True, exist_ok=True) proc_params = json.dumps( { diff --git a/tests/core/test_caching.py b/tests/core/test_caching.py index 0a6ac7dfe..269929a69 100644 --- a/tests/core/test_caching.py +++ b/tests/core/test_caching.py @@ -133,8 +133,7 @@ def test_set_task_signature(self): "self", "task", "num_workers", - "cache_dir", - "cache_format", +cd "cache_format", "input_processors", "output_processors", ] @@ -143,7 +142,6 @@ def test_set_task_signature(self): # Check default values self.assertEqual(sig.parameters["task"].default, None) self.assertEqual(sig.parameters["num_workers"].default, None) - self.assertEqual(sig.parameters["cache_dir"].default, None) self.assertEqual(sig.parameters["cache_format"].default, "parquet") self.assertEqual(sig.parameters["input_processors"].default, None) self.assertEqual(sig.parameters["output_processors"].default, None) @@ -151,7 +149,7 @@ def test_set_task_signature(self): def test_set_task_writes_cache_and_metadata(self): """Ensure set_task materializes cache files and schema metadata.""" with self.dataset.set_task( - self.task, cache_dir=self.cache_dir, cache_format="parquet" + self.task, cache_format="parquet" ) as sample_dataset: self.assertIsInstance(sample_dataset, SampleDataset) self.assertEqual(sample_dataset.dataset_name, "TestDataset") @@ -159,8 +157,18 @@ def test_set_task_writes_cache_and_metadata(self): self.assertEqual(len(sample_dataset), 4) self.assertEqual(self.task.call_count, 2) - # Ensure intermediate cache files are created - self.assertTrue((self.cache_dir / "task_df.ld" / "index.json").exists()) + # Ensure intermediate cache files are created in default location + task_params = json.dumps( + { + **vars(self.task), + "input_schema": self.task.input_schema, + "output_schema": self.task.output_schema, + }, + sort_keys=True, + default=str + ) + task_cache_dir = self.dataset.cache_dir / "tasks" / f"{self.task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}" + self.assertTrue((task_cache_dir / "task_df.ld" / "index.json").exists()) # Cache artifacts should be present for StreamingDataset assert sample_dataset.input_dir.path is not None @@ -185,7 +193,7 @@ def test_set_task_writes_cache_and_metadata(self): self.assertFalse((sample_dir / "index.json").exists()) self.assertFalse((sample_dir / "schema.pkl").exists()) # Ensure intermediate cache files are still present - self.assertTrue((self.cache_dir / "task_df.ld" / "index.json").exists()) + self.assertTrue((task_cache_dir / "task_df.ld" / "index.json").exists()) def test_default_cache_dir_is_used(self): @@ -208,14 +216,14 @@ def test_default_cache_dir_is_used(self): def test_reuses_existing_cache_without_regeneration(self): """Second call should reuse cached samples instead of recomputing.""" - sample_dataset = self.dataset.set_task(self.task, cache_dir=self.cache_dir) + sample_dataset = self.dataset.set_task(self.task) self.assertEqual(self.task.call_count, 2) with patch.object( self.task, "__call__", side_effect=AssertionError("Task should not rerun") ): cached_dataset = self.dataset.set_task( - self.task, cache_dir=self.cache_dir, cache_format="parquet" + self.task, cache_format="parquet" ) self.assertEqual(len(cached_dataset), 4) From 86115ccbfaf5c331ce09e0de0d50ef81fedc8017 Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Mon, 16 Feb 2026 02:53:30 +0000 Subject: [PATCH 07/16] Add proc UUID to samples cache dir docs --- pyhealth/datasets/base_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 8056e0c40..9e1c45e31 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -845,7 +845,7 @@ def set_task( {task_name}_{task_uuid}/ # Cached data for specific task based on task name, schema, and args task_df.ld/ # Intermediate task dataframe based on schema - samples.ld/ # Final processed samples after applying processors + samples_{proc_uuid}.ld/ # Final processed samples after applying processors schema.pkl # Saved SampleBuilder schema *.bin # Processed sample files From 2be1c8057809601a7b95940e18cb08e014833b57 Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Mon, 16 Feb 2026 02:54:18 +0000 Subject: [PATCH 08/16] Remove the cache_format param from set_task --- pyhealth/datasets/base_dataset.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 9e1c45e31..9da8bedd9 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -836,7 +836,6 @@ def set_task( self, task: Optional[BaseTask] = None, num_workers: Optional[int] = None, - cache_format: str = "parquet", input_processors: Optional[Dict[str, FeatureProcessor]] = None, output_processors: Optional[Dict[str, FeatureProcessor]] = None, ) -> SampleDataset: @@ -852,7 +851,6 @@ def set_task( Args: task (Optional[BaseTask]): The task to set. Uses default task if None. num_workers (int): Number of workers for multi-threading. Default is `self.num_workers`. - cache_format (str): Deprecated. Only "parquet" is supported now. input_processors (Optional[Dict[str, FeatureProcessor]]): Pre-fitted input processors. If provided, these will be used instead of creating new ones from task's input_schema. Defaults to None. @@ -875,9 +873,6 @@ def set_task( if num_workers is None: num_workers = self.num_workers - if cache_format != "parquet": - logger.warning("Only 'parquet' cache_format is supported now. ") - logger.info( f"Setting task {task.task_name} for {self.dataset_name} base dataset..." ) From d0c1a152a0e7c35bbf5d1b556102b79f5bc64e9f Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Mon, 16 Feb 2026 02:56:09 +0000 Subject: [PATCH 09/16] Remove task schema from the proc UUID (since it is now included in the task UUID) --- pyhealth/datasets/base_dataset.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 9da8bedd9..d11597a60 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -892,8 +892,6 @@ def set_task( proc_params = json.dumps( { - "input_schema": task.input_schema, - "output_schema": task.output_schema, "input_processors": ( { f"{k}_{v.__class__.__name__}": vars(v) From b06637662f7c027dbb6069dca50818543b8444dd Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Mon, 16 Feb 2026 03:00:42 +0000 Subject: [PATCH 10/16] Add proc UUID to the _init_cache_dir docs --- pyhealth/datasets/base_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index d11597a60..dcd194bc4 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -356,7 +356,7 @@ def _init_cache_dir(self, cache_dir: str | Path | None) -> Path: tasks/ # Cached task-specific data {task_name}_{task_uuid}/ # Cached data for specific task based on task name, schema, and args task_df.ld/ # Intermediate task dataframe based on schema - samples.ld/ # Final processed samples after applying processors + samples_{proc_uuid}.ld/ # Final processed samples after applying processors Returns: Path: The resolved cache directory path. From 840d3af4390ea68088fdddece49e9111f838b5a7 Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Mon, 16 Feb 2026 03:14:28 +0000 Subject: [PATCH 11/16] Remove cache_format param from unit tests --- tests/core/test_caching.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/core/test_caching.py b/tests/core/test_caching.py index 269929a69..299c2f430 100644 --- a/tests/core/test_caching.py +++ b/tests/core/test_caching.py @@ -133,7 +133,6 @@ def test_set_task_signature(self): "self", "task", "num_workers", -cd "cache_format", "input_processors", "output_processors", ] @@ -142,15 +141,12 @@ def test_set_task_signature(self): # Check default values self.assertEqual(sig.parameters["task"].default, None) self.assertEqual(sig.parameters["num_workers"].default, None) - self.assertEqual(sig.parameters["cache_format"].default, "parquet") self.assertEqual(sig.parameters["input_processors"].default, None) self.assertEqual(sig.parameters["output_processors"].default, None) def test_set_task_writes_cache_and_metadata(self): """Ensure set_task materializes cache files and schema metadata.""" - with self.dataset.set_task( - self.task, cache_format="parquet" - ) as sample_dataset: + with self.dataset.set_task(self.task) as sample_dataset: self.assertIsInstance(sample_dataset, SampleDataset) self.assertEqual(sample_dataset.dataset_name, "TestDataset") self.assertEqual(sample_dataset.task_name, self.task.task_name) @@ -222,9 +218,7 @@ def test_reuses_existing_cache_without_regeneration(self): with patch.object( self.task, "__call__", side_effect=AssertionError("Task should not rerun") ): - cached_dataset = self.dataset.set_task( - self.task, cache_format="parquet" - ) + cached_dataset = self.dataset.set_task(self.task) self.assertEqual(len(cached_dataset), 4) self.assertEqual(self.task.call_count, 2) From 375929e3738fa3463b351a8aada7820751bc26b0 Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Mon, 16 Feb 2026 03:31:39 +0000 Subject: [PATCH 12/16] Remove set_task(cache_dir) from more examples --- examples/clinical_tasks/dka_mimic4.py | 55 +++--- .../clinical_tasks/dka_mimic4_stageattn.py | 3 - .../clinical_tasks/dka_mimic4_stagenet.py | 3 - .../clinical_tasks/dka_mimic4_transformer.py | 3 - .../clinical_tasks/t1d_mimic4_stageattn.py | 3 - .../clinical_tasks/t1d_mimic4_stagenet.py | 3 - examples/clinical_tasks/t1dka_mimic4.py | 55 +++--- examples/concare_mimic4_example.ipynb | 15 +- .../tuev_conventional_conformal.py | 4 +- .../tuev_covariate_shift_conformal.py | 4 +- .../conformal_eeg/tuev_kmeans_conformal.py | 2 +- examples/conformal_eeg/tuev_ncp_conformal.py | 2 +- examples/cxr/covid19cxr_tutorial.ipynb | 5 +- examples/cxr/covid19cxr_tutorial.py | 1 - examples/cxr/covid19cxr_tutorial_display.py | 1 - .../drug_recommendation_mimic4_retain.py | 1 - .../interpretability/gim_stagenet_mimic4.py | 1 - .../gim_transformer_mimic4.py | 1 - ...integrated_gradients_benchmark_stagenet.py | 2 - ...ted_gradients_mortality_mimic4_stagenet.py | 1 - .../interpretability_metrics.py | 1 - .../shap_stagenet_mimic4.ipynb | 169 +++++++++--------- .../interpretability/shap_stagenet_mimic4.py | 7 +- .../length_of_stay_mimic4_ehrmamba.py | 7 +- .../length_of_stay_mimic4_stageattn.py | 3 - .../length_of_stay_mimic4_stagenet.py | 3 - .../length_of_stay_mimic4_transformer.py | 3 - examples/lime_stagenet_mimic4.py | 19 +- .../ehrmamba_mimic4_full.py | 7 +- .../mortality_mimic4_stageattn.py | 3 - .../mortality_mimic4_stagenet_v2.py | 3 - .../mortality_mimic4_transformer.py | 3 - examples/transformer_mimic4.ipynb | 3 +- .../tutorial_stagenet_comprehensive.ipynb | 7 +- 34 files changed, 168 insertions(+), 235 deletions(-) diff --git a/examples/clinical_tasks/dka_mimic4.py b/examples/clinical_tasks/dka_mimic4.py index e83f0dbbb..d33a7c0ac 100644 --- a/examples/clinical_tasks/dka_mimic4.py +++ b/examples/clinical_tasks/dka_mimic4.py @@ -31,18 +31,17 @@ def main(): """Main function to run DKA prediction pipeline on general population.""" - + # Configuration MIMIC4_ROOT = "/srv/local/data/physionet.org/files/mimiciv/2.2/" DATASET_CACHE_DIR = "/shared/rsaas/pyhealth/cache/mimic4_dataset" - TASK_CACHE_DIR = "/shared/rsaas/pyhealth/cache/mimic4_dka_general_stagenet" PROCESSOR_DIR = "/shared/rsaas/pyhealth/processors/stagenet_dka_general_mimic4" DEVICE = "cuda:5" if torch.cuda.is_available() else "cpu" - + print("=" * 60) print("DKA PREDICTION (GENERAL POPULATION) WITH STAGENET ON MIMIC-IV") print("=" * 60) - + # STEP 1: Load MIMIC-IV base dataset print("\n=== Step 1: Loading MIMIC-IV Dataset ===") base_dataset = MIMIC4Dataset( @@ -56,30 +55,29 @@ def main(): cache_dir=DATASET_CACHE_DIR, # dev=True, # Uncomment for faster development iteration ) - + print("Dataset initialized, proceeding to task processing...") - + # STEP 2: Apply DKA prediction task (general population) print("\n=== Step 2: Applying DKA Prediction Task (General Population) ===") - + # Create task with padding for unseen sequences # No T1DM filtering - includes ALL patients dka_task = DKAPredictionMIMIC4(padding=10) - + print(f"Task: {dka_task.task_name}") print(f"Input schema: {list(dka_task.input_schema.keys())}") print(f"Output schema: {list(dka_task.output_schema.keys())}") print("Note: This includes ALL patients (not just diabetics)") - + # Check for pre-fitted processors if os.path.exists(os.path.join(PROCESSOR_DIR, "input_processors.pkl")): print("\nLoading pre-fitted processors...") input_processors, output_processors = load_processors(PROCESSOR_DIR) - + sample_dataset = base_dataset.set_task( dka_task, num_workers=4, - cache_dir=TASK_CACHE_DIR, input_processors=input_processors, output_processors=output_processors, ) @@ -88,25 +86,24 @@ def main(): sample_dataset = base_dataset.set_task( dka_task, num_workers=4, - cache_dir=TASK_CACHE_DIR, ) - + # Save processors for future runs print("Saving processors...") os.makedirs(PROCESSOR_DIR, exist_ok=True) save_processors(sample_dataset, PROCESSOR_DIR) - + print(f"\nTotal samples: {len(sample_dataset)}") - + # Count label distribution label_counts = {0: 0, 1: 0} for sample in sample_dataset: label_counts[int(sample["label"].item())] += 1 - + print(f"Label distribution:") print(f" No DKA (0): {label_counts[0]} ({100*label_counts[0]/len(sample_dataset):.1f}%)") print(f" Has DKA (1): {label_counts[1]} ({100*label_counts[1]/len(sample_dataset):.1f}%)") - + # Inspect a sample sample = sample_dataset[0] print("\nSample structure:") @@ -114,22 +111,22 @@ def main(): print(f" ICD codes (diagnoses + procedures): {sample['icd_codes'][1].shape} (visits x codes)") print(f" Labs: {sample['labs'][0].shape} (timesteps x features)") print(f" Label: {sample['label']}") - + # STEP 3: Split dataset print("\n=== Step 3: Splitting Dataset ===") train_dataset, val_dataset, test_dataset = split_by_patient( sample_dataset, [0.8, 0.1, 0.1] ) - + print(f"Train: {len(train_dataset)} samples") print(f"Validation: {len(val_dataset)} samples") print(f"Test: {len(test_dataset)} samples") - + # Create dataloaders train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True) val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False) test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False) - + # STEP 4: Initialize StageNet model print("\n=== Step 4: Initializing StageNet Model ===") model = StageNet( @@ -139,10 +136,10 @@ def main(): levels=3, dropout=0.3, ) - + num_params = sum(p.numel() for p in model.parameters()) print(f"Model parameters: {num_params:,}") - + # STEP 5: Train the model print("\n=== Step 5: Training Model ===") trainer = Trainer( @@ -150,7 +147,7 @@ def main(): device=DEVICE, metrics=["pr_auc", "roc_auc", "accuracy", "f1"], ) - + trainer.train( train_dataloader=train_loader, val_dataloader=val_loader, @@ -158,27 +155,27 @@ def main(): monitor="roc_auc", optimizer_params={"lr": 1e-5}, ) - + # STEP 6: Evaluate on test set print("\n=== Step 6: Evaluation ===") results = trainer.evaluate(test_loader) print("\nTest Results:") for metric, value in results.items(): print(f" {metric}: {value:.4f}") - + # STEP 7: Inspect model predictions print("\n=== Step 7: Sample Predictions ===") sample_batch = next(iter(test_loader)) with torch.no_grad(): output = model(**sample_batch) - + print(f"Predicted probabilities: {output['y_prob'][:5]}") print(f"True labels: {output['y_true'][:5]}") - + print("\n" + "=" * 60) print("DKA PREDICTION (GENERAL POPULATION) TRAINING COMPLETED!") print("=" * 60) - + return results diff --git a/examples/clinical_tasks/dka_mimic4_stageattn.py b/examples/clinical_tasks/dka_mimic4_stageattn.py index 852b7afce..2f2f5a5a8 100644 --- a/examples/clinical_tasks/dka_mimic4_stageattn.py +++ b/examples/clinical_tasks/dka_mimic4_stageattn.py @@ -43,7 +43,6 @@ # - This ensures consistent encoding and saves computation time # - Processors include vocabulary mappings and sequence length statistics processor_dir = "/home/yongdaf2/dka_sa/processors" - cache_dir = "/home/yongdaf2/dka_sa/cache" if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")): print("\n=== Loading Pre-fitted Processors ===") @@ -52,7 +51,6 @@ sample_dataset = base_dataset.set_task( DKAPredictionMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors, ) @@ -61,7 +59,6 @@ sample_dataset = base_dataset.set_task( DKAPredictionMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, ) # Save processors for future runs diff --git a/examples/clinical_tasks/dka_mimic4_stagenet.py b/examples/clinical_tasks/dka_mimic4_stagenet.py index 7ba68b353..40a30d03f 100644 --- a/examples/clinical_tasks/dka_mimic4_stagenet.py +++ b/examples/clinical_tasks/dka_mimic4_stagenet.py @@ -43,7 +43,6 @@ # - This ensures consistent encoding and saves computation time # - Processors include vocabulary mappings and sequence length statistics processor_dir = "/home/yongdaf2/dka_sn/processors" - cache_dir = "/home/yongdaf2/dka_sn/cache" if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")): print("\n=== Loading Pre-fitted Processors ===") @@ -52,7 +51,6 @@ sample_dataset = base_dataset.set_task( DKAPredictionMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors, ) @@ -61,7 +59,6 @@ sample_dataset = base_dataset.set_task( DKAPredictionMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, ) # Save processors for future runs diff --git a/examples/clinical_tasks/dka_mimic4_transformer.py b/examples/clinical_tasks/dka_mimic4_transformer.py index 65056ef3c..c17af6376 100644 --- a/examples/clinical_tasks/dka_mimic4_transformer.py +++ b/examples/clinical_tasks/dka_mimic4_transformer.py @@ -43,7 +43,6 @@ # - This ensures consistent encoding and saves computation time # - Processors include vocabulary mappings and sequence length statistics processor_dir = "/home/yongdaf2/dka_tf/processors" - cache_dir = "/home/yongdaf2/dka_tf/cache" if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")): print("\n=== Loading Pre-fitted Processors ===") @@ -52,7 +51,6 @@ sample_dataset = base_dataset.set_task( DKAPredictionMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors, ) @@ -61,7 +59,6 @@ sample_dataset = base_dataset.set_task( DKAPredictionMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, ) # Save processors for future runs diff --git a/examples/clinical_tasks/t1d_mimic4_stageattn.py b/examples/clinical_tasks/t1d_mimic4_stageattn.py index 21682d95f..d53b17a85 100644 --- a/examples/clinical_tasks/t1d_mimic4_stageattn.py +++ b/examples/clinical_tasks/t1d_mimic4_stageattn.py @@ -43,7 +43,6 @@ # - This ensures consistent encoding and saves computation time # - Processors include vocabulary mappings and sequence length statistics processor_dir = "/home/yongdaf2/t1d_sa/processors" - cache_dir = "/home/yongdaf2/t1d_sa/cache" if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")): print("\n=== Loading Pre-fitted Processors ===") @@ -52,7 +51,6 @@ sample_dataset = base_dataset.set_task( T1DDKAPredictionMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors, ) @@ -61,7 +59,6 @@ sample_dataset = base_dataset.set_task( T1DDKAPredictionMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, ) # Save processors for future runs diff --git a/examples/clinical_tasks/t1d_mimic4_stagenet.py b/examples/clinical_tasks/t1d_mimic4_stagenet.py index 6e3f3724a..3364b59b7 100644 --- a/examples/clinical_tasks/t1d_mimic4_stagenet.py +++ b/examples/clinical_tasks/t1d_mimic4_stagenet.py @@ -43,7 +43,6 @@ # - This ensures consistent encoding and saves computation time # - Processors include vocabulary mappings and sequence length statistics processor_dir = "/home/yongdaf2/t1d_sn/processors" - cache_dir = "/home/yongdaf2/t1d_sn/cache" if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")): print("\n=== Loading Pre-fitted Processors ===") @@ -52,7 +51,6 @@ sample_dataset = base_dataset.set_task( T1DDKAPredictionMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors, ) @@ -61,7 +59,6 @@ sample_dataset = base_dataset.set_task( T1DDKAPredictionMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, ) # Save processors for future runs diff --git a/examples/clinical_tasks/t1dka_mimic4.py b/examples/clinical_tasks/t1dka_mimic4.py index 9d3705650..8bc56fc27 100644 --- a/examples/clinical_tasks/t1dka_mimic4.py +++ b/examples/clinical_tasks/t1dka_mimic4.py @@ -29,18 +29,17 @@ def main(): """Main function to run T1D DKA prediction pipeline.""" - + # Configuration MIMIC4_ROOT = "/srv/local/data/physionet.org/files/mimiciv/2.2/" DATASET_CACHE_DIR = "/shared/rsaas/pyhealth/cache/mimic4_dataset" - TASK_CACHE_DIR = "/shared/rsaas/pyhealth/cache/mimic4_t1d_dka_stagenet_v2" PROCESSOR_DIR = "/shared/rsaas/pyhealth/processors/stagenet_t1d_dka_mimic4_v2" DEVICE = "cuda:5" if torch.cuda.is_available() else "cpu" - + print("=" * 60) print("T1D DKA PREDICTION WITH STAGENET ON MIMIC-IV") print("=" * 60) - + # STEP 1: Load MIMIC-IV base dataset print("\n=== Step 1: Loading MIMIC-IV Dataset ===") base_dataset = MIMIC4Dataset( @@ -54,29 +53,28 @@ def main(): cache_dir=DATASET_CACHE_DIR, # dev=True, # Uncomment for faster development iteration ) - + print("Dataset initialized, proceeding to task processing...") - + # STEP 2: Apply T1D DKA prediction task print("\n=== Step 2: Applying T1D DKA Prediction Task ===") - + # Create task with 90-day DKA window and padding for unseen sequences dka_task = T1DDKAPredictionMIMIC4(dka_window_days=90, padding=20) - + print(f"Task: {dka_task.task_name}") print(f"DKA window: {dka_task.dka_window_days} days") print(f"Input schema: {list(dka_task.input_schema.keys())}") print(f"Output schema: {list(dka_task.output_schema.keys())}") - + # Check for pre-fitted processors if os.path.exists(os.path.join(PROCESSOR_DIR, "input_processors.pkl")): print("\nLoading pre-fitted processors...") input_processors, output_processors = load_processors(PROCESSOR_DIR) - + sample_dataset = base_dataset.set_task( dka_task, num_workers=4, - cache_dir=TASK_CACHE_DIR, input_processors=input_processors, output_processors=output_processors, ) @@ -85,25 +83,24 @@ def main(): sample_dataset = base_dataset.set_task( dka_task, num_workers=4, - cache_dir=TASK_CACHE_DIR, ) - + # Save processors for future runs print("Saving processors...") os.makedirs(PROCESSOR_DIR, exist_ok=True) save_processors(sample_dataset, PROCESSOR_DIR) - + print(f"\nTotal samples: {len(sample_dataset)}") - + # Count label distribution label_counts = {0: 0, 1: 0} for sample in sample_dataset: label_counts[int(sample["label"].item())] += 1 - + print(f"Label distribution:") print(f" No DKA (0): {label_counts[0]} ({100*label_counts[0]/len(sample_dataset):.1f}%)") print(f" Has DKA (1): {label_counts[1]} ({100*label_counts[1]/len(sample_dataset):.1f}%)") - + # Inspect a sample sample = sample_dataset[0] print("\nSample structure:") @@ -111,22 +108,22 @@ def main(): print(f" ICD codes (diagnoses + procedures): {sample['icd_codes'][1].shape} (visits x codes)") print(f" Labs: {sample['labs'][0].shape} (timesteps x features)") print(f" Label: {sample['label']}") - + # STEP 3: Split dataset print("\n=== Step 3: Splitting Dataset ===") train_dataset, val_dataset, test_dataset = split_by_patient( sample_dataset, [0.8, 0.1, 0.1] ) - + print(f"Train: {len(train_dataset)} samples") print(f"Validation: {len(val_dataset)} samples") print(f"Test: {len(test_dataset)} samples") - + # Create dataloaders train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True) val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False) test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False) - + # STEP 4: Initialize StageNet model print("\n=== Step 4: Initializing StageNet Model ===") model = StageNet( @@ -136,10 +133,10 @@ def main(): levels=3, dropout=0.3, ) - + num_params = sum(p.numel() for p in model.parameters()) print(f"Model parameters: {num_params:,}") - + # STEP 5: Train the model print("\n=== Step 5: Training Model ===") trainer = Trainer( @@ -147,7 +144,7 @@ def main(): device=DEVICE, metrics=["pr_auc", "roc_auc", "accuracy", "f1"], ) - + trainer.train( train_dataloader=train_loader, val_dataloader=val_loader, @@ -155,27 +152,27 @@ def main(): monitor="roc_auc", optimizer_params={"lr": 1e-5}, ) - + # STEP 6: Evaluate on test set print("\n=== Step 6: Evaluation ===") results = trainer.evaluate(test_loader) print("\nTest Results:") for metric, value in results.items(): print(f" {metric}: {value:.4f}") - + # STEP 7: Inspect model predictions print("\n=== Step 7: Sample Predictions ===") sample_batch = next(iter(test_loader)) with torch.no_grad(): output = model(**sample_batch) - + print(f"Predicted probabilities: {output['y_prob'][:5]}") print(f"True labels: {output['y_true'][:5]}") - + print("\n" + "=" * 60) print("T1D DKA PREDICTION TRAINING COMPLETED!") print("=" * 60) - + return results diff --git a/examples/concare_mimic4_example.ipynb b/examples/concare_mimic4_example.ipynb index 23e8c4976..5d4afef0a 100644 --- a/examples/concare_mimic4_example.ipynb +++ b/examples/concare_mimic4_example.ipynb @@ -14,12 +14,12 @@ "This example demonstrates how to train the ConCare model for in-hospital mortality\n", "prediction using the MIMIC-IV dataset.\n", "\n", - "ConCare (Concare: Personalized clinical feature embedding via capturing the \n", - "healthcare context) is a model that uses channel-wise GRUs and multi-head \n", + "ConCare (Concare: Personalized clinical feature embedding via capturing the\n", + "healthcare context) is a model that uses channel-wise GRUs and multi-head\n", "self-attention to capture feature correlations and temporal patterns in EHR data.\n", "\n", "Reference:\n", - " Liantao Ma et al. Concare: Personalized clinical feature embedding via \n", + " Liantao Ma et al. Concare: Personalized clinical feature embedding via\n", " capturing the healthcare context. AAAI 2020.\n", "\"\"\"" ] @@ -72,9 +72,8 @@ "\n", "# Apply task to dataset and create samples\n", "samples = dataset.set_task(\n", - " task, \n", - " num_workers=10, \n", - " cache_dir=\"./cache_concare_mortality_m4\"\n", + " task,\n", + " num_workers=10,\n", ")" ] }, @@ -117,7 +116,7 @@ "\n", "# Split dataset into train, validation, and test sets\n", "train_dataset, val_dataset, test_dataset = split_by_sample(\n", - " dataset=samples, \n", + " dataset=samples,\n", " ratios=[0.7, 0.1, 0.2]\n", ")\n", "\n", @@ -197,7 +196,7 @@ "\n", "# Initialize trainer with ROC-AUC metric\n", "trainer = Trainer(\n", - " model=model, \n", + " model=model,\n", " metrics=[\"roc_auc\", \"pr_auc\", \"accuracy\"]\n", ")\n", "\n", diff --git a/examples/conformal_eeg/tuev_conventional_conformal.py b/examples/conformal_eeg/tuev_conventional_conformal.py index 3f067b2e0..b62dc73d1 100644 --- a/examples/conformal_eeg/tuev_conventional_conformal.py +++ b/examples/conformal_eeg/tuev_conventional_conformal.py @@ -37,7 +37,7 @@ def parse_args() -> argparse.Namespace: default="downloads/tuev/v2.0.1/edf", help="Path to TUEV edf/ folder.", ) - parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"]) + parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"]) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--epochs", type=int, default=2) @@ -84,7 +84,7 @@ def main() -> None: print("STEP 1: Load TUEV + build task dataset") print("=" * 80) dataset = TUEVDataset(root=str(root), subset=args.subset) - sample_dataset = dataset.set_task(EEGEventsTUEV(), cache_dir="examples/conformal_eeg/cache") + sample_dataset = dataset.set_task(EEGEventsTUEV()) print(f"Task samples: {len(sample_dataset)}") print(f"Input schema: {sample_dataset.input_schema}") diff --git a/examples/conformal_eeg/tuev_covariate_shift_conformal.py b/examples/conformal_eeg/tuev_covariate_shift_conformal.py index 1afcdf08a..41a356a79 100644 --- a/examples/conformal_eeg/tuev_covariate_shift_conformal.py +++ b/examples/conformal_eeg/tuev_covariate_shift_conformal.py @@ -42,7 +42,7 @@ def parse_args() -> argparse.Namespace: default="downloads/tuev/v2.0.1/edf", help="Path to TUEV edf/ folder.", ) - parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"]) + parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"]) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--epochs", type=int, default=3) @@ -89,7 +89,7 @@ def main() -> None: print("STEP 1: Load TUEV + build task dataset") print("=" * 80) dataset = TUEVDataset(root=str(root), subset=args.subset) - sample_dataset = dataset.set_task(EEGEventsTUEV(), cache_dir="examples/conformal_eeg/cache") + sample_dataset = dataset.set_task(EEGEventsTUEV()) print(f"Task samples: {len(sample_dataset)}") print(f"Input schema: {sample_dataset.input_schema}") diff --git a/examples/conformal_eeg/tuev_kmeans_conformal.py b/examples/conformal_eeg/tuev_kmeans_conformal.py index b5a043dad..906883d72 100644 --- a/examples/conformal_eeg/tuev_kmeans_conformal.py +++ b/examples/conformal_eeg/tuev_kmeans_conformal.py @@ -149,7 +149,7 @@ def _run(args: argparse.Namespace) -> None: print("STEP 1: Load TUEV + build task dataset") print("=" * 80) dataset = TUEVDataset(root=str(root), subset=args.subset, dev=args.quick_test) - sample_dataset = dataset.set_task(EEGEventsTUEV(), cache_dir="examples/conformal_eeg/cache") + sample_dataset = dataset.set_task(EEGEventsTUEV()) if args.quick_test and len(sample_dataset) > quick_test_max_samples: sample_dataset = sample_dataset.subset(range(quick_test_max_samples)) print(f"Capped to {quick_test_max_samples} samples for quick-test.") diff --git a/examples/conformal_eeg/tuev_ncp_conformal.py b/examples/conformal_eeg/tuev_ncp_conformal.py index 9b51a7756..c5e207a6a 100644 --- a/examples/conformal_eeg/tuev_ncp_conformal.py +++ b/examples/conformal_eeg/tuev_ncp_conformal.py @@ -290,7 +290,7 @@ def _run(args: argparse.Namespace) -> None: print("STEP 1: Load TUEV + build task dataset") print("=" * 80) dataset = TUEVDataset(root=str(root), subset=args.subset, dev=args.quick_test) - sample_dataset = dataset.set_task(EEGEventsTUEV(), cache_dir="examples/conformal_eeg/cache") + sample_dataset = dataset.set_task(EEGEventsTUEV()) if args.quick_test and len(sample_dataset) > quick_test_max_samples: sample_dataset = sample_dataset.subset(range(quick_test_max_samples)) print(f"Capped to {quick_test_max_samples} samples for quick-test.") diff --git a/examples/cxr/covid19cxr_tutorial.ipynb b/examples/cxr/covid19cxr_tutorial.ipynb index 2a5e4624e..2a04844c5 100644 --- a/examples/cxr/covid19cxr_tutorial.ipynb +++ b/examples/cxr/covid19cxr_tutorial.ipynb @@ -142,7 +142,6 @@ "source": [ "root = \"/home/johnwu3/projects/PyHealth_Branch_Testing/datasets/COVID-19_Radiography_Dataset\"\n", "base_cache = \"/home/johnwu3/projects/covid19cxr_base_cache\"\n", - "task_cache = \"/home/johnwu3/projects/covid19cxr_task_cache\"\n", "model_checkpoint = \"/home/johnwu3/projects/covid19cxr_vit_model.ckpt\"\n", "\n", "base_dataset = COVID19CXRDataset(root, cache_dir=base_cache, num_workers=4)\n", @@ -1278,7 +1277,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "adfd29de", "metadata": {}, "outputs": [ @@ -1344,7 +1343,7 @@ " **batch\n", " )\n", " attr_map = result[\"image\"] # Keyed by task schema's feature key\n", - " \n", + "\n", " img_display, vit_attr_display, attention_overlay = visualize_image_attr(\n", " image=image[0],\n", " attribution=attr_map[0, 0],\n", diff --git a/examples/cxr/covid19cxr_tutorial.py b/examples/cxr/covid19cxr_tutorial.py index 39b4cae40..0f24f4b58 100644 --- a/examples/cxr/covid19cxr_tutorial.py +++ b/examples/cxr/covid19cxr_tutorial.py @@ -26,7 +26,6 @@ DATA_ROOT = "/home/johnwu3/projects/PyHealth_Branch_Testing/datasets" ROOT = f"{DATA_ROOT}/COVID-19_Radiography_Dataset" CACHE = "/home/johnwu3/projects/covid19cxr_base_cache" -TASK_CACHE = "/home/johnwu3/projects/covid19cxr_task_cache" CKPT = "/home/johnwu3/projects/covid19cxr_vit_model.ckpt" SEED = 42 diff --git a/examples/cxr/covid19cxr_tutorial_display.py b/examples/cxr/covid19cxr_tutorial_display.py index cb81678de..3f6a33b82 100644 --- a/examples/cxr/covid19cxr_tutorial_display.py +++ b/examples/cxr/covid19cxr_tutorial_display.py @@ -26,7 +26,6 @@ DATA_ROOT = "/home/johnwu3/projects/PyHealth_Branch_Testing/datasets" ROOT = f"{DATA_ROOT}/COVID-19_Radiography_Dataset" CACHE = "/home/johnwu3/projects/covid19cxr_base_cache" -TASK_CACHE = "/home/johnwu3/projects/covid19cxr_task_cache" CKPT = "/home/johnwu3/projects/covid19cxr_vit_model.ckpt" SEED = 42 diff --git a/examples/drug_recommendation/drug_recommendation_mimic4_retain.py b/examples/drug_recommendation/drug_recommendation_mimic4_retain.py index 688178a06..496edc7e3 100644 --- a/examples/drug_recommendation/drug_recommendation_mimic4_retain.py +++ b/examples/drug_recommendation/drug_recommendation_mimic4_retain.py @@ -35,7 +35,6 @@ sample_dataset = base_dataset.set_task( DrugRecommendationMIMIC4(), num_workers=4, - cache_dir="../../mimic4_drug_rec_cache", ) print(f"Total samples: {len(sample_dataset)}") diff --git a/examples/interpretability/gim_stagenet_mimic4.py b/examples/interpretability/gim_stagenet_mimic4.py index 8d45ad9a7..b38eb7a87 100644 --- a/examples/interpretability/gim_stagenet_mimic4.py +++ b/examples/interpretability/gim_stagenet_mimic4.py @@ -31,7 +31,6 @@ sample_dataset = dataset.set_task( MortalityPredictionStageNetMIMIC4(), - cache_dir="~/.cache/pyhealth/mimic4_stagenet_mortality", input_processors=input_processors, output_processors=output_processors, ) diff --git a/examples/interpretability/gim_transformer_mimic4.py b/examples/interpretability/gim_transformer_mimic4.py index 374d00d64..884be99a5 100644 --- a/examples/interpretability/gim_transformer_mimic4.py +++ b/examples/interpretability/gim_transformer_mimic4.py @@ -56,7 +56,6 @@ def maybe_load_processors(resource_dir: str, task): sample_dataset = dataset.set_task( task, - cache_dir="~/.cache/pyhealth/mimic4_transformer_mortality", input_processors=input_processors, output_processors=output_processors, ) diff --git a/examples/interpretability/integrated_gradients_benchmark_stagenet.py b/examples/interpretability/integrated_gradients_benchmark_stagenet.py index 82dbca218..c1cdbf0fd 100644 --- a/examples/interpretability/integrated_gradients_benchmark_stagenet.py +++ b/examples/interpretability/integrated_gradients_benchmark_stagenet.py @@ -37,7 +37,6 @@ "20260131-184735/best.ckpt" ) PROCESSOR_DIR = "../output/processors/stagenet_mortality_mimic4" -CACHE_DIR = "../../mimic4_stagenet_cache" def main(): @@ -85,7 +84,6 @@ def main(): sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(), num_workers=8, - cache_dir=CACHE_DIR, input_processors=input_processors, output_processors=output_processors, ) diff --git a/examples/interpretability/integrated_gradients_mortality_mimic4_stagenet.py b/examples/interpretability/integrated_gradients_mortality_mimic4_stagenet.py index b55a8524d..32bab62d5 100644 --- a/examples/interpretability/integrated_gradients_mortality_mimic4_stagenet.py +++ b/examples/interpretability/integrated_gradients_mortality_mimic4_stagenet.py @@ -319,7 +319,6 @@ def main(): sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(), num_workers=8, - cache_dir="../../mimic4_stagenet_cache", input_processors=input_processors, output_processors=output_processors, ) diff --git a/examples/interpretability/interpretability_metrics.py b/examples/interpretability/interpretability_metrics.py index 7adfb55a6..723f63da1 100644 --- a/examples/interpretability/interpretability_metrics.py +++ b/examples/interpretability/interpretability_metrics.py @@ -44,7 +44,6 @@ def main(): sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(), num_workers=4, - cache_dir="../../mimic4_stagenet_cache", ) print(f"✓ Loaded {len(sample_dataset)} samples") diff --git a/examples/interpretability/shap_stagenet_mimic4.ipynb b/examples/interpretability/shap_stagenet_mimic4.ipynb index a871d7326..f2f4e2efd 100644 --- a/examples/interpretability/shap_stagenet_mimic4.ipynb +++ b/examples/interpretability/shap_stagenet_mimic4.ipynb @@ -13,8 +13,8 @@ }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "PyTorch version: 2.9.0+cu126\n", "CUDA available: True\n", @@ -64,8 +64,8 @@ }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Collecting git+https://github.com/naveenkcb/PyHealth.git\n", " Cloning https://github.com/naveenkcb/PyHealth.git to /tmp/pip-req-build-u5cek8co\n", @@ -273,24 +273,24 @@ ] }, { - "output_type": "display_data", "data": { "application/vnd.colab-display-data+json": { + "id": "a21cf9550c9c4d1a916e50ccaf894bf2", "pip_warning": { "packages": [ "numpy", "torch", "torchgen" ] - }, - "id": "a21cf9550c9c4d1a916e50ccaf894bf2" + } } }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Requirement already satisfied: torch in /usr/local/lib/python3.12/dist-packages (2.7.1)\n", "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.12/dist-packages (1.7.2)\n", @@ -372,8 +372,8 @@ }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n", "Data directory: /content/mimic-iv-demo/2.2\n", @@ -430,8 +430,8 @@ }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Model checkpoint should be at: /content/resources/best.ckpt\n", "Checkpoint exists: True\n" @@ -477,205 +477,205 @@ }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Using default EHR config: /usr/local/lib/python3.12/dist-packages/pyhealth/datasets/configs/mimic4_ehr.yaml\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.mimic4:Using default EHR config: /usr/local/lib/python3.12/dist-packages/pyhealth/datasets/configs/mimic4_ehr.yaml\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Memory usage Before initializing mimic4_ehr: 1574.5 MB\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.mimic4:Memory usage Before initializing mimic4_ehr: 1574.5 MB\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Duplicate table names in tables list. Removing duplicates.\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "WARNING:pyhealth.datasets.base_dataset:Duplicate table names in tables list. Removing duplicates.\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Initializing mimic4_ehr dataset from https://physionet.org/files/mimic-iv-demo/2.2/ (dev mode: False)\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Initializing mimic4_ehr dataset from https://physionet.org/files/mimic-iv-demo/2.2/ (dev mode: False)\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Scanning table: admissions from https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Scanning table: admissions from https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Scanning table: procedures_icd from https://physionet.org/files/mimic-iv-demo/2.2/hosp/procedures_icd.csv.gz\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Scanning table: procedures_icd from https://physionet.org/files/mimic-iv-demo/2.2/hosp/procedures_icd.csv.gz\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Scanning table: diagnoses_icd from https://physionet.org/files/mimic-iv-demo/2.2/hosp/diagnoses_icd.csv.gz\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Scanning table: diagnoses_icd from https://physionet.org/files/mimic-iv-demo/2.2/hosp/diagnoses_icd.csv.gz\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Scanning table: icustays from https://physionet.org/files/mimic-iv-demo/2.2/icu/icustays.csv.gz\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Scanning table: icustays from https://physionet.org/files/mimic-iv-demo/2.2/icu/icustays.csv.gz\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Scanning table: patients from https://physionet.org/files/mimic-iv-demo/2.2/hosp/patients.csv.gz\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Scanning table: patients from https://physionet.org/files/mimic-iv-demo/2.2/hosp/patients.csv.gz\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Scanning table: labevents from https://physionet.org/files/mimic-iv-demo/2.2/hosp/labevents.csv.gz\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Scanning table: labevents from https://physionet.org/files/mimic-iv-demo/2.2/hosp/labevents.csv.gz\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/d_labitems.csv.gz\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/d_labitems.csv.gz\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Memory usage After initializing mimic4_ehr: 1574.9 MB\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.mimic4:Memory usage After initializing mimic4_ehr: 1574.9 MB\n" ] }, { - "output_type": "error", "ename": "TypeError", "evalue": "object of type 'MIMIC4EHRDataset' has no len()", + "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", @@ -722,19 +722,19 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "1a4d785d", "metadata": { - "id": "1a4d785d", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "1a4d785d", "outputId": "25686b72-cfd7-4766-c527-3e5e920da519" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "✓ Loaded input processors from /content/resources/input_processors.pkl\n", "✓ Loaded output processors from /content/resources/output_processors.pkl\n", @@ -742,103 +742,103 @@ ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Setting task MortalityPredictionStageNetMIMIC4 for mimic4_ehr base dataset...\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Generating samples with 1 worker(s)...\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Generating samples with 1 worker(s)...\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Collecting global event dataframe...\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Collecting global event dataframe...\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Collected dataframe with shape: (113470, 39)\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "INFO:pyhealth.datasets.base_dataset:Collected dataframe with shape: (113470, 39)\n", "Generating samples for MortalityPredictionStageNetMIMIC4 with 1 worker: 100%|██████████| 100/100 [00:16<00:00, 6.18it/s]" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Caching samples to /content/.cache/pyhealth/mimic4_stagenet_mortality/MortalityPredictionStageNetMIMIC4.parquet\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "\n", "INFO:pyhealth.datasets.base_dataset:Caching samples to /content/.cache/pyhealth/mimic4_stagenet_mortality/MortalityPredictionStageNetMIMIC4.parquet\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Failed to cache samples: failed to determine supertype of list[f64] and list[list[str]]\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "WARNING:pyhealth.datasets.base_dataset:Failed to cache samples: failed to determine supertype of list[f64] and list[list[str]]\n", "Processing samples: 100%|██████████| 100/100 [00:00<00:00, 1923.08it/s]" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Generated 100 samples for task MortalityPredictionStageNetMIMIC4\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "\n", "INFO:pyhealth.datasets.base_dataset:Generated 100 samples for task MortalityPredictionStageNetMIMIC4\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Total samples: 100\n" ] @@ -850,7 +850,6 @@ "\n", "sample_dataset = dataset.set_task(\n", " MortalityPredictionStageNetMIMIC4(),\n", - " cache_dir=\"/content/.cache/pyhealth/mimic4_stagenet_mortality\",\n", " input_processors=input_processors,\n", " output_processors=output_processors,\n", ")\n", @@ -872,16 +871,16 @@ "execution_count": 7, "id": "4594eea4", "metadata": { - "id": "4594eea4", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "4594eea4", "outputId": "a5f32db8-60fb-4fcb-c024-f1a3e6ce861e" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Loaded 0 ICD code descriptions\n" ] @@ -938,16 +937,16 @@ "execution_count": 8, "id": "22f70a91", "metadata": { - "id": "22f70a91", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "22f70a91", "outputId": "4ce86eee-b4f6-4dd6-bc64-900dfbda1f52" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Using device: cuda:0\n", "\n", @@ -995,16 +994,16 @@ "execution_count": 9, "id": "6cbda428", "metadata": { - "id": "6cbda428", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "6cbda428", "outputId": "54f18aa7-fd73-4acc-bda9-4bc28658d998" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Test samples: 20\n" ] @@ -1147,16 +1146,16 @@ "execution_count": 11, "id": "5ae57044", "metadata": { - "id": "5ae57044", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "5ae57044", "outputId": "0a6bd6e9-8168-467a-c5d1-6134b2be9c3b" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "================================================================================\n", "Initializing SHAP Explainer\n", @@ -1202,16 +1201,16 @@ "execution_count": 12, "id": "3cb63b98", "metadata": { - "id": "3cb63b98", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "3cb63b98", "outputId": "695af029-b984-4957-e6de-4ceb92974ed4" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "icd_codes: device=cuda:0\n", "labs: device=cuda:0\n", @@ -1281,16 +1280,16 @@ "execution_count": 13, "id": "c65de0c3", "metadata": { - "id": "c65de0c3", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "c65de0c3", "outputId": "2225e28a-0653-4441-a8c0-611f36e8bd73" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "\n", "================================================================================\n", @@ -1341,16 +1340,16 @@ "execution_count": 14, "id": "93490ab1", "metadata": { - "id": "93490ab1", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "93490ab1", "outputId": "4bff2e31-e7d5-42ad-d63c-ea7e1c02c20b" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "\n", "================================================================================\n", @@ -1433,16 +1432,16 @@ "execution_count": 15, "id": "c7b02451", "metadata": { - "id": "c7b02451", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "c7b02451", "outputId": "5f2b6f99-5432-4db4-f2fa-4227bd335858" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "\n", "================================================================================\n", @@ -1541,16 +1540,16 @@ "execution_count": 16, "id": "4a5c098c", "metadata": { - "id": "4a5c098c", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "4a5c098c", "outputId": "c1d93796-238e-42c1-ac0c-7ae9013030b4" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "\n", "================================================================================\n", @@ -1597,16 +1596,16 @@ "execution_count": 17, "id": "69867127", "metadata": { - "id": "69867127", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "69867127", "outputId": "d366b86c-e7f4-40d7-881f-cf7710acc812" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "\n", "================================================================================\n", @@ -1675,19 +1674,19 @@ } ], "metadata": { - "language_info": { - "name": "python" - }, + "accelerator": "GPU", "colab": { - "provenance": [], - "gpuType": "T4" + "gpuType": "T4", + "provenance": [] }, - "accelerator": "GPU", "kernelspec": { - "name": "python3", - "display_name": "Python 3" + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/examples/interpretability/shap_stagenet_mimic4.py b/examples/interpretability/shap_stagenet_mimic4.py index f376bc62e..a06a9300d 100644 --- a/examples/interpretability/shap_stagenet_mimic4.py +++ b/examples/interpretability/shap_stagenet_mimic4.py @@ -98,7 +98,7 @@ def print_top_attributions( """Print top-k most important features from SHAP attributions.""" if icd_code_to_desc is None: icd_code_to_desc = {} - + for feature_key, attr in attributions.items(): attr_cpu = attr.detach().cpu() if attr_cpu.dim() == 0 or attr_cpu.size(0) == 0: @@ -117,7 +117,7 @@ def print_top_attributions( print(f" Shape: {attr_cpu[0].shape}") print(f" Total attribution sum: {flattened.sum().item():+.6f}") print(f" Mean attribution: {flattened.mean().item():+.6f}") - + k = min(top_k, flattened.numel()) top_values, top_indices = torch.topk(flattened.abs(), k=k) processor = processors.get(feature_key) if processors else None @@ -169,7 +169,6 @@ def main(): sample_dataset = dataset.set_task( MortalityPredictionStageNetMIMIC4(), - cache_dir="~/.cache/pyhealth/mimic4_stagenet_mortality", input_processors=input_processors, output_processors=output_processors, ) @@ -222,7 +221,7 @@ def main(): probs = output["y_prob"] label_key = model.label_key true_label = sample_batch_device[label_key] - + # Handle binary classification (single probability output) if probs.shape[-1] == 1: prob_death = probs[0].item() diff --git a/examples/length_of_stay/length_of_stay_mimic4_ehrmamba.py b/examples/length_of_stay/length_of_stay_mimic4_ehrmamba.py index a5105a39d..26314db9b 100644 --- a/examples/length_of_stay/length_of_stay_mimic4_ehrmamba.py +++ b/examples/length_of_stay/length_of_stay_mimic4_ehrmamba.py @@ -41,7 +41,6 @@ EPOCHS = 20 DATASET_CACHE = os.path.join(CACHE_BASE, "mimic4_ehr_los") -TASK_CACHE = os.path.join(CACHE_BASE, "mimic4_los_ehrmamba") def main(): @@ -54,9 +53,6 @@ def main(): dataset_cache = os.path.join( CACHE_BASE, "mimic4_ehr_los_quick" if quick_test else "mimic4_ehr_los" ) - task_cache = os.path.join( - CACHE_BASE, "mimic4_los_ehrmamba_quick" if quick_test else "mimic4_los_ehrmamba" - ) num_workers = 1 if quick_test else 4 print("EHRMamba – Length of stay (full MIMIC-IV)") @@ -66,7 +62,7 @@ def main(): print("gpu:", gpu_id, "(CUDA_VISIBLE_DEVICES)") print("device:", DEVICE) print("ehr_root:", EHR_ROOT) - print("cache: dataset", dataset_cache, "| task", task_cache) + print("cache:", dataset_cache) print("seed:", SEED, "| batch_size:", BATCH_SIZE, "| epochs:", epochs) t0 = time.perf_counter() @@ -83,7 +79,6 @@ def main(): sample_dataset = dataset.set_task( task, num_workers=num_workers, - cache_dir=task_cache, ) print(f"Task set in {time.perf_counter() - t1:.1f}s | samples: {len(sample_dataset)}") diff --git a/examples/length_of_stay/length_of_stay_mimic4_stageattn.py b/examples/length_of_stay/length_of_stay_mimic4_stageattn.py index 70d15024a..514f7750c 100644 --- a/examples/length_of_stay/length_of_stay_mimic4_stageattn.py +++ b/examples/length_of_stay/length_of_stay_mimic4_stageattn.py @@ -43,7 +43,6 @@ # - This ensures consistent encoding and saves computation time # - Processors include vocabulary mappings and sequence length statistics processor_dir = "/home/yongdaf2/los_sa/processors" - cache_dir = "/home/yongdaf2/los_sa/cache" if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")): print("\n=== Loading Pre-fitted Processors ===") @@ -52,7 +51,6 @@ sample_dataset = base_dataset.set_task( LengthOfStayStageNetMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors, ) @@ -61,7 +59,6 @@ sample_dataset = base_dataset.set_task( LengthOfStayStageNetMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, ) # Save processors for future runs diff --git a/examples/length_of_stay/length_of_stay_mimic4_stagenet.py b/examples/length_of_stay/length_of_stay_mimic4_stagenet.py index 149db7e0b..3c560fe5b 100644 --- a/examples/length_of_stay/length_of_stay_mimic4_stagenet.py +++ b/examples/length_of_stay/length_of_stay_mimic4_stagenet.py @@ -43,7 +43,6 @@ # - This ensures consistent encoding and saves computation time # - Processors include vocabulary mappings and sequence length statistics processor_dir = "/home/yongdaf2/los_sn/processors" - cache_dir = "/home/yongdaf2/los_sn/cache" if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")): print("\n=== Loading Pre-fitted Processors ===") @@ -52,7 +51,6 @@ sample_dataset = base_dataset.set_task( LengthOfStayStageNetMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors, ) @@ -61,7 +59,6 @@ sample_dataset = base_dataset.set_task( LengthOfStayStageNetMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, ) # Save processors for future runs diff --git a/examples/length_of_stay/length_of_stay_mimic4_transformer.py b/examples/length_of_stay/length_of_stay_mimic4_transformer.py index 094be8965..3d7b5aa16 100644 --- a/examples/length_of_stay/length_of_stay_mimic4_transformer.py +++ b/examples/length_of_stay/length_of_stay_mimic4_transformer.py @@ -43,7 +43,6 @@ # - This ensures consistent encoding and saves computation time # - Processors include vocabulary mappings and sequence length statistics processor_dir = "/home/yongdaf2/los_tf/processors" - cache_dir = "/home/yongdaf2/los_tf/cache" if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")): print("\n=== Loading Pre-fitted Processors ===") @@ -52,7 +51,6 @@ sample_dataset = base_dataset.set_task( LengthOfStayStageNetMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors, ) @@ -61,7 +59,6 @@ sample_dataset = base_dataset.set_task( LengthOfStayStageNetMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, ) # Save processors for future runs diff --git a/examples/lime_stagenet_mimic4.py b/examples/lime_stagenet_mimic4.py index fe1a69922..b3b54085d 100644 --- a/examples/lime_stagenet_mimic4.py +++ b/examples/lime_stagenet_mimic4.py @@ -99,7 +99,7 @@ def print_top_attributions( """Print top-k most important features from LIME attributions.""" if icd_code_to_desc is None: icd_code_to_desc = {} - + for feature_key, attr in attributions.items(): attr_cpu = attr.detach().cpu() if attr_cpu.dim() == 0 or attr_cpu.size(0) == 0: @@ -118,7 +118,7 @@ def print_top_attributions( print(f" Shape: {attr_cpu[0].shape}") print(f" Total attribution sum: {flattened.sum().item():+.6f}") print(f" Mean attribution: {flattened.mean().item():+.6f}") - + k = min(top_k, flattened.numel()) top_values, top_indices = torch.topk(flattened.abs(), k=k) processor = processors.get(feature_key) if processors else None @@ -170,7 +170,6 @@ def main(): sample_dataset = dataset.set_task( MortalityPredictionStageNetMIMIC4(), - cache_dir="~/.cache/pyhealth/mimic4_stagenet_mortality", input_processors=input_processors, output_processors=output_processors, ) @@ -225,7 +224,7 @@ def main(): probs = output["y_prob"] label_key = model.label_key true_label = sample_batch_device[label_key] - + # Handle binary classification (single probability output) if probs.shape[-1] == 1: prob_death = probs[0].item() @@ -262,10 +261,10 @@ def main(): print("Positive values increase the mortality prediction, negative values decrease it.") print_top_attributions( - attributions, - sample_batch_device, - input_processors, - top_k=15, + attributions, + sample_batch_device, + input_processors, + top_k=15, icd_code_to_desc=ICD_CODE_TO_DESC, method_name="LIME" ) @@ -336,11 +335,11 @@ def main(): if key in attributions: flat_lasso = attributions[key][0].flatten().abs() flat_ridge = attr_ridge[key][0].flatten().abs() - + k = min(5, flat_lasso.numel()) top_lasso = torch.topk(flat_lasso, k=k) top_ridge = torch.topk(flat_ridge, k=k) - + print(f"\n {key}:") print(f" Lasso non-zero features: {(flat_lasso > 1e-6).sum().item()}/{flat_lasso.numel()}") print(f" Ridge non-zero features: {(flat_ridge > 1e-6).sum().item()}/{flat_ridge.numel()}") diff --git a/examples/mortality_prediction/ehrmamba_mimic4_full.py b/examples/mortality_prediction/ehrmamba_mimic4_full.py index a77e92c60..aabe7fd59 100644 --- a/examples/mortality_prediction/ehrmamba_mimic4_full.py +++ b/examples/mortality_prediction/ehrmamba_mimic4_full.py @@ -41,7 +41,6 @@ EPOCHS = 20 DATASET_CACHE = os.path.join(CACHE_BASE, "mimic4_ehr") -TASK_CACHE = os.path.join(CACHE_BASE, "mimic4_ihm_ehrmamba") def main(): @@ -52,9 +51,6 @@ def main(): dev = quick_test epochs = 2 if quick_test else EPOCHS dataset_cache = os.path.join(CACHE_BASE, "mimic4_ehr_quick" if quick_test else "mimic4_ehr") - task_cache = os.path.join( - CACHE_BASE, "mimic4_ihm_ehrmamba_quick" if quick_test else "mimic4_ihm_ehrmamba" - ) num_workers = 1 if quick_test else 4 print("EHRMamba – In-hospital mortality (full MIMIC-IV)") @@ -64,7 +60,7 @@ def main(): print("gpu:", gpu_id, "(CUDA_VISIBLE_DEVICES)") print("device:", DEVICE) print("ehr_root:", EHR_ROOT) - print("cache: dataset", dataset_cache, "| task", task_cache) + print("cache:", dataset_cache) print("seed:", SEED, "| batch_size:", BATCH_SIZE, "| epochs:", epochs) t0 = time.perf_counter() @@ -81,7 +77,6 @@ def main(): sample_dataset = dataset.set_task( task, num_workers=num_workers, - cache_dir=task_cache, ) print(f"Task set in {time.perf_counter() - t1:.1f}s | samples: {len(sample_dataset)}") diff --git a/examples/mortality_prediction/mortality_mimic4_stageattn.py b/examples/mortality_prediction/mortality_mimic4_stageattn.py index 18649832b..f302a8600 100644 --- a/examples/mortality_prediction/mortality_mimic4_stageattn.py +++ b/examples/mortality_prediction/mortality_mimic4_stageattn.py @@ -43,7 +43,6 @@ # - This ensures consistent encoding and saves computation time # - Processors include vocabulary mappings and sequence length statistics processor_dir = "/home/yongdaf2/stageattn/processors" - cache_dir = "/home/yongdaf2/stageattn/cache" if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")): print("\n=== Loading Pre-fitted Processors ===") @@ -52,7 +51,6 @@ sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors, ) @@ -61,7 +59,6 @@ sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, ) # Save processors for future runs diff --git a/examples/mortality_prediction/mortality_mimic4_stagenet_v2.py b/examples/mortality_prediction/mortality_mimic4_stagenet_v2.py index 0e267ae1f..acf9598d0 100644 --- a/examples/mortality_prediction/mortality_mimic4_stagenet_v2.py +++ b/examples/mortality_prediction/mortality_mimic4_stagenet_v2.py @@ -160,7 +160,6 @@ def generate_holdout_set( # - This ensures consistent encoding and saves computation time # - Processors include vocabulary mappings and sequence length statistics processor_dir = "../../output/processors/stagenet_mortality_mimic4" - cache_dir = "../../mimic4_stagenet_cache_v3" if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")): print("\n=== Loading Pre-fitted Processors ===") @@ -169,7 +168,6 @@ def generate_holdout_set( sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(padding=20), num_workers=4, - cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors, ) @@ -178,7 +176,6 @@ def generate_holdout_set( sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(padding=20), num_workers=4, - cache_dir=cache_dir, ) # Save processors for future runs diff --git a/examples/mortality_prediction/mortality_mimic4_transformer.py b/examples/mortality_prediction/mortality_mimic4_transformer.py index 838496849..d2afc585d 100644 --- a/examples/mortality_prediction/mortality_mimic4_transformer.py +++ b/examples/mortality_prediction/mortality_mimic4_transformer.py @@ -43,7 +43,6 @@ # - This ensures consistent encoding and saves computation time # - Processors include vocabulary mappings and sequence length statistics processor_dir = "/home/yongdaf2/mp_tf/processors" - cache_dir = "/home/yongdaf2/mp_tf/cache" if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")): print("\n=== Loading Pre-fitted Processors ===") @@ -52,7 +51,6 @@ sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors, ) @@ -61,7 +59,6 @@ sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(padding=20), num_workers=16, - cache_dir=cache_dir, ) # Save processors for future runs diff --git a/examples/transformer_mimic4.ipynb b/examples/transformer_mimic4.ipynb index f7a435449..37c098521 100644 --- a/examples/transformer_mimic4.ipynb +++ b/examples/transformer_mimic4.ipynb @@ -97,7 +97,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "46bfc544", "metadata": {}, "outputs": [ @@ -161,7 +161,6 @@ "task = InHospitalMortalityMIMIC4()\n", "sample_dataset = dataset.set_task(\n", " task,\n", - " cache_dir=\"../../test_cache_transformer_m4\"\n", ")\n", "train_dataset, val_dataset, test_dataset = split_by_sample(sample_dataset, ratios=[0.7, 0.1, 0.2])" ] diff --git a/examples/tutorial_stagenet_comprehensive.ipynb b/examples/tutorial_stagenet_comprehensive.ipynb index e6ce98399..7e1a9a3fd 100644 --- a/examples/tutorial_stagenet_comprehensive.ipynb +++ b/examples/tutorial_stagenet_comprehensive.ipynb @@ -891,7 +891,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "8e01f7ec", "metadata": {}, "outputs": [ @@ -953,9 +953,8 @@ ], "source": [ "from pyhealth.datasets.utils import save_processors, load_processors\n", - "import os \n", + "import os\n", "processor_dir = \"../../output/processors/stagenet_mortality_mimic4\"\n", - "cache_dir = \"../../mimic4_stagenet_cache_v3\"\n", "\n", "if os.path.exists(os.path.join(processor_dir, \"input_processors.pkl\")):\n", " print(\"\\n=== Loading Pre-fitted Processors ===\")\n", @@ -964,7 +963,6 @@ " sample_dataset = base_dataset.set_task(\n", " MortalityPredictionStageNetMIMIC4(padding=20),\n", " num_workers=1,\n", - " cache_dir=cache_dir,\n", " input_processors=input_processors,\n", " output_processors=output_processors,\n", " )\n", @@ -973,7 +971,6 @@ " sample_dataset = base_dataset.set_task(\n", " MortalityPredictionStageNetMIMIC4(padding=20),\n", " num_workers=1,\n", - " cache_dir=cache_dir,\n", " )\n", "\n", " # Save processors for future runs\n", From 9b7c48f9830f463ab468634927e5a4acff51c27f Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Mon, 16 Feb 2026 04:07:33 +0000 Subject: [PATCH 13/16] Fix cache unit test (call_count increments trigger new caches and we can no longer override the default task cache) --- tests/core/test_caching.py | 48 +++++++++++--------------------------- 1 file changed, 13 insertions(+), 35 deletions(-) diff --git a/tests/core/test_caching.py b/tests/core/test_caching.py index 299c2f430..4b7f1d512 100644 --- a/tests/core/test_caching.py +++ b/tests/core/test_caching.py @@ -20,16 +20,13 @@ class MockTask(BaseTask): input_schema = {"test_attribute": "raw"} output_schema = {"test_label": "binary"} - def __init__(self, param=None): - self.call_count = 0 - if param: - self.param = param + def __init__(self, param=0): + self.param = param def __call__(self, patient): """Return mock samples based on patient data.""" # Extract patient's test data from the patient's data source patient_data = patient.data_source - self.call_count += 1 samples = [] for row in patient_data.iter_rows(named=True): @@ -49,16 +46,13 @@ class MockTask2(BaseTask): input_schema = {"test_attribute": "raw"} output_schema = {"test_label": "multiclass"} - def __init__(self, param=None): - self.call_count = 0 - if param: - self.param = param + def __init__(self, param=0): + self.param = param def __call__(self, patient): """Return mock samples based on patient data.""" # Extract patient's test data from the patient's data source patient_data = patient.data_source - self.call_count += 1 samples = [] for row in patient_data.iter_rows(named=True): @@ -108,19 +102,10 @@ def load_data(self) -> dd.DataFrame: class TestCachingFunctionality(BaseTestCase): """Test cases for caching functionality in BaseDataset.set_task().""" - - @classmethod - def setUpClass(cls): - cls.temp_dir = tempfile.TemporaryDirectory() - cls.dataset = MockDataset(cache_dir=cls.temp_dir.name) - def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.dataset = MockDataset(cache_dir=self.temp_dir.name) self.task = MockTask() - self.cache_dir = Path(self.temp_dir.name) / "task_cache" - self.cache_dir.mkdir() - - def tearDown(self): - shutil.rmtree(self.cache_dir) def test_set_task_signature(self): """Test that set_task has the correct method signature.""" @@ -151,15 +136,10 @@ def test_set_task_writes_cache_and_metadata(self): self.assertEqual(sample_dataset.dataset_name, "TestDataset") self.assertEqual(sample_dataset.task_name, self.task.task_name) self.assertEqual(len(sample_dataset), 4) - self.assertEqual(self.task.call_count, 2) # Ensure intermediate cache files are created in default location task_params = json.dumps( - { - **vars(self.task), - "input_schema": self.task.input_schema, - "output_schema": self.task.output_schema, - }, + {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "binary"}, "param": 0}, sort_keys=True, default=str ) @@ -195,7 +175,7 @@ def test_set_task_writes_cache_and_metadata(self): def test_default_cache_dir_is_used(self): """When cache_dir is omitted, default cache dir should be used.""" task_params = json.dumps( - {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "binary"}, "call_count": 0}, + {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "binary"}, "param": 0}, sort_keys=True, default=str ) @@ -213,15 +193,13 @@ def test_default_cache_dir_is_used(self): def test_reuses_existing_cache_without_regeneration(self): """Second call should reuse cached samples instead of recomputing.""" sample_dataset = self.dataset.set_task(self.task) - self.assertEqual(self.task.call_count, 2) with patch.object( - self.task, "__call__", side_effect=AssertionError("Task should not rerun") + type(self.task), "__call__", side_effect=AssertionError("Task should not rerun") ): cached_dataset = self.dataset.set_task(self.task) self.assertEqual(len(cached_dataset), 4) - self.assertEqual(self.task.call_count, 2) sample_dataset.close() cached_dataset.close() @@ -233,13 +211,13 @@ def test_tasks_with_diff_param_values_get_diff_caches(self): self.assertNotEqual(sample_dataset1.path, sample_dataset2.path) task_params1 = json.dumps( - {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "binary"}, "call_count": 0, "param": 1}, + {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "binary"}, "param": 1}, sort_keys=True, default=str ) task_params2 = json.dumps( - {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "binary"}, "call_count": 0, "param": 2}, + {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "binary"}, "param": 2}, sort_keys=True, default=str ) @@ -265,13 +243,13 @@ def test_tasks_with_diff_output_schemas_get_diff_caches(self): self.assertNotEqual(sample_dataset1.path, sample_dataset2.path) task_params1 = json.dumps( - {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "binary"}, "call_count": 0}, + {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "binary"}, "param": 0}, sort_keys=True, default=str ) task_params2 = json.dumps( - {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "multiclass"}, "call_count": 0}, + {"input_schema": {"test_attribute": "raw"}, "output_schema": {"test_label": "multiclass"}, "param": 0}, sort_keys=True, default=str ) From 6c9a31f55a635330f459129c91cee45390d3b596 Mon Sep 17 00:00:00 2001 From: John Wu Date: Mon, 16 Feb 2026 11:57:07 -0600 Subject: [PATCH 14/16] update cache size directory check to be more in line with our tasks/ subdir setup with how cache_dir works --- examples/benchmark_perf/benchmark_workers_1.py | 3 +-- examples/benchmark_perf/benchmark_workers_12.py | 3 +-- examples/benchmark_perf/benchmark_workers_4.py | 3 +-- examples/benchmark_perf/benchmark_workers_n.py | 8 ++------ .../benchmark_workers_n_drug_recommendation.py | 8 ++------ .../benchmark_perf/benchmark_workers_n_length_of_stay.py | 8 ++------ examples/mortality_prediction/multimodal_mimic4_demo.py | 4 ++-- 7 files changed, 11 insertions(+), 26 deletions(-) diff --git a/examples/benchmark_perf/benchmark_workers_1.py b/examples/benchmark_perf/benchmark_workers_1.py index 1106ac007..c4fbcab5d 100644 --- a/examples/benchmark_perf/benchmark_workers_1.py +++ b/examples/benchmark_perf/benchmark_workers_1.py @@ -140,7 +140,6 @@ def main(): sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(), num_workers=1, - cache_dir=f"{cache_root}/task_samples", ) task_time = time.time() - task_start @@ -149,7 +148,7 @@ def main(): # Measure cache sizes print("\n[3/3] Measuring cache sizes...") base_cache_dir = f"{cache_root}/base_dataset" - task_cache_dir = f"{cache_root}/task_samples" + task_cache_dir = f"{base_cache_dir}/tasks" base_cache_size = get_directory_size(base_cache_dir) task_cache_size = get_directory_size(task_cache_dir) diff --git a/examples/benchmark_perf/benchmark_workers_12.py b/examples/benchmark_perf/benchmark_workers_12.py index 01302fa02..4f41796c0 100644 --- a/examples/benchmark_perf/benchmark_workers_12.py +++ b/examples/benchmark_perf/benchmark_workers_12.py @@ -140,7 +140,6 @@ def main(): sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(), num_workers=12, - cache_dir=f"{cache_root}/task_samples", ) task_time = time.time() - task_start @@ -149,7 +148,7 @@ def main(): # Measure cache sizes print("\n[3/3] Measuring cache sizes...") base_cache_dir = f"{cache_root}/base_dataset" - task_cache_dir = f"{cache_root}/task_samples" + task_cache_dir = f"{base_cache_dir}/tasks" base_cache_size = get_directory_size(base_cache_dir) task_cache_size = get_directory_size(task_cache_dir) diff --git a/examples/benchmark_perf/benchmark_workers_4.py b/examples/benchmark_perf/benchmark_workers_4.py index 82b56059a..82abca3f9 100644 --- a/examples/benchmark_perf/benchmark_workers_4.py +++ b/examples/benchmark_perf/benchmark_workers_4.py @@ -140,7 +140,6 @@ def main(): sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(), num_workers=4, - cache_dir=f"{cache_root}/task_samples_old", ) task_time = time.time() - task_start @@ -149,7 +148,7 @@ def main(): # Measure cache sizes print("\n[3/3] Measuring cache sizes...") base_cache_dir = f"{cache_root}/base_dataset" - task_cache_dir = f"{cache_root}/task_samples" + task_cache_dir = f"{base_cache_dir}/tasks" base_cache_size = get_directory_size(base_cache_dir) task_cache_size = get_directory_size(task_cache_dir) diff --git a/examples/benchmark_perf/benchmark_workers_n.py b/examples/benchmark_perf/benchmark_workers_n.py index 08454e16b..bea46c1ca 100644 --- a/examples/benchmark_perf/benchmark_workers_n.py +++ b/examples/benchmark_perf/benchmark_workers_n.py @@ -282,11 +282,8 @@ def main() -> None: print("\n[1/1] Sweeping num_workers (each run reloads dataset + task)...") for w in args.workers: for r in range(args.repeats): - task_cache_dir = cache_root / f"task_samples_w{w}" - # Ensure no cache artifacts before this run. remove_dir(base_cache_dir) - ensure_empty_dir(task_cache_dir) tracker.reset() run_start = time.time() @@ -311,13 +308,13 @@ def main() -> None: sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(), num_workers=w, - cache_dir=str(task_cache_dir), ) task_process_s = time.time() - task_start total_s = time.time() - run_start peak_rss_bytes = tracker.peak_bytes() - task_cache_bytes = get_directory_size(task_cache_dir) + tasks_dir = base_cache_dir / "tasks" + task_cache_bytes = get_directory_size(tasks_dir) # Capture sample count BEFORE cleaning up the cache (litdata needs it). num_samples = len(sample_dataset) @@ -327,7 +324,6 @@ def main() -> None: del base_dataset # Clean up to avoid disk growth across an overnight sweep. - remove_dir(task_cache_dir) remove_dir(base_cache_dir) results.append( diff --git a/examples/benchmark_perf/benchmark_workers_n_drug_recommendation.py b/examples/benchmark_perf/benchmark_workers_n_drug_recommendation.py index 7cfe46e46..6aa56c679 100644 --- a/examples/benchmark_perf/benchmark_workers_n_drug_recommendation.py +++ b/examples/benchmark_perf/benchmark_workers_n_drug_recommendation.py @@ -284,11 +284,8 @@ def main() -> None: print("\n[1/1] Sweeping num_workers (each run reloads dataset + task)...") for w in args.workers: for r in range(args.repeats): - task_cache_dir = cache_root / f"task_samples_drug_rec_w{w}" - # Ensure no cache artifacts before this run. remove_dir(base_cache_dir) - ensure_empty_dir(task_cache_dir) tracker.reset() run_start = time.time() @@ -313,13 +310,13 @@ def main() -> None: sample_dataset = base_dataset.set_task( DrugRecommendationMIMIC4(), num_workers=w, - cache_dir=str(task_cache_dir), ) task_process_s = time.time() - task_start total_s = time.time() - run_start peak_rss_bytes = tracker.peak_bytes() - task_cache_bytes = get_directory_size(task_cache_dir) + tasks_dir = base_cache_dir / "tasks" + task_cache_bytes = get_directory_size(tasks_dir) # Capture sample count BEFORE cleaning up the cache (litdata needs it). num_samples = len(sample_dataset) @@ -329,7 +326,6 @@ def main() -> None: del base_dataset # Clean up to avoid disk growth across an overnight sweep. - remove_dir(task_cache_dir) remove_dir(base_cache_dir) results.append( diff --git a/examples/benchmark_perf/benchmark_workers_n_length_of_stay.py b/examples/benchmark_perf/benchmark_workers_n_length_of_stay.py index b8603e01f..541144adc 100644 --- a/examples/benchmark_perf/benchmark_workers_n_length_of_stay.py +++ b/examples/benchmark_perf/benchmark_workers_n_length_of_stay.py @@ -284,11 +284,8 @@ def main() -> None: print("\n[1/1] Sweeping num_workers (each run reloads dataset + task)...") for w in args.workers: for r in range(args.repeats): - task_cache_dir = cache_root / f"task_samples_los_w{w}" - # Ensure no cache artifacts before this run. remove_dir(base_cache_dir) - ensure_empty_dir(task_cache_dir) tracker.reset() run_start = time.time() @@ -313,13 +310,13 @@ def main() -> None: sample_dataset = base_dataset.set_task( LengthOfStayPredictionMIMIC4(), num_workers=w, - cache_dir=str(task_cache_dir), ) task_process_s = time.time() - task_start total_s = time.time() - run_start peak_rss_bytes = tracker.peak_bytes() - task_cache_bytes = get_directory_size(task_cache_dir) + tasks_dir = base_cache_dir / "tasks" + task_cache_bytes = get_directory_size(tasks_dir) # Capture sample count BEFORE cleaning up the cache (litdata needs it). num_samples = len(sample_dataset) @@ -329,7 +326,6 @@ def main() -> None: del base_dataset # Clean up to avoid disk growth across an overnight sweep. - remove_dir(task_cache_dir) remove_dir(base_cache_dir) results.append( diff --git a/examples/mortality_prediction/multimodal_mimic4_demo.py b/examples/mortality_prediction/multimodal_mimic4_demo.py index c84a71c37..3c9b4cab4 100644 --- a/examples/mortality_prediction/multimodal_mimic4_demo.py +++ b/examples/mortality_prediction/multimodal_mimic4_demo.py @@ -883,7 +883,6 @@ def main(): cache_suffix += "_full" base_cache_dir = cache_root / f"base_dataset{cache_suffix}" - task_cache_dir = cache_root / f"task_samples{cache_suffix}" # Initialize memory tracker tracker = PeakMemoryTracker(poll_interval_s=0.1) @@ -971,7 +970,8 @@ def main(): if not args.skip_benchmark: print(" Calculating cache size...", flush=True) cache_calc_start = time.time() - task_cache_bytes = get_directory_size(task_cache_dir) + tasks_dir = base_cache_dir / "tasks" + task_cache_bytes = get_directory_size(tasks_dir) cache_calc_time = time.time() - cache_calc_start print(f" ✓ Task cache size: {format_size(task_cache_bytes)} " f"(calculated in {cache_calc_time:.1f}s)") From 18b82ac9cb1dea4ed94037ec1e1125f7380bd5a5 Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Mon, 16 Feb 2026 22:00:29 +0000 Subject: [PATCH 15/16] Remove unused ensure_empty_dir function --- .../benchmark_perf/benchmark_workers_n.py | 7 ------ ...benchmark_workers_n_drug_recommendation.py | 7 ------ .../benchmark_workers_n_length_of_stay.py | 7 ------ .../multimodal_mimic4_demo.py | 24 +++++++------------ 4 files changed, 8 insertions(+), 37 deletions(-) diff --git a/examples/benchmark_perf/benchmark_workers_n.py b/examples/benchmark_perf/benchmark_workers_n.py index bea46c1ca..6b53091cb 100644 --- a/examples/benchmark_perf/benchmark_workers_n.py +++ b/examples/benchmark_perf/benchmark_workers_n.py @@ -159,13 +159,6 @@ def parse_workers(value: str) -> list[int]: return workers -def ensure_empty_dir(path: str | Path) -> None: - p = Path(path) - if p.exists(): - shutil.rmtree(p) - p.mkdir(parents=True, exist_ok=True) - - def remove_dir(path: str | Path, retries: int = 3, delay: float = 1.0) -> None: """Remove a directory with retry logic for busy file handles.""" p = Path(path) diff --git a/examples/benchmark_perf/benchmark_workers_n_drug_recommendation.py b/examples/benchmark_perf/benchmark_workers_n_drug_recommendation.py index 6aa56c679..2d12f746e 100644 --- a/examples/benchmark_perf/benchmark_workers_n_drug_recommendation.py +++ b/examples/benchmark_perf/benchmark_workers_n_drug_recommendation.py @@ -159,13 +159,6 @@ def parse_workers(value: str) -> list[int]: return workers -def ensure_empty_dir(path: str | Path) -> None: - p = Path(path) - if p.exists(): - shutil.rmtree(p) - p.mkdir(parents=True, exist_ok=True) - - def remove_dir(path: str | Path, retries: int = 3, delay: float = 1.0) -> None: """Remove a directory with retry logic for busy file handles.""" p = Path(path) diff --git a/examples/benchmark_perf/benchmark_workers_n_length_of_stay.py b/examples/benchmark_perf/benchmark_workers_n_length_of_stay.py index 541144adc..79c2a6f23 100644 --- a/examples/benchmark_perf/benchmark_workers_n_length_of_stay.py +++ b/examples/benchmark_perf/benchmark_workers_n_length_of_stay.py @@ -159,13 +159,6 @@ def parse_workers(value: str) -> list[int]: return workers -def ensure_empty_dir(path: str | Path) -> None: - p = Path(path) - if p.exists(): - shutil.rmtree(p) - p.mkdir(parents=True, exist_ok=True) - - def remove_dir(path: str | Path, retries: int = 3, delay: float = 1.0) -> None: """Remove a directory with retry logic for busy file handles.""" p = Path(path) diff --git a/examples/mortality_prediction/multimodal_mimic4_demo.py b/examples/mortality_prediction/multimodal_mimic4_demo.py index 3c9b4cab4..bce91e63e 100644 --- a/examples/mortality_prediction/multimodal_mimic4_demo.py +++ b/examples/mortality_prediction/multimodal_mimic4_demo.py @@ -75,14 +75,6 @@ def get_directory_size(path: str | Path) -> int: return total -def ensure_empty_dir(path: str | Path) -> None: - """Ensure directory exists and is empty.""" - p = Path(path) - if p.exists(): - shutil.rmtree(p) - p.mkdir(parents=True, exist_ok=True) - - def remove_dir(path: str | Path, retries: int = 3, delay: float = 1.0) -> None: """Remove a directory with retry logic.""" p = Path(path) @@ -558,7 +550,7 @@ def nested_indices_to_codes(nested_indices, processor): visit_info = lookup_icd_codes(visit_codes, "ICD10CM", max_display=10) if all(info["name"] == "[Unknown code]" for info in visit_info): visit_info = lookup_icd_codes(visit_codes, "ICD9CM", max_display=10) - + codes_preview = ", ".join(info['code'] for info in visit_info) if len(visit_codes) > 10: codes_preview += f" (+{len(visit_codes) - 10} more)" @@ -578,7 +570,7 @@ def nested_indices_to_codes(nested_indices, processor): if all(info["name"] in ["[Unknown code]", "[ICD10PROC unavailable]"] for info in visit_info): visit_info = lookup_icd_codes(visit_codes, "ICD9PROC", max_display=10) - + codes_preview = ", ".join(info['code'] for info in visit_info) if len(visit_codes) > 10: codes_preview += f" (+{len(visit_codes) - 10} more)" @@ -651,7 +643,7 @@ def nested_indices_to_codes(nested_indices, processor): # - lab_times: list of floats from raw processor lab_values_tensor = sample.get("lab_values") lab_times_raw = sample.get("lab_times", []) - + labs_data = None if lab_values_tensor is not None and len(lab_times_raw) > 0: # Convert tensor to list of lists for display functions @@ -870,7 +862,7 @@ def main(): from pyhealth.tasks import MultimodalMortalityPredictionMIMIC4 cache_root = Path(args.cache_dir) - + # Create cache directory name based on configuration cache_suffix = "_dev" if args.dev else "" if args.no_notes and args.no_cxr: @@ -881,7 +873,7 @@ def main(): cache_suffix += "_ehr_notes" else: cache_suffix += "_full" - + base_cache_dir = cache_root / f"base_dataset{cache_suffix}" # Initialize memory tracker @@ -950,9 +942,9 @@ def main(): task_process_s = time.time() - task_start total_s = time.time() - run_start peak_rss_bytes = tracker.peak_bytes() - + print(f" ✓ Task completed in {task_process_s:.2f}s", flush=True) - + # Get sample count first (faster, uses cached count if available) print(" Getting sample count...", flush=True) try: @@ -965,7 +957,7 @@ def main(): except Exception as e: print(f" ✗ Error getting sample count: {e}") num_samples = 0 - + # Calculate cache size (can be slow for large directories) if not args.skip_benchmark: print(" Calculating cache size...", flush=True) From 314a648aad27907f7a509591847e0a40a2c36e08 Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Mon, 16 Feb 2026 22:12:27 +0000 Subject: [PATCH 16/16] Correct cache size calculations (the "workers_n" examples do not need to be corrected because they were already using the default task cache path) --- examples/benchmark_perf/benchmark_workers_1.py | 11 +++++------ examples/benchmark_perf/benchmark_workers_12.py | 11 +++++------ examples/benchmark_perf/benchmark_workers_4.py | 11 +++++------ 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/examples/benchmark_perf/benchmark_workers_1.py b/examples/benchmark_perf/benchmark_workers_1.py index c4fbcab5d..7549a1805 100644 --- a/examples/benchmark_perf/benchmark_workers_1.py +++ b/examples/benchmark_perf/benchmark_workers_1.py @@ -117,6 +117,7 @@ def main(): print("\n[1/2] Loading MIMIC-IV base dataset...") dataset_start = time.time() + base_cache_dir = f"{cache_root}/base_dataset" base_dataset = MIMIC4Dataset( ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", ehr_tables=[ @@ -127,7 +128,7 @@ def main(): "labevents", ], dev=dev, - cache_dir=f"{cache_root}/base_dataset", + cache_dir=base_cache_dir, ) dataset_time = time.time() - dataset_start @@ -147,12 +148,10 @@ def main(): # Measure cache sizes print("\n[3/3] Measuring cache sizes...") - base_cache_dir = f"{cache_root}/base_dataset" - task_cache_dir = f"{base_cache_dir}/tasks" - base_cache_size = get_directory_size(base_cache_dir) - task_cache_size = get_directory_size(task_cache_dir) - total_cache_size = base_cache_size + task_cache_size + total_cache_size = get_directory_size(base_cache_dir) + task_cache_size = get_directory_size(f"{base_cache_dir}/tasks") + base_cache_size = total_cache_size - task_cache_size print(f"✓ Base dataset cache: {format_size(base_cache_size)}") print(f"✓ Task samples cache: {format_size(task_cache_size)}") diff --git a/examples/benchmark_perf/benchmark_workers_12.py b/examples/benchmark_perf/benchmark_workers_12.py index 4f41796c0..207522d97 100644 --- a/examples/benchmark_perf/benchmark_workers_12.py +++ b/examples/benchmark_perf/benchmark_workers_12.py @@ -117,6 +117,7 @@ def main(): print("\n[1/2] Loading MIMIC-IV base dataset...") dataset_start = time.time() + base_cache_dir = f"{cache_root}/base_dataset" base_dataset = MIMIC4Dataset( ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", ehr_tables=[ @@ -127,7 +128,7 @@ def main(): "labevents", ], dev=dev, - cache_dir=f"{cache_root}/base_dataset", + cache_dir=base_cache_dir, ) dataset_time = time.time() - dataset_start @@ -147,12 +148,10 @@ def main(): # Measure cache sizes print("\n[3/3] Measuring cache sizes...") - base_cache_dir = f"{cache_root}/base_dataset" - task_cache_dir = f"{base_cache_dir}/tasks" - base_cache_size = get_directory_size(base_cache_dir) - task_cache_size = get_directory_size(task_cache_dir) - total_cache_size = base_cache_size + task_cache_size + total_cache_size = get_directory_size(base_cache_dir) + task_cache_size = get_directory_size(f"{base_cache_dir}/tasks") + base_cache_size = total_cache_size - task_cache_size print(f"✓ Base dataset cache: {format_size(base_cache_size)}") print(f"✓ Task samples cache: {format_size(task_cache_size)}") diff --git a/examples/benchmark_perf/benchmark_workers_4.py b/examples/benchmark_perf/benchmark_workers_4.py index 82abca3f9..bcae29835 100644 --- a/examples/benchmark_perf/benchmark_workers_4.py +++ b/examples/benchmark_perf/benchmark_workers_4.py @@ -117,6 +117,7 @@ def main(): print("\n[1/2] Loading MIMIC-IV base dataset...") dataset_start = time.time() + base_cache_dir = f"{cache_root}/base_dataset" base_dataset = MIMIC4Dataset( ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", ehr_tables=[ @@ -127,7 +128,7 @@ def main(): "labevents", ], dev=dev, - # cache_dir=f"{cache_root}/base_dataset", + cache_dir=base_cache_dir, ) dataset_time = time.time() - dataset_start @@ -147,12 +148,10 @@ def main(): # Measure cache sizes print("\n[3/3] Measuring cache sizes...") - base_cache_dir = f"{cache_root}/base_dataset" - task_cache_dir = f"{base_cache_dir}/tasks" - base_cache_size = get_directory_size(base_cache_dir) - task_cache_size = get_directory_size(task_cache_dir) - total_cache_size = base_cache_size + task_cache_size + total_cache_size = get_directory_size(base_cache_dir) + task_cache_size = get_directory_size(f"{base_cache_dir}/tasks") + base_cache_size = total_cache_size - task_cache_size print(f"✓ Base dataset cache: {format_size(base_cache_size)}") print(f"✓ Task samples cache: {format_size(task_cache_size)}")