From 8cb0b7bdfe0e45222950611af184d17fa1b0f03d Mon Sep 17 00:00:00 2001 From: CalamitousFelicitousness Date: Tue, 24 Mar 2026 00:16:37 +0000 Subject: [PATCH 01/11] Add Flux2 LoKR adapter support with dual conversion paths MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Custom lossless path: BFL LoKR keys → peft LoKrConfig (fuse-first QKV) - Generic lossy path: optional SVD conversion via peft.convert_to_lora - Fix alpha handling for lora_down/lora_up format checkpoints - Re-fuse LoRA keys when model QKV is fused from prior LoKR load --- .../loaders/lora_conversion_utils.py | 115 +++++++++++++++++- src/diffusers/loaders/lora_pipeline.py | 31 ++++- src/diffusers/loaders/peft.py | 111 +++++++++-------- src/diffusers/utils/peft_utils.py | 115 ++++++++++++++++++ 4 files changed, 314 insertions(+), 58 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 41948d205c89..b9c4501a9813 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -2331,6 +2331,18 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict): temp_state_dict[new_key] = v original_state_dict = temp_state_dict + # Bake alpha/rank scaling into lora_A weights so .alpha keys are consumed. + # Matches the pattern used by _convert_kohya_flux_lora_to_diffusers for Flux1. + alpha_keys = [k for k in original_state_dict if k.endswith(".alpha")] + for alpha_key in alpha_keys: + alpha = original_state_dict.pop(alpha_key).item() + module_path = alpha_key[: -len(".alpha")] + lora_a_key = f"{module_path}.lora_A.weight" + if lora_a_key in original_state_dict: + rank = original_state_dict[lora_a_key].shape[0] + scale = alpha / rank + original_state_dict[lora_a_key] = original_state_dict[lora_a_key] * scale + num_double_layers = 0 num_single_layers = 0 for key in original_state_dict.keys(): @@ -2628,6 +2640,105 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): return ait_sd +def _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict): + """Convert non-diffusers Flux2 LoKR state dict (kohya/LyCORIS format) to peft-compatible diffusers format. + + Uses fuse-first QKV mapping: BFL fused `img_attn.qkv` maps to diffusers `attn.to_qkv` (created by + `fuse_projections()`), avoiding lossy Kronecker factor splitting. The caller must fuse the model's + QKV projections before injecting the adapter. + """ + converted_state_dict = {} + + prefix = "diffusion_model." + original_state_dict = {k[len(prefix) :] if k.startswith(prefix) else k: v for k, v in state_dict.items()} + + num_double_layers = 0 + num_single_layers = 0 + for key in original_state_dict: + if key.startswith("single_blocks."): + num_single_layers = max(num_single_layers, int(key.split(".")[1]) + 1) + elif key.startswith("double_blocks."): + num_double_layers = max(num_double_layers, int(key.split(".")[1]) + 1) + + lokr_suffixes = ("lokr_w1", "lokr_w1_a", "lokr_w1_b", "lokr_w2", "lokr_w2_a", "lokr_w2_b", "lokr_t2") + + def _remap_lokr_module(bfl_path, diff_path): + """Pop all lokr keys for a BFL module, bake alpha scaling, and store under diffusers path.""" + alpha_key = f"{bfl_path}.alpha" + alpha = original_state_dict.pop(alpha_key).item() if alpha_key in original_state_dict else None + + for suffix in lokr_suffixes: + src_key = f"{bfl_path}.{suffix}" + if src_key not in original_state_dict: + continue + + weight = original_state_dict.pop(src_key) + + # Bake alpha/rank scaling into the first w1 tensor encountered for this module. + # After baking, peft's config uses alpha=r so its runtime scaling is 1.0. + if alpha is not None and suffix in ("lokr_w1", "lokr_w1_a"): + w2a_key = f"{bfl_path}.lokr_w2_a" + w1a_key = f"{bfl_path}.lokr_w1_a" + if w2a_key in original_state_dict: + r_eff = original_state_dict[w2a_key].shape[1] + elif w1a_key in original_state_dict: + r_eff = original_state_dict[w1a_key].shape[1] + else: + r_eff = alpha + scale = alpha / r_eff + weight = weight * scale + alpha = None # only bake once per module + + converted_state_dict[f"{diff_path}.{suffix}"] = weight + + # --- Single blocks --- + for sl in range(num_single_layers): + _remap_lokr_module(f"single_blocks.{sl}.linear1", f"single_transformer_blocks.{sl}.attn.to_qkv_mlp_proj") + _remap_lokr_module(f"single_blocks.{sl}.linear2", f"single_transformer_blocks.{sl}.attn.to_out") + + # --- Double blocks --- + for dl in range(num_double_layers): + tb = f"transformer_blocks.{dl}" + db = f"double_blocks.{dl}" + + # QKV -> fused to_qkv / to_added_qkv (model must be fused before injection) + _remap_lokr_module(f"{db}.img_attn.qkv", f"{tb}.attn.to_qkv") + _remap_lokr_module(f"{db}.txt_attn.qkv", f"{tb}.attn.to_added_qkv") + + # Projections + _remap_lokr_module(f"{db}.img_attn.proj", f"{tb}.attn.to_out.0") + _remap_lokr_module(f"{db}.txt_attn.proj", f"{tb}.attn.to_add_out") + + # MLPs + _remap_lokr_module(f"{db}.img_mlp.0", f"{tb}.ff.linear_in") + _remap_lokr_module(f"{db}.img_mlp.2", f"{tb}.ff.linear_out") + _remap_lokr_module(f"{db}.txt_mlp.0", f"{tb}.ff_context.linear_in") + _remap_lokr_module(f"{db}.txt_mlp.2", f"{tb}.ff_context.linear_out") + + # --- Extra mappings (embedders, modulation, final layer) --- + extra_mappings = { + "img_in": "x_embedder", + "txt_in": "context_embedder", + "time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1", + "time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2", + "final_layer.linear": "proj_out", + "final_layer.adaLN_modulation.1": "norm_out.linear", + "single_stream_modulation.lin": "single_stream_modulation.linear", + "double_stream_modulation_img.lin": "double_stream_modulation_img.linear", + "double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear", + } + for bfl_key, diff_key in extra_mappings.items(): + _remap_lokr_module(bfl_key, diff_key) + + if len(original_state_dict) > 0: + raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.") + + for key in list(converted_state_dict.keys()): + converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) + + return converted_state_dict + + def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict): """ Convert non-diffusers ZImage LoRA state dict to diffusers format. @@ -2785,14 +2896,14 @@ def get_alpha_scales(down_weight, alpha_key): base = k[: -len(lora_dot_down_key)] - # Skip combined "qkv" projection — individual to.q/k/v keys are also present. + # Skip combined "qkv" projection - individual to.q/k/v keys are also present. if base.endswith(".qkv"): state_dict.pop(k) state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None) state_dict.pop(base + ".alpha", None) continue - # Skip bare "out.lora.*" — "to_out.0.lora.*" covers the same projection. + # Skip bare "out.lora.*" - "to_out.0.lora.*" covers the same projection. if re.search(r"\.out$", base) and ".to_out" not in base: state_dict.pop(k) state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 6ec23389ac08..6c8bba726ff1 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -46,6 +46,7 @@ _convert_kohya_flux2_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers, _convert_musubi_wan_lora_to_diffusers, + _convert_non_diffusers_flux2_lokr_to_diffusers, _convert_non_diffusers_flux2_lora_to_diffusers, _convert_non_diffusers_hidream_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, @@ -57,6 +58,7 @@ _convert_non_diffusers_z_image_lora_to_diffusers, _convert_xlabs_flux_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers, + _refuse_flux2_lora_state_dict, ) @@ -5687,12 +5689,18 @@ def lora_state_dict( is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict) if is_ai_toolkit: - state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict) + is_lokr = any("lokr_" in k for k in state_dict) + if is_lokr: + state_dict = _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict) + if metadata is None: + metadata = {} + metadata["is_lokr"] = "true" + else: + state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict) out = (state_dict, metadata) if return_lora_metadata else state_dict return out - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], @@ -5720,13 +5728,26 @@ def load_lora_weights( kwargs["return_lora_metadata"] = True state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - is_correct_format = all("lora" in key for key in state_dict.keys()) + is_correct_format = all("lora" in key or "lokr" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") + raise ValueError("Invalid LoRA/LoKR checkpoint. Make sure all param names contain `'lora'` or `'lokr'`.") + + # For LoKR adapters, fuse QKV projections so peft can target the fused modules directly. + is_lokr = metadata is not None and metadata.get("is_lokr") == "true" + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + if is_lokr: + transformer.fuse_qkv_projections() + elif ( + hasattr(transformer, "transformer_blocks") + and len(transformer.transformer_blocks) > 0 + and getattr(transformer.transformer_blocks[0].attn, "fused_projections", False) + ): + # Model QKV is fused but LoRA targets separate Q/K/V - re-fuse the keys to match. + state_dict = _refuse_flux2_lora_state_dict(state_dict) self.load_lora_into_transformer( state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + transformer=transformer, adapter_name=adapter_name, metadata=metadata, _pipeline=self, diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index daa078bc25d5..1ec304f24944 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -38,7 +38,7 @@ set_adapter_layers, set_weights_and_activate_adapters, ) -from ..utils.peft_utils import _create_lora_config, _maybe_warn_for_unhandled_keys +from ..utils.peft_utils import _create_lokr_config, _create_lora_config, _maybe_warn_for_unhandled_keys from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading from .unet_loader_utils import _maybe_expand_lora_scales @@ -46,7 +46,7 @@ logger = logging.get_logger(__name__) _SET_ADAPTER_SCALE_FN_MAPPING = defaultdict( - lambda: (lambda model_cls, weights: weights), + lambda: lambda model_cls, weights: weights, { "UNet2DConditionModel": _maybe_expand_lora_scales, "UNetMotionModel": _maybe_expand_lora_scales, @@ -213,56 +213,65 @@ def load_lora_adapter( "Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping." ) - # check with first key if is not in peft format - first_key = next(iter(state_dict.keys())) - if "lora_A" not in first_key: - state_dict = convert_unet_state_dict_to_peft(state_dict) - - # Control LoRA from SAI is different from BFL Control LoRA - # https://huggingface.co/stabilityai/control-lora - # https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors - is_sai_sd_control_lora = "lora_controlnet" in state_dict - if is_sai_sd_control_lora: - state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict) - - rank = {} - for key, val in state_dict.items(): - # Cannot figure out rank from lora layers that don't have at least 2 dimensions. - # Bias layers in LoRA only have a single dimension - if "lora_B" in key and val.ndim > 1: - # Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol. - # We may run into some ambiguous configuration values when a model has module - # names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`, - # for example) and they have different LoRA ranks. - rank[f"^{key}"] = val.shape[1] - - if network_alphas is not None and len(network_alphas) >= 1: - alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] - network_alphas = { - k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys - } - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(self) - - # create LoraConfig - lora_config = _create_lora_config( - state_dict, - network_alphas, - metadata, - rank, - model_state_dict=self.state_dict(), - adapter_name=adapter_name, - ) + # Detect whether this is a LoKR adapter (Kronecker product, not low-rank) + is_lokr = any("lokr_" in k for k in state_dict) + + if is_lokr: + if adapter_name is None: + adapter_name = get_adapter_name(self) + lora_config = _create_lokr_config(state_dict) + is_sai_sd_control_lora = False + else: + # check with first key if is not in peft format + first_key = next(iter(state_dict.keys())) + if "lora_A" not in first_key: + state_dict = convert_unet_state_dict_to_peft(state_dict) + + # Control LoRA from SAI is different from BFL Control LoRA + # https://huggingface.co/stabilityai/control-lora + # https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors + is_sai_sd_control_lora = "lora_controlnet" in state_dict + if is_sai_sd_control_lora: + state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict) + + rank = {} + for key, val in state_dict.items(): + # Cannot figure out rank from lora layers that don't have at least 2 dimensions. + # Bias layers in LoRA only have a single dimension + if "lora_B" in key and val.ndim > 1: + # Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol. + # We may run into some ambiguous configuration values when a model has module + # names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`, + # for example) and they have different LoRA ranks. + rank[f"^{key}"] = val.shape[1] + + if network_alphas is not None and len(network_alphas) >= 1: + alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] + network_alphas = { + k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys + } + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(self) + + # create LoraConfig + lora_config = _create_lora_config( + state_dict, + network_alphas, + metadata, + rank, + model_state_dict=self.state_dict(), + adapter_name=adapter_name, + ) - # Adjust LoRA config for Control LoRA - if is_sai_sd_control_lora: - lora_config.lora_alpha = lora_config.r - lora_config.alpha_pattern = lora_config.rank_pattern - lora_config.bias = "all" - lora_config.modules_to_save = lora_config.exclude_modules - lora_config.exclude_modules = None + # Adjust LoRA config for Control LoRA + if is_sai_sd_control_lora: + lora_config.lora_alpha = lora_config.r + lora_config.alpha_pattern = lora_config.rank_pattern + lora_config.bias = "all" + lora_config.modules_to_save = lora_config.exclude_modules + lora_config.exclude_modules = None # None: ) +def _create_lokr_config(state_dict): + """Create a peft LoKrConfig from a converted LoKR state dict. + + Infers rank, decompose_both, decompose_factor, and target_modules from the state dict key names + and tensor shapes. Alpha scaling is assumed to be already baked into the weights, so config + alpha = r (scaling = 1.0). + + Peft determines w2 decomposition via ``r < max(out_k, in_n) / 2``. We must set per-module rank + values that reproduce the same decomposition pattern as the checkpoint. For modules with full + (non-decomposed) lokr_w2, we set rank = max(lokr_w2.shape) so that peft also creates a full w2. + """ + from peft import LoKrConfig + + # Infer decompose_both from presence of lokr_w1_a keys + decompose_both = any("lokr_w1_a" in k for k in state_dict) + + # Infer decompose_factor from lokr_w1 shapes. + # With a fixed factor (e.g., 4), all w1 shapes are (factor, factor). + # With factor=-1 (near-sqrt), w1 shapes vary per module based on dimension. + w1_shapes = set() + for key, val in state_dict.items(): + if "lokr_w1" in key and "lokr_w1_a" not in key and "lokr_w1_b" not in key and val.ndim == 2: + w1_shapes.add(val.shape[0]) + if len(w1_shapes) == 1: + # All w1 have the same first dimension - this is the decompose_factor + decompose_factor = w1_shapes.pop() + else: + # Shapes vary - near-sqrt factorization was used + decompose_factor = -1 + + # Extract target modules and their decomposition state + lokr_suffixes = {"lokr_w1", "lokr_w1_a", "lokr_w1_b", "lokr_w2", "lokr_w2_a", "lokr_w2_b", "lokr_t2"} + target_modules = set() + for key in state_dict: + for suffix in lokr_suffixes: + if f".{suffix}" in key: + target_modules.add(key.split(f".{suffix}")[0]) + break + + # Build per-module rank dict that ensures peft creates matching decomposition + rank_dict = {} + for key, val in state_dict.items(): + if "lokr_w2_a" in key and val.ndim > 1: + # Decomposed w2: rank = inner dimension of w2_a + module_name = key.split(".lokr_w2_a")[0] + rank_dict[module_name] = val.shape[1] + elif "lokr_w2" in key and "lokr_w2_a" not in key and "lokr_w2_b" not in key and val.ndim > 1: + # Full w2 matrix: set rank high enough so peft also creates full w2. + # Peft uses full w2 when r >= max(out_k, in_n) / 2, where (out_k, in_n) = lokr_w2.shape. + module_name = key.split(".lokr_w2")[0] + if module_name not in rank_dict: + rank_dict[module_name] = max(val.shape) + + # Also extract rank from w1_a if w2 info is missing + for key, val in state_dict.items(): + if "lokr_w1_a" in key and val.ndim > 1: + module_name = key.split(".lokr_w1_a")[0] + if module_name not in rank_dict: + rank_dict[module_name] = val.shape[1] + + # Determine default rank (most common) and per-module rank pattern + if rank_dict: + import collections + + r = collections.Counter(rank_dict.values()).most_common()[0][0] + rank_pattern = {k: v for k, v in rank_dict.items() if v != r} + else: + r = 1 + rank_pattern = {} + + lokr_config_kwargs = { + "r": r, + "alpha": r, # alpha baked into weights, so runtime scaling = alpha/r = 1.0 + "target_modules": list(target_modules), + "rank_pattern": rank_pattern, + "alpha_pattern": dict(rank_pattern), # keep alpha=r per module + "decompose_both": decompose_both, + "decompose_factor": decompose_factor, + } + + try: + return LoKrConfig(**lokr_config_kwargs) + except TypeError as e: + raise TypeError("`LoKrConfig` class could not be instantiated.") from e + + +def _convert_adapter_to_lora(model, rank, adapter_name="default"): + """Convert a loaded non-LoRA peft adapter (e.g., LoKR) to LoRA via truncated SVD. + + Wraps ``peft.convert_to_lora`` which materializes each adapter layer's delta weight + and decomposes it as ``U @ diag(S) @ V ≈ lora_B @ lora_A``. The conversion is lossy: + higher ``rank`` preserves more fidelity at the cost of larger LoRA matrices. + + Args: + model: ``nn.Module`` with a peft adapter already injected. + rank: ``int`` for a fixed LoRA rank, or ``float`` in (0, 1] as an energy threshold + (picks the smallest rank capturing that fraction of singular value energy). + adapter_name: Name of the adapter to convert. + + Returns: + Tuple of ``(LoraConfig, state_dict)`` for the converted LoRA adapter. + + Raises: + ImportError: If peft does not provide ``convert_to_lora`` (requires peft >= 0.19.0). + """ + try: + from peft import convert_to_lora + except ImportError: + raise ImportError( + "`peft.convert_to_lora` is required for lossy LoKR-to-LoRA conversion. " + "Install peft >= 0.19.0 or from source: pip install git+https://github.com/huggingface/peft.git" + ) + return convert_to_lora(model, rank, adapter_name=adapter_name) + + def _create_lora_config( state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None ): From b5958e633eec8ee784e57f9b409280f440085a0e Mon Sep 17 00:00:00 2001 From: CalamitousFelicitousness Date: Thu, 26 Mar 2026 20:57:08 +0000 Subject: [PATCH 02/11] Add comprehensive Flux2 LoKR adapter support with dual conversion paths - BFL format: remap keys + split fused QKV via Kronecker re-factorization (Van Loan) - LyCORIS format: decode underscore-encoded paths to diffusers module names - Diffusers native format: add transformer. prefix and bake alpha - Generic lossy path: _convert_adapter_to_lora utility wrapping peft.convert_to_lora - Fix alpha handling for lora_down/lora_up format checkpoints --- benchmark_lokr.py | 159 ++++++++++++ .../loaders/lora_conversion_utils.py | 231 +++++++++++++++++- src/diffusers/loaders/lora_pipeline.py | 33 ++- 3 files changed, 391 insertions(+), 32 deletions(-) create mode 100644 benchmark_lokr.py diff --git a/benchmark_lokr.py b/benchmark_lokr.py new file mode 100644 index 000000000000..bbde88a3378b --- /dev/null +++ b/benchmark_lokr.py @@ -0,0 +1,159 @@ +"""Benchmark: Lossless LoKR vs Lossy LoRA-via-SVD on Flux2 Klein 9B. + +Generates images using both conversion paths for visual comparison. +Uses bf16 with CPU offload. + +Usage: + python benchmark_lokr.py + python benchmark_lokr.py --lokr-path "puttmorbidly233/lora" --lokr-name "klein_snofs_v1_2.safetensors" + python benchmark_lokr.py --prompt "a portrait in besch art style" --ranks 32 64 128 +""" + +import argparse +import gc +import os +import time + +import torch +from diffusers import Flux2KleinPipeline +from peft import convert_to_lora + +MODEL_ID = "black-forest-labs/FLUX.2-klein-9B" +DEFAULT_LOKR_PATH = "gattaplayer/besch-flux2-klein-9b-lokr-lion-3e-6-bs2-ga2-v02" +OUTPUT_DIR = "benchmark_output" + + +def load_pipeline(): + """Load Flux2 Klein 9B in bf16 with model CPU offload.""" + pipe = Flux2KleinPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16) + pipe.enable_model_cpu_offload() + return pipe + + +def generate(pipe, prompt, seed, num_steps=4, guidance_scale=1.0): + """Generate a single image with fixed seed for reproducibility.""" + generator = torch.Generator(device="cpu").manual_seed(seed) + image = pipe( + prompt=prompt, + num_inference_steps=num_steps, + guidance_scale=guidance_scale, + generator=generator, + height=1024, + width=1024, + ).images[0] + return image + + +def benchmark_lossless(pipe, prompt, seed, lokr_path, lokr_name): + """Path A: Load LoKR natively (lossless).""" + print("\n=== Path A: Lossless LoKR ===") + t0 = time.time() + kwargs = {"weight_name": lokr_name} if lokr_name else {} + pipe.load_lora_weights(lokr_path, **kwargs) + print(f" Loaded in {time.time() - t0:.1f}s") + + t0 = time.time() + image = generate(pipe, prompt, seed) + print(f" Generated in {time.time() - t0:.1f}s") + + pipe.unload_lora_weights() + return image + + +def benchmark_lossy(pipe, prompt, seed, rank, lokr_path, lokr_name): + """Path B: Load LoKR, convert to LoRA via SVD (lossy).""" + print(f"\n=== Path B: Lossy LoRA via SVD (rank={rank}) ===") + t0 = time.time() + kwargs = {"weight_name": lokr_name} if lokr_name else {} + pipe.load_lora_weights(lokr_path, **kwargs) + load_time = time.time() - t0 + + # Detect the actual adapter name assigned by peft + adapter_name = next(iter(pipe.transformer.peft_config.keys())) + print(f" Adapter name: {adapter_name}") + + t0 = time.time() + lora_config, lora_sd = convert_to_lora(pipe.transformer, rank, adapter_name=adapter_name, progressbar=True) + convert_time = time.time() - t0 + print(f" Loaded LoKR in {load_time:.1f}s, converted to LoRA in {convert_time:.1f}s") + + # Replace LoKR adapter with converted LoRA + from peft import inject_adapter_in_model, set_peft_model_state_dict + + pipe.transformer.delete_adapters(adapter_name) + inject_adapter_in_model(lora_config, pipe.transformer, adapter_name=adapter_name) + set_peft_model_state_dict(pipe.transformer, lora_sd, adapter_name=adapter_name) + + t0 = time.time() + image = generate(pipe, prompt, seed) + print(f" Generated in {time.time() - t0:.1f}s") + + pipe.unload_lora_weights() + return image + + +def benchmark_baseline(pipe, prompt, seed): + """Baseline: No adapter.""" + print("\n=== Baseline: No adapter ===") + t0 = time.time() + image = generate(pipe, prompt, seed) + print(f" Generated in {time.time() - t0:.1f}s") + return image + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark LoKR vs LoRA-via-SVD") + parser.add_argument("--prompt", default="a portrait painting in besch art style") + parser.add_argument("--lokr-path", default=DEFAULT_LOKR_PATH, help="HF repo or local path to LoKR checkpoint") + parser.add_argument("--lokr-name", default=None, help="Filename within HF repo (if multi-file)") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--ranks", type=int, nargs="+", default=[32, 64, 128]) + parser.add_argument("--skip-baseline", action="store_true") + parser.add_argument("--skip-lossy", action="store_true") + args = parser.parse_args() + + os.makedirs(OUTPUT_DIR, exist_ok=True) + + print(f"Model: {MODEL_ID}") + print(f"LoKR: {args.lokr_path}" + (f" ({args.lokr_name})" if args.lokr_name else "")) + print(f"Prompt: {args.prompt}") + print(f"Seed: {args.seed}") + if not args.skip_lossy: + print(f"SVD ranks to test: {args.ranks}") + + print("\nLoading pipeline (bf16, model CPU offload)...") + pipe = load_pipeline() + + # Baseline + if not args.skip_baseline: + img = benchmark_baseline(pipe, args.prompt, args.seed) + path = os.path.join(OUTPUT_DIR, "baseline.png") + img.save(path) + print(f" Saved: {path}") + + # Path A: Lossless LoKR + img = benchmark_lossless(pipe, args.prompt, args.seed, args.lokr_path, args.lokr_name) + path = os.path.join(OUTPUT_DIR, "lokr_lossless.png") + img.save(path) + print(f" Saved: {path}") + + gc.collect() + torch.cuda.empty_cache() + + # Path B: Lossy LoRA via SVD at various ranks + if not args.skip_lossy: + for rank in args.ranks: + img = benchmark_lossy(pipe, args.prompt, args.seed, rank, args.lokr_path, args.lokr_name) + path = os.path.join(OUTPUT_DIR, f"lora_svd_rank{rank}.png") + img.save(path) + print(f" Saved: {path}") + + gc.collect() + torch.cuda.empty_cache() + + print(f"\nAll results saved to {OUTPUT_DIR}/") + print("Compare: baseline.png vs lokr_lossless.png vs lora_svd_rank*.png") + + +if __name__ == "__main__": + main() diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index b9c4501a9813..8c96ce565b88 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -2640,12 +2640,59 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): return ait_sd +def _nearest_kronecker_product(matrix, m1, n1, m2, n2): + """Find the nearest rank-1 Kronecker product approximation (Van Loan & Pitsianis). + + Given matrix M of shape (m1*m2, n1*n2), finds w1 (m1, n1) and w2 (m2, n2) + minimizing ||M - kron(w1, w2)||_F via rank-1 SVD of a rearranged matrix. + """ + # Rearrange M into R of shape (m1*n1, m2*n2) + # R[i*n1+j, k*n2+l] = M[i*m2+k, j*n2+l] + R = matrix.reshape(m1, m2, n1, n2).permute(0, 2, 1, 3).reshape(m1 * n1, m2 * n2) + # Rank-1 SVD + U, S, Vh = torch.linalg.svd(R, full_matrices=False) + sigma = S[0] + sqrt_s = torch.sqrt(sigma) + w1 = sqrt_s * U[:, 0].reshape(m1, n1) + w2 = sqrt_s * Vh[0].reshape(m2, n2) + return w1, w2 + + +def _split_lokr_qkv(w1, w2, target_keys, factor): + """Split fused LoKR QKV factors into separate per-projection Kronecker factors. + + Materializes kron(w1, w2), chunks along dim=0, and re-factorizes each chunk + as a rank-1 Kronecker product using the Van Loan algorithm. + + Args: + w1: First Kronecker factor, shape (f, f) where f = decompose_factor. + w2: Second Kronecker factor, shape (out_total/f, in_total/f). + target_keys: List of target projection names (e.g., ["to_q", "to_k", "to_v"]). + factor: Kronecker decompose factor for the split chunks. + + Returns: + Dict mapping "{target_key}.lokr_w1" and "{target_key}.lokr_w2" to tensors. + """ + full_delta = torch.kron(w1.float(), w2.float()) + chunks = torch.chunk(full_delta, len(target_keys), dim=0) + + result = {} + for target_key, chunk in zip(target_keys, chunks): + rows, cols = chunk.shape + m1 = n1 = factor + m2 = rows // m1 + n2 = cols // n1 + new_w1, new_w2 = _nearest_kronecker_product(chunk, m1, n1, m2, n2) + result[f"{target_key}.lokr_w1"] = new_w1.to(w1.dtype) + result[f"{target_key}.lokr_w2"] = new_w2.to(w2.dtype) + return result + + def _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict): - """Convert non-diffusers Flux2 LoKR state dict (kohya/LyCORIS format) to peft-compatible diffusers format. + """Convert BFL-format Flux2 LoKR state dict to peft-compatible diffusers format. - Uses fuse-first QKV mapping: BFL fused `img_attn.qkv` maps to diffusers `attn.to_qkv` (created by - `fuse_projections()`), avoiding lossy Kronecker factor splitting. The caller must fuse the model's - QKV projections before injecting the adapter. + Handles fused QKV by splitting via Kronecker re-factorization (Van Loan algorithm). + Non-QKV modules are remapped directly. Alpha scaling is baked into lokr_w1. """ converted_state_dict = {} @@ -2662,8 +2709,25 @@ def _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict): lokr_suffixes = ("lokr_w1", "lokr_w1_a", "lokr_w1_b", "lokr_w2", "lokr_w2_a", "lokr_w2_b", "lokr_t2") + def _pop_alpha_and_bake(bfl_path, w1_weight): + """Pop alpha for a module and bake scaling into w1. Returns scaled w1.""" + alpha_key = f"{bfl_path}.alpha" + if alpha_key not in original_state_dict: + return w1_weight + alpha = original_state_dict.pop(alpha_key).item() + w2a_key = f"{bfl_path}.lokr_w2_a" + w1a_key = f"{bfl_path}.lokr_w1_a" + if w2a_key in original_state_dict: + r_eff = original_state_dict[w2a_key].shape[1] + elif w1a_key in original_state_dict: + r_eff = original_state_dict[w1a_key].shape[1] + else: + r_eff = alpha + return w1_weight * (alpha / r_eff) + def _remap_lokr_module(bfl_path, diff_path): - """Pop all lokr keys for a BFL module, bake alpha scaling, and store under diffusers path.""" + """Pop all LoKR keys for a BFL module, bake alpha, and store under diffusers path.""" + # Pop alpha separately (consumed by first w1 tensor) alpha_key = f"{bfl_path}.alpha" alpha = original_state_dict.pop(alpha_key).item() if alpha_key in original_state_dict else None @@ -2674,8 +2738,7 @@ def _remap_lokr_module(bfl_path, diff_path): weight = original_state_dict.pop(src_key) - # Bake alpha/rank scaling into the first w1 tensor encountered for this module. - # After baking, peft's config uses alpha=r so its runtime scaling is 1.0. + # Bake alpha/rank scaling into the first w1 tensor for this module. if alpha is not None and suffix in ("lokr_w1", "lokr_w1_a"): w2a_key = f"{bfl_path}.lokr_w2_a" w1a_key = f"{bfl_path}.lokr_w1_a" @@ -2685,12 +2748,41 @@ def _remap_lokr_module(bfl_path, diff_path): r_eff = original_state_dict[w1a_key].shape[1] else: r_eff = alpha - scale = alpha / r_eff - weight = weight * scale - alpha = None # only bake once per module + weight = weight * (alpha / r_eff) + alpha = None converted_state_dict[f"{diff_path}.{suffix}"] = weight + def _remap_lokr_qkv(bfl_path, target_keys): + """Pop fused QKV LoKR factors, split into separate projections via Kronecker re-factorization.""" + w1_key = f"{bfl_path}.lokr_w1" + w2_key = f"{bfl_path}.lokr_w2" + if w1_key not in original_state_dict or w2_key not in original_state_dict: + # Fall back to direct remap if decomposed factors (w1_a/w1_b) are used + _remap_lokr_module(bfl_path, target_keys[0].rsplit(".", 1)[0]) + return + + w1 = original_state_dict.pop(w1_key) + w2 = original_state_dict.pop(w2_key) + + # Bake alpha before splitting + alpha_key = f"{bfl_path}.alpha" + if alpha_key in original_state_dict: + alpha = original_state_dict.pop(alpha_key).item() + w2a_key = f"{bfl_path}.lokr_w2_a" + w1a_key = f"{bfl_path}.lokr_w1_a" + if w2a_key in original_state_dict: + r_eff = original_state_dict[w2a_key].shape[1] + elif w1a_key in original_state_dict: + r_eff = original_state_dict[w1a_key].shape[1] + else: + r_eff = alpha + w1 = w1 * (alpha / r_eff) + + factor = w1.shape[0] + split_result = _split_lokr_qkv(w1, w2, target_keys, factor) + converted_state_dict.update(split_result) + # --- Single blocks --- for sl in range(num_single_layers): _remap_lokr_module(f"single_blocks.{sl}.linear1", f"single_transformer_blocks.{sl}.attn.to_qkv_mlp_proj") @@ -2701,9 +2793,11 @@ def _remap_lokr_module(bfl_path, diff_path): tb = f"transformer_blocks.{dl}" db = f"double_blocks.{dl}" - # QKV -> fused to_qkv / to_added_qkv (model must be fused before injection) - _remap_lokr_module(f"{db}.img_attn.qkv", f"{tb}.attn.to_qkv") - _remap_lokr_module(f"{db}.txt_attn.qkv", f"{tb}.attn.to_added_qkv") + # Split fused QKV into separate Q/K/V via Kronecker re-factorization + _remap_lokr_qkv(f"{db}.img_attn.qkv", [f"{tb}.attn.to_q", f"{tb}.attn.to_k", f"{tb}.attn.to_v"]) + _remap_lokr_qkv( + f"{db}.txt_attn.qkv", [f"{tb}.attn.add_q_proj", f"{tb}.attn.add_k_proj", f"{tb}.attn.add_v_proj"] + ) # Projections _remap_lokr_module(f"{db}.img_attn.proj", f"{tb}.attn.to_out.0") @@ -2739,6 +2833,117 @@ def _remap_lokr_module(bfl_path, diff_path): return converted_state_dict +# Mapping from LyCORIS underscore-encoded sub-paths to dotted diffusers module paths +_LYCORIS_SUBPATH_MAP = { + "attn_to_q": "attn.to_q", + "attn_to_k": "attn.to_k", + "attn_to_v": "attn.to_v", + "attn_to_out_0": "attn.to_out.0", + "attn_to_add_out": "attn.to_add_out", + "attn_add_q_proj": "attn.add_q_proj", + "attn_add_k_proj": "attn.add_k_proj", + "attn_add_v_proj": "attn.add_v_proj", + "attn_to_qkv_mlp_proj": "attn.to_qkv_mlp_proj", + "attn_to_out": "attn.to_out", + "ff_context_linear_in": "ff_context.linear_in", + "ff_context_linear_out": "ff_context.linear_out", + "ff_linear_in": "ff.linear_in", + "ff_linear_out": "ff.linear_out", +} + + +def _bake_lokr_alpha(state_dict): + """Consume .alpha keys by baking alpha/rank scaling into lokr_w1 weights in-place.""" + lokr_w1_suffixes = (".lokr_w1", ".lokr_w1_a") + alpha_keys = [k for k in state_dict if k.endswith(".alpha")] + + for alpha_key in alpha_keys: + alpha = state_dict.pop(alpha_key).item() + module_path = alpha_key[: -len(".alpha")] + + # Find the w1 tensor to bake into + for w1_suffix in lokr_w1_suffixes: + w1_key = f"{module_path}{w1_suffix}" + if w1_key in state_dict: + # Determine effective rank + w2a_key = f"{module_path}.lokr_w2_a" + w1a_key = f"{module_path}.lokr_w1_a" + if w2a_key in state_dict: + r_eff = state_dict[w2a_key].shape[1] + elif w1a_key in state_dict: + r_eff = state_dict[w1a_key].shape[1] + else: + r_eff = alpha + state_dict[w1_key] = state_dict[w1_key] * (alpha / r_eff) + break + + +def _convert_lycoris_flux2_lokr_to_diffusers(state_dict): + """Convert LyCORIS underscore-format Flux2 LoKR state dict to peft-compatible diffusers format. + + LyCORIS keys use underscore-encoded paths (e.g., lycoris_transformer_blocks_0_attn_to_q.lokr_w1). + Decodes these to dotted diffusers paths using a known sub-path lookup table. + """ + import re + + converted_state_dict = {} + original_state_dict = dict(state_dict) + + _bake_lokr_alpha(original_state_dict) + + lycoris_pattern = re.compile(r"^lycoris_((?:single_)?transformer_blocks)_(\d+)_(.+)$") + + for key in list(original_state_dict.keys()): + # Split key into module_path and lokr suffix + parts = key.rsplit(".", 1) + if len(parts) != 2: + continue + module_encoded, suffix = parts + + match = lycoris_pattern.match(module_encoded) + if not match: + continue + + container, block_idx, sub_path = match.groups() + + # Decode sub-path using lookup table (try longest match first) + diff_sub_path = None + for lycoris_sub, diff_sub in sorted(_LYCORIS_SUBPATH_MAP.items(), key=lambda x: -len(x[0])): + if sub_path == lycoris_sub: + diff_sub_path = diff_sub + break + + if diff_sub_path is None: + continue + + diff_key = f"transformer.{container}.{block_idx}.{diff_sub_path}.{suffix}" + converted_state_dict[diff_key] = original_state_dict.pop(key) + + if len(original_state_dict) > 0: + logger.warning(f"Unconverted LyCORIS LoKR keys: {list(original_state_dict.keys())}") + + return converted_state_dict + + +def _convert_diffusers_flux2_lokr_to_peft(state_dict): + """Convert diffusers-native Flux2 LoKR state dict by adding transformer. prefix and baking alpha. + + Diffusers-native keys already use dotted module paths matching the model structure. + Only alpha baking and the transformer. prefix are needed. + """ + original_state_dict = dict(state_dict) + _bake_lokr_alpha(original_state_dict) + + converted_state_dict = {} + for key, val in original_state_dict.items(): + if key.startswith("transformer."): + converted_state_dict[key] = val + else: + converted_state_dict[f"transformer.{key}"] = val + + return converted_state_dict + + def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict): """ Convert non-diffusers ZImage LoRA state dict to diffusers format. diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 6c8bba726ff1..847b07319e3f 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -41,10 +41,12 @@ ) from .lora_conversion_utils import ( _convert_bfl_flux_control_lora_to_diffusers, + _convert_diffusers_flux2_lokr_to_peft, _convert_fal_kontext_lora_to_diffusers, _convert_hunyuan_video_lora_to_diffusers, _convert_kohya_flux2_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers, + _convert_lycoris_flux2_lokr_to_diffusers, _convert_musubi_wan_lora_to_diffusers, _convert_non_diffusers_flux2_lokr_to_diffusers, _convert_non_diffusers_flux2_lora_to_diffusers, @@ -58,7 +60,6 @@ _convert_non_diffusers_z_image_lora_to_diffusers, _convert_xlabs_flux_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers, - _refuse_flux2_lora_state_dict, ) @@ -5687,15 +5688,20 @@ def lora_state_dict( if is_peft_format: state_dict = {k.replace("base_model.model.", "diffusion_model."): v for k, v in state_dict.items()} - is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict) - if is_ai_toolkit: - is_lokr = any("lokr_" in k for k in state_dict) - if is_lokr: + is_lokr = any("lokr_" in k for k in state_dict) + if is_lokr: + if any(k.startswith("diffusion_model.") for k in state_dict): state_dict = _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict) - if metadata is None: - metadata = {} - metadata["is_lokr"] = "true" + elif any(k.startswith("lycoris_") for k in state_dict): + state_dict = _convert_lycoris_flux2_lokr_to_diffusers(state_dict) else: + state_dict = _convert_diffusers_flux2_lokr_to_peft(state_dict) + if metadata is None: + metadata = {} + metadata["is_lokr"] = "true" + else: + is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict) + if is_ai_toolkit: state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict) out = (state_dict, metadata) if return_lora_metadata else state_dict @@ -5732,18 +5738,7 @@ def load_lora_weights( if not is_correct_format: raise ValueError("Invalid LoRA/LoKR checkpoint. Make sure all param names contain `'lora'` or `'lokr'`.") - # For LoKR adapters, fuse QKV projections so peft can target the fused modules directly. - is_lokr = metadata is not None and metadata.get("is_lokr") == "true" transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer - if is_lokr: - transformer.fuse_qkv_projections() - elif ( - hasattr(transformer, "transformer_blocks") - and len(transformer.transformer_blocks) > 0 - and getattr(transformer.transformer_blocks[0].attn, "fused_projections", False) - ): - # Model QKV is fused but LoRA targets separate Q/K/V - re-fuse the keys to match. - state_dict = _refuse_flux2_lora_state_dict(state_dict) self.load_lora_into_transformer( state_dict, From 0920939e6d31d625e165319169362e5bec80e814 Mon Sep 17 00:00:00 2001 From: CalamitousFelicitousness Date: Thu, 26 Mar 2026 22:12:51 +0000 Subject: [PATCH 03/11] Add three-tier LoKR quality comparison (fuse-first, Kronecker split, SVD) - Add fuse_qkv parameter to BFL LoKR converter for lossless fuse-first path - Thread fuse_qkv through lora_pipeline.py (lora_state_dict -> load_lora_weights) - Fuse model QKV projections before adapter injection when fuse_qkv=True - Update benchmark script with --tiers, --no-offload flags for all three paths --- benchmark_lokr.py | 129 ++++++++++++------ .../loaders/lora_conversion_utils.py | 26 ++-- src/diffusers/loaders/lora_pipeline.py | 9 +- 3 files changed, 110 insertions(+), 54 deletions(-) diff --git a/benchmark_lokr.py b/benchmark_lokr.py index bbde88a3378b..6d3a16da0574 100644 --- a/benchmark_lokr.py +++ b/benchmark_lokr.py @@ -1,12 +1,20 @@ -"""Benchmark: Lossless LoKR vs Lossy LoRA-via-SVD on Flux2 Klein 9B. +"""Benchmark: Three-tier LoKR quality comparison on Flux2 Klein 9B. + +Tier 1 - Fuse-first (lossless): Fuse model QKV, map BFL LoKR directly. Exact. +Tier 2 - Kronecker split (default): Split fused QKV via Van Loan re-factorization. Slight loss. +Tier 3 - SVD to LoRA (fully lossy): Convert entire LoKR to LoRA via peft.convert_to_lora. + +Tiers 1+2 only apply to BFL-format LoKR (fused QKV). LyCORIS and diffusers-native +formats already have separate Q/K/V and only run the default path. -Generates images using both conversion paths for visual comparison. Uses bf16 with CPU offload. Usage: python benchmark_lokr.py python benchmark_lokr.py --lokr-path "puttmorbidly233/lora" --lokr-name "klein_snofs_v1_2.safetensors" python benchmark_lokr.py --prompt "a portrait in besch art style" --ranks 32 64 128 + python benchmark_lokr.py --tiers 1 2 # skip SVD tier + python benchmark_lokr.py --tiers 2 3 # skip fuse-first tier """ import argparse @@ -15,18 +23,22 @@ import time import torch + from diffusers import Flux2KleinPipeline -from peft import convert_to_lora + MODEL_ID = "black-forest-labs/FLUX.2-klein-9B" DEFAULT_LOKR_PATH = "gattaplayer/besch-flux2-klein-9b-lokr-lion-3e-6-bs2-ga2-v02" OUTPUT_DIR = "benchmark_output" -def load_pipeline(): - """Load Flux2 Klein 9B in bf16 with model CPU offload.""" +def load_pipeline(no_offload=False): + """Load Flux2 Klein 9B in bf16.""" pipe = Flux2KleinPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16) - pipe.enable_model_cpu_offload() + if no_offload: + pipe = pipe.to("cuda") + else: + pipe.enable_model_cpu_offload() return pipe @@ -44,9 +56,34 @@ def generate(pipe, prompt, seed, num_steps=4, guidance_scale=1.0): return image -def benchmark_lossless(pipe, prompt, seed, lokr_path, lokr_name): - """Path A: Load LoKR natively (lossless).""" - print("\n=== Path A: Lossless LoKR ===") +def benchmark_baseline(pipe, prompt, seed): + """Baseline: No adapter.""" + print("\n=== Baseline: No adapter ===") + t0 = time.time() + image = generate(pipe, prompt, seed) + print(f" Generated in {time.time() - t0:.1f}s") + return image + + +def benchmark_tier1_fuse_first(pipe, prompt, seed, lokr_path, lokr_name): + """Tier 1: Fuse model QKV, then load BFL LoKR directly (lossless).""" + print("\n=== Tier 1: Fuse-first LoKR (lossless) ===") + t0 = time.time() + kwargs = {"weight_name": lokr_name} if lokr_name else {} + pipe.load_lora_weights(lokr_path, fuse_qkv=True, **kwargs) + print(f" Loaded in {time.time() - t0:.1f}s") + + t0 = time.time() + image = generate(pipe, prompt, seed) + print(f" Generated in {time.time() - t0:.1f}s") + + pipe.unload_lora_weights() + return image + + +def benchmark_tier2_kronecker_split(pipe, prompt, seed, lokr_path, lokr_name): + """Tier 2: Split fused QKV via Kronecker re-factorization (default path).""" + print("\n=== Tier 2: Kronecker split LoKR (default) ===") t0 = time.time() kwargs = {"weight_name": lokr_name} if lokr_name else {} pipe.load_lora_weights(lokr_path, **kwargs) @@ -60,15 +97,16 @@ def benchmark_lossless(pipe, prompt, seed, lokr_path, lokr_name): return image -def benchmark_lossy(pipe, prompt, seed, rank, lokr_path, lokr_name): - """Path B: Load LoKR, convert to LoRA via SVD (lossy).""" - print(f"\n=== Path B: Lossy LoRA via SVD (rank={rank}) ===") +def benchmark_tier3_svd(pipe, prompt, seed, rank, lokr_path, lokr_name): + """Tier 3: Convert LoKR to LoRA via SVD (fully lossy).""" + from peft import convert_to_lora, inject_adapter_in_model, set_peft_model_state_dict + + print(f"\n=== Tier 3: SVD to LoRA (rank={rank}) ===") t0 = time.time() kwargs = {"weight_name": lokr_name} if lokr_name else {} pipe.load_lora_weights(lokr_path, **kwargs) load_time = time.time() - t0 - # Detect the actual adapter name assigned by peft adapter_name = next(iter(pipe.transformer.peft_config.keys())) print(f" Adapter name: {adapter_name}") @@ -77,9 +115,6 @@ def benchmark_lossy(pipe, prompt, seed, rank, lokr_path, lokr_name): convert_time = time.time() - t0 print(f" Loaded LoKR in {load_time:.1f}s, converted to LoRA in {convert_time:.1f}s") - # Replace LoKR adapter with converted LoRA - from peft import inject_adapter_in_model, set_peft_model_state_dict - pipe.transformer.delete_adapters(adapter_name) inject_adapter_in_model(lora_config, pipe.transformer, adapter_name=adapter_name) set_peft_model_state_dict(pipe.transformer, lora_sd, adapter_name=adapter_name) @@ -92,24 +127,18 @@ def benchmark_lossy(pipe, prompt, seed, rank, lokr_path, lokr_name): return image -def benchmark_baseline(pipe, prompt, seed): - """Baseline: No adapter.""" - print("\n=== Baseline: No adapter ===") - t0 = time.time() - image = generate(pipe, prompt, seed) - print(f" Generated in {time.time() - t0:.1f}s") - return image - - def main(): - parser = argparse.ArgumentParser(description="Benchmark LoKR vs LoRA-via-SVD") + parser = argparse.ArgumentParser(description="Benchmark LoKR quality tiers") parser.add_argument("--prompt", default="a portrait painting in besch art style") parser.add_argument("--lokr-path", default=DEFAULT_LOKR_PATH, help="HF repo or local path to LoKR checkpoint") parser.add_argument("--lokr-name", default=None, help="Filename within HF repo (if multi-file)") parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--ranks", type=int, nargs="+", default=[32, 64, 128]) + parser.add_argument( + "--tiers", type=int, nargs="+", default=[1, 2, 3], help="Tiers to run (1=fuse, 2=kronecker, 3=svd)" + ) + parser.add_argument("--ranks", type=int, nargs="+", default=[32, 64, 128], help="SVD ranks for tier 3") parser.add_argument("--skip-baseline", action="store_true") - parser.add_argument("--skip-lossy", action="store_true") + parser.add_argument("--no-offload", action="store_true", help="Keep model on GPU instead of CPU offload") args = parser.parse_args() os.makedirs(OUTPUT_DIR, exist_ok=True) @@ -118,11 +147,13 @@ def main(): print(f"LoKR: {args.lokr_path}" + (f" ({args.lokr_name})" if args.lokr_name else "")) print(f"Prompt: {args.prompt}") print(f"Seed: {args.seed}") - if not args.skip_lossy: - print(f"SVD ranks to test: {args.ranks}") + print(f"Tiers: {args.tiers}") + if 3 in args.tiers: + print(f"SVD ranks: {args.ranks}") - print("\nLoading pipeline (bf16, model CPU offload)...") - pipe = load_pipeline() + mode = "on GPU" if args.no_offload else "with CPU offload" + print(f"\nLoading pipeline (bf16, {mode})...") + pipe = load_pipeline(no_offload=args.no_offload) # Baseline if not args.skip_baseline: @@ -131,28 +162,36 @@ def main(): img.save(path) print(f" Saved: {path}") - # Path A: Lossless LoKR - img = benchmark_lossless(pipe, args.prompt, args.seed, args.lokr_path, args.lokr_name) - path = os.path.join(OUTPUT_DIR, "lokr_lossless.png") - img.save(path) - print(f" Saved: {path}") + # Tier 1: Fuse-first (lossless, BFL only) + if 1 in args.tiers: + img = benchmark_tier1_fuse_first(pipe, args.prompt, args.seed, args.lokr_path, args.lokr_name) + path = os.path.join(OUTPUT_DIR, "tier1_fuse_lossless.png") + img.save(path) + print(f" Saved: {path}") + gc.collect() + torch.cuda.empty_cache() - gc.collect() - torch.cuda.empty_cache() + # Tier 2: Kronecker split (default) + if 2 in args.tiers: + img = benchmark_tier2_kronecker_split(pipe, args.prompt, args.seed, args.lokr_path, args.lokr_name) + path = os.path.join(OUTPUT_DIR, "tier2_kronecker.png") + img.save(path) + print(f" Saved: {path}") + gc.collect() + torch.cuda.empty_cache() - # Path B: Lossy LoRA via SVD at various ranks - if not args.skip_lossy: + # Tier 3: SVD to LoRA at various ranks + if 3 in args.tiers: for rank in args.ranks: - img = benchmark_lossy(pipe, args.prompt, args.seed, rank, args.lokr_path, args.lokr_name) - path = os.path.join(OUTPUT_DIR, f"lora_svd_rank{rank}.png") + img = benchmark_tier3_svd(pipe, args.prompt, args.seed, rank, args.lokr_path, args.lokr_name) + path = os.path.join(OUTPUT_DIR, f"tier3_svd_rank{rank}.png") img.save(path) print(f" Saved: {path}") - gc.collect() torch.cuda.empty_cache() print(f"\nAll results saved to {OUTPUT_DIR}/") - print("Compare: baseline.png vs lokr_lossless.png vs lora_svd_rank*.png") + print("Compare: baseline.png vs tier1_fuse_lossless.png vs tier2_kronecker.png vs tier3_svd_rank*.png") if __name__ == "__main__": diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 8c96ce565b88..a76429ee09a8 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -2688,11 +2688,15 @@ def _split_lokr_qkv(w1, w2, target_keys, factor): return result -def _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict): +def _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict, fuse_qkv=False): """Convert BFL-format Flux2 LoKR state dict to peft-compatible diffusers format. - Handles fused QKV by splitting via Kronecker re-factorization (Van Loan algorithm). - Non-QKV modules are remapped directly. Alpha scaling is baked into lokr_w1. + Args: + state_dict: BFL-format LoKR state dict with ``diffusion_model.`` prefix. + fuse_qkv: If True, map fused QKV directly to ``to_qkv``/``to_added_qkv`` targets + (lossless, but requires the model's QKV to be fused before injection). + If False (default), split fused QKV into separate Q/K/V via Kronecker + re-factorization (slightly lossy, no model fusion needed). """ converted_state_dict = {} @@ -2793,11 +2797,17 @@ def _remap_lokr_qkv(bfl_path, target_keys): tb = f"transformer_blocks.{dl}" db = f"double_blocks.{dl}" - # Split fused QKV into separate Q/K/V via Kronecker re-factorization - _remap_lokr_qkv(f"{db}.img_attn.qkv", [f"{tb}.attn.to_q", f"{tb}.attn.to_k", f"{tb}.attn.to_v"]) - _remap_lokr_qkv( - f"{db}.txt_attn.qkv", [f"{tb}.attn.add_q_proj", f"{tb}.attn.add_k_proj", f"{tb}.attn.add_v_proj"] - ) + if fuse_qkv: + # Lossless: map directly to fused targets (caller must fuse model QKV first) + _remap_lokr_module(f"{db}.img_attn.qkv", f"{tb}.attn.to_qkv") + _remap_lokr_module(f"{db}.txt_attn.qkv", f"{tb}.attn.to_added_qkv") + else: + # Split fused QKV into separate Q/K/V via Kronecker re-factorization + _remap_lokr_qkv(f"{db}.img_attn.qkv", [f"{tb}.attn.to_q", f"{tb}.attn.to_k", f"{tb}.attn.to_v"]) + _remap_lokr_qkv( + f"{db}.txt_attn.qkv", + [f"{tb}.attn.add_q_proj", f"{tb}.attn.add_k_proj", f"{tb}.attn.add_v_proj"], + ) # Projections _remap_lokr_module(f"{db}.img_attn.proj", f"{tb}.attn.to_out.0") diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 847b07319e3f..913242195996 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -5648,6 +5648,7 @@ def lora_state_dict( weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) return_lora_metadata = kwargs.pop("return_lora_metadata", False) + fuse_qkv = kwargs.pop("fuse_qkv", False) allow_pickle = False if use_safetensors is None: @@ -5691,7 +5692,7 @@ def lora_state_dict( is_lokr = any("lokr_" in k for k in state_dict) if is_lokr: if any(k.startswith("diffusion_model.") for k in state_dict): - state_dict = _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict) + state_dict = _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict, fuse_qkv=fuse_qkv) elif any(k.startswith("lycoris_") for k in state_dict): state_dict = _convert_lycoris_flux2_lokr_to_diffusers(state_dict) else: @@ -5699,6 +5700,8 @@ def lora_state_dict( if metadata is None: metadata = {} metadata["is_lokr"] = "true" + if fuse_qkv: + metadata["fuse_qkv"] = "true" else: is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict) if is_ai_toolkit: @@ -5740,6 +5743,10 @@ def load_lora_weights( transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + # Fuse model QKV projections before injection if requested (lossless path for BFL LoKR) + if metadata and metadata.get("fuse_qkv") == "true": + transformer.fuse_qkv_projections() + self.load_lora_into_transformer( state_dict, transformer=transformer, From 16c274b0f3c9692f1deec80852d9ba31b076b2fb Mon Sep 17 00:00:00 2001 From: CalamitousFelicitousness Date: Thu, 26 Mar 2026 22:42:03 +0000 Subject: [PATCH 04/11] Fix fuse_qkv only applying to BFL-format LoKR fuse_qkv=True with LyCORIS or diffusers-native checkpoints would fuse the model QKV then fail injection (adapter targets separate Q/K/V modules that no longer exist). Now only set fuse_qkv metadata for BFL format. --- benchmark_lokr.py | 3 ++- src/diffusers/loaders/lora_pipeline.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/benchmark_lokr.py b/benchmark_lokr.py index 6d3a16da0574..46ebc85ce4aa 100644 --- a/benchmark_lokr.py +++ b/benchmark_lokr.py @@ -162,8 +162,9 @@ def main(): img.save(path) print(f" Saved: {path}") - # Tier 1: Fuse-first (lossless, BFL only) + # Tier 1: Fuse-first (lossless, BFL format only - identical to tier 2 for other formats) if 1 in args.tiers: + print("\n Note: Tier 1 only differs from tier 2 for BFL-format LoKR (fused QKV).") img = benchmark_tier1_fuse_first(pipe, args.prompt, args.seed, args.lokr_path, args.lokr_name) path = os.path.join(OUTPUT_DIR, "tier1_fuse_lossless.png") img.save(path) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 913242195996..383ae88c3c5f 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -5691,7 +5691,8 @@ def lora_state_dict( is_lokr = any("lokr_" in k for k in state_dict) if is_lokr: - if any(k.startswith("diffusion_model.") for k in state_dict): + is_bfl_format = any(k.startswith("diffusion_model.") for k in state_dict) + if is_bfl_format: state_dict = _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict, fuse_qkv=fuse_qkv) elif any(k.startswith("lycoris_") for k in state_dict): state_dict = _convert_lycoris_flux2_lokr_to_diffusers(state_dict) @@ -5700,7 +5701,8 @@ def lora_state_dict( if metadata is None: metadata = {} metadata["is_lokr"] = "true" - if fuse_qkv: + # Only fuse model QKV for BFL format (which has fused QKV keys to map 1:1) + if fuse_qkv and is_bfl_format: metadata["fuse_qkv"] = "true" else: is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict) From 3ee2721d6796998b18e19df3a43aff873e8be8b9 Mon Sep 17 00:00:00 2001 From: CalamitousFelicitousness Date: Thu, 26 Mar 2026 23:00:05 +0000 Subject: [PATCH 05/11] Add weight-space error analysis for Kronecker split vs lossless Compares materialized kron(w1, w2) from fuse-first path against reconstructed cat(kron(w1_q, w2_q), ...) from Van Loan split. Reports per-module and aggregate relative Frobenius norm error. No model loading needed - runs on checkpoint state dict only. --- benchmark_lokr.py | 121 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) diff --git a/benchmark_lokr.py b/benchmark_lokr.py index 46ebc85ce4aa..ce9230a98c24 100644 --- a/benchmark_lokr.py +++ b/benchmark_lokr.py @@ -15,6 +15,7 @@ python benchmark_lokr.py --prompt "a portrait in besch art style" --ranks 32 64 128 python benchmark_lokr.py --tiers 1 2 # skip SVD tier python benchmark_lokr.py --tiers 2 3 # skip fuse-first tier + python benchmark_lokr.py --weight-space # weight-space error analysis only (no image generation) """ import argparse @@ -56,6 +57,113 @@ def generate(pipe, prompt, seed, num_steps=4, guidance_scale=1.0): return image +# --------------------------------------------------------------------------- +# Weight-space error analysis +# --------------------------------------------------------------------------- + + +def load_raw_state_dict(lokr_path, lokr_name): + """Download/load a LoKR checkpoint and return the raw state dict.""" + from huggingface_hub import hf_hub_download + from safetensors.torch import load_file + + if os.path.isfile(lokr_path): + return load_file(lokr_path) + + if os.path.isdir(lokr_path): + path = os.path.join(lokr_path, lokr_name) if lokr_name else lokr_path + return load_file(path) + + # HF repo + path = hf_hub_download(lokr_path, filename=lokr_name or "pytorch_lora_weights.safetensors") + return load_file(path) + + +def weight_space_analysis(lokr_path, lokr_name): + """Compare tier 1 (lossless) vs tier 2 (Kronecker split) in weight space. + + For each fused QKV module, materializes the exact delta from the fuse-first path + and the reconstructed delta from the Kronecker split path, then reports the + relative Frobenius norm error. + """ + from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_flux2_lokr_to_diffusers + + raw_sd = load_raw_state_dict(lokr_path, lokr_name) + + is_bfl = any(k.startswith("diffusion_model.") for k in raw_sd) + if not is_bfl: + print(" Checkpoint is not BFL format - no fused QKV to compare.") + print(" Tiers 1 and 2 produce identical results for this format.") + return + + # Convert both ways from the same raw state dict + sd_fused = _convert_non_diffusers_flux2_lokr_to_diffusers(dict(raw_sd), fuse_qkv=True) + sd_split = _convert_non_diffusers_flux2_lokr_to_diffusers(dict(raw_sd), fuse_qkv=False) + + # Find all fused QKV modules (to_qkv and to_added_qkv) + qkv_modules = {} + for key in sd_fused: + if ".to_qkv.lokr_w1" in key or ".to_added_qkv.lokr_w1" in key: + module_path = key.rsplit(".lokr_w1", 1)[0] + qkv_modules[module_path] = key + + print(f"\n Found {len(qkv_modules)} fused QKV modules to compare\n") + print(f" {'Module':<65} {'Rel Error':>12} {'Abs Error':>12} {'Orig Norm':>12}") + print(f" {'-' * 65} {'-' * 12} {'-' * 12} {'-' * 12}") + + errors = [] + for module_path in sorted(qkv_modules.keys()): + # Materialize exact delta from fused path + w1_f = sd_fused[f"{module_path}.lokr_w1"].float() + w2_f = sd_fused[f"{module_path}.lokr_w2"].float() + delta_exact = torch.kron(w1_f, w2_f) + + # Determine split target keys + if ".to_qkv" in module_path: + base = module_path.replace(".attn.to_qkv", "") + proj_keys = [f"{base}.attn.to_q", f"{base}.attn.to_k", f"{base}.attn.to_v"] + else: + base = module_path.replace(".attn.to_added_qkv", "") + proj_keys = [f"{base}.attn.add_q_proj", f"{base}.attn.add_k_proj", f"{base}.attn.add_v_proj"] + + # Materialize reconstructed delta from split path + chunks = [] + for proj in proj_keys: + w1_key = f"{proj}.lokr_w1" + w2_key = f"{proj}.lokr_w2" + if w1_key not in sd_split: + break + w1_s = sd_split[w1_key].float() + w2_s = sd_split[w2_key].float() + chunks.append(torch.kron(w1_s, w2_s)) + + if len(chunks) != 3: + print(f" {module_path:<65} {'SKIP':>12}") + continue + + delta_recon = torch.cat(chunks, dim=0) + + orig_norm = delta_exact.norm().item() + abs_err = (delta_exact - delta_recon).norm().item() + rel_err = abs_err / orig_norm if orig_norm > 0 else 0.0 + + errors.append(rel_err) + + short_name = module_path.replace("transformer.", "") + print(f" {short_name:<65} {rel_err:>11.6f}% {abs_err:>12.6f} {orig_norm:>12.4f}") + + if errors: + print(f"\n Aggregate over {len(errors)} QKV modules:") + print(f" Mean relative error: {sum(errors) / len(errors):.6f}%") + print(f" Max relative error: {max(errors):.6f}%") + print(f" Min relative error: {min(errors):.6f}%") + + +# --------------------------------------------------------------------------- +# Image generation benchmarks +# --------------------------------------------------------------------------- + + def benchmark_baseline(pipe, prompt, seed): """Baseline: No adapter.""" print("\n=== Baseline: No adapter ===") @@ -127,6 +235,11 @@ def benchmark_tier3_svd(pipe, prompt, seed, rank, lokr_path, lokr_name): return image +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + def main(): parser = argparse.ArgumentParser(description="Benchmark LoKR quality tiers") parser.add_argument("--prompt", default="a portrait painting in besch art style") @@ -139,12 +252,20 @@ def main(): parser.add_argument("--ranks", type=int, nargs="+", default=[32, 64, 128], help="SVD ranks for tier 3") parser.add_argument("--skip-baseline", action="store_true") parser.add_argument("--no-offload", action="store_true", help="Keep model on GPU instead of CPU offload") + parser.add_argument("--weight-space", action="store_true", help="Run weight-space error analysis only (no images)") args = parser.parse_args() os.makedirs(OUTPUT_DIR, exist_ok=True) print(f"Model: {MODEL_ID}") print(f"LoKR: {args.lokr_path}" + (f" ({args.lokr_name})" if args.lokr_name else "")) + + # Weight-space analysis (no model needed) + if args.weight_space: + print("\n=== Weight-space error: Tier 1 (lossless) vs Tier 2 (Kronecker split) ===") + weight_space_analysis(args.lokr_path, args.lokr_name) + return + print(f"Prompt: {args.prompt}") print(f"Seed: {args.seed}") print(f"Tiers: {args.tiers}") From d239f2ffc003141d440848f50ec896fed629e507 Mon Sep 17 00:00:00 2001 From: CalamitousFelicitousness Date: Thu, 26 Mar 2026 23:28:21 +0000 Subject: [PATCH 06/11] Extend weight-space analysis to cover all three tiers Tier 1 vs 2 (Kronecker): lightweight, no model needed. Tier 1 vs 3 (SVD): loads model, runs peft.convert_to_lora, compares materialized LoKR deltas against LoRA deltas for all modules at each requested rank. --- benchmark_lokr.py | 204 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 167 insertions(+), 37 deletions(-) diff --git a/benchmark_lokr.py b/benchmark_lokr.py index ce9230a98c24..2bc894cbc098 100644 --- a/benchmark_lokr.py +++ b/benchmark_lokr.py @@ -79,12 +79,57 @@ def load_raw_state_dict(lokr_path, lokr_name): return load_file(path) -def weight_space_analysis(lokr_path, lokr_name): +def _materialize_lokr_delta(state_dict, module_path): + """Materialize the full delta weight from LoKR factors for a single module.""" + w1_key = f"{module_path}.lokr_w1" + w2_key = f"{module_path}.lokr_w2" + w1a_key = f"{module_path}.lokr_w1_a" + w1b_key = f"{module_path}.lokr_w1_b" + w2a_key = f"{module_path}.lokr_w2_a" + w2b_key = f"{module_path}.lokr_w2_b" + + # w1: full or decomposed + if w1_key in state_dict: + w1 = state_dict[w1_key].float() + elif w1a_key in state_dict and w1b_key in state_dict: + w1 = state_dict[w1a_key].float() @ state_dict[w1b_key].float() + else: + return None + + # w2: full or decomposed + if w2_key in state_dict: + w2 = state_dict[w2_key].float() + elif w2a_key in state_dict and w2b_key in state_dict: + w2 = state_dict[w2a_key].float() @ state_dict[w2b_key].float() + else: + return None + + return torch.kron(w1, w2) + + +def _print_error_table(title, results): + """Print a formatted error table and aggregate stats.""" + print(f"\n {title}\n") + print(f" {'Module':<60} {'Rel Error %':>12} {'Abs Error':>12} {'Orig Norm':>12}") + print(f" {'-' * 60} {'-' * 12} {'-' * 12} {'-' * 12}") + + errors = [] + for name, rel_err, abs_err, orig_norm in results: + errors.append(rel_err) + print(f" {name:<60} {rel_err:>11.6f}% {abs_err:>12.6f} {orig_norm:>12.4f}") + + if errors: + print(f"\n Aggregate over {len(errors)} modules:") + print(f" Mean relative error: {sum(errors) / len(errors):.6f}%") + print(f" Max relative error: {max(errors):.6f}%") + print(f" Min relative error: {min(errors):.6f}%") + + +def weight_space_kronecker(lokr_path, lokr_name): """Compare tier 1 (lossless) vs tier 2 (Kronecker split) in weight space. - For each fused QKV module, materializes the exact delta from the fuse-first path - and the reconstructed delta from the Kronecker split path, then reports the - relative Frobenius norm error. + No model loading needed - operates on checkpoint state dicts only. + Only meaningful for BFL-format LoKR (fused QKV). """ from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_flux2_lokr_to_diffusers @@ -96,27 +141,22 @@ def weight_space_analysis(lokr_path, lokr_name): print(" Tiers 1 and 2 produce identical results for this format.") return - # Convert both ways from the same raw state dict sd_fused = _convert_non_diffusers_flux2_lokr_to_diffusers(dict(raw_sd), fuse_qkv=True) sd_split = _convert_non_diffusers_flux2_lokr_to_diffusers(dict(raw_sd), fuse_qkv=False) - # Find all fused QKV modules (to_qkv and to_added_qkv) - qkv_modules = {} + # Find all fused QKV modules + qkv_modules = [] for key in sd_fused: if ".to_qkv.lokr_w1" in key or ".to_added_qkv.lokr_w1" in key: - module_path = key.rsplit(".lokr_w1", 1)[0] - qkv_modules[module_path] = key + qkv_modules.append(key.rsplit(".lokr_w1", 1)[0]) - print(f"\n Found {len(qkv_modules)} fused QKV modules to compare\n") - print(f" {'Module':<65} {'Rel Error':>12} {'Abs Error':>12} {'Orig Norm':>12}") - print(f" {'-' * 65} {'-' * 12} {'-' * 12} {'-' * 12}") + print(f"\n Found {len(qkv_modules)} fused QKV modules to compare") - errors = [] - for module_path in sorted(qkv_modules.keys()): - # Materialize exact delta from fused path - w1_f = sd_fused[f"{module_path}.lokr_w1"].float() - w2_f = sd_fused[f"{module_path}.lokr_w2"].float() - delta_exact = torch.kron(w1_f, w2_f) + results = [] + for module_path in sorted(qkv_modules): + delta_exact = _materialize_lokr_delta(sd_fused, module_path) + if delta_exact is None: + continue # Determine split target keys if ".to_qkv" in module_path: @@ -126,37 +166,122 @@ def weight_space_analysis(lokr_path, lokr_name): base = module_path.replace(".attn.to_added_qkv", "") proj_keys = [f"{base}.attn.add_q_proj", f"{base}.attn.add_k_proj", f"{base}.attn.add_v_proj"] - # Materialize reconstructed delta from split path chunks = [] for proj in proj_keys: - w1_key = f"{proj}.lokr_w1" - w2_key = f"{proj}.lokr_w2" - if w1_key not in sd_split: + delta = _materialize_lokr_delta(sd_split, proj) + if delta is None: break - w1_s = sd_split[w1_key].float() - w2_s = sd_split[w2_key].float() - chunks.append(torch.kron(w1_s, w2_s)) + chunks.append(delta) if len(chunks) != 3: - print(f" {module_path:<65} {'SKIP':>12}") continue delta_recon = torch.cat(chunks, dim=0) - orig_norm = delta_exact.norm().item() abs_err = (delta_exact - delta_recon).norm().item() rel_err = abs_err / orig_norm if orig_norm > 0 else 0.0 - errors.append(rel_err) - short_name = module_path.replace("transformer.", "") - print(f" {short_name:<65} {rel_err:>11.6f}% {abs_err:>12.6f} {orig_norm:>12.4f}") + results.append((short_name, rel_err, abs_err, orig_norm)) - if errors: - print(f"\n Aggregate over {len(errors)} QKV modules:") - print(f" Mean relative error: {sum(errors) / len(errors):.6f}%") - print(f" Max relative error: {max(errors):.6f}%") - print(f" Min relative error: {min(errors):.6f}%") + _print_error_table("Tier 1 (lossless) vs Tier 2 (Kronecker split) - QKV modules only", results) + + +def weight_space_svd(lokr_path, lokr_name, ranks, no_offload=False): + """Compare tier 1 (lossless) vs tier 3 (SVD to LoRA) in weight space. + + Requires loading the full model to run peft.convert_to_lora. + Compares materialized LoKR deltas against LoRA deltas for ALL modules. + """ + from peft import convert_to_lora + + # Build reference deltas from the converted state dict (tier 2 / default path) + # For non-QKV modules tier 2 is identical to tier 1, so this is ground truth. + raw_sd = load_raw_state_dict(lokr_path, lokr_name) + is_bfl = any(k.startswith("diffusion_model.") for k in raw_sd) + + if is_bfl: + from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_flux2_lokr_to_diffusers + + sd_ref = _convert_non_diffusers_flux2_lokr_to_diffusers(dict(raw_sd), fuse_qkv=False) + else: + # For non-BFL, just use the default conversion as reference (already lossless) + from diffusers.loaders.lora_conversion_utils import ( + _convert_diffusers_flux2_lokr_to_peft, + _convert_lycoris_flux2_lokr_to_diffusers, + ) + + if any(k.startswith("lycoris_") for k in raw_sd): + sd_ref = _convert_lycoris_flux2_lokr_to_diffusers(dict(raw_sd)) + else: + sd_ref = _convert_diffusers_flux2_lokr_to_peft(dict(raw_sd)) + + # Find all LoKR modules and materialize their deltas + ref_deltas = {} + lokr_modules = set() + for key in sd_ref: + if ".lokr_w1" in key and ".lokr_w1_" not in key: + module_path = key.rsplit(".lokr_w1", 1)[0] + lokr_modules.add(module_path) + elif ".lokr_w1_a" in key: + module_path = key.rsplit(".lokr_w1_a", 1)[0] + lokr_modules.add(module_path) + + for module_path in lokr_modules: + delta = _materialize_lokr_delta(sd_ref, module_path) + if delta is not None: + ref_deltas[module_path] = delta + + print(f"\n Materialized {len(ref_deltas)} reference LoKR deltas") + + # Load model and LoKR adapter + print("\n Loading model for SVD conversion...") + pipe = load_pipeline(no_offload=no_offload) + kwargs = {"weight_name": lokr_name} if lokr_name else {} + pipe.load_lora_weights(lokr_path, **kwargs) + adapter_name = next(iter(pipe.transformer.peft_config.keys())) + + for rank in ranks: + print(f"\n Converting to LoRA rank={rank}...") + t0 = time.time() + lora_config, lora_sd = convert_to_lora(pipe.transformer, rank, adapter_name=adapter_name, progressbar=True) + print(f" Converted in {time.time() - t0:.1f}s") + + # Compare each module: LoKR delta vs LoRA delta (lora_B @ lora_A) + results = [] + for module_path in sorted(ref_deltas.keys()): + delta_ref = ref_deltas[module_path] + + # Map module_path to LoRA key format: transformer.X.Y -> base_model.model.X.Y + lora_module = module_path.replace("transformer.", "") + lora_a_key = f"base_model.model.{lora_module}.lora_A.weight" + lora_b_key = f"base_model.model.{lora_module}.lora_B.weight" + + if lora_a_key not in lora_sd or lora_b_key not in lora_sd: + # Try without base_model.model prefix + lora_a_key = f"{lora_module}.lora_A.weight" + lora_b_key = f"{lora_module}.lora_B.weight" + + if lora_a_key not in lora_sd or lora_b_key not in lora_sd: + continue + + lora_a = lora_sd[lora_a_key].float() + lora_b = lora_sd[lora_b_key].float() + delta_lora = lora_b @ lora_a + + orig_norm = delta_ref.norm().item() + abs_err = (delta_ref - delta_lora).norm().item() + rel_err = abs_err / orig_norm if orig_norm > 0 else 0.0 + + short_name = module_path.replace("transformer.", "") + results.append((short_name, rel_err, abs_err, orig_norm)) + + _print_error_table(f"Tier 1 (lossless) vs Tier 3 (SVD rank={rank}) - all modules", results) + + pipe.unload_lora_weights() + del pipe + gc.collect() + torch.cuda.empty_cache() # --------------------------------------------------------------------------- @@ -260,10 +385,15 @@ def main(): print(f"Model: {MODEL_ID}") print(f"LoKR: {args.lokr_path}" + (f" ({args.lokr_name})" if args.lokr_name else "")) - # Weight-space analysis (no model needed) + # Weight-space analysis if args.weight_space: print("\n=== Weight-space error: Tier 1 (lossless) vs Tier 2 (Kronecker split) ===") - weight_space_analysis(args.lokr_path, args.lokr_name) + weight_space_kronecker(args.lokr_path, args.lokr_name) + + if args.ranks: + print("\n=== Weight-space error: Tier 1 (lossless) vs Tier 3 (SVD to LoRA) ===") + weight_space_svd(args.lokr_path, args.lokr_name, args.ranks, no_offload=args.no_offload) + return print(f"Prompt: {args.prompt}") From 4aa8a2b78c709d6bbb5cdd03c93f076fc1373c00 Mon Sep 17 00:00:00 2001 From: CalamitousFelicitousness Date: Thu, 26 Mar 2026 23:44:04 +0000 Subject: [PATCH 07/11] Address remaining review feedback from sayakpaul - Rename lora_config to adapter_config in load_lora_adapter (peft.py) - Remove redundant import collections (already at module level in peft_utils.py) --- src/diffusers/loaders/peft.py | 22 +++++++++++----------- src/diffusers/utils/peft_utils.py | 2 -- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 1ec304f24944..7c8fa0a9e854 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -219,7 +219,7 @@ def load_lora_adapter( if is_lokr: if adapter_name is None: adapter_name = get_adapter_name(self) - lora_config = _create_lokr_config(state_dict) + adapter_config = _create_lokr_config(state_dict) is_sai_sd_control_lora = False else: # check with first key if is not in peft format @@ -256,7 +256,7 @@ def load_lora_adapter( adapter_name = get_adapter_name(self) # create LoraConfig - lora_config = _create_lora_config( + adapter_config = _create_lora_config( state_dict, network_alphas, metadata, @@ -267,11 +267,11 @@ def load_lora_adapter( # Adjust LoRA config for Control LoRA if is_sai_sd_control_lora: - lora_config.lora_alpha = lora_config.r - lora_config.alpha_pattern = lora_config.rank_pattern - lora_config.bias = "all" - lora_config.modules_to_save = lora_config.exclude_modules - lora_config.exclude_modules = None + adapter_config.lora_alpha = adapter_config.r + adapter_config.alpha_pattern = adapter_config.rank_pattern + adapter_config.bias = "all" + adapter_config.modules_to_save = adapter_config.exclude_modules + adapter_config.exclude_modules = None # Date: Thu, 26 Mar 2026 23:51:32 +0000 Subject: [PATCH 08/11] Fix device mismatch in SVD weight-space comparison --- benchmark_lokr.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmark_lokr.py b/benchmark_lokr.py index 2bc894cbc098..27188ad3c549 100644 --- a/benchmark_lokr.py +++ b/benchmark_lokr.py @@ -265,12 +265,12 @@ def weight_space_svd(lokr_path, lokr_name, ranks, no_offload=False): if lora_a_key not in lora_sd or lora_b_key not in lora_sd: continue - lora_a = lora_sd[lora_a_key].float() - lora_b = lora_sd[lora_b_key].float() + lora_a = lora_sd[lora_a_key].float().cpu() + lora_b = lora_sd[lora_b_key].float().cpu() delta_lora = lora_b @ lora_a orig_norm = delta_ref.norm().item() - abs_err = (delta_ref - delta_lora).norm().item() + abs_err = (delta_ref.cpu() - delta_lora).norm().item() rel_err = abs_err / orig_norm if orig_norm > 0 else 0.0 short_name = module_path.replace("transformer.", "") From 401c00eaeb188697fb470748003bd4a738d6104f Mon Sep 17 00:00:00 2001 From: CalamitousFelicitousness Date: Fri, 27 Mar 2026 00:35:34 +0000 Subject: [PATCH 09/11] Add LoKR/LoRA config debug output to tier 3 benchmark --- benchmark_lokr.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/benchmark_lokr.py b/benchmark_lokr.py index 27188ad3c549..ff7c2432b955 100644 --- a/benchmark_lokr.py +++ b/benchmark_lokr.py @@ -246,6 +246,12 @@ def weight_space_svd(lokr_path, lokr_name, ranks, no_offload=False): t0 = time.time() lora_config, lora_sd = convert_to_lora(pipe.transformer, rank, adapter_name=adapter_name, progressbar=True) print(f" Converted in {time.time() - t0:.1f}s") + print(f" LoRA config: alpha={lora_config.lora_alpha}, r={lora_config.r}") + + # Also print the LoKR config for reference + lokr_cfg = pipe.transformer.peft_config.get(adapter_name) + if lokr_cfg: + print(f" LoKR config: alpha={lokr_cfg.alpha}, r={lokr_cfg.r}") # Compare each module: LoKR delta vs LoRA delta (lora_B @ lora_A) results = [] @@ -343,10 +349,16 @@ def benchmark_tier3_svd(pipe, prompt, seed, rank, lokr_path, lokr_name): adapter_name = next(iter(pipe.transformer.peft_config.keys())) print(f" Adapter name: {adapter_name}") + t0 = time.time() + lokr_cfg = pipe.transformer.peft_config.get(adapter_name) + if lokr_cfg: + print(f" LoKR config: alpha={lokr_cfg.alpha}, r={lokr_cfg.r}") + t0 = time.time() lora_config, lora_sd = convert_to_lora(pipe.transformer, rank, adapter_name=adapter_name, progressbar=True) convert_time = time.time() - t0 print(f" Loaded LoKR in {load_time:.1f}s, converted to LoRA in {convert_time:.1f}s") + print(f" LoRA config: alpha={lora_config.lora_alpha}, r={lora_config.r}") pipe.transformer.delete_adapters(adapter_name) inject_adapter_in_model(lora_config, pipe.transformer, adapter_name=adapter_name) From b08b26ea9c8fd27dd8fbe2cf3bc930cfa5de0478 Mon Sep 17 00:00:00 2001 From: CalamitousFelicitousness Date: Fri, 27 Mar 2026 01:38:45 +0000 Subject: [PATCH 10/11] Fix AttributeError: handle both LoKrConfig.alpha and LoraConfig.lora_alpha --- benchmark_lokr.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/benchmark_lokr.py b/benchmark_lokr.py index ff7c2432b955..a8fd6d052758 100644 --- a/benchmark_lokr.py +++ b/benchmark_lokr.py @@ -251,7 +251,8 @@ def weight_space_svd(lokr_path, lokr_name, ranks, no_offload=False): # Also print the LoKR config for reference lokr_cfg = pipe.transformer.peft_config.get(adapter_name) if lokr_cfg: - print(f" LoKR config: alpha={lokr_cfg.alpha}, r={lokr_cfg.r}") + alpha = getattr(lokr_cfg, "alpha", getattr(lokr_cfg, "lora_alpha", "?")) + print(f" Adapter config: {type(lokr_cfg).__name__}, alpha={alpha}, r={lokr_cfg.r}") # Compare each module: LoKR delta vs LoRA delta (lora_B @ lora_A) results = [] @@ -352,7 +353,8 @@ def benchmark_tier3_svd(pipe, prompt, seed, rank, lokr_path, lokr_name): t0 = time.time() lokr_cfg = pipe.transformer.peft_config.get(adapter_name) if lokr_cfg: - print(f" LoKR config: alpha={lokr_cfg.alpha}, r={lokr_cfg.r}") + alpha = getattr(lokr_cfg, "alpha", getattr(lokr_cfg, "lora_alpha", "?")) + print(f" Adapter config: {type(lokr_cfg).__name__}, alpha={alpha}, r={lokr_cfg.r}") t0 = time.time() lora_config, lora_sd = convert_to_lora(pipe.transformer, rank, adapter_name=adapter_name, progressbar=True) From 613e6a13df0e45f722f7891efb1a7ade35770e89 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 27 Mar 2026 09:11:18 +0000 Subject: [PATCH 11/11] Apply style fixes --- .../loaders/lora_conversion_utils.py | 21 +++++++++---------- src/diffusers/utils/peft_utils.py | 17 +++++++-------- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index a76429ee09a8..4d38e241c423 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -2643,8 +2643,8 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): def _nearest_kronecker_product(matrix, m1, n1, m2, n2): """Find the nearest rank-1 Kronecker product approximation (Van Loan & Pitsianis). - Given matrix M of shape (m1*m2, n1*n2), finds w1 (m1, n1) and w2 (m2, n2) - minimizing ||M - kron(w1, w2)||_F via rank-1 SVD of a rearranged matrix. + Given matrix M of shape (m1*m2, n1*n2), finds w1 (m1, n1) and w2 (m2, n2) minimizing ||M - kron(w1, w2)||_F via + rank-1 SVD of a rearranged matrix. """ # Rearrange M into R of shape (m1*n1, m2*n2) # R[i*n1+j, k*n2+l] = M[i*m2+k, j*n2+l] @@ -2661,8 +2661,8 @@ def _nearest_kronecker_product(matrix, m1, n1, m2, n2): def _split_lokr_qkv(w1, w2, target_keys, factor): """Split fused LoKR QKV factors into separate per-projection Kronecker factors. - Materializes kron(w1, w2), chunks along dim=0, and re-factorizes each chunk - as a rank-1 Kronecker product using the Van Loan algorithm. + Materializes kron(w1, w2), chunks along dim=0, and re-factorizes each chunk as a rank-1 Kronecker product using the + Van Loan algorithm. Args: w1: First Kronecker factor, shape (f, f) where f = decompose_factor. @@ -2694,9 +2694,8 @@ def _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict, fuse_qkv=False): Args: state_dict: BFL-format LoKR state dict with ``diffusion_model.`` prefix. fuse_qkv: If True, map fused QKV directly to ``to_qkv``/``to_added_qkv`` targets - (lossless, but requires the model's QKV to be fused before injection). - If False (default), split fused QKV into separate Q/K/V via Kronecker - re-factorization (slightly lossy, no model fusion needed). + (lossless, but requires the model's QKV to be fused before injection). If False (default), split fused QKV + into separate Q/K/V via Kronecker re-factorization (slightly lossy, no model fusion needed). """ converted_state_dict = {} @@ -2891,8 +2890,8 @@ def _bake_lokr_alpha(state_dict): def _convert_lycoris_flux2_lokr_to_diffusers(state_dict): """Convert LyCORIS underscore-format Flux2 LoKR state dict to peft-compatible diffusers format. - LyCORIS keys use underscore-encoded paths (e.g., lycoris_transformer_blocks_0_attn_to_q.lokr_w1). - Decodes these to dotted diffusers paths using a known sub-path lookup table. + LyCORIS keys use underscore-encoded paths (e.g., lycoris_transformer_blocks_0_attn_to_q.lokr_w1). Decodes these to + dotted diffusers paths using a known sub-path lookup table. """ import re @@ -2938,8 +2937,8 @@ def _convert_lycoris_flux2_lokr_to_diffusers(state_dict): def _convert_diffusers_flux2_lokr_to_peft(state_dict): """Convert diffusers-native Flux2 LoKR state dict by adding transformer. prefix and baking alpha. - Diffusers-native keys already use dotted module paths matching the model structure. - Only alpha baking and the transformer. prefix are needed. + Diffusers-native keys already use dotted module paths matching the model structure. Only alpha baking and the + transformer. prefix are needed. """ original_state_dict = dict(state_dict) _bake_lokr_alpha(original_state_dict) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 3089255a922b..6b21fbba13f5 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -347,13 +347,12 @@ def check_peft_version(min_version: str) -> None: def _create_lokr_config(state_dict): """Create a peft LoKrConfig from a converted LoKR state dict. - Infers rank, decompose_both, decompose_factor, and target_modules from the state dict key names - and tensor shapes. Alpha scaling is assumed to be already baked into the weights, so config - alpha = r (scaling = 1.0). + Infers rank, decompose_both, decompose_factor, and target_modules from the state dict key names and tensor shapes. + Alpha scaling is assumed to be already baked into the weights, so config alpha = r (scaling = 1.0). - Peft determines w2 decomposition via ``r < max(out_k, in_n) / 2``. We must set per-module rank - values that reproduce the same decomposition pattern as the checkpoint. For modules with full - (non-decomposed) lokr_w2, we set rank = max(lokr_w2.shape) so that peft also creates a full w2. + Peft determines w2 decomposition via ``r < max(out_k, in_n) / 2``. We must set per-module rank values that + reproduce the same decomposition pattern as the checkpoint. For modules with full (non-decomposed) lokr_w2, we set + rank = max(lokr_w2.shape) so that peft also creates a full w2. """ from peft import LoKrConfig @@ -431,9 +430,9 @@ def _create_lokr_config(state_dict): def _convert_adapter_to_lora(model, rank, adapter_name="default"): """Convert a loaded non-LoRA peft adapter (e.g., LoKR) to LoRA via truncated SVD. - Wraps ``peft.convert_to_lora`` which materializes each adapter layer's delta weight - and decomposes it as ``U @ diag(S) @ V ≈ lora_B @ lora_A``. The conversion is lossy: - higher ``rank`` preserves more fidelity at the cost of larger LoRA matrices. + Wraps ``peft.convert_to_lora`` which materializes each adapter layer's delta weight and decomposes it as ``U @ + diag(S) @ V ≈ lora_B @ lora_A``. The conversion is lossy: higher ``rank`` preserves more fidelity at the cost of + larger LoRA matrices. Args: model: ``nn.Module`` with a peft adapter already injected.