From 14b140377db5cafdf6acb80f2f5df110737b7ef4 Mon Sep 17 00:00:00 2001 From: Aditya Borate <23110065@iitgn.ac.in> Date: Sat, 21 Mar 2026 09:46:57 +0000 Subject: [PATCH] Corrected casting of latents_bn_std --- src/diffusers/pipelines/flux2/pipeline_flux2.py | 4 +++- src/diffusers/pipelines/flux2/pipeline_flux2_klein.py | 4 +++- src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py | 4 +++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index 4b60c6042d4f..b1645b4ae244 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -611,7 +611,9 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): image_latents = self._patchify_latents(image_latents) latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) - latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + image_latents.device, image_latents.dtype + ) image_latents = (image_latents - latents_bn_mean) / latents_bn_std return image_latents diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py index 936d2c3804ab..e50ced76ec55 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py @@ -465,7 +465,9 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): image_latents = self._patchify_latents(image_latents) latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) - latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + image_latents.device, image_latents.dtype + ) image_latents = (image_latents - latents_bn_mean) / latents_bn_std return image_latents diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py index 671953be63c1..78ed42f20afb 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py @@ -477,7 +477,9 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): image_latents = self._patchify_latents(image_latents) latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) - latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + image_latents.device, image_latents.dtype + ) image_latents = (image_latents - latents_bn_mean) / latents_bn_std return image_latents