Skip to content

fix: cast input to weight dtype in PatchEmbed and FluxTransformer2DModel to prevent dtype mismatch#13303

Open
s-zx wants to merge 1 commit intohuggingface:mainfrom
s-zx:fix/patch-embed-dtype-mismatch
Open

fix: cast input to weight dtype in PatchEmbed and FluxTransformer2DModel to prevent dtype mismatch#13303
s-zx wants to merge 1 commit intohuggingface:mainfrom
s-zx:fix/patch-embed-dtype-mismatch

Conversation

@s-zx
Copy link

@s-zx s-zx commented Mar 21, 2026

What does this PR do?

Fixes dtype mismatch errors when running SD3 or FLUX models on CPU with torch_dtype=torch.float16.

When the model's projection layers (Conv2d in PatchEmbed, Linear in FluxTransformer2DModel) are in float32 but receive float16 inputs, PyTorch raises a RuntimeError on CPU because autocast is not available there:

RuntimeError: Input type (c10::Half) and bias type (float) should be the same

The fix casts the input tensor to match the layer weight dtype before calling the projection:

  • PatchEmbed.forward: latent.to(self.proj.weight.dtype) before Conv2d call
  • FluxTransformer2DModel.forward: hidden_states.to(self.x_embedder.weight.dtype) before Linear call

Fixes #13300

…ype mismatch

When running SD3/FLUX models on CPU with torch_dtype=float16, the
projection layers (Conv2d in PatchEmbed, Linear in FluxTransformer2DModel)
may have float32 weights/bias while receiving half-precision inputs,
causing 'Input type (c10::Half) and bias type (float) should be the same'.

Cast inputs to match layer weight dtype before calling the projection.

Fixes huggingface#13300
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

diffusers fails in PyTorch when generating image using stabilityai/stable-diffusion-3.5-large-turbo, black-forest-labs/FLUX.1-dev on CPU

2 participants