From 7e754e7af55f5a6ba56b6b8ec665c2202922c586 Mon Sep 17 00:00:00 2001 From: Dev-X25874 <283057883+Dev-X25874@users.noreply.github.com> Date: Thu, 14 May 2026 07:35:26 +0530 Subject: [PATCH 1/2] examples/dreambooth: chunk weighting tensor alongside model_pred and target when using prior preservation (flux LoRA) --- examples/dreambooth/train_dreambooth_lora_flux.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 2ee8fee80644..5fb666a4d42c 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1823,10 +1823,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) + weighting, weighting_prior = torch.chunk(weighting, 2, dim=0) # Compute prior loss prior_loss = torch.mean( - (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + (weighting_prior.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( target_prior.shape[0], -1 ), 1, From 9a3f50c1dff9aff1fa1192b30f6f4af4737ca252 Mon Sep 17 00:00:00 2001 From: Dev-X25874 <283057883+Dev-X25874@users.noreply.github.com> Date: Thu, 14 May 2026 07:36:34 +0530 Subject: [PATCH 2/2] examples/dreambooth: chunk weighting tensor alongside model_pred and target when using prior preservation (SD3 LoRA) --- examples/dreambooth/train_dreambooth_lora_sd3.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 81f4681dcc3d..396f18113bf5 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1824,10 +1824,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) + weighting, weighting_prior = torch.chunk(weighting, 2, dim=0) # Compute prior loss prior_loss = torch.mean( - (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + (weighting_prior.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( target_prior.shape[0], -1 ), 1,