From d2d5fdad964590438b3f2ddede78f3a714ac66c5 Mon Sep 17 00:00:00 2001 From: Varaprasad Date: Sat, 4 Apr 2026 00:19:05 +0530 Subject: [PATCH 1/3] fix: cache modulate_index in QwenImageTransformer2DModel to avoid per-step DtoH sync When zero_cond_t=True, the modulate_index tensor was being recreated on every transformer forward pass (once per denoising step) using: torch.tensor(list_comprehension, device=timestep.device, ...) This triggers a Python list comprehension + torch.tensor() from a Python list, which causes a cudaMemcpyAsync + cudaStreamSynchronize (DtoH sync) that forces the CPU to wait for all pending GPU kernels. Since img_shapes (which fully determines modulate_index) is fixed for the entire inference run, the resulting tensor is identical across all steps. We cache it in _modulate_index_cache keyed by (img_shapes, device), so the tensor is built only on the first step and reused thereafter. This eliminates N-1 unnecessary torch.tensor() constructions and DtoH syncs during inference (where N = num_inference_steps). This issue was identified in the profiling guide added in #13356 and referenced in #13401. Follows the same caching pattern as _compute_video_freqs in QwenEmbedRope. --- .../transformers/transformer_qwenimage.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index d88aef4dcf2a..4fa82bc85232 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -832,6 +832,13 @@ def __init__( self.gradient_checkpointing = False self.zero_cond_t = zero_cond_t + # Cache for modulate_index tensor to avoid rebuilding it on every forward pass. + # The tensor is determined solely by img_shapes (fixed during inference), so it + # only needs to be computed once per unique (img_shapes, device) combination. + # Without caching, every forward call triggers a Python list comprehension + + # torch.tensor() construction which is visible as CPU overhead in profiling traces. + self._modulate_index_cache: dict = {} + @apply_lora_scale("attention_kwargs") def forward( self, @@ -898,11 +905,19 @@ def forward( if self.zero_cond_t: timestep = torch.cat([timestep, timestep * 0], dim=0) - modulate_index = torch.tensor( - [[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in img_shapes], - device=timestep.device, - dtype=torch.int, - ) + # Cache modulate_index to avoid rebuilding it on every forward pass. + # img_shapes is fixed during inference (same across all denoising steps), + # so we can build the tensor once and reuse it, eliminating the CPU overhead + # and implicit sync from torch.tensor() on each step. + device = timestep.device + cache_key = (tuple(tuple(s) for s in img_shapes), device) + if cache_key not in self._modulate_index_cache: + self._modulate_index_cache[cache_key] = torch.tensor( + [[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in img_shapes], + device=device, + dtype=torch.int, + ) + modulate_index = self._modulate_index_cache[cache_key] else: modulate_index = None From bb43eee08333b610e305c9a2636b803030cc51ac Mon Sep 17 00:00:00 2001 From: G Srinivasa Eswara Vara Prasad <105410227+varaprasadtarunkumar@users.noreply.github.com> Date: Sat, 4 Apr 2026 00:36:04 +0530 Subject: [PATCH 2/3] Update src/diffusers/models/transformers/transformer_qwenimage.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../transformers/transformer_qwenimage.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 4fa82bc85232..031c70d360fb 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -909,15 +909,34 @@ def forward( # img_shapes is fixed during inference (same across all denoising steps), # so we can build the tensor once and reuse it, eliminating the CPU overhead # and implicit sync from torch.tensor() on each step. + # + # However, mutating a Python dict inside forward can cause graph breaks or + # repeated recompiles under torch.compile and can be problematic for + # torch.export, so disable cache reads/writes in those modes. device = timestep.device cache_key = (tuple(tuple(s) for s in img_shapes), device) - if cache_key not in self._modulate_index_cache: - self._modulate_index_cache[cache_key] = torch.tensor( + is_compile_or_export = torch.compiler.is_compiling() or ( + hasattr(torch.compiler, "is_exporting") and torch.compiler.is_exporting() + ) + + if is_compile_or_export: + modulate_index = torch.tensor( [[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in img_shapes], device=device, dtype=torch.int, ) - modulate_index = self._modulate_index_cache[cache_key] + else: + modulate_index_cache = getattr(self, "_modulate_index_cache", None) + if modulate_index_cache is None: + modulate_index_cache = {} + setattr(self, "_modulate_index_cache", modulate_index_cache) + if cache_key not in modulate_index_cache: + modulate_index_cache[cache_key] = torch.tensor( + [[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in img_shapes], + device=device, + dtype=torch.int, + ) + modulate_index = modulate_index_cache[cache_key] else: modulate_index = None From 0663f4d69564dd89b30b83a2c6ce4221bc5b7f36 Mon Sep 17 00:00:00 2001 From: Varaprasad Date: Sat, 4 Apr 2026 00:43:07 +0530 Subject: [PATCH 3/3] add: micro-benchmark script for modulate_index caching fix --- .../profiling/benchmark_modulate_index.py | 114 ++++++++++++++++++ 1 file changed, 114 insertions(+) create mode 100644 examples/profiling/benchmark_modulate_index.py diff --git a/examples/profiling/benchmark_modulate_index.py b/examples/profiling/benchmark_modulate_index.py new file mode 100644 index 000000000000..dda925448c5e --- /dev/null +++ b/examples/profiling/benchmark_modulate_index.py @@ -0,0 +1,114 @@ +""" +Micro-benchmark: modulate_index tensor creation before vs after caching fix. + +This script demonstrates the overhead of recreating the modulate_index tensor +from a Python list comprehension on every forward pass (old behaviour) vs +returning a cached tensor (new behaviour). + +Run on any machine — no GPU or model weights required: + python examples/profiling/benchmark_modulate_index.py + +For GPU results, the improvement is larger because torch.tensor() on GPU +additionally triggers cudaMemcpyAsync + cudaStreamSynchronize. +""" + +import time +from math import prod + +import torch + + +# ── Helpers ────────────────────────────────────────────────────────────────── + +def build_modulate_index(img_shapes, device): + """Original implementation: rebuilt every forward pass.""" + return torch.tensor( + [[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in img_shapes], + device=device, + dtype=torch.int, + ) + + +def build_modulate_index_cached(img_shapes, cache, device): + """Fixed implementation: built once, then looked up from cache.""" + cache_key = (tuple(tuple(s) for s in img_shapes), device) + if cache_key not in cache: + cache[cache_key] = torch.tensor( + [[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in img_shapes], + device=device, + dtype=torch.int, + ) + return cache[cache_key] + + +def timeit(fn, n=1000): + # Warmup + for _ in range(10): + fn() + torch.cuda.synchronize() if torch.cuda.is_available() else None + + t0 = time.perf_counter() + for _ in range(n): + fn() + torch.cuda.synchronize() if torch.cuda.is_available() else None + elapsed = (time.perf_counter() - t0) / n * 1e6 # µs per call + return elapsed + + +# ── Benchmark ──────────────────────────────────────────────────────────────── + +def run(device_str: str, num_inference_steps: int = 20, n_trials: int = 1000): + device = torch.device(device_str) + + # Realistic img_shapes for QwenImage at 1024x1024 with zero_cond_t=True. + # sample[0] = primary layer patches, sample[1:] = condition patches. + # patch_size=2 → 1024//2 = 512 tokens per side → 512*512 = 262144 tokens + # (simplified for demo — actual numbers depend on model config) + patch_h, patch_w = 64, 64 # 128x128 latent / 2 patch size + img_shapes = [ + [(1, patch_h, patch_w), (1, patch_h // 2, patch_w // 2)], # batch item 1 + ] + + cache: dict = {} + + # --- Pre-cache (first call, same cost as uncached) --- + _ = build_modulate_index_cached(img_shapes, cache, device) + + # --- Benchmark --- + uncached_us = timeit(lambda: build_modulate_index(img_shapes, device), n=n_trials) + cached_us = timeit(lambda: build_modulate_index_cached(img_shapes, cache, device), n=n_trials) + + speedup = uncached_us / cached_us + total_uncached_ms = uncached_us * num_inference_steps / 1e3 + total_cached_ms = cached_us * num_inference_steps / 1e3 + # First call is shared — only steps 2..N benefit + saved_ms = uncached_us * (num_inference_steps - 1) / 1e3 + + print(f"\n{'='*60}") + print(f" Device : {device_str}") + print(f" img_shapes : {img_shapes}") + print(f" num_inference_steps: {num_inference_steps}") + print(f"{'='*60}") + print(f" Per-call (uncached): {uncached_us:.2f} µs") + print(f" Per-call (cached) : {cached_us:.2f} µs") + print(f" Speedup per call : {speedup:.1f}x") + print(f"{'─'*60}") + print(f" Total over {num_inference_steps} steps (uncached): {total_uncached_ms:.3f} ms") + print(f" Total over {num_inference_steps} steps (cached) : {total_cached_ms:.3f} ms") + print(f" CPU overhead saved : {saved_ms:.3f} ms") + print(f"{'='*60}") + + # Verify outputs are identical + out_uncached = build_modulate_index(img_shapes, device) + out_cached = build_modulate_index_cached(img_shapes, cache, device) + assert torch.equal(out_uncached, out_cached), "BUG: cached and uncached tensors differ!" + print(" ✅ Output tensors are identical (correctness verified)") + print(f"{'='*60}\n") + + +if __name__ == "__main__": + run("cpu", num_inference_steps=20) + if torch.cuda.is_available(): + run("cuda", num_inference_steps=20) + else: + print("(No CUDA device found — run on a GPU machine for full DtoH sync numbers)\n")