diff --git a/benchmark_lokr.py b/benchmark_lokr.py new file mode 100644 index 000000000000..a8fd6d052758 --- /dev/null +++ b/benchmark_lokr.py @@ -0,0 +1,464 @@ +"""Benchmark: Three-tier LoKR quality comparison on Flux2 Klein 9B. + +Tier 1 - Fuse-first (lossless): Fuse model QKV, map BFL LoKR directly. Exact. +Tier 2 - Kronecker split (default): Split fused QKV via Van Loan re-factorization. Slight loss. +Tier 3 - SVD to LoRA (fully lossy): Convert entire LoKR to LoRA via peft.convert_to_lora. + +Tiers 1+2 only apply to BFL-format LoKR (fused QKV). LyCORIS and diffusers-native +formats already have separate Q/K/V and only run the default path. + +Uses bf16 with CPU offload. + +Usage: + python benchmark_lokr.py + python benchmark_lokr.py --lokr-path "puttmorbidly233/lora" --lokr-name "klein_snofs_v1_2.safetensors" + python benchmark_lokr.py --prompt "a portrait in besch art style" --ranks 32 64 128 + python benchmark_lokr.py --tiers 1 2 # skip SVD tier + python benchmark_lokr.py --tiers 2 3 # skip fuse-first tier + python benchmark_lokr.py --weight-space # weight-space error analysis only (no image generation) +""" + +import argparse +import gc +import os +import time + +import torch + +from diffusers import Flux2KleinPipeline + + +MODEL_ID = "black-forest-labs/FLUX.2-klein-9B" +DEFAULT_LOKR_PATH = "gattaplayer/besch-flux2-klein-9b-lokr-lion-3e-6-bs2-ga2-v02" +OUTPUT_DIR = "benchmark_output" + + +def load_pipeline(no_offload=False): + """Load Flux2 Klein 9B in bf16.""" + pipe = Flux2KleinPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16) + if no_offload: + pipe = pipe.to("cuda") + else: + pipe.enable_model_cpu_offload() + return pipe + + +def generate(pipe, prompt, seed, num_steps=4, guidance_scale=1.0): + """Generate a single image with fixed seed for reproducibility.""" + generator = torch.Generator(device="cpu").manual_seed(seed) + image = pipe( + prompt=prompt, + num_inference_steps=num_steps, + guidance_scale=guidance_scale, + generator=generator, + height=1024, + width=1024, + ).images[0] + return image + + +# --------------------------------------------------------------------------- +# Weight-space error analysis +# --------------------------------------------------------------------------- + + +def load_raw_state_dict(lokr_path, lokr_name): + """Download/load a LoKR checkpoint and return the raw state dict.""" + from huggingface_hub import hf_hub_download + from safetensors.torch import load_file + + if os.path.isfile(lokr_path): + return load_file(lokr_path) + + if os.path.isdir(lokr_path): + path = os.path.join(lokr_path, lokr_name) if lokr_name else lokr_path + return load_file(path) + + # HF repo + path = hf_hub_download(lokr_path, filename=lokr_name or "pytorch_lora_weights.safetensors") + return load_file(path) + + +def _materialize_lokr_delta(state_dict, module_path): + """Materialize the full delta weight from LoKR factors for a single module.""" + w1_key = f"{module_path}.lokr_w1" + w2_key = f"{module_path}.lokr_w2" + w1a_key = f"{module_path}.lokr_w1_a" + w1b_key = f"{module_path}.lokr_w1_b" + w2a_key = f"{module_path}.lokr_w2_a" + w2b_key = f"{module_path}.lokr_w2_b" + + # w1: full or decomposed + if w1_key in state_dict: + w1 = state_dict[w1_key].float() + elif w1a_key in state_dict and w1b_key in state_dict: + w1 = state_dict[w1a_key].float() @ state_dict[w1b_key].float() + else: + return None + + # w2: full or decomposed + if w2_key in state_dict: + w2 = state_dict[w2_key].float() + elif w2a_key in state_dict and w2b_key in state_dict: + w2 = state_dict[w2a_key].float() @ state_dict[w2b_key].float() + else: + return None + + return torch.kron(w1, w2) + + +def _print_error_table(title, results): + """Print a formatted error table and aggregate stats.""" + print(f"\n {title}\n") + print(f" {'Module':<60} {'Rel Error %':>12} {'Abs Error':>12} {'Orig Norm':>12}") + print(f" {'-' * 60} {'-' * 12} {'-' * 12} {'-' * 12}") + + errors = [] + for name, rel_err, abs_err, orig_norm in results: + errors.append(rel_err) + print(f" {name:<60} {rel_err:>11.6f}% {abs_err:>12.6f} {orig_norm:>12.4f}") + + if errors: + print(f"\n Aggregate over {len(errors)} modules:") + print(f" Mean relative error: {sum(errors) / len(errors):.6f}%") + print(f" Max relative error: {max(errors):.6f}%") + print(f" Min relative error: {min(errors):.6f}%") + + +def weight_space_kronecker(lokr_path, lokr_name): + """Compare tier 1 (lossless) vs tier 2 (Kronecker split) in weight space. + + No model loading needed - operates on checkpoint state dicts only. + Only meaningful for BFL-format LoKR (fused QKV). + """ + from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_flux2_lokr_to_diffusers + + raw_sd = load_raw_state_dict(lokr_path, lokr_name) + + is_bfl = any(k.startswith("diffusion_model.") for k in raw_sd) + if not is_bfl: + print(" Checkpoint is not BFL format - no fused QKV to compare.") + print(" Tiers 1 and 2 produce identical results for this format.") + return + + sd_fused = _convert_non_diffusers_flux2_lokr_to_diffusers(dict(raw_sd), fuse_qkv=True) + sd_split = _convert_non_diffusers_flux2_lokr_to_diffusers(dict(raw_sd), fuse_qkv=False) + + # Find all fused QKV modules + qkv_modules = [] + for key in sd_fused: + if ".to_qkv.lokr_w1" in key or ".to_added_qkv.lokr_w1" in key: + qkv_modules.append(key.rsplit(".lokr_w1", 1)[0]) + + print(f"\n Found {len(qkv_modules)} fused QKV modules to compare") + + results = [] + for module_path in sorted(qkv_modules): + delta_exact = _materialize_lokr_delta(sd_fused, module_path) + if delta_exact is None: + continue + + # Determine split target keys + if ".to_qkv" in module_path: + base = module_path.replace(".attn.to_qkv", "") + proj_keys = [f"{base}.attn.to_q", f"{base}.attn.to_k", f"{base}.attn.to_v"] + else: + base = module_path.replace(".attn.to_added_qkv", "") + proj_keys = [f"{base}.attn.add_q_proj", f"{base}.attn.add_k_proj", f"{base}.attn.add_v_proj"] + + chunks = [] + for proj in proj_keys: + delta = _materialize_lokr_delta(sd_split, proj) + if delta is None: + break + chunks.append(delta) + + if len(chunks) != 3: + continue + + delta_recon = torch.cat(chunks, dim=0) + orig_norm = delta_exact.norm().item() + abs_err = (delta_exact - delta_recon).norm().item() + rel_err = abs_err / orig_norm if orig_norm > 0 else 0.0 + + short_name = module_path.replace("transformer.", "") + results.append((short_name, rel_err, abs_err, orig_norm)) + + _print_error_table("Tier 1 (lossless) vs Tier 2 (Kronecker split) - QKV modules only", results) + + +def weight_space_svd(lokr_path, lokr_name, ranks, no_offload=False): + """Compare tier 1 (lossless) vs tier 3 (SVD to LoRA) in weight space. + + Requires loading the full model to run peft.convert_to_lora. + Compares materialized LoKR deltas against LoRA deltas for ALL modules. + """ + from peft import convert_to_lora + + # Build reference deltas from the converted state dict (tier 2 / default path) + # For non-QKV modules tier 2 is identical to tier 1, so this is ground truth. + raw_sd = load_raw_state_dict(lokr_path, lokr_name) + is_bfl = any(k.startswith("diffusion_model.") for k in raw_sd) + + if is_bfl: + from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_flux2_lokr_to_diffusers + + sd_ref = _convert_non_diffusers_flux2_lokr_to_diffusers(dict(raw_sd), fuse_qkv=False) + else: + # For non-BFL, just use the default conversion as reference (already lossless) + from diffusers.loaders.lora_conversion_utils import ( + _convert_diffusers_flux2_lokr_to_peft, + _convert_lycoris_flux2_lokr_to_diffusers, + ) + + if any(k.startswith("lycoris_") for k in raw_sd): + sd_ref = _convert_lycoris_flux2_lokr_to_diffusers(dict(raw_sd)) + else: + sd_ref = _convert_diffusers_flux2_lokr_to_peft(dict(raw_sd)) + + # Find all LoKR modules and materialize their deltas + ref_deltas = {} + lokr_modules = set() + for key in sd_ref: + if ".lokr_w1" in key and ".lokr_w1_" not in key: + module_path = key.rsplit(".lokr_w1", 1)[0] + lokr_modules.add(module_path) + elif ".lokr_w1_a" in key: + module_path = key.rsplit(".lokr_w1_a", 1)[0] + lokr_modules.add(module_path) + + for module_path in lokr_modules: + delta = _materialize_lokr_delta(sd_ref, module_path) + if delta is not None: + ref_deltas[module_path] = delta + + print(f"\n Materialized {len(ref_deltas)} reference LoKR deltas") + + # Load model and LoKR adapter + print("\n Loading model for SVD conversion...") + pipe = load_pipeline(no_offload=no_offload) + kwargs = {"weight_name": lokr_name} if lokr_name else {} + pipe.load_lora_weights(lokr_path, **kwargs) + adapter_name = next(iter(pipe.transformer.peft_config.keys())) + + for rank in ranks: + print(f"\n Converting to LoRA rank={rank}...") + t0 = time.time() + lora_config, lora_sd = convert_to_lora(pipe.transformer, rank, adapter_name=adapter_name, progressbar=True) + print(f" Converted in {time.time() - t0:.1f}s") + print(f" LoRA config: alpha={lora_config.lora_alpha}, r={lora_config.r}") + + # Also print the LoKR config for reference + lokr_cfg = pipe.transformer.peft_config.get(adapter_name) + if lokr_cfg: + alpha = getattr(lokr_cfg, "alpha", getattr(lokr_cfg, "lora_alpha", "?")) + print(f" Adapter config: {type(lokr_cfg).__name__}, alpha={alpha}, r={lokr_cfg.r}") + + # Compare each module: LoKR delta vs LoRA delta (lora_B @ lora_A) + results = [] + for module_path in sorted(ref_deltas.keys()): + delta_ref = ref_deltas[module_path] + + # Map module_path to LoRA key format: transformer.X.Y -> base_model.model.X.Y + lora_module = module_path.replace("transformer.", "") + lora_a_key = f"base_model.model.{lora_module}.lora_A.weight" + lora_b_key = f"base_model.model.{lora_module}.lora_B.weight" + + if lora_a_key not in lora_sd or lora_b_key not in lora_sd: + # Try without base_model.model prefix + lora_a_key = f"{lora_module}.lora_A.weight" + lora_b_key = f"{lora_module}.lora_B.weight" + + if lora_a_key not in lora_sd or lora_b_key not in lora_sd: + continue + + lora_a = lora_sd[lora_a_key].float().cpu() + lora_b = lora_sd[lora_b_key].float().cpu() + delta_lora = lora_b @ lora_a + + orig_norm = delta_ref.norm().item() + abs_err = (delta_ref.cpu() - delta_lora).norm().item() + rel_err = abs_err / orig_norm if orig_norm > 0 else 0.0 + + short_name = module_path.replace("transformer.", "") + results.append((short_name, rel_err, abs_err, orig_norm)) + + _print_error_table(f"Tier 1 (lossless) vs Tier 3 (SVD rank={rank}) - all modules", results) + + pipe.unload_lora_weights() + del pipe + gc.collect() + torch.cuda.empty_cache() + + +# --------------------------------------------------------------------------- +# Image generation benchmarks +# --------------------------------------------------------------------------- + + +def benchmark_baseline(pipe, prompt, seed): + """Baseline: No adapter.""" + print("\n=== Baseline: No adapter ===") + t0 = time.time() + image = generate(pipe, prompt, seed) + print(f" Generated in {time.time() - t0:.1f}s") + return image + + +def benchmark_tier1_fuse_first(pipe, prompt, seed, lokr_path, lokr_name): + """Tier 1: Fuse model QKV, then load BFL LoKR directly (lossless).""" + print("\n=== Tier 1: Fuse-first LoKR (lossless) ===") + t0 = time.time() + kwargs = {"weight_name": lokr_name} if lokr_name else {} + pipe.load_lora_weights(lokr_path, fuse_qkv=True, **kwargs) + print(f" Loaded in {time.time() - t0:.1f}s") + + t0 = time.time() + image = generate(pipe, prompt, seed) + print(f" Generated in {time.time() - t0:.1f}s") + + pipe.unload_lora_weights() + return image + + +def benchmark_tier2_kronecker_split(pipe, prompt, seed, lokr_path, lokr_name): + """Tier 2: Split fused QKV via Kronecker re-factorization (default path).""" + print("\n=== Tier 2: Kronecker split LoKR (default) ===") + t0 = time.time() + kwargs = {"weight_name": lokr_name} if lokr_name else {} + pipe.load_lora_weights(lokr_path, **kwargs) + print(f" Loaded in {time.time() - t0:.1f}s") + + t0 = time.time() + image = generate(pipe, prompt, seed) + print(f" Generated in {time.time() - t0:.1f}s") + + pipe.unload_lora_weights() + return image + + +def benchmark_tier3_svd(pipe, prompt, seed, rank, lokr_path, lokr_name): + """Tier 3: Convert LoKR to LoRA via SVD (fully lossy).""" + from peft import convert_to_lora, inject_adapter_in_model, set_peft_model_state_dict + + print(f"\n=== Tier 3: SVD to LoRA (rank={rank}) ===") + t0 = time.time() + kwargs = {"weight_name": lokr_name} if lokr_name else {} + pipe.load_lora_weights(lokr_path, **kwargs) + load_time = time.time() - t0 + + adapter_name = next(iter(pipe.transformer.peft_config.keys())) + print(f" Adapter name: {adapter_name}") + + t0 = time.time() + lokr_cfg = pipe.transformer.peft_config.get(adapter_name) + if lokr_cfg: + alpha = getattr(lokr_cfg, "alpha", getattr(lokr_cfg, "lora_alpha", "?")) + print(f" Adapter config: {type(lokr_cfg).__name__}, alpha={alpha}, r={lokr_cfg.r}") + + t0 = time.time() + lora_config, lora_sd = convert_to_lora(pipe.transformer, rank, adapter_name=adapter_name, progressbar=True) + convert_time = time.time() - t0 + print(f" Loaded LoKR in {load_time:.1f}s, converted to LoRA in {convert_time:.1f}s") + print(f" LoRA config: alpha={lora_config.lora_alpha}, r={lora_config.r}") + + pipe.transformer.delete_adapters(adapter_name) + inject_adapter_in_model(lora_config, pipe.transformer, adapter_name=adapter_name) + set_peft_model_state_dict(pipe.transformer, lora_sd, adapter_name=adapter_name) + + t0 = time.time() + image = generate(pipe, prompt, seed) + print(f" Generated in {time.time() - t0:.1f}s") + + pipe.unload_lora_weights() + return image + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark LoKR quality tiers") + parser.add_argument("--prompt", default="a portrait painting in besch art style") + parser.add_argument("--lokr-path", default=DEFAULT_LOKR_PATH, help="HF repo or local path to LoKR checkpoint") + parser.add_argument("--lokr-name", default=None, help="Filename within HF repo (if multi-file)") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument( + "--tiers", type=int, nargs="+", default=[1, 2, 3], help="Tiers to run (1=fuse, 2=kronecker, 3=svd)" + ) + parser.add_argument("--ranks", type=int, nargs="+", default=[32, 64, 128], help="SVD ranks for tier 3") + parser.add_argument("--skip-baseline", action="store_true") + parser.add_argument("--no-offload", action="store_true", help="Keep model on GPU instead of CPU offload") + parser.add_argument("--weight-space", action="store_true", help="Run weight-space error analysis only (no images)") + args = parser.parse_args() + + os.makedirs(OUTPUT_DIR, exist_ok=True) + + print(f"Model: {MODEL_ID}") + print(f"LoKR: {args.lokr_path}" + (f" ({args.lokr_name})" if args.lokr_name else "")) + + # Weight-space analysis + if args.weight_space: + print("\n=== Weight-space error: Tier 1 (lossless) vs Tier 2 (Kronecker split) ===") + weight_space_kronecker(args.lokr_path, args.lokr_name) + + if args.ranks: + print("\n=== Weight-space error: Tier 1 (lossless) vs Tier 3 (SVD to LoRA) ===") + weight_space_svd(args.lokr_path, args.lokr_name, args.ranks, no_offload=args.no_offload) + + return + + print(f"Prompt: {args.prompt}") + print(f"Seed: {args.seed}") + print(f"Tiers: {args.tiers}") + if 3 in args.tiers: + print(f"SVD ranks: {args.ranks}") + + mode = "on GPU" if args.no_offload else "with CPU offload" + print(f"\nLoading pipeline (bf16, {mode})...") + pipe = load_pipeline(no_offload=args.no_offload) + + # Baseline + if not args.skip_baseline: + img = benchmark_baseline(pipe, args.prompt, args.seed) + path = os.path.join(OUTPUT_DIR, "baseline.png") + img.save(path) + print(f" Saved: {path}") + + # Tier 1: Fuse-first (lossless, BFL format only - identical to tier 2 for other formats) + if 1 in args.tiers: + print("\n Note: Tier 1 only differs from tier 2 for BFL-format LoKR (fused QKV).") + img = benchmark_tier1_fuse_first(pipe, args.prompt, args.seed, args.lokr_path, args.lokr_name) + path = os.path.join(OUTPUT_DIR, "tier1_fuse_lossless.png") + img.save(path) + print(f" Saved: {path}") + gc.collect() + torch.cuda.empty_cache() + + # Tier 2: Kronecker split (default) + if 2 in args.tiers: + img = benchmark_tier2_kronecker_split(pipe, args.prompt, args.seed, args.lokr_path, args.lokr_name) + path = os.path.join(OUTPUT_DIR, "tier2_kronecker.png") + img.save(path) + print(f" Saved: {path}") + gc.collect() + torch.cuda.empty_cache() + + # Tier 3: SVD to LoRA at various ranks + if 3 in args.tiers: + for rank in args.ranks: + img = benchmark_tier3_svd(pipe, args.prompt, args.seed, rank, args.lokr_path, args.lokr_name) + path = os.path.join(OUTPUT_DIR, f"tier3_svd_rank{rank}.png") + img.save(path) + print(f" Saved: {path}") + gc.collect() + torch.cuda.empty_cache() + + print(f"\nAll results saved to {OUTPUT_DIR}/") + print("Compare: baseline.png vs tier1_fuse_lossless.png vs tier2_kronecker.png vs tier3_svd_rank*.png") + + +if __name__ == "__main__": + main() diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 41948d205c89..4d38e241c423 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -2331,6 +2331,18 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict): temp_state_dict[new_key] = v original_state_dict = temp_state_dict + # Bake alpha/rank scaling into lora_A weights so .alpha keys are consumed. + # Matches the pattern used by _convert_kohya_flux_lora_to_diffusers for Flux1. + alpha_keys = [k for k in original_state_dict if k.endswith(".alpha")] + for alpha_key in alpha_keys: + alpha = original_state_dict.pop(alpha_key).item() + module_path = alpha_key[: -len(".alpha")] + lora_a_key = f"{module_path}.lora_A.weight" + if lora_a_key in original_state_dict: + rank = original_state_dict[lora_a_key].shape[0] + scale = alpha / rank + original_state_dict[lora_a_key] = original_state_dict[lora_a_key] * scale + num_double_layers = 0 num_single_layers = 0 for key in original_state_dict.keys(): @@ -2628,6 +2640,319 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): return ait_sd +def _nearest_kronecker_product(matrix, m1, n1, m2, n2): + """Find the nearest rank-1 Kronecker product approximation (Van Loan & Pitsianis). + + Given matrix M of shape (m1*m2, n1*n2), finds w1 (m1, n1) and w2 (m2, n2) minimizing ||M - kron(w1, w2)||_F via + rank-1 SVD of a rearranged matrix. + """ + # Rearrange M into R of shape (m1*n1, m2*n2) + # R[i*n1+j, k*n2+l] = M[i*m2+k, j*n2+l] + R = matrix.reshape(m1, m2, n1, n2).permute(0, 2, 1, 3).reshape(m1 * n1, m2 * n2) + # Rank-1 SVD + U, S, Vh = torch.linalg.svd(R, full_matrices=False) + sigma = S[0] + sqrt_s = torch.sqrt(sigma) + w1 = sqrt_s * U[:, 0].reshape(m1, n1) + w2 = sqrt_s * Vh[0].reshape(m2, n2) + return w1, w2 + + +def _split_lokr_qkv(w1, w2, target_keys, factor): + """Split fused LoKR QKV factors into separate per-projection Kronecker factors. + + Materializes kron(w1, w2), chunks along dim=0, and re-factorizes each chunk as a rank-1 Kronecker product using the + Van Loan algorithm. + + Args: + w1: First Kronecker factor, shape (f, f) where f = decompose_factor. + w2: Second Kronecker factor, shape (out_total/f, in_total/f). + target_keys: List of target projection names (e.g., ["to_q", "to_k", "to_v"]). + factor: Kronecker decompose factor for the split chunks. + + Returns: + Dict mapping "{target_key}.lokr_w1" and "{target_key}.lokr_w2" to tensors. + """ + full_delta = torch.kron(w1.float(), w2.float()) + chunks = torch.chunk(full_delta, len(target_keys), dim=0) + + result = {} + for target_key, chunk in zip(target_keys, chunks): + rows, cols = chunk.shape + m1 = n1 = factor + m2 = rows // m1 + n2 = cols // n1 + new_w1, new_w2 = _nearest_kronecker_product(chunk, m1, n1, m2, n2) + result[f"{target_key}.lokr_w1"] = new_w1.to(w1.dtype) + result[f"{target_key}.lokr_w2"] = new_w2.to(w2.dtype) + return result + + +def _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict, fuse_qkv=False): + """Convert BFL-format Flux2 LoKR state dict to peft-compatible diffusers format. + + Args: + state_dict: BFL-format LoKR state dict with ``diffusion_model.`` prefix. + fuse_qkv: If True, map fused QKV directly to ``to_qkv``/``to_added_qkv`` targets + (lossless, but requires the model's QKV to be fused before injection). If False (default), split fused QKV + into separate Q/K/V via Kronecker re-factorization (slightly lossy, no model fusion needed). + """ + converted_state_dict = {} + + prefix = "diffusion_model." + original_state_dict = {k[len(prefix) :] if k.startswith(prefix) else k: v for k, v in state_dict.items()} + + num_double_layers = 0 + num_single_layers = 0 + for key in original_state_dict: + if key.startswith("single_blocks."): + num_single_layers = max(num_single_layers, int(key.split(".")[1]) + 1) + elif key.startswith("double_blocks."): + num_double_layers = max(num_double_layers, int(key.split(".")[1]) + 1) + + lokr_suffixes = ("lokr_w1", "lokr_w1_a", "lokr_w1_b", "lokr_w2", "lokr_w2_a", "lokr_w2_b", "lokr_t2") + + def _pop_alpha_and_bake(bfl_path, w1_weight): + """Pop alpha for a module and bake scaling into w1. Returns scaled w1.""" + alpha_key = f"{bfl_path}.alpha" + if alpha_key not in original_state_dict: + return w1_weight + alpha = original_state_dict.pop(alpha_key).item() + w2a_key = f"{bfl_path}.lokr_w2_a" + w1a_key = f"{bfl_path}.lokr_w1_a" + if w2a_key in original_state_dict: + r_eff = original_state_dict[w2a_key].shape[1] + elif w1a_key in original_state_dict: + r_eff = original_state_dict[w1a_key].shape[1] + else: + r_eff = alpha + return w1_weight * (alpha / r_eff) + + def _remap_lokr_module(bfl_path, diff_path): + """Pop all LoKR keys for a BFL module, bake alpha, and store under diffusers path.""" + # Pop alpha separately (consumed by first w1 tensor) + alpha_key = f"{bfl_path}.alpha" + alpha = original_state_dict.pop(alpha_key).item() if alpha_key in original_state_dict else None + + for suffix in lokr_suffixes: + src_key = f"{bfl_path}.{suffix}" + if src_key not in original_state_dict: + continue + + weight = original_state_dict.pop(src_key) + + # Bake alpha/rank scaling into the first w1 tensor for this module. + if alpha is not None and suffix in ("lokr_w1", "lokr_w1_a"): + w2a_key = f"{bfl_path}.lokr_w2_a" + w1a_key = f"{bfl_path}.lokr_w1_a" + if w2a_key in original_state_dict: + r_eff = original_state_dict[w2a_key].shape[1] + elif w1a_key in original_state_dict: + r_eff = original_state_dict[w1a_key].shape[1] + else: + r_eff = alpha + weight = weight * (alpha / r_eff) + alpha = None + + converted_state_dict[f"{diff_path}.{suffix}"] = weight + + def _remap_lokr_qkv(bfl_path, target_keys): + """Pop fused QKV LoKR factors, split into separate projections via Kronecker re-factorization.""" + w1_key = f"{bfl_path}.lokr_w1" + w2_key = f"{bfl_path}.lokr_w2" + if w1_key not in original_state_dict or w2_key not in original_state_dict: + # Fall back to direct remap if decomposed factors (w1_a/w1_b) are used + _remap_lokr_module(bfl_path, target_keys[0].rsplit(".", 1)[0]) + return + + w1 = original_state_dict.pop(w1_key) + w2 = original_state_dict.pop(w2_key) + + # Bake alpha before splitting + alpha_key = f"{bfl_path}.alpha" + if alpha_key in original_state_dict: + alpha = original_state_dict.pop(alpha_key).item() + w2a_key = f"{bfl_path}.lokr_w2_a" + w1a_key = f"{bfl_path}.lokr_w1_a" + if w2a_key in original_state_dict: + r_eff = original_state_dict[w2a_key].shape[1] + elif w1a_key in original_state_dict: + r_eff = original_state_dict[w1a_key].shape[1] + else: + r_eff = alpha + w1 = w1 * (alpha / r_eff) + + factor = w1.shape[0] + split_result = _split_lokr_qkv(w1, w2, target_keys, factor) + converted_state_dict.update(split_result) + + # --- Single blocks --- + for sl in range(num_single_layers): + _remap_lokr_module(f"single_blocks.{sl}.linear1", f"single_transformer_blocks.{sl}.attn.to_qkv_mlp_proj") + _remap_lokr_module(f"single_blocks.{sl}.linear2", f"single_transformer_blocks.{sl}.attn.to_out") + + # --- Double blocks --- + for dl in range(num_double_layers): + tb = f"transformer_blocks.{dl}" + db = f"double_blocks.{dl}" + + if fuse_qkv: + # Lossless: map directly to fused targets (caller must fuse model QKV first) + _remap_lokr_module(f"{db}.img_attn.qkv", f"{tb}.attn.to_qkv") + _remap_lokr_module(f"{db}.txt_attn.qkv", f"{tb}.attn.to_added_qkv") + else: + # Split fused QKV into separate Q/K/V via Kronecker re-factorization + _remap_lokr_qkv(f"{db}.img_attn.qkv", [f"{tb}.attn.to_q", f"{tb}.attn.to_k", f"{tb}.attn.to_v"]) + _remap_lokr_qkv( + f"{db}.txt_attn.qkv", + [f"{tb}.attn.add_q_proj", f"{tb}.attn.add_k_proj", f"{tb}.attn.add_v_proj"], + ) + + # Projections + _remap_lokr_module(f"{db}.img_attn.proj", f"{tb}.attn.to_out.0") + _remap_lokr_module(f"{db}.txt_attn.proj", f"{tb}.attn.to_add_out") + + # MLPs + _remap_lokr_module(f"{db}.img_mlp.0", f"{tb}.ff.linear_in") + _remap_lokr_module(f"{db}.img_mlp.2", f"{tb}.ff.linear_out") + _remap_lokr_module(f"{db}.txt_mlp.0", f"{tb}.ff_context.linear_in") + _remap_lokr_module(f"{db}.txt_mlp.2", f"{tb}.ff_context.linear_out") + + # --- Extra mappings (embedders, modulation, final layer) --- + extra_mappings = { + "img_in": "x_embedder", + "txt_in": "context_embedder", + "time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1", + "time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2", + "final_layer.linear": "proj_out", + "final_layer.adaLN_modulation.1": "norm_out.linear", + "single_stream_modulation.lin": "single_stream_modulation.linear", + "double_stream_modulation_img.lin": "double_stream_modulation_img.linear", + "double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear", + } + for bfl_key, diff_key in extra_mappings.items(): + _remap_lokr_module(bfl_key, diff_key) + + if len(original_state_dict) > 0: + raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.") + + for key in list(converted_state_dict.keys()): + converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) + + return converted_state_dict + + +# Mapping from LyCORIS underscore-encoded sub-paths to dotted diffusers module paths +_LYCORIS_SUBPATH_MAP = { + "attn_to_q": "attn.to_q", + "attn_to_k": "attn.to_k", + "attn_to_v": "attn.to_v", + "attn_to_out_0": "attn.to_out.0", + "attn_to_add_out": "attn.to_add_out", + "attn_add_q_proj": "attn.add_q_proj", + "attn_add_k_proj": "attn.add_k_proj", + "attn_add_v_proj": "attn.add_v_proj", + "attn_to_qkv_mlp_proj": "attn.to_qkv_mlp_proj", + "attn_to_out": "attn.to_out", + "ff_context_linear_in": "ff_context.linear_in", + "ff_context_linear_out": "ff_context.linear_out", + "ff_linear_in": "ff.linear_in", + "ff_linear_out": "ff.linear_out", +} + + +def _bake_lokr_alpha(state_dict): + """Consume .alpha keys by baking alpha/rank scaling into lokr_w1 weights in-place.""" + lokr_w1_suffixes = (".lokr_w1", ".lokr_w1_a") + alpha_keys = [k for k in state_dict if k.endswith(".alpha")] + + for alpha_key in alpha_keys: + alpha = state_dict.pop(alpha_key).item() + module_path = alpha_key[: -len(".alpha")] + + # Find the w1 tensor to bake into + for w1_suffix in lokr_w1_suffixes: + w1_key = f"{module_path}{w1_suffix}" + if w1_key in state_dict: + # Determine effective rank + w2a_key = f"{module_path}.lokr_w2_a" + w1a_key = f"{module_path}.lokr_w1_a" + if w2a_key in state_dict: + r_eff = state_dict[w2a_key].shape[1] + elif w1a_key in state_dict: + r_eff = state_dict[w1a_key].shape[1] + else: + r_eff = alpha + state_dict[w1_key] = state_dict[w1_key] * (alpha / r_eff) + break + + +def _convert_lycoris_flux2_lokr_to_diffusers(state_dict): + """Convert LyCORIS underscore-format Flux2 LoKR state dict to peft-compatible diffusers format. + + LyCORIS keys use underscore-encoded paths (e.g., lycoris_transformer_blocks_0_attn_to_q.lokr_w1). Decodes these to + dotted diffusers paths using a known sub-path lookup table. + """ + import re + + converted_state_dict = {} + original_state_dict = dict(state_dict) + + _bake_lokr_alpha(original_state_dict) + + lycoris_pattern = re.compile(r"^lycoris_((?:single_)?transformer_blocks)_(\d+)_(.+)$") + + for key in list(original_state_dict.keys()): + # Split key into module_path and lokr suffix + parts = key.rsplit(".", 1) + if len(parts) != 2: + continue + module_encoded, suffix = parts + + match = lycoris_pattern.match(module_encoded) + if not match: + continue + + container, block_idx, sub_path = match.groups() + + # Decode sub-path using lookup table (try longest match first) + diff_sub_path = None + for lycoris_sub, diff_sub in sorted(_LYCORIS_SUBPATH_MAP.items(), key=lambda x: -len(x[0])): + if sub_path == lycoris_sub: + diff_sub_path = diff_sub + break + + if diff_sub_path is None: + continue + + diff_key = f"transformer.{container}.{block_idx}.{diff_sub_path}.{suffix}" + converted_state_dict[diff_key] = original_state_dict.pop(key) + + if len(original_state_dict) > 0: + logger.warning(f"Unconverted LyCORIS LoKR keys: {list(original_state_dict.keys())}") + + return converted_state_dict + + +def _convert_diffusers_flux2_lokr_to_peft(state_dict): + """Convert diffusers-native Flux2 LoKR state dict by adding transformer. prefix and baking alpha. + + Diffusers-native keys already use dotted module paths matching the model structure. Only alpha baking and the + transformer. prefix are needed. + """ + original_state_dict = dict(state_dict) + _bake_lokr_alpha(original_state_dict) + + converted_state_dict = {} + for key, val in original_state_dict.items(): + if key.startswith("transformer."): + converted_state_dict[key] = val + else: + converted_state_dict[f"transformer.{key}"] = val + + return converted_state_dict + + def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict): """ Convert non-diffusers ZImage LoRA state dict to diffusers format. @@ -2785,14 +3110,14 @@ def get_alpha_scales(down_weight, alpha_key): base = k[: -len(lora_dot_down_key)] - # Skip combined "qkv" projection — individual to.q/k/v keys are also present. + # Skip combined "qkv" projection - individual to.q/k/v keys are also present. if base.endswith(".qkv"): state_dict.pop(k) state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None) state_dict.pop(base + ".alpha", None) continue - # Skip bare "out.lora.*" — "to_out.0.lora.*" covers the same projection. + # Skip bare "out.lora.*" - "to_out.0.lora.*" covers the same projection. if re.search(r"\.out$", base) and ".to_out" not in base: state_dict.pop(k) state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 6ec23389ac08..383ae88c3c5f 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -41,11 +41,14 @@ ) from .lora_conversion_utils import ( _convert_bfl_flux_control_lora_to_diffusers, + _convert_diffusers_flux2_lokr_to_peft, _convert_fal_kontext_lora_to_diffusers, _convert_hunyuan_video_lora_to_diffusers, _convert_kohya_flux2_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers, + _convert_lycoris_flux2_lokr_to_diffusers, _convert_musubi_wan_lora_to_diffusers, + _convert_non_diffusers_flux2_lokr_to_diffusers, _convert_non_diffusers_flux2_lora_to_diffusers, _convert_non_diffusers_hidream_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, @@ -5645,6 +5648,7 @@ def lora_state_dict( weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) return_lora_metadata = kwargs.pop("return_lora_metadata", False) + fuse_qkv = kwargs.pop("fuse_qkv", False) allow_pickle = False if use_safetensors is None: @@ -5685,14 +5689,29 @@ def lora_state_dict( if is_peft_format: state_dict = {k.replace("base_model.model.", "diffusion_model."): v for k, v in state_dict.items()} - is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict) - if is_ai_toolkit: - state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict) + is_lokr = any("lokr_" in k for k in state_dict) + if is_lokr: + is_bfl_format = any(k.startswith("diffusion_model.") for k in state_dict) + if is_bfl_format: + state_dict = _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict, fuse_qkv=fuse_qkv) + elif any(k.startswith("lycoris_") for k in state_dict): + state_dict = _convert_lycoris_flux2_lokr_to_diffusers(state_dict) + else: + state_dict = _convert_diffusers_flux2_lokr_to_peft(state_dict) + if metadata is None: + metadata = {} + metadata["is_lokr"] = "true" + # Only fuse model QKV for BFL format (which has fused QKV keys to map 1:1) + if fuse_qkv and is_bfl_format: + metadata["fuse_qkv"] = "true" + else: + is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict) + if is_ai_toolkit: + state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict) out = (state_dict, metadata) if return_lora_metadata else state_dict return out - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], @@ -5720,13 +5739,19 @@ def load_lora_weights( kwargs["return_lora_metadata"] = True state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - is_correct_format = all("lora" in key for key in state_dict.keys()) + is_correct_format = all("lora" in key or "lokr" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") + raise ValueError("Invalid LoRA/LoKR checkpoint. Make sure all param names contain `'lora'` or `'lokr'`.") + + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + + # Fuse model QKV projections before injection if requested (lossless path for BFL LoKR) + if metadata and metadata.get("fuse_qkv") == "true": + transformer.fuse_qkv_projections() self.load_lora_into_transformer( state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + transformer=transformer, adapter_name=adapter_name, metadata=metadata, _pipeline=self, diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index daa078bc25d5..7c8fa0a9e854 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -38,7 +38,7 @@ set_adapter_layers, set_weights_and_activate_adapters, ) -from ..utils.peft_utils import _create_lora_config, _maybe_warn_for_unhandled_keys +from ..utils.peft_utils import _create_lokr_config, _create_lora_config, _maybe_warn_for_unhandled_keys from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading from .unet_loader_utils import _maybe_expand_lora_scales @@ -46,7 +46,7 @@ logger = logging.get_logger(__name__) _SET_ADAPTER_SCALE_FN_MAPPING = defaultdict( - lambda: (lambda model_cls, weights: weights), + lambda: lambda model_cls, weights: weights, { "UNet2DConditionModel": _maybe_expand_lora_scales, "UNetMotionModel": _maybe_expand_lora_scales, @@ -213,56 +213,65 @@ def load_lora_adapter( "Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping." ) - # check with first key if is not in peft format - first_key = next(iter(state_dict.keys())) - if "lora_A" not in first_key: - state_dict = convert_unet_state_dict_to_peft(state_dict) - - # Control LoRA from SAI is different from BFL Control LoRA - # https://huggingface.co/stabilityai/control-lora - # https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors - is_sai_sd_control_lora = "lora_controlnet" in state_dict - if is_sai_sd_control_lora: - state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict) - - rank = {} - for key, val in state_dict.items(): - # Cannot figure out rank from lora layers that don't have at least 2 dimensions. - # Bias layers in LoRA only have a single dimension - if "lora_B" in key and val.ndim > 1: - # Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol. - # We may run into some ambiguous configuration values when a model has module - # names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`, - # for example) and they have different LoRA ranks. - rank[f"^{key}"] = val.shape[1] - - if network_alphas is not None and len(network_alphas) >= 1: - alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] - network_alphas = { - k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys - } - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(self) - - # create LoraConfig - lora_config = _create_lora_config( - state_dict, - network_alphas, - metadata, - rank, - model_state_dict=self.state_dict(), - adapter_name=adapter_name, - ) + # Detect whether this is a LoKR adapter (Kronecker product, not low-rank) + is_lokr = any("lokr_" in k for k in state_dict) + + if is_lokr: + if adapter_name is None: + adapter_name = get_adapter_name(self) + adapter_config = _create_lokr_config(state_dict) + is_sai_sd_control_lora = False + else: + # check with first key if is not in peft format + first_key = next(iter(state_dict.keys())) + if "lora_A" not in first_key: + state_dict = convert_unet_state_dict_to_peft(state_dict) + + # Control LoRA from SAI is different from BFL Control LoRA + # https://huggingface.co/stabilityai/control-lora + # https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors + is_sai_sd_control_lora = "lora_controlnet" in state_dict + if is_sai_sd_control_lora: + state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict) + + rank = {} + for key, val in state_dict.items(): + # Cannot figure out rank from lora layers that don't have at least 2 dimensions. + # Bias layers in LoRA only have a single dimension + if "lora_B" in key and val.ndim > 1: + # Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol. + # We may run into some ambiguous configuration values when a model has module + # names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`, + # for example) and they have different LoRA ranks. + rank[f"^{key}"] = val.shape[1] + + if network_alphas is not None and len(network_alphas) >= 1: + alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] + network_alphas = { + k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys + } + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(self) + + # create LoraConfig + adapter_config = _create_lora_config( + state_dict, + network_alphas, + metadata, + rank, + model_state_dict=self.state_dict(), + adapter_name=adapter_name, + ) - # Adjust LoRA config for Control LoRA - if is_sai_sd_control_lora: - lora_config.lora_alpha = lora_config.r - lora_config.alpha_pattern = lora_config.rank_pattern - lora_config.bias = "all" - lora_config.modules_to_save = lora_config.exclude_modules - lora_config.exclude_modules = None + # Adjust LoRA config for Control LoRA + if is_sai_sd_control_lora: + adapter_config.lora_alpha = adapter_config.r + adapter_config.alpha_pattern = adapter_config.rank_pattern + adapter_config.bias = "all" + adapter_config.modules_to_save = adapter_config.exclude_modules + adapter_config.exclude_modules = None # None: ) +def _create_lokr_config(state_dict): + """Create a peft LoKrConfig from a converted LoKR state dict. + + Infers rank, decompose_both, decompose_factor, and target_modules from the state dict key names and tensor shapes. + Alpha scaling is assumed to be already baked into the weights, so config alpha = r (scaling = 1.0). + + Peft determines w2 decomposition via ``r < max(out_k, in_n) / 2``. We must set per-module rank values that + reproduce the same decomposition pattern as the checkpoint. For modules with full (non-decomposed) lokr_w2, we set + rank = max(lokr_w2.shape) so that peft also creates a full w2. + """ + from peft import LoKrConfig + + # Infer decompose_both from presence of lokr_w1_a keys + decompose_both = any("lokr_w1_a" in k for k in state_dict) + + # Infer decompose_factor from lokr_w1 shapes. + # With a fixed factor (e.g., 4), all w1 shapes are (factor, factor). + # With factor=-1 (near-sqrt), w1 shapes vary per module based on dimension. + w1_shapes = set() + for key, val in state_dict.items(): + if "lokr_w1" in key and "lokr_w1_a" not in key and "lokr_w1_b" not in key and val.ndim == 2: + w1_shapes.add(val.shape[0]) + if len(w1_shapes) == 1: + # All w1 have the same first dimension - this is the decompose_factor + decompose_factor = w1_shapes.pop() + else: + # Shapes vary - near-sqrt factorization was used + decompose_factor = -1 + + # Extract target modules and their decomposition state + lokr_suffixes = {"lokr_w1", "lokr_w1_a", "lokr_w1_b", "lokr_w2", "lokr_w2_a", "lokr_w2_b", "lokr_t2"} + target_modules = set() + for key in state_dict: + for suffix in lokr_suffixes: + if f".{suffix}" in key: + target_modules.add(key.split(f".{suffix}")[0]) + break + + # Build per-module rank dict that ensures peft creates matching decomposition + rank_dict = {} + for key, val in state_dict.items(): + if "lokr_w2_a" in key and val.ndim > 1: + # Decomposed w2: rank = inner dimension of w2_a + module_name = key.split(".lokr_w2_a")[0] + rank_dict[module_name] = val.shape[1] + elif "lokr_w2" in key and "lokr_w2_a" not in key and "lokr_w2_b" not in key and val.ndim > 1: + # Full w2 matrix: set rank high enough so peft also creates full w2. + # Peft uses full w2 when r >= max(out_k, in_n) / 2, where (out_k, in_n) = lokr_w2.shape. + module_name = key.split(".lokr_w2")[0] + if module_name not in rank_dict: + rank_dict[module_name] = max(val.shape) + + # Also extract rank from w1_a if w2 info is missing + for key, val in state_dict.items(): + if "lokr_w1_a" in key and val.ndim > 1: + module_name = key.split(".lokr_w1_a")[0] + if module_name not in rank_dict: + rank_dict[module_name] = val.shape[1] + + # Determine default rank (most common) and per-module rank pattern + if rank_dict: + r = collections.Counter(rank_dict.values()).most_common()[0][0] + rank_pattern = {k: v for k, v in rank_dict.items() if v != r} + else: + r = 1 + rank_pattern = {} + + lokr_config_kwargs = { + "r": r, + "alpha": r, # alpha baked into weights, so runtime scaling = alpha/r = 1.0 + "target_modules": list(target_modules), + "rank_pattern": rank_pattern, + "alpha_pattern": dict(rank_pattern), # keep alpha=r per module + "decompose_both": decompose_both, + "decompose_factor": decompose_factor, + } + + try: + return LoKrConfig(**lokr_config_kwargs) + except TypeError as e: + raise TypeError("`LoKrConfig` class could not be instantiated.") from e + + +def _convert_adapter_to_lora(model, rank, adapter_name="default"): + """Convert a loaded non-LoRA peft adapter (e.g., LoKR) to LoRA via truncated SVD. + + Wraps ``peft.convert_to_lora`` which materializes each adapter layer's delta weight and decomposes it as ``U @ + diag(S) @ V ≈ lora_B @ lora_A``. The conversion is lossy: higher ``rank`` preserves more fidelity at the cost of + larger LoRA matrices. + + Args: + model: ``nn.Module`` with a peft adapter already injected. + rank: ``int`` for a fixed LoRA rank, or ``float`` in (0, 1] as an energy threshold + (picks the smallest rank capturing that fraction of singular value energy). + adapter_name: Name of the adapter to convert. + + Returns: + Tuple of ``(LoraConfig, state_dict)`` for the converted LoRA adapter. + + Raises: + ImportError: If peft does not provide ``convert_to_lora`` (requires peft >= 0.19.0). + """ + try: + from peft import convert_to_lora + except ImportError: + raise ImportError( + "`peft.convert_to_lora` is required for lossy LoKR-to-LoRA conversion. " + "Install peft >= 0.19.0 or from source: pip install git+https://github.com/huggingface/peft.git" + ) + return convert_to_lora(model, rank, adapter_name=adapter_name) + + def _create_lora_config( state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None ):