Skip to content

Add Flux2 LoKR adapter support prototype with dual conversion paths#13326

Open
CalamitousFelicitousness wants to merge 1 commit intohuggingface:mainfrom
CalamitousFelicitousness:feature/flux2-klein-lokr
Open

Add Flux2 LoKR adapter support prototype with dual conversion paths#13326
CalamitousFelicitousness wants to merge 1 commit intohuggingface:mainfrom
CalamitousFelicitousness:feature/flux2-klein-lokr

Conversation

@CalamitousFelicitousness
Copy link
Contributor

@CalamitousFelicitousness CalamitousFelicitousness commented Mar 25, 2026

Adds support for Flux2 LoKR, with dual path to benchmark implementations.

  • Custom lossless path: BFL LoKR keys → peft LoKrConfig (fuse-first QKV)
  • Generic lossy path: optional SVD conversion via peft.convert_to_lora
  • Fix alpha handling for lora_down/lora_up format checkpoints
  • Re-fuse LoRA keys when model QKV is fused from prior LoKR load

Given that Civitai does not have a LoKR category I didn't feel like digging for a SFW one, so I just used what user brought up to me when they reported the issue.

Benchmark test
"""Benchmark: Lossless LoKR vs Lossy LoRA-via-SVD on Flux2 Klein 9B.

Generates images using both conversion paths for visual comparison.
Uses 4-bit quantization + CPU offload to fit on a single 24GB GPU.

Usage:
    python benchmark_lokr.py
    python benchmark_lokr.py --prompt "a cat in a garden" --ranks 32 64 128
"""

import argparse
import gc
import os
import time

import torch

# Use local model cache
os.environ["HF_HUB_CACHE"] = "/home/ohiom/database/models/huggingface"

from diffusers import Flux2KleinPipeline  # noqa: E402
from peft import convert_to_lora  # noqa: E402

MODEL_ID = "black-forest-labs/FLUX.2-klein-9B"
LOKR_PATH = "/home/ohiom/database/models/Lora/Flux.2 Klein 9B/klein_snofs_v1_1.safetensors"
OUTPUT_DIR = "/home/ohiom/diffusers/benchmark_output"


def load_pipeline():
    """Load Flux2 Klein 9B in bf16 with model CPU offload."""
    pipe = Flux2KleinPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
    pipe.enable_model_cpu_offload()
    return pipe


def generate(pipe, prompt, seed, num_steps=4, guidance_scale=1.0):
    """Generate a single image with fixed seed for reproducibility."""
    generator = torch.Generator(device="cpu").manual_seed(seed)
    image = pipe(
        prompt=prompt,
        num_inference_steps=num_steps,
        guidance_scale=guidance_scale,
        generator=generator,
        height=1024,
        width=1024,
    ).images[0]
    return image


def benchmark_lossless(pipe, prompt, seed):
    """Path A: Load LoKR natively (lossless)."""
    print("\n=== Path A: Lossless LoKR ===")
    t0 = time.time()
    pipe.load_lora_weights(LOKR_PATH)
    print(f"  Loaded in {time.time() - t0:.1f}s")

    t0 = time.time()
    image = generate(pipe, prompt, seed)
    print(f"  Generated in {time.time() - t0:.1f}s")

    pipe.unload_lora_weights()
    return image


def benchmark_lossy(pipe, prompt, seed, rank):
    """Path B: Load LoKR, convert to LoRA via SVD (lossy)."""
    print(f"\n=== Path B: Lossy LoRA via SVD (rank={rank}) ===")
    t0 = time.time()
    pipe.load_lora_weights(LOKR_PATH)
    load_time = time.time() - t0

    # Detect the actual adapter name assigned by peft
    adapter_name = pipe.transformer.peft_config.keys().__iter__().__next__()
    print(f"  Adapter name: {adapter_name}")

    pipe.transformer.to("cuda")
    t0 = time.time()
    lora_config, lora_sd = convert_to_lora(pipe.transformer, rank, adapter_name=adapter_name, progressbar=True)
    convert_time = time.time() - t0
    print(f"  Loaded LoKR in {load_time:.1f}s, converted to LoRA in {convert_time:.1f}s")

    # Replace LoKR adapter with converted LoRA
    from peft import inject_adapter_in_model, set_peft_model_state_dict

    pipe.transformer.delete_adapter(adapter_name)
    inject_adapter_in_model(pipe.transformer, lora_config, adapter_name=adapter_name)
    set_peft_model_state_dict(pipe.transformer, lora_sd, adapter_name=adapter_name)

    t0 = time.time()
    image = generate(pipe, prompt, seed)
    print(f"  Generated in {time.time() - t0:.1f}s")

    pipe.unload_lora_weights()
    return image


def benchmark_baseline(pipe, prompt, seed):
    """Baseline: No adapter."""
    print("\n=== Baseline: No adapter ===")
    t0 = time.time()
    image = generate(pipe, prompt, seed)
    print(f"  Generated in {time.time() - t0:.1f}s")
    return image


def main():
    parser = argparse.ArgumentParser(description="Benchmark LoKR vs LoRA-via-SVD")
    parser.add_argument(
        "--prompt",
        default="A high-angle POV photograph shows a nude white woman with blonde hair",
    )
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--ranks", type=int, nargs="+", default=[32, 64, 128])
    parser.add_argument("--steps", type=int, default=28)
    parser.add_argument("--skip-baseline", action="store_true")
    args = parser.parse_args()

    os.makedirs(OUTPUT_DIR, exist_ok=True)

    print(f"Model: {MODEL_ID}")
    print(f"LoKR:  {LOKR_PATH}")
    print(f"Prompt: {args.prompt}")
    print(f"Seed: {args.seed}")
    print(f"SVD ranks to test: {args.ranks}")

    print("\nLoading pipeline (bf16, model CPU offload)...")
    pipe = load_pipeline()

    # Baseline
    if not args.skip_baseline:
        img = benchmark_baseline(pipe, args.prompt, args.seed)
        path = os.path.join(OUTPUT_DIR, "baseline.png")
        img.save(path)
        print(f"  Saved: {path}")

    # Path A: Lossless LoKR
    img = benchmark_lossless(pipe, args.prompt, args.seed)
    path = os.path.join(OUTPUT_DIR, "lokr_lossless.png")
    img.save(path)
    print(f"  Saved: {path}")

    gc.collect()
    torch.cuda.empty_cache()

    # Path B: Lossy LoRA via SVD at various ranks
    for rank in args.ranks:
        img = benchmark_lossy(pipe, args.prompt, args.seed, rank)
        path = os.path.join(OUTPUT_DIR, f"lora_svd_rank{rank}.png")
        img.save(path)
        print(f"  Saved: {path}")

        gc.collect()
        torch.cuda.empty_cache()

    print(f"\nAll results saved to {OUTPUT_DIR}/")
    print("Compare: baseline.png vs lokr_lossless.png vs lora_svd_rank*.png")


if __name__ == "__main__":
    main()

Not sure if I'm doing it correctly for PEFT, but due to lack of quantization support I couldn't run the lossy path for now.

What does this PR do?

Fixes #13261

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul @BenjaminBossan

- Custom lossless path: BFL LoKR keys → peft LoKrConfig (fuse-first QKV)
- Generic lossy path: optional SVD conversion via peft.convert_to_lora
- Fix alpha handling for lora_down/lora_up format checkpoints
- Re-fuse LoRA keys when model QKV is fused from prior LoKR load
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

black-forest-labs/FLUX.2-klein-9B not working with lora with lokr

1 participant