Skip to content

Add Flux2 LoKR adapter support prototype with dual conversion paths#13326

Open
CalamitousFelicitousness wants to merge 11 commits intohuggingface:mainfrom
CalamitousFelicitousness:feature/flux2-klein-lokr
Open

Add Flux2 LoKR adapter support prototype with dual conversion paths#13326
CalamitousFelicitousness wants to merge 11 commits intohuggingface:mainfrom
CalamitousFelicitousness:feature/flux2-klein-lokr

Conversation

@CalamitousFelicitousness
Copy link
Copy Markdown
Contributor

@CalamitousFelicitousness CalamitousFelicitousness commented Mar 25, 2026

Adds support for Flux2 LoKR, with dual path to benchmark implementations.

  • Custom lossless path: BFL LoKR keys → peft LoKrConfig (fuse-first QKV)
  • Generic lossy path: optional SVD conversion via peft.convert_to_lora
  • Fix alpha handling for lora_down/lora_up format checkpoints
  • Re-fuse LoRA keys when model QKV is fused from prior LoKR load

Given that Civitai does not have a LoKR category I didn't feel like digging for a SFW one, so I just used what user brought up to me when they reported the issue.

Benchmark test
"""Benchmark: Lossless LoKR vs Lossy LoRA-via-SVD on Flux2 Klein 9B.

Generates images using both conversion paths for visual comparison.
Uses 4-bit quantization + CPU offload to fit on a single 24GB GPU.

Usage:
    python benchmark_lokr.py
    python benchmark_lokr.py --prompt "a cat in a garden" --ranks 32 64 128
"""

import argparse
import gc
import os
import time

import torch

# Use local model cache
os.environ["HF_HUB_CACHE"] = "/home/ohiom/database/models/huggingface"

from diffusers import Flux2KleinPipeline  # noqa: E402
from peft import convert_to_lora  # noqa: E402

MODEL_ID = "black-forest-labs/FLUX.2-klein-9B"
LOKR_PATH = "/home/ohiom/database/models/Lora/Flux.2 Klein 9B/klein_snofs_v1_1.safetensors"
OUTPUT_DIR = "/home/ohiom/diffusers/benchmark_output"


def load_pipeline():
    """Load Flux2 Klein 9B in bf16 with model CPU offload."""
    pipe = Flux2KleinPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
    pipe.enable_model_cpu_offload()
    return pipe


def generate(pipe, prompt, seed, num_steps=4, guidance_scale=1.0):
    """Generate a single image with fixed seed for reproducibility."""
    generator = torch.Generator(device="cpu").manual_seed(seed)
    image = pipe(
        prompt=prompt,
        num_inference_steps=num_steps,
        guidance_scale=guidance_scale,
        generator=generator,
        height=1024,
        width=1024,
    ).images[0]
    return image


def benchmark_lossless(pipe, prompt, seed):
    """Path A: Load LoKR natively (lossless)."""
    print("\n=== Path A: Lossless LoKR ===")
    t0 = time.time()
    pipe.load_lora_weights(LOKR_PATH)
    print(f"  Loaded in {time.time() - t0:.1f}s")

    t0 = time.time()
    image = generate(pipe, prompt, seed)
    print(f"  Generated in {time.time() - t0:.1f}s")

    pipe.unload_lora_weights()
    return image


def benchmark_lossy(pipe, prompt, seed, rank):
    """Path B: Load LoKR, convert to LoRA via SVD (lossy)."""
    print(f"\n=== Path B: Lossy LoRA via SVD (rank={rank}) ===")
    t0 = time.time()
    pipe.load_lora_weights(LOKR_PATH)
    load_time = time.time() - t0

    # Detect the actual adapter name assigned by peft
    adapter_name = pipe.transformer.peft_config.keys().__iter__().__next__()
    print(f"  Adapter name: {adapter_name}")

    pipe.transformer.to("cuda")
    t0 = time.time()
    lora_config, lora_sd = convert_to_lora(pipe.transformer, rank, adapter_name=adapter_name, progressbar=True)
    convert_time = time.time() - t0
    print(f"  Loaded LoKR in {load_time:.1f}s, converted to LoRA in {convert_time:.1f}s")

    # Replace LoKR adapter with converted LoRA
    from peft import inject_adapter_in_model, set_peft_model_state_dict

    pipe.transformer.delete_adapter(adapter_name)
    inject_adapter_in_model(pipe.transformer, lora_config, adapter_name=adapter_name)
    set_peft_model_state_dict(pipe.transformer, lora_sd, adapter_name=adapter_name)

    t0 = time.time()
    image = generate(pipe, prompt, seed)
    print(f"  Generated in {time.time() - t0:.1f}s")

    pipe.unload_lora_weights()
    return image


