Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,19 @@
TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils"
CONNECTED_PIPES_KEYS = ["prior"]

# Auxiliary (non-weight) files a transformers component saves next to its weights. Repos with a flat,
# transformers-style layout host a component's files at the repo root instead of in a subfolder, where the
# folder-based allow patterns of `DiffusionPipeline.download` would miss them. Root-hosted weights and
# `config.json` are matched by their own patterns, so only these auxiliary filenames need listing.
# Currently the set needed by DiffusionGemma — extend as new flat-layout pipelines require it.
TRANSFORMERS_COMPONENT_AUX_FILES = [
"chat_template.jinja",
"generation_config.json",
"processor_config.json",
"tokenizer.json",
"tokenizer_config.json",
]

logger = logging.get_logger(__name__)

LOADABLE_CLASSES = {
Expand Down Expand Up @@ -136,6 +149,8 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
)

passed_components = passed_components or []
# only weight files matter for safetensors compatibility
filenames = filter_model_files(filenames)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it's a bug here but let me know cc @DN6

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍🏽

if folder_names:
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}

Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
CONNECTED_PIPES_KEYS,
CUSTOM_PIPELINE_FILE_NAME,
LOADABLE_CLASSES,
TRANSFORMERS_COMPONENT_AUX_FILES,
_download_dduf_file,
_fetch_class_library_tuple,
_get_custom_components_and_folders,
Expand Down Expand Up @@ -1751,6 +1752,11 @@ def download(cls, pretrained_model_name, **kwargs) -> str | os.PathLike:
p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components)
]

# Repos with a flat, transformers-style layout host a component's files at the repo root instead of
# in a subfolder, where the folder-based allow patterns above miss its auxiliary files (root-hosted
# weights are already included via `model_filenames`, root `config.json` via `CONFIG_NAME`).
allow_patterns += TRANSFORMERS_COMPONENT_AUX_FILES

if pipeline_class._load_connected_pipes:
allow_patterns.append("README.md")

Expand Down
38 changes: 38 additions & 0 deletions tests/pipelines/test_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,44 @@ def test_diffusers_is_compatible_no_components_only_variants(self):
]
self.assertFalse(is_safetensors_compatible(filenames))

def test_diffusers_is_compatible_no_components_safetensors(self):
filenames = [
"diffusion_pytorch_model.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))

def test_diffusers_is_compatible_no_components_safetensors_only_variants(self):
filenames = [
"diffusion_pytorch_model.fp16.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))

def test_transformers_is_compatible_weightless_subfolder(self):
# transformers-style flat layout: transformers-named weights at the root + a weight-less subfolder
filenames = [
"model.safetensors",
"scheduler/scheduler_config.json",
]
self.assertTrue(is_safetensors_compatible(filenames))

def test_transformers_is_not_compatible_weightless_subfolder(self):
# same flat layout but only .bin weights at the root -> not safetensors compatible
filenames = [
"pytorch_model.bin",
"scheduler/scheduler_config.json",
]
self.assertFalse(is_safetensors_compatible(filenames))

def test_transformers_is_compatible_sharded_root_weights(self):
# sharded transformers-style weights at the repo root (e.g. DiffusionGemma's model-00001-of-00011.safetensors)
filenames = [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
"model.safetensors.index.json",
"scheduler/scheduler_config.json",
]
self.assertTrue(is_safetensors_compatible(filenames))

def test_is_compatible_mixed_variants(self):
filenames = [
"unet/diffusion_pytorch_model.fp16.safetensors",
Expand Down
22 changes: 22 additions & 0 deletions tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2083,6 +2083,28 @@ def test_smart_download(self):
# is not downloaded, but all the expected ones
assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy"))

def test_download_flat_transformers_style_repo(self):
# Repos with a flat, transformers-style layout host a component's files at the repo root instead of in a
# subfolder (here `model` and `processor`; only `scheduler/` has a folder). The download patterns must
# pick up the transformers auxiliary files at the root, while unrelated root files are still skipped.
model_id = "hf-internal-testing/tiny-flat-transformers-style-pipe"
with tempfile.TemporaryDirectory() as tmpdirname:
snapshot_dir = DiffusionPipeline.download(model_id, cache_dir=tmpdirname, force_download=True)

assert os.path.isfile(os.path.join(snapshot_dir, "model.safetensors"))
assert os.path.isfile(os.path.join(snapshot_dir, CONFIG_NAME))
for aux_file in [
"tokenizer.json",
"tokenizer_config.json",
"processor_config.json",
"chat_template.jinja",
"generation_config.json",
]:
assert os.path.isfile(os.path.join(snapshot_dir, aux_file))
assert os.path.isfile(os.path.join(snapshot_dir, "scheduler", SCHEDULER_CONFIG_NAME))
# unrelated root files are still not downloaded
assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy"))

def test_warning_unused_kwargs(self):
model_id = "hf-internal-testing/unet-pipeline-dummy"
logger = logging.get_logger("diffusers.pipelines")
Expand Down
Loading