diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index c7bb2de4437a..c6d022821843 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -33,6 +33,7 @@ convert_chroma_transformer_checkpoint_to_diffusers, convert_controlnet_checkpoint, convert_cosmos_transformer_checkpoint_to_diffusers, + convert_ernie_image_transformer_checkpoint_to_diffusers, convert_flux2_transformer_checkpoint_to_diffusers, convert_flux_transformer_checkpoint_to_diffusers, convert_hidream_transformer_to_diffusers, @@ -114,6 +115,10 @@ "checkpoint_mapping_fn": convert_chroma_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, + "ErnieImageTransformer2DModel": { + "checkpoint_mapping_fn": convert_ernie_image_transformer_checkpoint_to_diffusers, + "default_subfolder": "transformer", + }, "LTXVideoTransformer3DModel": { "checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 98b9e8266506..acadb96120c3 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -4162,3 +4162,17 @@ def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None: update_state_dict_inplace(converted_state_dict, key, new_key) return converted_state_dict + + +def convert_ernie_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + keys = list(checkpoint.keys()) + + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + for k in list(checkpoint.keys()): + converted_state_dict[k] = checkpoint.pop(k) + + return converted_state_dict diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index 473fc1039dc8..9a3e3ab12ea1 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -25,7 +25,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import BaseOutput, logging from ..attention import AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn @@ -289,7 +289,7 @@ def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: return x -class ErnieImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): +class ErnieImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): _supports_gradient_checkpointing = True _repeated_blocks = ["ErnieImageSharedAdaLNBlock"]