Add Flux2 LoKR adapter support prototype with dual conversion paths#13326
Add Flux2 LoKR adapter support prototype with dual conversion paths#13326CalamitousFelicitousness wants to merge 11 commits intohuggingface:mainfrom
Conversation
|
PR looks good AFAICT, but of course I'm no Diffusers expert. @CalamitousFelicitousness do you have results from the test script that you can share? Also, for other readers: The LoRA conversion currently requires installing PEFT from the main branch (in the future: version 0.19). |
As I mentioned in the message I can't currently fit it for the SVD tests, OOMs on my RTX 3090 and my 6000 Ada is not available at the moment. For now I only know programatic tests pass. |
|
I will help with the lossy path results. |
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks a lot for the clean PR! Left a couple of comments. LMK if anything is unclear.
|
@claude please review this PR as well. |
|
Also @CalamitousFelicitousness you might want to change the default prompt in the benchmarking script. That is highly NSFW. Let's be mindful of that. |
|
If someone can provide me with a link to a SFW LoKR I can use I can update it, the prompt was taken verbatim from the creator's examples. Using a prompt unrelated to the aim of the adaptor is not ideal. |
|
Then we will have to wait for one that is SFW. Respectfully, we cannot base our work on the grounds of NSFW content. @chaowenguo can you provide a LoKR checkpoint that doesn't use any form of nudity? |
|
I ran the benchmarking script anyway. The following diff was needed mostly to resolve the conflicts and to fix the benchmarking script: diffdiff --git a/benchmark_lokr_13326.py b/benchmark_lokr_13326.py
index 7fb53cced..9abbda788 100644
--- a/benchmark_lokr_13326.py
+++ b/benchmark_lokr_13326.py
@@ -19,7 +19,7 @@ from diffusers import Flux2KleinPipeline # noqa: E402
from peft import convert_to_lora # noqa: E402
MODEL_ID = "black-forest-labs/FLUX.2-klein-9B"
-LOKR_PATH = "chaowenguo/lora"
+LOKR_PATH = "puttmorbidly233/lora"
OUTPUT_DIR = "benchmark_output"
@@ -48,7 +48,7 @@ def benchmark_lossless(pipe, prompt, seed):
"""Path A: Load LoKR natively (lossless)."""
print("\n=== Path A: Lossless LoKR ===")
t0 = time.time()
- pipe.load_lora_weights(LOKR_PATH, weight_name="klein_snofs_v1_1.safetensors")
+ pipe.load_lora_weights(LOKR_PATH, weight_name="klein_snofs_v1_2.safetensors")
print(f" Loaded in {time.time() - t0:.1f}s")
t0 = time.time()
@@ -63,7 +63,7 @@ def benchmark_lossy(pipe, prompt, seed, rank):
"""Path B: Load LoKR, convert to LoRA via SVD (lossy)."""
print(f"\n=== Path B: Lossy LoRA via SVD (rank={rank}) ===")
t0 = time.time()
- pipe.load_lora_weights(LOKR_PATH, weight_name="klein_snofs_v1_1.safetensors")
+ pipe.load_lora_weights(LOKR_PATH, weight_name="klein_snofs_v1_2.safetensors")
load_time = time.time() - t0
# Detect the actual adapter name assigned by peft
@@ -79,8 +79,8 @@ def benchmark_lossy(pipe, prompt, seed, rank):
# Replace LoKR adapter with converted LoRA
from peft import inject_adapter_in_model, set_peft_model_state_dict
- pipe.transformer.delete_adapter(adapter_name)
- inject_adapter_in_model(pipe.transformer, lora_config, adapter_name=adapter_name)
+ pipe.transformer.delete_adapters(adapter_name)
+ inject_adapter_in_model(lora_config, pipe.transformer, adapter_name=adapter_name)
set_peft_model_state_dict(pipe.transformer, lora_sd, adapter_name=adapter_name)
t0 = time.time()
@@ -104,11 +104,10 @@ def main():
parser = argparse.ArgumentParser(description="Benchmark LoKR vs LoRA-via-SVD")
parser.add_argument(
"--prompt",
- default="A high-angle POV photograph shows a nude white woman with blonde hair",
+ default="A high-angle POV photograph shows a polar bear.",
)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--ranks", type=int, nargs="+", default=[32, 64, 128])
- parser.add_argument("--steps", type=int, default=28)
parser.add_argument("--skip-baseline", action="store_true")
args = parser.parse_args()
diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py
index df58af0a1..6f33b8219 100644
--- a/src/diffusers/loaders/lora_conversion_utils.py
+++ b/src/diffusers/loaders/lora_conversion_utils.py
@@ -2455,7 +2455,6 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
return converted_state_dict
-<<<<<<< HEAD
def _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict):
"""Convert non-diffusers Flux2 LoKR state dict (kohya/LyCORIS format) to peft-compatible diffusers format.
@@ -2600,191 +2599,6 @@ def _refuse_flux2_lora_state_dict(state_dict):
# Pass through all non-QKV keys unchanged
converted.update(remaining)
return converted
-=======
-def _convert_kohya_flux2_lora_to_diffusers(state_dict):
- def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
- if sds_key + ".lora_down.weight" not in sds_sd:
- return
- down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
-
- # scale weight by alpha and dim
- rank = down_weight.shape[0]
- default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False)
- alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item()
- scale = alpha / rank
-
- scale_down = scale
- scale_up = 1.0
- while scale_down * 2 < scale_up:
- scale_down *= 2
- scale_up /= 2
-
- ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
- ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
-
- def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
- if sds_key + ".lora_down.weight" not in sds_sd:
- return
- down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
- up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
- sd_lora_rank = down_weight.shape[0]
-
- default_alpha = torch.tensor(
- sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False
- )
- alpha = sds_sd.pop(sds_key + ".alpha", default_alpha)
- scale = alpha / sd_lora_rank
-
- scale_down = scale
- scale_up = 1.0
- while scale_down * 2 < scale_up:
- scale_down *= 2
- scale_up /= 2
-
- down_weight = down_weight * scale_down
- up_weight = up_weight * scale_up
-
- num_splits = len(ait_keys)
- if dims is None:
- dims = [up_weight.shape[0] // num_splits] * num_splits
- else:
- assert sum(dims) == up_weight.shape[0]
-
- # check if upweight is sparse
- is_sparse = False
- if sd_lora_rank % num_splits == 0:
- ait_rank = sd_lora_rank // num_splits
- is_sparse = True
- i = 0
- for j in range(len(dims)):
- for k in range(len(dims)):
- if j == k:
- continue
- is_sparse = is_sparse and torch.all(
- up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
- )
- i += dims[j]
- if is_sparse:
- logger.info(f"weight is sparse: {sds_key}")
-
- ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
- ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
- if not is_sparse:
- ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
- ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
- else:
- ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416
- i = 0
- for j in range(len(dims)):
- ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
- i += dims[j]
-
- # Detect number of blocks from keys
- num_double_layers = 0
- num_single_layers = 0
- for key in state_dict.keys():
- if key.startswith("lora_unet_double_blocks_"):
- block_idx = int(key.split("_")[4])
- num_double_layers = max(num_double_layers, block_idx + 1)
- elif key.startswith("lora_unet_single_blocks_"):
- block_idx = int(key.split("_")[4])
- num_single_layers = max(num_single_layers, block_idx + 1)
-
- ait_sd = {}
-
- for i in range(num_double_layers):
- # Attention projections
- _convert_to_ai_toolkit(
- state_dict,
- ait_sd,
- f"lora_unet_double_blocks_{i}_img_attn_proj",
- f"transformer.transformer_blocks.{i}.attn.to_out.0",
- )
- _convert_to_ai_toolkit_cat(
- state_dict,
- ait_sd,
- f"lora_unet_double_blocks_{i}_img_attn_qkv",
- [
- f"transformer.transformer_blocks.{i}.attn.to_q",
- f"transformer.transformer_blocks.{i}.attn.to_k",
- f"transformer.transformer_blocks.{i}.attn.to_v",
- ],
- )
- _convert_to_ai_toolkit(
- state_dict,
- ait_sd,
- f"lora_unet_double_blocks_{i}_txt_attn_proj",
- f"transformer.transformer_blocks.{i}.attn.to_add_out",
- )
- _convert_to_ai_toolkit_cat(
- state_dict,
- ait_sd,
- f"lora_unet_double_blocks_{i}_txt_attn_qkv",
- [
- f"transformer.transformer_blocks.{i}.attn.add_q_proj",
- f"transformer.transformer_blocks.{i}.attn.add_k_proj",
- f"transformer.transformer_blocks.{i}.attn.add_v_proj",
- ],
- )
- # MLP layers (Flux2 uses ff.linear_in/linear_out)
- _convert_to_ai_toolkit(
- state_dict,
- ait_sd,
- f"lora_unet_double_blocks_{i}_img_mlp_0",
- f"transformer.transformer_blocks.{i}.ff.linear_in",
- )
- _convert_to_ai_toolkit(
- state_dict,
- ait_sd,
- f"lora_unet_double_blocks_{i}_img_mlp_2",
- f"transformer.transformer_blocks.{i}.ff.linear_out",
- )
- _convert_to_ai_toolkit(
- state_dict,
- ait_sd,
- f"lora_unet_double_blocks_{i}_txt_mlp_0",
- f"transformer.transformer_blocks.{i}.ff_context.linear_in",
- )
- _convert_to_ai_toolkit(
- state_dict,
- ait_sd,
- f"lora_unet_double_blocks_{i}_txt_mlp_2",
- f"transformer.transformer_blocks.{i}.ff_context.linear_out",
- )
-
- for i in range(num_single_layers):
- # Single blocks: linear1 -> attn.to_qkv_mlp_proj (fused, no split needed)
- _convert_to_ai_toolkit(
- state_dict,
- ait_sd,
- f"lora_unet_single_blocks_{i}_linear1",
- f"transformer.single_transformer_blocks.{i}.attn.to_qkv_mlp_proj",
- )
- # Single blocks: linear2 -> attn.to_out
- _convert_to_ai_toolkit(
- state_dict,
- ait_sd,
- f"lora_unet_single_blocks_{i}_linear2",
- f"transformer.single_transformer_blocks.{i}.attn.to_out",
- )
-
- # Handle optional extra keys
- extra_mappings = {
- "lora_unet_img_in": "transformer.x_embedder",
- "lora_unet_txt_in": "transformer.context_embedder",
- "lora_unet_time_in_in_layer": "transformer.time_guidance_embed.timestep_embedder.linear_1",
- "lora_unet_time_in_out_layer": "transformer.time_guidance_embed.timestep_embedder.linear_2",
- "lora_unet_final_layer_linear": "transformer.proj_out",
- }
- for sds_key, ait_key in extra_mappings.items():
- _convert_to_ai_toolkit(state_dict, ait_sd, sds_key, ait_key)
-
- remaining_keys = list(state_dict.keys())
- if remaining_keys:
- logger.warning(f"Unsupported keys for Kohya Flux2 LoRA conversion: {remaining_keys}")
-
- return ait_sd
->>>>>>> 153fcbc5a (fix klein lora loading. (#13313))
def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index 6c8bba726..fea38ddc3 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -43,7 +43,7 @@ from .lora_conversion_utils import (
_convert_bfl_flux_control_lora_to_diffusers,
_convert_fal_kontext_lora_to_diffusers,
_convert_hunyuan_video_lora_to_diffusers,
- _convert_kohya_flux2_lora_to_diffusers,
+ # _convert_kohya_flux2_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
_convert_musubi_wan_lora_to_diffusers,
_convert_non_diffusers_flux2_lokr_to_diffusers,
diff --git a/updates.patch b/updates.patch
index 3c31af601..523f11b9d 100644
--- a/updates.patch
+++ b/updates.patch
@@ -1,24 +0,0 @@
-diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py
-index 20e6e09fd..62fc2f81e 100644
---- a/tests/models/testing_utils/quantization.py
-+++ b/tests/models/testing_utils/quantization.py
-@@ -175,6 +175,11 @@ class QuantizationTesterMixin:
- model_quantized.to(torch_device)
-
- inputs = self.get_dummy_inputs()
-+ model_dtype = next(model_quantized.parameters()).dtype
-+ inputs = {
-+ k: v.to(dtype=model_dtype) if torch.is_tensor(v) and torch.is_floating_point(v) else v
-+ for k, v in inputs.items()
-+ }
- output = model_quantized(**inputs, return_dict=False)[0]
-
- assert output is not None, "Model output is None"
-@@ -928,6 +933,7 @@ class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin):
- """Test that device_map='auto' works correctly with quantization."""
- self._test_quantization_device_map(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"])
-
-+ @pytest.mark.xfail(reason="dequantize is not implemented in torchao")
- def test_torchao_dequantize(self):
- """Test that dequantize() works correctly."""
- self._test_dequantize(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"])I know the results don't mean much but here's a comparative plot:
Additionally, here's the log: https://pastebin.com/HUV2GUjc |
|
This LoKR is for a giantess effect and appears to be SFW. |
I think the bear that lost the head would beg to differ. |
Depending on the LoRA the head might have been given. Which is not just an offtop joke since that genuinely could be the LoRA working correctly. |
- Custom lossless path: BFL LoKR keys → peft LoKrConfig (fuse-first QKV) - Generic lossy path: optional SVD conversion via peft.convert_to_lora - Fix alpha handling for lora_down/lora_up format checkpoints - Re-fuse LoRA keys when model QKV is fused from prior LoKR load
- BFL format: remap keys + split fused QKV via Kronecker re-factorization (Van Loan) - LyCORIS format: decode underscore-encoded paths to diffusers module names - Diffusers native format: add transformer. prefix and bake alpha - Generic lossy path: _convert_adapter_to_lora utility wrapping peft.convert_to_lora - Fix alpha handling for lora_down/lora_up format checkpoints
…SVD) - Add fuse_qkv parameter to BFL LoKR converter for lossless fuse-first path - Thread fuse_qkv through lora_pipeline.py (lora_state_dict -> load_lora_weights) - Fuse model QKV projections before adapter injection when fuse_qkv=True - Update benchmark script with --tiers, --no-offload flags for all three paths
fuse_qkv=True with LyCORIS or diffusers-native checkpoints would fuse the model QKV then fail injection (adapter targets separate Q/K/V modules that no longer exist). Now only set fuse_qkv metadata for BFL format.
Compares materialized kron(w1, w2) from fuse-first path against reconstructed cat(kron(w1_q, w2_q), ...) from Van Loan split. Reports per-module and aggregate relative Frobenius norm error. No model loading needed - runs on checkpoint state dict only.
So, the SVD conversions either don't work properly due to bug in my code, or the loss is big enough that it nullifies a lot of the effects the adaptor should have. |
Seems to be just a LoRA. |
Tier 1 vs 2 (Kronecker): lightweight, no model needed. Tier 1 vs 3 (SVD): loads model, runs peft.convert_to_lora, compares materialized LoKR deltas against LoRA deltas for all modules at each requested rank.
- Rename lora_config to adapter_config in load_lora_adapter (peft.py) - Remove redundant import collections (already at module level in peft_utils.py)
c41b7b5 to
b5958e6
Compare
|
So far it seems that the lossy path decimates the impact of the LoRA, the remnants are there and are clear, but even at rank 128 it does not result in the expected outcome. I added additional tests to measure the impacts, fused vs non-fused QKV are as follows, rest is ongoing: Losless vs QKV Kronecker split (rest of modules lossless)``` === Weight-space error: Tier 1 (lossless) vs Tier 2 (Kronecker split) ===Found 16 fused QKV modules to compare Tier 1 (lossless) vs Tier 2 (Kronecker split) - QKV modules only Module Rel Error % Abs Error Orig Norm transformer_blocks.0.attn.to_added_qkv 0.699293% 8.666989 12.3939 Aggregate over 16 modules: |
|
Thanks @CalamitousFelicitousness. Did you want me to take another look at the latest changes? |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Yes, please. I'll update the main description with the commands for more benchmarking, as I added some more tests to try and chase the problem down. I'm starting to lean towards this potentially being something in PEFT. Loss in PEFT path is higher than mine in both my completely lossless and QKV lossy paths, especially that SVD touches all 144 modules (rather than just 16 in QKV), but on the face of it seems unlikely it would be enough to cause the fidelity loss. What is of confirmed concern is the 7+ minute conversion time on A6000 GPU. |
|
Yes the conversion time is indeed a thing which can be sped up by providing I will run the benchmark first to eliminate any pending issues and then review. I am hoping that the benchmarking script is ready to fire (I think that the bare minimum that was there when I ran it yesterday is sufficient). |
|
@bot /style |
|
Style bot fixed some files and pushed the changes. |
The PEFT approach is very brute force, but I don't know any other way that would work with the breadth of methods we support in PEFT. That said, if for a specific method like LoKr, there is a more precise implementation, I'd lean towards using that here (and PEFT for everything else). The main disadvantage is probably the maintenance burden compared to relying on PEFT for the weight conversion. |





Adds support for Flux2 LoKR, with dual path to benchmark implementations.
Given that Civitai does not have a LoKR category I didn't feel like digging for a SFW one, so I just used what user brought up to me when they reported the issue.
Benchmark test
Not sure if I'm doing it correctly for PEFT, but due to lack of quantization support I couldn't run the lossy path for now.
What does this PR do?
Fixes #13261
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sayakpaul @BenjaminBossan