def benchmark_baseline(pipe, prompt, seed):
    """Baseline: No adapter."""
    print("\n=== Baseline: No adapter ===")
    t0 = time.time()
    image = generate(pipe, prompt, seed)
    print(f"  Generated in {time.time() - t0:.1f}s")
    return image


def main():
    parser = argparse.ArgumentParser(description="Benchmark LoKR vs LoRA-via-SVD")
    parser.add_argument(
        "--prompt",
        default="A high-angle POV photograph shows a nude white woman with blonde hair",
    )
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--ranks", type=int, nargs="+", default=[32, 64, 128])
    parser.add_argument("--steps", type=int, default=28)
    parser.add_argument("--skip-baseline", action="store_true")
    args = parser.parse_args()

    os.makedirs(OUTPUT_DIR, exist_ok=True)

    print(f"Model: {MODEL_ID}")
    print(f"LoKR:  {LOKR_PATH}")
    print(f"Prompt: {args.prompt}")
    print(f"Seed: {args.seed}")
    print(f"SVD ranks to test: {args.ranks}")

    print("\nLoading pipeline (bf16, model CPU offload)...")
    pipe = load_pipeline()

    # Baseline
    if not args.skip_baseline:
        img = benchmark_baseline(pipe, args.prompt, args.seed)
        path = os.path.join(OUTPUT_DIR, "baseline.png")
        img.save(path)
        print(f"  Saved: {path}")

    # Path A: Lossless LoKR
    img = benchmark_lossless(pipe, args.prompt, args.seed)
    path = os.path.join(OUTPUT_DIR, "lokr_lossless.png")
    img.save(path)
    print(f"  Saved: {path}")

    gc.collect()
    torch.cuda.empty_cache()

    # Path B: Lossy LoRA via SVD at various ranks
    for rank in args.ranks:
        img = benchmark_lossy(pipe, args.prompt, args.seed, rank)
        path = os.path.join(OUTPUT_DIR, f"lora_svd_rank{rank}.png")
        img.save(path)
        print(f"  Saved: {path}")

        gc.collect()
        torch.cuda.empty_cache()

    print(f"\nAll results saved to {OUTPUT_DIR}/")
    print("Compare: baseline.png vs lokr_lossless.png vs lora_svd_rank*.png")


if __name__ == "__main__":
    main()

Not sure if I'm doing it correctly for PEFT, but due to lack of quantization support I couldn't run the lossy path for now.

What does this PR do?

Fixes #13261

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul @BenjaminBossan

@BenjaminBossan
Copy link
Copy Markdown
Member

PR looks good AFAICT, but of course I'm no Diffusers expert. @CalamitousFelicitousness do you have results from the test script that you can share? Also, for other readers: The LoRA conversion currently requires installing PEFT from the main branch (in the future: version 0.19).

@CalamitousFelicitousness
Copy link
Copy Markdown
Contributor Author

CalamitousFelicitousness commented Mar 25, 2026

PR looks good AFAICT, but of course I'm no Diffusers expert. @CalamitousFelicitousness do you have results from the test script that you can share? Also, for other readers: The LoRA conversion currently requires installing PEFT from the main branch (in the future: version 0.19).

As I mentioned in the message I can't currently fit it for the SVD tests, OOMs on my RTX 3090 and my 6000 Ada is not available at the moment. For now I only know programatic tests pass.

@sayakpaul
Copy link
Copy Markdown
Member

I will help with the lossy path results.

Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the clean PR! Left a couple of comments. LMK if anything is unclear.

@sayakpaul
Copy link
Copy Markdown
Member

@claude please review this PR as well.

@sayakpaul
Copy link
Copy Markdown
Member

