diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index d695f5e7284d..90cbffc5b69d 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -69,6 +69,19 @@ TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils" CONNECTED_PIPES_KEYS = ["prior"] +# Auxiliary (non-weight) files a transformers component saves next to its weights. Repos with a flat, +# transformers-style layout host a component's files at the repo root instead of in a subfolder, where the +# folder-based allow patterns of `DiffusionPipeline.download` would miss them. Root-hosted weights and +# `config.json` are matched by their own patterns, so only these auxiliary filenames need listing. +# Currently the set needed by DiffusionGemma — extend as new flat-layout pipelines require it. +TRANSFORMERS_COMPONENT_AUX_FILES = [ + "chat_template.jinja", + "generation_config.json", + "processor_config.json", + "tokenizer.json", + "tokenizer_config.json", +] + logger = logging.get_logger(__name__) LOADABLE_CLASSES = { @@ -136,6 +149,8 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No ) passed_components = passed_components or [] + # only weight files matter for safetensors compatibility + filenames = filter_model_files(filenames) if folder_names: filenames = {f for f in filenames if os.path.split(f)[0] in folder_names} diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 1fa4db90d995..a3ef2260751f 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -82,6 +82,7 @@ CONNECTED_PIPES_KEYS, CUSTOM_PIPELINE_FILE_NAME, LOADABLE_CLASSES, + TRANSFORMERS_COMPONENT_AUX_FILES, _download_dduf_file, _fetch_class_library_tuple, _get_custom_components_and_folders, @@ -1751,6 +1752,11 @@ def download(cls, pretrained_model_name, **kwargs) -> str | os.PathLike: p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components) ] + # Repos with a flat, transformers-style layout host a component's files at the repo root instead of + # in a subfolder, where the folder-based allow patterns above miss its auxiliary files (root-hosted + # weights are already included via `model_filenames`, root `config.json` via `CONFIG_NAME`). + allow_patterns += TRANSFORMERS_COMPONENT_AUX_FILES + if pipeline_class._load_connected_pipes: allow_patterns.append("README.md") diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index 6d9e68197976..e0c73b96b8f9 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -218,6 +218,44 @@ def test_diffusers_is_compatible_no_components_only_variants(self): ] self.assertFalse(is_safetensors_compatible(filenames)) + def test_diffusers_is_compatible_no_components_safetensors(self): + filenames = [ + "diffusion_pytorch_model.safetensors", + ] + self.assertTrue(is_safetensors_compatible(filenames)) + + def test_diffusers_is_compatible_no_components_safetensors_only_variants(self): + filenames = [ + "diffusion_pytorch_model.fp16.safetensors", + ] + self.assertTrue(is_safetensors_compatible(filenames, variant="fp16")) + + def test_transformers_is_compatible_weightless_subfolder(self): + # transformers-style flat layout: transformers-named weights at the root + a weight-less subfolder + filenames = [ + "model.safetensors", + "scheduler/scheduler_config.json", + ] + self.assertTrue(is_safetensors_compatible(filenames)) + + def test_transformers_is_not_compatible_weightless_subfolder(self): + # same flat layout but only .bin weights at the root -> not safetensors compatible + filenames = [ + "pytorch_model.bin", + "scheduler/scheduler_config.json", + ] + self.assertFalse(is_safetensors_compatible(filenames)) + + def test_transformers_is_compatible_sharded_root_weights(self): + # sharded transformers-style weights at the repo root (e.g. DiffusionGemma's model-00001-of-00011.safetensors) + filenames = [ + "model-00001-of-00002.safetensors", + "model-00002-of-00002.safetensors", + "model.safetensors.index.json", + "scheduler/scheduler_config.json", + ] + self.assertTrue(is_safetensors_compatible(filenames)) + def test_is_compatible_mixed_variants(self): filenames = [ "unet/diffusion_pytorch_model.fp16.safetensors", diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 1df2cfa569e7..5092a5f86b3b 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -2083,6 +2083,28 @@ def test_smart_download(self): # is not downloaded, but all the expected ones assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy")) + def test_download_flat_transformers_style_repo(self): + # Repos with a flat, transformers-style layout host a component's files at the repo root instead of in a + # subfolder (here `model` and `processor`; only `scheduler/` has a folder). The download patterns must + # pick up the transformers auxiliary files at the root, while unrelated root files are still skipped. + model_id = "hf-internal-testing/tiny-flat-transformers-style-pipe" + with tempfile.TemporaryDirectory() as tmpdirname: + snapshot_dir = DiffusionPipeline.download(model_id, cache_dir=tmpdirname, force_download=True) + + assert os.path.isfile(os.path.join(snapshot_dir, "model.safetensors")) + assert os.path.isfile(os.path.join(snapshot_dir, CONFIG_NAME)) + for aux_file in [ + "tokenizer.json", + "tokenizer_config.json", + "processor_config.json", + "chat_template.jinja", + "generation_config.json", + ]: + assert os.path.isfile(os.path.join(snapshot_dir, aux_file)) + assert os.path.isfile(os.path.join(snapshot_dir, "scheduler", SCHEDULER_CONFIG_NAME)) + # unrelated root files are still not downloaded + assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy")) + def test_warning_unused_kwargs(self): model_id = "hf-internal-testing/unet-pipeline-dummy" logger = logging.get_logger("diffusers.pipelines")