diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index 17c8bd0ffd52..c3fa6ac141f3 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -292,8 +292,8 @@ def __init__(self): self.gate_fn = nn.SiLU() def forward(self, x: torch.Tensor) -> torch.Tensor: - half = x.shape[-1] // 2 - x = self.gate_fn(x[..., :half]) * x[..., half:] + x1, x2 = x.chunk(2, dim=-1) + x = self.gate_fn(x1) * x2 return x