Also @CalamitousFelicitousness you might want to change the default prompt in the benchmarking script. That is highly NSFW. Let's be mindful of that.

@CalamitousFelicitousness
Copy link
Copy Markdown
Contributor Author

If someone can provide me with a link to a SFW LoKR I can use I can update it, the prompt was taken verbatim from the creator's examples. Using a prompt unrelated to the aim of the adaptor is not ideal.

@sayakpaul
Copy link
Copy Markdown
Member

Then we will have to wait for one that is SFW. Respectfully, we cannot base our work on the grounds of NSFW content.

@chaowenguo can you provide a LoKR checkpoint that doesn't use any form of nudity?

@sayakpaul
Copy link
Copy Markdown
Member

I ran the benchmarking script anyway. The following diff was needed mostly to resolve the conflicts and to fix the benchmarking script:

diff
diff --git a/benchmark_lokr_13326.py b/benchmark_lokr_13326.py
index 7fb53cced..9abbda788 100644
--- a/benchmark_lokr_13326.py
+++ b/benchmark_lokr_13326.py
@@ -19,7 +19,7 @@ from diffusers import Flux2KleinPipeline  # noqa: E402
 from peft import convert_to_lora  # noqa: E402
 
 MODEL_ID = "black-forest-labs/FLUX.2-klein-9B"
-LOKR_PATH = "chaowenguo/lora"
+LOKR_PATH = "puttmorbidly233/lora"
 OUTPUT_DIR = "benchmark_output"
 
 
@@ -48,7 +48,7 @@ def benchmark_lossless(pipe, prompt, seed):
     """Path A: Load LoKR natively (lossless)."""
     print("\n=== Path A: Lossless LoKR ===")
     t0 = time.time()
-    pipe.load_lora_weights(LOKR_PATH, weight_name="klein_snofs_v1_1.safetensors")
+    pipe.load_lora_weights(LOKR_PATH, weight_name="klein_snofs_v1_2.safetensors")
     print(f"  Loaded in {time.time() - t0:.1f}s")
 
     t0 = time.time()
@@ -63,7 +63,7 @@ def benchmark_lossy(pipe, prompt, seed, rank):
     """Path B: Load LoKR, convert to LoRA via SVD (lossy)."""
     print(f"\n=== Path B: Lossy LoRA via SVD (rank={rank}) ===")
     t0 = time.time()
-    pipe.load_lora_weights(LOKR_PATH, weight_name="klein_snofs_v1_1.safetensors")
+    pipe.load_lora_weights(LOKR_PATH, weight_name="klein_snofs_v1_2.safetensors")
     load_time = time.time() - t0
 
     # Detect the actual adapter name assigned by peft
@@ -79,8 +79,8 @@ def benchmark_lossy(pipe, prompt, seed, rank):
     # Replace LoKR adapter with converted LoRA
     from peft import inject_adapter_in_model, set_peft_model_state_dict
 
-    pipe.transformer.delete_adapter(adapter_name)
-    inject_adapter_in_model(pipe.transformer, lora_config, adapter_name=adapter_name)
+    pipe.transformer.delete_adapters(adapter_name)
+    inject_adapter_in_model(lora_config, pipe.transformer, adapter_name=adapter_name)
     set_peft_model_state_dict(pipe.transformer, lora_sd, adapter_name=adapter_name)
 
     t0 = time.time()
@@ -104,11 +104,10 @@ def main():
     parser = argparse.ArgumentParser(description="Benchmark LoKR vs LoRA-via-SVD")
     parser.add_argument(
         "--prompt",
-        default="A high-angle POV photograph shows a nude white woman with blonde hair",
+        default="A high-angle POV photograph shows a polar bear.",
     )
     parser.add_argument("--seed", type=int, default=42)
     parser.add_argument("--ranks", type=int, nargs="+", default=[32, 64, 128])
-    parser.add_argument("--steps", type=int, default=28)
     parser.add_argument("--skip-baseline", action="store_true")
     args = parser.parse_args()
 
diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py
index df58af0a1..6f33b8219 100644
--- a/src/diffusers/loaders/lora_conversion_utils.py
+++ b/src/diffusers/loaders/lora_conversion_utils.py
@@ -2455,7 +2455,6 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
     return converted_state_dict
 
 
