From 944245451b9e4091b66c8b5d5304e408fd4a8051 Mon Sep 17 00:00:00 2001 From: Alexkkir Date: Tue, 24 Mar 2026 08:14:49 +0000 Subject: [PATCH] refactor: use defaultdict for _SET_ADAPTER_SCALE_FN_MAPPING --- src/diffusers/loaders/peft.py | 35 ++++++++--------------------------- 1 file changed, 8 insertions(+), 27 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index a96542c2a50c..daa078bc25d5 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -15,6 +15,7 @@ import inspect import json import os +from collections import defaultdict from functools import partial from pathlib import Path from typing import Literal @@ -44,33 +45,13 @@ logger = logging.get_logger(__name__) -_SET_ADAPTER_SCALE_FN_MAPPING = { - "UNet2DConditionModel": _maybe_expand_lora_scales, - "UNetMotionModel": _maybe_expand_lora_scales, - "SD3Transformer2DModel": lambda model_cls, weights: weights, - "FluxTransformer2DModel": lambda model_cls, weights: weights, - "CogVideoXTransformer3DModel": lambda model_cls, weights: weights, - "ConsisIDTransformer3DModel": lambda model_cls, weights: weights, - "HeliosTransformer3DModel": lambda model_cls, weights: weights, - "MochiTransformer3DModel": lambda model_cls, weights: weights, - "HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights, - "LTXVideoTransformer3DModel": lambda model_cls, weights: weights, - "SanaTransformer2DModel": lambda model_cls, weights: weights, - "AuraFlowTransformer2DModel": lambda model_cls, weights: weights, - "Lumina2Transformer2DModel": lambda model_cls, weights: weights, - "WanTransformer3DModel": lambda model_cls, weights: weights, - "CogView4Transformer2DModel": lambda model_cls, weights: weights, - "HiDreamImageTransformer2DModel": lambda model_cls, weights: weights, - "HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights, - "WanVACETransformer3DModel": lambda model_cls, weights: weights, - "ChromaTransformer2DModel": lambda model_cls, weights: weights, - "ChronoEditTransformer3DModel": lambda model_cls, weights: weights, - "QwenImageTransformer2DModel": lambda model_cls, weights: weights, - "Flux2Transformer2DModel": lambda model_cls, weights: weights, - "ZImageTransformer2DModel": lambda model_cls, weights: weights, - "LTX2VideoTransformer3DModel": lambda model_cls, weights: weights, - "LTX2TextConnectors": lambda model_cls, weights: weights, -} +_SET_ADAPTER_SCALE_FN_MAPPING = defaultdict( + lambda: (lambda model_cls, weights: weights), + { + "UNet2DConditionModel": _maybe_expand_lora_scales, + "UNetMotionModel": _maybe_expand_lora_scales, + }, +) class PeftAdapterMixin: