diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 2ee8fee80644..5fb666a4d42c 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1823,10 +1823,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) + weighting, weighting_prior = torch.chunk(weighting, 2, dim=0) # Compute prior loss prior_loss = torch.mean( - (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + (weighting_prior.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( target_prior.shape[0], -1 ), 1, diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 81f4681dcc3d..396f18113bf5 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1824,10 +1824,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) + weighting, weighting_prior = torch.chunk(weighting, 2, dim=0) # Compute prior loss prior_loss = torch.mean( - (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + (weighting_prior.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( target_prior.shape[0], -1 ), 1,