From c4f9c3b9fefd77a5ee0be1b1129f4f8456be0da4 Mon Sep 17 00:00:00 2001 From: Ricardo-M-L Date: Tue, 21 Apr 2026 10:59:47 +0800 Subject: [PATCH] Raise ValueError instead of tearing down CUDA when AuraFlow latents exceed pos_embed_max_size When the input latent grid exceeds the pretrained positional embedding grid, pe_selection_index_based_on_dim silently produces negative / out-of-range gather indices. On CUDA this trips a vectorized_gather_kernel device-side assert, which destroys the CUDA context for the entire process and forces a Python restart (see #12656). Check the bounds up front and raise a ValueError with a clear message about the largest supported resolution, matching how PatchEmbed.cropped_pos_embed in models/embeddings.py handles the same situation for SD3. Fixes #12656 --- .../models/transformers/auraflow_transformer_2d.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 3fa4df738784..b973e98b995b 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -76,6 +76,19 @@ def pe_selection_index_based_on_dim(self, h, w): h_p, w_p = h // self.patch_size, w // self.patch_size h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5) + # Guard against inputs larger than the pretrained positional embedding grid: + # without this check the centered crop produces negative / out-of-range + # indices, which silently corrupt the output on CPU and trigger a + # `vectorized_gather_kernel` device-side assert on CUDA that tears down + # the entire process (see #12656). + if h_p > h_max or w_p > w_max: + raise ValueError( + f"Input latent size ({h_p}, {w_p}) exceeds the pretrained positional " + f"embedding grid ({h_max}, {w_max}). The positional embedding supports " + f"latents up to ({h_max * self.patch_size}, {w_max * self.patch_size}) " + f"pixels at patch_size={self.patch_size}." + ) + # Calculate the top-left corner indices for the centered patch grid starth = h_max // 2 - h_p // 2 startw = w_max // 2 - w_p // 2