-
Notifications
You must be signed in to change notification settings - Fork 6.9k
fix: cache modulate_index tensor to eliminate per-step DtoH sync #13404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
varaprasadtarunkumar
wants to merge
3
commits into
huggingface:main
from
varaprasadtarunkumar:fix/qwenimage-modulate-index-caching
+152
−4
Closed
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
d2d5fda
fix: cache modulate_index in QwenImageTransformer2DModel to avoid per…
varaprasadtarunkumar bb43eee
Update src/diffusers/models/transformers/transformer_qwenimage.py
varaprasadtarunkumar 0663f4d
add: micro-benchmark script for modulate_index caching fix
varaprasadtarunkumar File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| """ | ||
| Micro-benchmark: modulate_index tensor creation before vs after caching fix. | ||
|
|
||
| This script demonstrates the overhead of recreating the modulate_index tensor | ||
| from a Python list comprehension on every forward pass (old behaviour) vs | ||
| returning a cached tensor (new behaviour). | ||
|
|
||
| Run on any machine — no GPU or model weights required: | ||
| python examples/profiling/benchmark_modulate_index.py | ||
|
|
||
| For GPU results, the improvement is larger because torch.tensor() on GPU | ||
| additionally triggers cudaMemcpyAsync + cudaStreamSynchronize. | ||
| """ | ||
|
|
||
| import time | ||
| from math import prod | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| # ── Helpers ────────────────────────────────────────────────────────────────── | ||
|
|
||
| def build_modulate_index(img_shapes, device): | ||
| """Original implementation: rebuilt every forward pass.""" | ||
| return torch.tensor( | ||
| [[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in img_shapes], | ||
| device=device, | ||
| dtype=torch.int, | ||
| ) | ||
|
|
||
|
|
||
| def build_modulate_index_cached(img_shapes, cache, device): | ||
| """Fixed implementation: built once, then looked up from cache.""" | ||
| cache_key = (tuple(tuple(s) for s in img_shapes), device) | ||
| if cache_key not in cache: | ||
| cache[cache_key] = torch.tensor( | ||
| [[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in img_shapes], | ||
| device=device, | ||
| dtype=torch.int, | ||
| ) | ||
| return cache[cache_key] | ||
|
|
||
|
|
||
| def timeit(fn, n=1000): | ||
| # Warmup | ||
| for _ in range(10): | ||
| fn() | ||
| torch.cuda.synchronize() if torch.cuda.is_available() else None | ||
|
|
||
| t0 = time.perf_counter() | ||
| for _ in range(n): | ||
| fn() | ||
| torch.cuda.synchronize() if torch.cuda.is_available() else None | ||
| elapsed = (time.perf_counter() - t0) / n * 1e6 # µs per call | ||
| return elapsed | ||
|
|
||
|
|
||
| # ── Benchmark ──────────────────────────────────────────────────────────────── | ||
|
|
||
| def run(device_str: str, num_inference_steps: int = 20, n_trials: int = 1000): | ||
| device = torch.device(device_str) | ||
|
|
||
| # Realistic img_shapes for QwenImage at 1024x1024 with zero_cond_t=True. | ||
| # sample[0] = primary layer patches, sample[1:] = condition patches. | ||
| # patch_size=2 → 1024//2 = 512 tokens per side → 512*512 = 262144 tokens | ||
| # (simplified for demo — actual numbers depend on model config) | ||
| patch_h, patch_w = 64, 64 # 128x128 latent / 2 patch size | ||
| img_shapes = [ | ||
| [(1, patch_h, patch_w), (1, patch_h // 2, patch_w // 2)], # batch item 1 | ||
| ] | ||
|
|
||
| cache: dict = {} | ||
|
|
||
| # --- Pre-cache (first call, same cost as uncached) --- | ||
| _ = build_modulate_index_cached(img_shapes, cache, device) | ||
|
|
||
| # --- Benchmark --- | ||
| uncached_us = timeit(lambda: build_modulate_index(img_shapes, device), n=n_trials) | ||
| cached_us = timeit(lambda: build_modulate_index_cached(img_shapes, cache, device), n=n_trials) | ||
|
|
||
| speedup = uncached_us / cached_us | ||
| total_uncached_ms = uncached_us * num_inference_steps / 1e3 | ||
| total_cached_ms = cached_us * num_inference_steps / 1e3 | ||
| # First call is shared — only steps 2..N benefit | ||
| saved_ms = uncached_us * (num_inference_steps - 1) / 1e3 | ||
|
|
||
| print(f"\n{'='*60}") | ||
| print(f" Device : {device_str}") | ||
| print(f" img_shapes : {img_shapes}") | ||
| print(f" num_inference_steps: {num_inference_steps}") | ||
| print(f"{'='*60}") | ||
| print(f" Per-call (uncached): {uncached_us:.2f} µs") | ||
| print(f" Per-call (cached) : {cached_us:.2f} µs") | ||
| print(f" Speedup per call : {speedup:.1f}x") | ||
| print(f"{'─'*60}") | ||
| print(f" Total over {num_inference_steps} steps (uncached): {total_uncached_ms:.3f} ms") | ||
| print(f" Total over {num_inference_steps} steps (cached) : {total_cached_ms:.3f} ms") | ||
| print(f" CPU overhead saved : {saved_ms:.3f} ms") | ||
| print(f"{'='*60}") | ||
|
|
||
| # Verify outputs are identical | ||
| out_uncached = build_modulate_index(img_shapes, device) | ||
| out_cached = build_modulate_index_cached(img_shapes, cache, device) | ||
| assert torch.equal(out_uncached, out_cached), "BUG: cached and uncached tensors differ!" | ||
| print(" ✅ Output tensors are identical (correctness verified)") | ||
| print(f"{'='*60}\n") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| run("cpu", num_inference_steps=20) | ||
| if torch.cuda.is_available(): | ||
| run("cuda", num_inference_steps=20) | ||
| else: | ||
| print("(No CUDA device found — run on a GPU machine for full DtoH sync numbers)\n") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.