diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 298aa61d37ed..41948d205c89 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -2443,6 +2443,191 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict): return converted_state_dict +def _convert_kohya_flux2_lora_to_diffusers(state_dict): + def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + + # scale weight by alpha and dim + rank = down_weight.shape[0] + default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False) + alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item() + scale = alpha / rank + + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down + ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up + + def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + up_weight = sds_sd.pop(sds_key + ".lora_up.weight") + sd_lora_rank = down_weight.shape[0] + + default_alpha = torch.tensor( + sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False + ) + alpha = sds_sd.pop(sds_key + ".alpha", default_alpha) + scale = alpha / sd_lora_rank + + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + down_weight = down_weight * scale_down + up_weight = up_weight * scale_up + + num_splits = len(ait_keys) + if dims is None: + dims = [up_weight.shape[0] // num_splits] * num_splits + else: + assert sum(dims) == up_weight.shape[0] + + # check if upweight is sparse + is_sparse = False + if sd_lora_rank % num_splits == 0: + ait_rank = sd_lora_rank // num_splits + is_sparse = True + i = 0 + for j in range(len(dims)): + for k in range(len(dims)): + if j == k: + continue + is_sparse = is_sparse and torch.all( + up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0 + ) + i += dims[j] + if is_sparse: + logger.info(f"weight is sparse: {sds_key}") + + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + if not is_sparse: + ait_sd.update(dict.fromkeys(ait_down_keys, down_weight)) + ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 + else: + ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416 + i = 0 + for j in range(len(dims)): + ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous() + i += dims[j] + + # Detect number of blocks from keys + num_double_layers = 0 + num_single_layers = 0 + for key in state_dict.keys(): + if key.startswith("lora_unet_double_blocks_"): + block_idx = int(key.split("_")[4]) + num_double_layers = max(num_double_layers, block_idx + 1) + elif key.startswith("lora_unet_single_blocks_"): + block_idx = int(key.split("_")[4]) + num_single_layers = max(num_single_layers, block_idx + 1) + + ait_sd = {} + + for i in range(num_double_layers): + # Attention projections + _convert_to_ai_toolkit( + state_dict, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_proj", + f"transformer.transformer_blocks.{i}.attn.to_out.0", + ) + _convert_to_ai_toolkit_cat( + state_dict, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.to_q", + f"transformer.transformer_blocks.{i}.attn.to_k", + f"transformer.transformer_blocks.{i}.attn.to_v", + ], + ) + _convert_to_ai_toolkit( + state_dict, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_proj", + f"transformer.transformer_blocks.{i}.attn.to_add_out", + ) + _convert_to_ai_toolkit_cat( + state_dict, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.add_q_proj", + f"transformer.transformer_blocks.{i}.attn.add_k_proj", + f"transformer.transformer_blocks.{i}.attn.add_v_proj", + ], + ) + # MLP layers (Flux2 uses ff.linear_in/linear_out) + _convert_to_ai_toolkit( + state_dict, + ait_sd, + f"lora_unet_double_blocks_{i}_img_mlp_0", + f"transformer.transformer_blocks.{i}.ff.linear_in", + ) + _convert_to_ai_toolkit( + state_dict, + ait_sd, + f"lora_unet_double_blocks_{i}_img_mlp_2", + f"transformer.transformer_blocks.{i}.ff.linear_out", + ) + _convert_to_ai_toolkit( + state_dict, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_mlp_0", + f"transformer.transformer_blocks.{i}.ff_context.linear_in", + ) + _convert_to_ai_toolkit( + state_dict, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_mlp_2", + f"transformer.transformer_blocks.{i}.ff_context.linear_out", + ) + + for i in range(num_single_layers): + # Single blocks: linear1 -> attn.to_qkv_mlp_proj (fused, no split needed) + _convert_to_ai_toolkit( + state_dict, + ait_sd, + f"lora_unet_single_blocks_{i}_linear1", + f"transformer.single_transformer_blocks.{i}.attn.to_qkv_mlp_proj", + ) + # Single blocks: linear2 -> attn.to_out + _convert_to_ai_toolkit( + state_dict, + ait_sd, + f"lora_unet_single_blocks_{i}_linear2", + f"transformer.single_transformer_blocks.{i}.attn.to_out", + ) + + # Handle optional extra keys + extra_mappings = { + "lora_unet_img_in": "transformer.x_embedder", + "lora_unet_txt_in": "transformer.context_embedder", + "lora_unet_time_in_in_layer": "transformer.time_guidance_embed.timestep_embedder.linear_1", + "lora_unet_time_in_out_layer": "transformer.time_guidance_embed.timestep_embedder.linear_2", + "lora_unet_final_layer_linear": "transformer.proj_out", + } + for sds_key, ait_key in extra_mappings.items(): + _convert_to_ai_toolkit(state_dict, ait_sd, sds_key, ait_key) + + remaining_keys = list(state_dict.keys()) + if remaining_keys: + logger.warning(f"Unsupported keys for Kohya Flux2 LoRA conversion: {remaining_keys}") + + return ait_sd + + def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict): """ Convert non-diffusers ZImage LoRA state dict to diffusers format. diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 5d10f596f2e6..6ec23389ac08 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -43,6 +43,7 @@ _convert_bfl_flux_control_lora_to_diffusers, _convert_fal_kontext_lora_to_diffusers, _convert_hunyuan_video_lora_to_diffusers, + _convert_kohya_flux2_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers, _convert_musubi_wan_lora_to_diffusers, _convert_non_diffusers_flux2_lora_to_diffusers, @@ -5673,6 +5674,13 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + is_kohya = any(".lora_down.weight" in k for k in state_dict) + if is_kohya: + state_dict = _convert_kohya_flux2_lora_to_diffusers(state_dict) + # Kohya already takes care of scaling the LoRA parameters with alpha. + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out + is_peft_format = any(k.startswith("base_model.model.") for k in state_dict) if is_peft_format: state_dict = {k.replace("base_model.model.", "diffusion_model."): v for k, v in state_dict.items()}