-<<<<<<< HEAD
 def _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict):
     """Convert non-diffusers Flux2 LoKR state dict (kohya/LyCORIS format) to peft-compatible diffusers format.
 
@@ -2600,191 +2599,6 @@ def _refuse_flux2_lora_state_dict(state_dict):
     # Pass through all non-QKV keys unchanged
     converted.update(remaining)
     return converted
-=======
-def _convert_kohya_flux2_lora_to_diffusers(state_dict):
-    def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
-        if sds_key + ".lora_down.weight" not in sds_sd:
-            return
-        down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
-
-        # scale weight by alpha and dim
-        rank = down_weight.shape[0]
-        default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False)
-        alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item()
-        scale = alpha / rank
-
-        scale_down = scale
-        scale_up = 1.0
-        while scale_down * 2 < scale_up:
-            scale_down *= 2
-            scale_up /= 2
-
-        ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
-        ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
-
-    def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
-        if sds_key + ".lora_down.weight" not in sds_sd:
-            return
-        down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
-        up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
-        sd_lora_rank = down_weight.shape[0]
-
-        default_alpha = torch.tensor(
-            sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False
-        )
-        alpha = sds_sd.pop(sds_key + ".alpha", default_alpha)
-        scale = alpha / sd_lora_rank
-
-        scale_down = scale
-        scale_up = 1.0
-        while scale_down * 2 < scale_up:
-            scale_down *= 2
-            scale_up /= 2
-
-        down_weight = down_weight * scale_down
-        up_weight = up_weight * scale_up
-
-        num_splits = len(ait_keys)
-        if dims is None:
-            dims = [up_weight.shape[0] // num_splits] * num_splits
-        else:
-            assert sum(dims) == up_weight.shape[0]
-
-        # check if upweight is sparse
-        is_sparse = False
-        if sd_lora_rank % num_splits == 0:
-            ait_rank = sd_lora_rank // num_splits
-            is_sparse = True
-            i = 0
-            for j in range(len(dims)):
-                for k in range(len(dims)):
-                    if j == k:
-                        continue
-                    is_sparse = is_sparse and torch.all(
-                        up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
-                    )
-                i += dims[j]
-            if is_sparse:
-                logger.info(f"weight is sparse: {sds_key}")
-
-        ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
-        ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
-        if not is_sparse:
-            ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
-            ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))})  # noqa: C416
-        else:
-            ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))})  # noqa: C416
-            i = 0
-            for j in range(len(dims)):
-                ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
-                i += dims[j]
-
-    # Detect number of blocks from keys
-    num_double_layers = 0
-    num_single_layers = 0
-    for key in state_dict.keys():
-        if key.startswith("lora_unet_double_blocks_"):
-            block_idx = int(key.split("_")[4])
-            num_double_layers = max(num_double_layers, block_idx + 1)
-        elif key.startswith("lora_unet_single_blocks_"):
-            block_idx = int(key.split("_")[4])
-            num_single_layers = max(num_single_layers, block_idx + 1)
-
-    ait_sd = {}
-
-    for i in range(num_double_layers):
-        # Attention projections
-        _convert_to_ai_toolkit(
-            state_dict,
-            ait_sd,
-            f"lora_unet_double_blocks_{i}_img_attn_proj",
-            f"transformer.transformer_blocks.{i}.attn.to_out.0",
-        )
-        _convert_to_ai_toolkit_cat(
-            state_dict,
-            ait_sd,
-            f"lora_unet_double_blocks_{i}_img_attn_qkv",
-            [
-                f"transformer.transformer_blocks.{i}.attn.to_q",
-                f"transformer.transformer_blocks.{i}.attn.to_k",
-                f"transformer.transformer_blocks.{i}.attn.to_v",
-            ],
-        )
-        _convert_to_ai_toolkit(
-            state_dict,
-            ait_sd,
-            f"lora_unet_double_blocks_{i}_txt_attn_proj",
-            f"transformer.transformer_blocks.{i}.attn.to_add_out",
-        )
-        _convert_to_ai_toolkit_cat(
-            state_dict,
-            ait_sd,
-            f"lora_unet_double_blocks_{i}_txt_attn_qkv",
-            [
-                f"transformer.transformer_blocks.{i}.attn.add_q_proj",
-                f"transformer.transformer_blocks.{i}.attn.add_k_proj",
-                f"transformer.transformer_blocks.{i}.attn.add_v_proj",
-            ],
-        )
-        # MLP layers (Flux2 uses ff.linear_in/linear_out)
-        _convert_to_ai_toolkit(
-            state_dict,
-            ait_sd,
-            f"lora_unet_double_blocks_{i}_img_mlp_0",
-            f"transformer.transformer_blocks.{i}.ff.linear_in",
-        )
-        _convert_to_ai_toolkit(
-            state_dict,
-            ait_sd,
-            f"lora_unet_double_blocks_{i}_img_mlp_2",
-            f"transformer.transformer_blocks.{i}.ff.linear_out",
-        )
-        _convert_to_ai_toolkit(
-            state_dict,
-            ait_sd,
-            f"lora_unet_double_blocks_{i}_txt_mlp_0",
-            f"transformer.transformer_blocks.{i}.ff_context.linear_in",
-        )
-        _convert_to_ai_toolkit(
-            state_dict,
-            ait_sd,
-            f"lora_unet_double_blocks_{i}_txt_mlp_2",
-            f"transformer.transformer_blocks.{i}.ff_context.linear_out",
-        )
-
-    for i in range(num_single_layers):
-        # Single blocks: linear1 -> attn.to_qkv_mlp_proj (fused, no split needed)
-        _convert_to_ai_toolkit(
-            state_dict,
-            ait_sd,
-            f"lora_unet_single_blocks_{i}_linear1",
-            f"transformer.single_transformer_blocks.{i}.attn.to_qkv_mlp_proj",
-        )
-        # Single blocks: linear2 -> attn.to_out
-        _convert_to_ai_toolkit(
-            state_dict,
-            ait_sd,
-            f"lora_unet_single_blocks_{i}_linear2",
-            f"transformer.single_transformer_blocks.{i}.attn.to_out",
-        )
-
-    # Handle optional extra keys
-    extra_mappings = {
-        "lora_unet_img_in": "transformer.x_embedder",
-        "lora_unet_txt_in": "transformer.context_embedder",
-        "lora_unet_time_in_in_layer": "transformer.time_guidance_embed.timestep_embedder.linear_1",
-        "lora_unet_time_in_out_layer": "transformer.time_guidance_embed.timestep_embedder.linear_2",
-        "lora_unet_final_layer_linear": "transformer.proj_out",
-    }
-    for sds_key, ait_key in extra_mappings.items():
-        _convert_to_ai_toolkit(state_dict, ait_sd, sds_key, ait_key)
-
-    remaining_keys = list(state_dict.keys())
-    if remaining_keys:
-        logger.warning(f"Unsupported keys for Kohya Flux2 LoRA conversion: {remaining_keys}")
-
-    return ait_sd
->>>>>>> 153fcbc5a (fix klein lora loading. (#13313))
 
 
 def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index 6c8bba726..fea38ddc3 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -43,7 +43,7 @@ from .lora_conversion_utils import (
     _convert_bfl_flux_control_lora_to_diffusers,
     _convert_fal_kontext_lora_to_diffusers,
     _convert_hunyuan_video_lora_to_diffusers,
-    _convert_kohya_flux2_lora_to_diffusers,
+    # _convert_kohya_flux2_lora_to_diffusers,
     _convert_kohya_flux_lora_to_diffusers,
     _convert_musubi_wan_lora_to_diffusers,
     _convert_non_diffusers_flux2_lokr_to_diffusers,
diff --git a/updates.patch b/updates.patch
index 3c31af601..523f11b9d 100644
--- a/updates.patch
+++ b/updates.patch
@@ -1,24 +0,0 @@
-diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py
-index 20e6e09fd..62fc2f81e 100644
---- a/tests/models/testing_utils/quantization.py
-+++ b/tests/models/testing_utils/quantization.py
-@@ -175,6 +175,11 @@ class QuantizationTesterMixin:
-         model_quantized.to(torch_device)
- 
-         inputs = self.get_dummy_inputs()
-+        model_dtype = next(model_quantized.parameters()).dtype
-+        inputs = {
-+            k: v.to(dtype=model_dtype) if torch.is_tensor(v) and torch.is_floating_point(v) else v
-+            for k, v in inputs.items()
-+        }
-         output = model_quantized(**inputs, return_dict=False)[0]
- 
-         assert output is not None, "Model output is None"
-@@ -928,6 +933,7 @@ class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin):
-         """Test that device_map='auto' works correctly with quantization."""
-         self._test_quantization_device_map(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"])
- 
-+    @pytest.mark.xfail(reason="dequantize is not implemented in torchao")
-     def test_torchao_dequantize(self):
-         """Test that dequantize() works correctly."""
-         self._test_dequantize(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"])

I know the results don't mean much but here's a comparative plot:

Baseline LoRA SVD Rank 32 LoRA SVD Rank 64 LoRA SVD Rank 128 LoKR Lossless
baseline rank32 rank64 rank128 lokr

Additionally, here's the log: https://pastebin.com/HUV2GUjc

@iwr-redmond
Copy link
Copy Markdown

This LoKR is for a giantess effect and appears to be SFW.

@BenjaminBossan
Copy link
Copy Markdown
Member

LoKR Lossless

I think the bear that lost the head would beg to differ.

@CalamitousFelicitousness
Copy link
Copy Markdown
Contributor Author

CalamitousFelicitousness commented Mar 26, 2026

LoKR Lossless

I think the bear that lost the head would beg to differ.

Depending on the LoRA the head might have been given. Which is not just an offtop joke since that genuinely could be the LoRA working correctly.

- Custom lossless path: BFL LoKR keys → peft LoKrConfig (fuse-first QKV)
- Generic lossy path: optional SVD conversion via peft.convert_to_lora
- Fix alpha handling for lora_down/lora_up format checkpoints
- Re-fuse LoRA keys when model QKV is fused from prior LoKR load
- BFL format: remap keys + split fused QKV via Kronecker re-factorization (Van Loan)
- LyCORIS format: decode underscore-encoded paths to diffusers module names
- Diffusers native format: add transformer. prefix and bake alpha
- Generic lossy path: _convert_adapter_to_lora utility wrapping peft.convert_to_lora
- Fix alpha handling for lora_down/lora_up format checkpoints
…SVD)

- Add fuse_qkv parameter to BFL LoKR converter for lossless fuse-first path
- Thread fuse_qkv through lora_pipeline.py (lora_state_dict -> load_lora_weights)
- Fuse model QKV projections before adapter injection when fuse_qkv=True
- Update benchmark script with --tiers, --no-offload flags for all three paths
fuse_qkv=True with LyCORIS or diffusers-native checkpoints would fuse
the model QKV then fail injection (adapter targets separate Q/K/V modules
that no longer exist). Now only set fuse_qkv metadata for BFL format.
Compares materialized kron(w1, w2) from fuse-first path against
reconstructed cat(kron(w1_q, w2_q), ...) from Van Loan split.
Reports per-module and aggregate relative Frobenius norm error.
No model loading needed - runs on checkpoint state dict only.
@CalamitousFelicitousness
Copy link
Copy Markdown
Contributor Author

CalamitousFelicitousness commented Mar 26, 2026

I ran the benchmarking script anyway. The following diff was needed mostly to resolve the conflicts and to fix the benchmarking script:
Additionally, here's the log: https://pastebin.com/HUV2GUjc

So, the SVD conversions either don't work properly due to bug in my code, or the loss is big enough that it nullifies a lot of the effects the adaptor should have.

@CalamitousFelicitousness
Copy link
Copy Markdown
Contributor Author

This LoKR is for a giantess effect and appears to be SFW.

Seems to be just a LoRA.

Tier 1 vs 2 (Kronecker): lightweight, no model needed.
Tier 1 vs 3 (SVD): loads model, runs peft.convert_to_lora,
compares materialized LoKR deltas against LoRA deltas for
all modules at each requested rank.
- Rename lora_config to adapter_config in load_lora_adapter (peft.py)
- Remove redundant import collections (already at module level in peft_utils.py)
@CalamitousFelicitousness
Copy link
Copy Markdown
Contributor Author

CalamitousFelicitousness commented Mar 26, 2026

So far it seems that the lossy path decimates the impact of the LoRA, the remnants are there and are clear, but even at rank 128 it does not result in the expected outcome.

I added additional tests to measure the impacts, fused vs non-fused QKV are as follows, rest is ongoing:

Losless vs QKV Kronecker split (rest of modules lossless) ``` === Weight-space error: Tier 1 (lossless) vs Tier 2 (Kronecker split) ===

Found 16 fused QKV modules to compare

Tier 1 (lossless) vs Tier 2 (Kronecker split) - QKV modules only

Module Rel Error % Abs Error Orig Norm


transformer_blocks.0.attn.to_added_qkv 0.699293% 8.666989 12.3939
transformer_blocks.0.attn.to_qkv 0.707887% 3.421168 4.8329
transformer_blocks.1.attn.to_added_qkv 0.760188% 8.873293 11.6725
transformer_blocks.1.attn.to_qkv 0.664022% 4.771522 7.1858
transformer_blocks.2.attn.to_added_qkv 0.722216% 10.066339 13.9381
transformer_blocks.2.attn.to_qkv 0.672340% 5.162684 7.6787
transformer_blocks.3.attn.to_added_qkv 0.712900% 8.487123 11.9051
transformer_blocks.3.attn.to_qkv 0.765270% 5.232332 6.8372
transformer_blocks.4.attn.to_added_qkv 0.701074% 9.473522 13.5129
transformer_blocks.4.attn.to_qkv 0.703233% 3.634037 5.1676
transformer_blocks.5.attn.to_added_qkv 0.720226% 8.557453 11.8816
transformer_blocks.5.attn.to_qkv 0.709118% 3.895474 5.4934
transformer_blocks.6.attn.to_added_qkv 0.702485% 7.409676 10.5478
transformer_blocks.6.attn.to_qkv 0.732748% 4.318213 5.8932
transformer_blocks.7.attn.to_added_qkv 0.738822% 8.467062 11.4602
transformer_blocks.7.attn.to_qkv 0.665311% 3.437334 5.1665

Aggregate over 16 modules:
Mean relative error: 0.711071%
Max relative error: 0.765270%
Min relative error: 0.664022%

</details>

@sayakpaul
Copy link
Copy Markdown
Member

Thanks @CalamitousFelicitousness. Did you want me to take another look at the latest changes?

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@CalamitousFelicitousness
Copy link
Copy Markdown
Contributor Author

Thanks @CalamitousFelicitousness. Did you want me to take another look at the latest changes?

Yes, please. I'll update the main description with the commands for more benchmarking, as I added some more tests to try and chase the problem down.

I'm starting to lean towards this potentially being something in PEFT. Loss in PEFT path is higher than mine in both my completely lossless and QKV lossy paths, especially that SVD touches all 144 modules (rather than just 16 in QKV), but on the face of it seems unlikely it would be enough to cause the fidelity loss.

What is of confirmed concern is the 7+ minute conversion time on A6000 GPU.

@sayakpaul
Copy link
Copy Markdown
Member

Yes the conversion time is indeed a thing which can be sped up by providing compilation_kwargs (prefer using fullgraph=True, dynamic=True, mode="max-autotune-no-cudagraphs").

I will run the benchmark first to eliminate any pending issues and then review. I am hoping that the benchmarking script is ready to fire (I think that the bare minimum that was there when I ran it yesterday is sufficient).

@sayakpaul
Copy link
Copy Markdown
Member

@bot /style

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Mar 27, 2026

Style bot fixed some files and pushed the changes.

@BenjaminBossan
Copy link
Copy Markdown
Member

I'm starting to lean towards this potentially being something in PEFT. Loss in PEFT path is higher than mine in both my completely lossless and QKV lossy paths, especially that SVD touches all 144 modules (rather than just 16 in QKV), but on the face of it seems unlikely it would be enough to cause the fidelity loss.

The PEFT approach is very brute force, but I don't know any other way that would work with the breadth of methods we support in PEFT. That said, if for a specific method like LoKr, there is a more precise implementation, I'd lean towards using that here (and PEFT for everything else). The main disadvantage is probably the maintenance burden compared to relying on PEFT for the weight conversion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

black-forest-labs/FLUX.2-klein-9B not working with lora with lokr

5 participants