From af961094358821218e00e87e92f93acdfe4b0653 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 26 Mar 2026 17:01:41 +0530 Subject: [PATCH 01/27] add a profiling worflow. --- profiling/PROFILING_PLAN.md | 168 ++++++++++++++++++++++++++++ profiling/profiling_pipelines.py | 182 +++++++++++++++++++++++++++++++ profiling/profiling_utils.py | 143 ++++++++++++++++++++++++ profiling/run_profiling.sh | 39 +++++++ 4 files changed, 532 insertions(+) create mode 100644 profiling/PROFILING_PLAN.md create mode 100644 profiling/profiling_pipelines.py create mode 100644 profiling/profiling_utils.py create mode 100755 profiling/run_profiling.sh diff --git a/profiling/PROFILING_PLAN.md b/profiling/PROFILING_PLAN.md new file mode 100644 index 000000000000..c31b531936e1 --- /dev/null +++ b/profiling/PROFILING_PLAN.md @@ -0,0 +1,168 @@ +# Profiling Plan: Diffusers Pipeline Profiling with torch.profiler + +## Context + +We want to uncover CPU overhead, CPU-GPU sync points, and other bottlenecks in popular diffusers pipelines — especially issues that become non-trivial under `torch.compile`. The approach is inspired by [flux-fast's run_benchmark.py](https://github.com/huggingface/flux-fast/blob/0a1dcc91658f0df14cd7fce862a5c8842784c6da/run_benchmark.py#L66-L85) which uses `torch.profiler` with method-level annotations, and motivated by issues like [diffusers#11696](https://github.com/huggingface/diffusers/pull/11696) (DtoH sync from scheduler `.item()` call). + +## Target Pipelines + +| Pipeline | Type | Checkpoint | Steps | +|----------|------|-----------|-------| +| `FluxPipeline` | text-to-image | `black-forest-labs/FLUX.1-dev` | 4 | +| `Flux2Pipeline` | text-to-image | `black-forest-labs/FLUX.2-dev` | 4 | +| `WanPipeline` | text-to-video | `Wan-AI/Wan2.1-T2V-14B-Diffusers` | 4 | +| `LTX2Pipeline` | text-to-video | `Lightricks/LTX-2` | 4 | +| `QwenImagePipeline` | text-to-image | `Qwen/Qwen-Image` | 4 | + +## Approach + +Follow the flux-fast pattern: **annotate key pipeline methods** with `torch.profiler.record_function` wrappers, then run the pipeline under `torch.profiler.profile` and export a Chrome trace. + +### New Files + +``` +profiling/ + profiling_utils.py # Annotation helper + profiler setup + profiling_pipelines.py # CLI entry point with pipeline configs +``` + +### Step 1: `profiling_utils.py` — Annotation and Profiler Infrastructure + +**A) `annotate(func, name)` helper** (same pattern as flux-fast): + +```python +def annotate(func, name): + """Wrap a function with torch.profiler.record_function for trace annotation.""" + @functools.wraps(func) + def wrapper(*args, **kwargs): + with torch.profiler.record_function(name): + return func(*args, **kwargs) + return wrapper +``` + +**B) `annotate_pipeline(pipe)` function** — applies annotations to key methods on any pipeline: + +- `pipe.transformer.forward` → `"transformer_forward"` +- `pipe.vae.decode` → `"vae_decode"` (if present) +- `pipe.vae.encode` → `"vae_encode"` (if present) +- `pipe.scheduler.step` → `"scheduler_step"` +- `pipe.encode_prompt` → `"encode_prompt"` (if present, for full-pipeline profiling) + +This is non-invasive — it monkey-patches bound methods without modifying source. + +**C) `PipelineProfiler` class:** + +- `__init__(pipeline_config, output_dir, mode="eager"|"compile")` +- `setup_pipeline()` → loads from pretrained, optionally compiles transformer, calls `annotate_pipeline()` +- `run()`: + 1. Warm up with 1 unannotated run + 2. Profile 1 run with `torch.profiler.profile`: + - `activities=[CPU, CUDA]` + - `record_shapes=True` + - `profile_memory=True` + - `with_stack=True` + 3. Export Chrome trace JSON + 4. Print `key_averages()` summary table (sorted by CUDA time) to stdout + +### Step 2: `profiling_pipelines.py` — CLI with Pipeline Configs + +**Pipeline config registry** — each entry specifies: + +- `pipeline_cls`, `pretrained_model_name_or_path`, `torch_dtype` +- `call_kwargs` with pipeline-specific defaults: + +| Pipeline | Resolution | Frames | Steps | Extra | +|----------|-----------|--------|-------|-------| +| Flux | 1024x1024 | — | 4 | `guidance_scale=3.5` | +| Flux2 | 1024x1024 | — | 4 | `guidance_scale=3.5` | +| Wan | 480x832 | 81 | 4 | — | +| LTX2 | 768x512 | 121 | 4 | `guidance_scale=4.0` | +| QwenImage | 1024x1024 | — | 4 | `true_cfg_scale=4.0` | + +All configs use `output_type="latent"` by default (skip VAE decode for cleaner denoising-loop traces). + +**CLI flags:** + +- `--pipeline flux|flux2|wan|ltx2|qwenimage|all` +- `--mode eager|compile|both` +- `--output_dir profiling_results/` +- `--num_steps N` (override, default 4) +- `--full_decode` (switch output_type from `"latent"` to `"pil"` to include VAE) +- `--compile_mode default|reduce-overhead|max-autotune` +- `--compile_fullgraph` flag + +**Output:** `{output_dir}/{pipeline}_{mode}.json` Chrome trace + stdout summary. + +### Step 3: Known Sync Issues to Validate + +The profiling should surface these known/suspected issues: + +1. **Scheduler DtoH sync via `nonzero().item()`** — For Flux, this was fixed by adding `scheduler.set_begin_index(0)` before the denoising loop ([diffusers#11696](https://github.com/huggingface/diffusers/pull/11696)). Profiling should reveal whether similar sync points exist in other pipelines. + +2. **`modulate_index` tensor rebuilt every forward in `transformer_qwenimage.py`** (line 901-905) — Python list comprehension + `torch.tensor()` each step. Minor but visible in trace. + +3. **Any other `.item()`, `.cpu()`, `.numpy()` calls** in the denoising loop hot path — the profiler's `with_stack=True` will surface these as CPU stalls with Python stack traces. + +## Verification + +1. Run: `python profiling/profiling_pipelines.py --pipeline flux --mode eager --num_steps 4` +2. Verify `profiling_results/flux_eager.json` is produced +3. Open trace in [Perfetto UI](https://ui.perfetto.dev/) — confirm: + - `transformer_forward` and `scheduler_step` annotations visible + - CPU and CUDA timelines present + - Stack traces visible on CPU events +4. Run with `--mode compile` and compare trace for fewer/fused CUDA kernels + +## Interpreting Traces in Perfetto UI + +Open the exported `.json` trace at [ui.perfetto.dev](https://ui.perfetto.dev/). The trace has two main rows: **CPU** (top) and **CUDA** (bottom). + +### What to look for + +**1. Gaps between CUDA kernels** + +Zoom into the CUDA row during the denoising loop. Ideally, GPU kernels should be back-to-back with no gaps. Gaps mean the GPU is idle waiting for the CPU to launch the next kernel. Common causes: +- Python overhead between ops (visible as CPU slices in the CPU row during the gap) +- DtoH sync (`.item()`, `.cpu()`) forcing the GPU to drain before the CPU can proceed + +**2. CPU stalls (DtoH syncs)** + +Look for long CPU slices labeled `cudaStreamSynchronize` or `cudaDeviceSynchronize`. Click on them — if `with_stack=True` was enabled, the bottom panel shows the Python stack trace pointing to the exact line causing the sync (e.g., a `.item()` call in the scheduler). + +**3. Annotated regions** + +Our `record_function` annotations (`transformer_forward`, `scheduler_step`, etc.) appear as labeled spans on the CPU row. This lets you quickly: +- Measure how long each phase takes (click a span to see duration) +- See if `scheduler_step` is disproportionately expensive relative to `transformer_forward` (it should be negligible) +- Spot unexpected CPU work between annotated regions + +**4. Eager vs compile comparison** + +Open both traces side by side (two Perfetto tabs). Key differences to look for: +- **Fewer, wider CUDA kernels** in compile mode (fused ops) vs many small kernels in eager +- **Smaller CPU gaps** between kernels in compile mode (less Python dispatch overhead) +- **Graph breaks**: if compile mode still shows many small kernels in a section, that section likely has a graph break — check `TORCH_LOGS="+dynamo"` output for details + +**5. Memory timeline** + +In Perfetto, look for the memory counter track (if `profile_memory=True`). Spikes during the denoising loop suggest unexpected allocations per step. Steady-state memory during denoising is expected — growing memory is not. + +**6. Kernel launch latency** + +Each CUDA kernel is launched from the CPU. In Perfetto, you can see the CPU-side launch call (e.g., `cudaLaunchKernel`) and the corresponding GPU-side kernel execution. The time between the CPU dispatch and the GPU kernel starting should be minimal (single-digit microseconds). If you see consistent delays > 10-20us between launch and execution: +- The launch queue may be starved because of excessive Python work between ops +- There may be implicit syncs forcing serialization +- `torch.compile` should help here by batching launches — compare eager vs compile to confirm + +To inspect this: zoom into a single denoising step, select a CUDA kernel on the GPU row, and look at the corresponding CPU-side launch slice directly above it. The horizontal offset between them is the launch latency. In a healthy trace, CPU launch slices should be well ahead of GPU execution (the CPU is "feeding" the GPU faster than it can consume). + +### Quick checklist per pipeline + +| Question | Where to look | Healthy | Unhealthy | +|----------|--------------|---------|-----------| +| GPU staying busy? | CUDA row gaps | Back-to-back kernels | Frequent gaps > 100us | +| CPU blocking on GPU? | `cudaStreamSynchronize` slices | Rare/absent during denoise | Present every step | +| Scheduler overhead? | `scheduler_step` span duration | < 1% of step time | > 5% of step time | +| Compile effective? | CUDA kernel count per step | Fewer large kernels | Same as eager | +| Kernel launch latency? | CPU launch → GPU kernel offset | < 10us, CPU ahead of GPU | > 20us or CPU trailing GPU | +| Memory stable? | Memory counter track | Flat during denoise loop | Growing per step | diff --git a/profiling/profiling_pipelines.py b/profiling/profiling_pipelines.py new file mode 100644 index 000000000000..eddbba24bd05 --- /dev/null +++ b/profiling/profiling_pipelines.py @@ -0,0 +1,182 @@ +""" +Profile diffusers pipelines with torch.profiler. + +Usage: + python profiling/profiling_pipelines.py --pipeline flux --mode eager + python profiling/profiling_pipelines.py --pipeline flux --mode compile + python profiling/profiling_pipelines.py --pipeline flux --mode both + python profiling/profiling_pipelines.py --pipeline all --mode eager + python profiling/profiling_pipelines.py --pipeline wan --mode eager --full_decode + python profiling/profiling_pipelines.py --pipeline flux --mode compile --num_steps 4 +""" + +import argparse +import copy +import logging + +import torch + +from profiling_utils import PipelineProfiler, PipelineProfilingConfig + + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") +logger = logging.getLogger(__name__) + +PROMPT = "A cat holding a sign that says hello world" + + +def build_registry(): + """Build the pipeline config registry. Imports are deferred to avoid loading all pipelines upfront.""" + from diffusers import FluxPipeline, Flux2Pipeline, WanPipeline, LTX2Pipeline, QwenImagePipeline + + return { + "flux": PipelineProfilingConfig( + name="flux", + pipeline_cls=FluxPipeline, + pipeline_init_kwargs={ + "pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev", + "torch_dtype": torch.bfloat16, + }, + pipeline_call_kwargs={ + "prompt": PROMPT, + "height": 1024, + "width": 1024, + "num_inference_steps": 4, + "guidance_scale": 3.5, + "output_type": "latent", + }, + ), + "flux2": PipelineProfilingConfig( + name="flux2", + pipeline_cls=Flux2Pipeline, + pipeline_init_kwargs={ + "pretrained_model_name_or_path": "black-forest-labs/FLUX.2-klein-base-9B", + "torch_dtype": torch.bfloat16, + }, + pipeline_call_kwargs={ + "prompt": PROMPT, + "height": 1024, + "width": 1024, + "num_inference_steps": 4, + "guidance_scale": 3.5, + "output_type": "latent", + }, + ), + "wan": PipelineProfilingConfig( + name="wan", + pipeline_cls=WanPipeline, + pipeline_init_kwargs={ + "pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers", + "torch_dtype": torch.bfloat16, + }, + pipeline_call_kwargs={ + "prompt": PROMPT, + "negative_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards", + "height": 480, + "width": 832, + "num_frames": 81, + "num_inference_steps": 4, + "output_type": "latent", + }, + ), + "ltx2": PipelineProfilingConfig( + name="ltx2", + pipeline_cls=LTX2Pipeline, + pipeline_init_kwargs={ + "pretrained_model_name_or_path": "Lightricks/LTX-2", + "torch_dtype": torch.bfloat16, + }, + pipeline_call_kwargs={ + "prompt": PROMPT, + "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted", + "height": 512, + "width": 768, + "num_frames": 121, + "num_inference_steps": 4, + "guidance_scale": 4.0, + "output_type": "latent", + }, + ), + "qwenimage": PipelineProfilingConfig( + name="qwenimage", + pipeline_cls=QwenImagePipeline, + pipeline_init_kwargs={ + "pretrained_model_name_or_path": "Qwen/Qwen-Image", + "torch_dtype": torch.bfloat16, + }, + pipeline_call_kwargs={ + "prompt": PROMPT, + "negative_prompt": " ", + "height": 1024, + "width": 1024, + "num_inference_steps": 4, + "true_cfg_scale": 4.0, + "output_type": "latent", + }, + ), + } + + +def main(): + parser = argparse.ArgumentParser(description="Profile diffusers pipelines with torch.profiler") + parser.add_argument( + "--pipeline", + choices=["flux", "flux2", "wan", "ltx2", "qwenimage", "all"], + required=True, + help="Which pipeline to profile", + ) + parser.add_argument( + "--mode", + choices=["eager", "compile", "both"], + default="eager", + help="Run in eager mode, compile mode, or both", + ) + parser.add_argument("--output_dir", default="profiling_results", help="Directory for trace output") + parser.add_argument("--num_steps", type=int, default=None, help="Override num_inference_steps") + parser.add_argument("--full_decode", action="store_true", help="Profile including VAE decode (output_type='pil')") + parser.add_argument( + "--compile_mode", + default="default", + choices=["default", "reduce-overhead", "max-autotune"], + help="torch.compile mode", + ) + parser.add_argument("--compile_fullgraph", action="store_true", help="Use fullgraph=True for torch.compile") + parser.add_argument( + "--compile_regional", + action="store_true", + help="Use compile_repeated_blocks() instead of full model compile", + ) + args = parser.parse_args() + + registry = build_registry() + + pipeline_names = list(registry.keys()) if args.pipeline == "all" else [args.pipeline] + modes = ["eager", "compile"] if args.mode == "both" else [args.mode] + + for pipeline_name in pipeline_names: + for mode in modes: + config = copy.deepcopy(registry[pipeline_name]) + + # Apply overrides + if args.num_steps is not None: + config.pipeline_call_kwargs["num_inference_steps"] = args.num_steps + if args.full_decode: + config.pipeline_call_kwargs["output_type"] = "pil" + if mode == "compile": + config.compile_kwargs = { + "fullgraph": args.compile_fullgraph, + "mode": args.compile_mode, + } + config.compile_regional = args.compile_regional + + logger.info(f"Profiling {pipeline_name} in {mode} mode...") + profiler = PipelineProfiler(config, args.output_dir) + try: + trace_file = profiler.run() + logger.info(f"Done: {trace_file}") + except Exception as e: + logger.error(f"Failed to profile {pipeline_name} ({mode}): {e}") + + +if __name__ == "__main__": + main() diff --git a/profiling/profiling_utils.py b/profiling/profiling_utils.py new file mode 100644 index 000000000000..9f9417af270b --- /dev/null +++ b/profiling/profiling_utils.py @@ -0,0 +1,143 @@ +import functools +import gc +import logging +import os +from dataclasses import dataclass, field +from typing import Any + +import torch +import torch.profiler + + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") +logger = logging.getLogger(__name__) + + +def annotate(func, name): + """Wrap a function with torch.profiler.record_function for trace annotation.""" + + @functools.wraps(func) + def wrapper(*args, **kwargs): + with torch.profiler.record_function(name): + return func(*args, **kwargs) + + return wrapper + + +def annotate_pipeline(pipe): + """Apply profiler annotations to key pipeline methods. + + Monkey-patches bound methods so they appear as named spans in the trace. + Non-invasive — no source modifications required. + """ + annotations = [ + ("transformer", "forward", "transformer_forward"), + ("vae", "decode", "vae_decode"), + ("vae", "encode", "vae_encode"), + ("scheduler", "step", "scheduler_step"), + ] + + # Annotate sub-component methods + for component_name, method_name, label in annotations: + component = getattr(pipe, component_name, None) + if component is None: + continue + method = getattr(component, method_name, None) + if method is None: + continue + setattr(component, method_name, annotate(method, label)) + + # Annotate pipeline-level methods + if hasattr(pipe, "encode_prompt"): + pipe.encode_prompt = annotate(pipe.encode_prompt, "encode_prompt") + + +def flush(): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + +@dataclass +class PipelineProfilingConfig: + name: str + pipeline_cls: Any + pipeline_init_kwargs: dict[str, Any] + pipeline_call_kwargs: dict[str, Any] + compile_kwargs: dict[str, Any] | None = field(default=None) + compile_regional: bool = False + + +class PipelineProfiler: + def __init__(self, config: PipelineProfilingConfig, output_dir: str = "profiling_results"): + self.config = config + self.output_dir = output_dir + os.makedirs(output_dir, exist_ok=True) + + def setup_pipeline(self): + """Load the pipeline from pretrained, optionally compile, and annotate.""" + logger.info(f"Loading pipeline: {self.config.name}") + pipe = self.config.pipeline_cls.from_pretrained(**self.config.pipeline_init_kwargs) + pipe.to("cuda") + + if self.config.compile_kwargs: + if self.config.compile_regional: + logger.info(f"Regional compilation (compile_repeated_blocks) with kwargs: {self.config.compile_kwargs}") + pipe.transformer.compile_repeated_blocks(**self.config.compile_kwargs) + else: + logger.info(f"Full compilation with kwargs: {self.config.compile_kwargs}") + pipe.transformer.compile(**self.config.compile_kwargs) + + annotate_pipeline(pipe) + return pipe + + def run(self): + """Execute the profiling run: warmup, then profile one pipeline call.""" + pipe = self.setup_pipeline() + flush() + + mode = "compile" if self.config.compile_kwargs else "eager" + trace_file = os.path.join(self.output_dir, f"{self.config.name}_{mode}.json") + + # Warmup (pipeline __call__ is already decorated with @torch.no_grad()) + logger.info("Running warmup...") + pipe(**self.config.pipeline_call_kwargs) + flush() + + # Profile + logger.info("Running profiled iteration...") + activities = [ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ] + with torch.profiler.profile( + activities=activities, + record_shapes=True, + profile_memory=True, + with_stack=True, + ) as prof: + with torch.profiler.record_function("pipeline_call"): + pipe(**self.config.pipeline_call_kwargs) + + # Export trace + prof.export_chrome_trace(trace_file) + logger.info(f"Chrome trace saved to: {trace_file}") + + # Print summary + print("\n" + "=" * 80) + print(f"Profile summary: {self.config.name} ({mode})") + print("=" * 80) + print( + prof.key_averages().table( + sort_by="cuda_time_total", + row_limit=20, + ) + ) + + # Cleanup + pipe.to("cpu") + del pipe + flush() + + return trace_file diff --git a/profiling/run_profiling.sh b/profiling/run_profiling.sh new file mode 100755 index 000000000000..651fa0355c4c --- /dev/null +++ b/profiling/run_profiling.sh @@ -0,0 +1,39 @@ +#!/bin/bash +# Run profiling across all pipelines in eager and compile (regional) modes. +# +# Usage: +# bash profiling/run_profiling.sh +# bash profiling/run_profiling.sh --output_dir my_results + +set -euo pipefail + +OUTPUT_DIR="${1:-profiling_results}" +NUM_STEPS=2 +PIPELINES=("flux" "flux2" "wan" "ltx2" "qwenimage") +MODES=("eager" "compile") + +for pipeline in "${PIPELINES[@]}"; do + for mode in "${MODES[@]}"; do + echo "============================================================" + echo "Profiling: ${pipeline} | mode: ${mode}" + echo "============================================================" + + COMPILE_ARGS="" + if [ "$mode" = "compile" ]; then + COMPILE_ARGS="--compile_regional --compile_fullgraph --compile_mode default" + fi + + python profiling/profiling_pipelines.py \ + --pipeline "$pipeline" \ + --mode "$mode" \ + --output_dir "$OUTPUT_DIR" \ + --num_steps "$NUM_STEPS" \ + $COMPILE_ARGS + + echo "" + done +done + +echo "============================================================" +echo "All traces saved to: ${OUTPUT_DIR}/" +echo "============================================================" From eddef12a5437f0a52315db4a7946e487c0e71b5a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 27 Mar 2026 09:13:39 +0530 Subject: [PATCH 02/27] fix --- profiling/profiling_pipelines.py | 4 ++-- profiling/run_profiling.sh | 11 +++++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/profiling/profiling_pipelines.py b/profiling/profiling_pipelines.py index eddbba24bd05..d19540014e11 100644 --- a/profiling/profiling_pipelines.py +++ b/profiling/profiling_pipelines.py @@ -27,7 +27,7 @@ def build_registry(): """Build the pipeline config registry. Imports are deferred to avoid loading all pipelines upfront.""" - from diffusers import FluxPipeline, Flux2Pipeline, WanPipeline, LTX2Pipeline, QwenImagePipeline + from diffusers import FluxPipeline, Flux2KleinPipeline, WanPipeline, LTX2Pipeline, QwenImagePipeline return { "flux": PipelineProfilingConfig( @@ -48,7 +48,7 @@ def build_registry(): ), "flux2": PipelineProfilingConfig( name="flux2", - pipeline_cls=Flux2Pipeline, + pipeline_cls=Flux2KleinPipeline, pipeline_init_kwargs={ "pretrained_model_name_or_path": "black-forest-labs/FLUX.2-klein-base-9B", "torch_dtype": torch.bfloat16, diff --git a/profiling/run_profiling.sh b/profiling/run_profiling.sh index 651fa0355c4c..a6be682e929d 100755 --- a/profiling/run_profiling.sh +++ b/profiling/run_profiling.sh @@ -7,9 +7,16 @@ set -euo pipefail -OUTPUT_DIR="${1:-profiling_results}" +OUTPUT_DIR="profiling_results" +while [[ $# -gt 0 ]]; do + case "$1" in + --output_dir) OUTPUT_DIR="$2"; shift 2 ;; + *) echo "Unknown arg: $1"; exit 1 ;; + esac +done NUM_STEPS=2 -PIPELINES=("flux" "flux2" "wan" "ltx2" "qwenimage") +# PIPELINES=("flux" "flux2" "wan" "ltx2" "qwenimage") +PIPELINES=("flux2") MODES=("eager" "compile") for pipeline in "${PIPELINES[@]}"; do From e4d6293b4df49224c68a2bfaf9d1b13518835ac2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 27 Mar 2026 09:17:50 +0530 Subject: [PATCH 03/27] fix --- profiling/PROFILING_PLAN.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/profiling/PROFILING_PLAN.md b/profiling/PROFILING_PLAN.md index c31b531936e1..6526c6af8a0d 100644 --- a/profiling/PROFILING_PLAN.md +++ b/profiling/PROFILING_PLAN.md @@ -9,7 +9,7 @@ We want to uncover CPU overhead, CPU-GPU sync points, and other bottlenecks in p | Pipeline | Type | Checkpoint | Steps | |----------|------|-----------|-------| | `FluxPipeline` | text-to-image | `black-forest-labs/FLUX.1-dev` | 4 | -| `Flux2Pipeline` | text-to-image | `black-forest-labs/FLUX.2-dev` | 4 | +| `Flux2KleinPipeline` | text-to-image | `black-forest-labs/FLUX.2-klein-base-9B` | 4 | | `WanPipeline` | text-to-video | `Wan-AI/Wan2.1-T2V-14B-Diffusers` | 4 | | `LTX2Pipeline` | text-to-video | `Lightricks/LTX-2` | 4 | | `QwenImagePipeline` | text-to-image | `Qwen/Qwen-Image` | 4 | @@ -74,7 +74,7 @@ This is non-invasive — it monkey-patches bound methods without modifying sourc | Pipeline | Resolution | Frames | Steps | Extra | |----------|-----------|--------|-------|-------| | Flux | 1024x1024 | — | 4 | `guidance_scale=3.5` | -| Flux2 | 1024x1024 | — | 4 | `guidance_scale=3.5` | +| Flux2Klein | 1024x1024 | — | 4 | `guidance_scale=3.5` | | Wan | 480x832 | 81 | 4 | — | | LTX2 | 768x512 | 121 | 4 | `guidance_scale=4.0` | | QwenImage | 1024x1024 | — | 4 | `true_cfg_scale=4.0` | From b2b6330a54d43804fb37396b7c45f16ec69e9d5b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 27 Mar 2026 10:33:09 +0530 Subject: [PATCH 04/27] more clarification --- profiling/PROFILING_PLAN.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/profiling/PROFILING_PLAN.md b/profiling/PROFILING_PLAN.md index 6526c6af8a0d..2926747434df 100644 --- a/profiling/PROFILING_PLAN.md +++ b/profiling/PROFILING_PLAN.md @@ -115,7 +115,9 @@ The profiling should surface these known/suspected issues: ## Interpreting Traces in Perfetto UI -Open the exported `.json` trace at [ui.perfetto.dev](https://ui.perfetto.dev/). The trace has two main rows: **CPU** (top) and **CUDA** (bottom). +Open the exported `.json` trace at [ui.perfetto.dev](https://ui.perfetto.dev/). The trace has two main rows: **CPU** (top) and **CUDA** (bottom). In Perfetto, the CPU row is typically labeled with the process/thread name (e.g., `python (PID)` or `MainThread`) and appears at the top. The CUDA row is labeled `GPU 0` (or similar) and appears below the CPU rows. + +**Navigation:** Use `W` to zoom in, `S` to zoom out, and `A`/`D` to pan left/right. You can also scroll to zoom and click-drag to pan. Use `Shift+scroll` to scroll vertically through rows. ### What to look for @@ -127,7 +129,7 @@ Zoom into the CUDA row during the denoising loop. Ideally, GPU kernels should be **2. CPU stalls (DtoH syncs)** -Look for long CPU slices labeled `cudaStreamSynchronize` or `cudaDeviceSynchronize`. Click on them — if `with_stack=True` was enabled, the bottom panel shows the Python stack trace pointing to the exact line causing the sync (e.g., a `.item()` call in the scheduler). +These appear on the **CPU row** (not the CUDA row) — they are CPU-side blocking calls that wait for the GPU to finish. Look for long slices labeled `cudaStreamSynchronize` or `cudaDeviceSynchronize`. To find them: zoom into the CPU row during a denoising step and look for unusually wide slices, or use Perfetto's search bar (press `/`) and type `cudaStreamSynchronize` to jump directly to matching events. Click on a slice — if `with_stack=True` was enabled, the bottom panel ("Current Selection") shows the Python stack trace pointing to the exact line causing the sync (e.g., a `.item()` call in the scheduler). **3. Annotated regions** @@ -149,7 +151,7 @@ In Perfetto, look for the memory counter track (if `profile_memory=True`). Spike **6. Kernel launch latency** -Each CUDA kernel is launched from the CPU. In Perfetto, you can see the CPU-side launch call (e.g., `cudaLaunchKernel`) and the corresponding GPU-side kernel execution. The time between the CPU dispatch and the GPU kernel starting should be minimal (single-digit microseconds). If you see consistent delays > 10-20us between launch and execution: +Each CUDA kernel is launched from the CPU. The CPU-side launch calls (`cudaLaunchKernel`) appear as small slices on the **CPU row** — zoom in closely to a denoising step to see them. The corresponding GPU-side kernel executions appear on the **CUDA row** directly below. You can also use Perfetto's search bar (`/`) and type `cudaLaunchKernel` to find them. The time between the CPU dispatch and the GPU kernel starting should be minimal (single-digit microseconds). If you see consistent delays > 10-20us between launch and execution: - The launch queue may be starved because of excessive Python work between ops - There may be implicit syncs forcing serialization - `torch.compile` should help here by batching launches — compare eager vs compile to confirm From 60d4148529849bc53ecf5b314dab9497cd169415 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 27 Mar 2026 11:10:17 +0530 Subject: [PATCH 05/27] add points. --- profiling/PROFILING_PLAN.md | 1 + 1 file changed, 1 insertion(+) diff --git a/profiling/PROFILING_PLAN.md b/profiling/PROFILING_PLAN.md index 2926747434df..7b119dd2a53d 100644 --- a/profiling/PROFILING_PLAN.md +++ b/profiling/PROFILING_PLAN.md @@ -143,6 +143,7 @@ Our `record_function` annotations (`transformer_forward`, `scheduler_step`, etc. Open both traces side by side (two Perfetto tabs). Key differences to look for: - **Fewer, wider CUDA kernels** in compile mode (fused ops) vs many small kernels in eager - **Smaller CPU gaps** between kernels in compile mode (less Python dispatch overhead) +- **CUDA kernel count per step**: to compare, zoom into a single `transformer_forward` span on the CUDA row and count the distinct kernel slices within it. In eager mode you'll typically see many narrow slices (one per op); in compile mode these fuse into fewer, wider slices. A quick way to estimate: select a time range covering one denoising step on the CUDA row — Perfetto shows the number of slices in the selection summary at the bottom. If compile mode shows a similar kernel count to eager, fusion isn't happening effectively (likely due to graph breaks). - **Graph breaks**: if compile mode still shows many small kernels in a section, that section likely has a graph break — check `TORCH_LOGS="+dynamo"` output for details **5. Memory timeline** From 179fa5134208a9b59c7682f2bf17a67a183f9a1d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 27 Mar 2026 11:44:21 +0530 Subject: [PATCH 06/27] up --- profiling/profiling_utils.py | 3 +++ .../pipelines/flux2/pipeline_flux2_klein.py | 14 ++++++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/profiling/profiling_utils.py b/profiling/profiling_utils.py index 9f9417af270b..cd30f912f938 100644 --- a/profiling/profiling_utils.py +++ b/profiling/profiling_utils.py @@ -89,6 +89,9 @@ def setup_pipeline(self): logger.info(f"Full compilation with kwargs: {self.config.compile_kwargs}") pipe.transformer.compile(**self.config.compile_kwargs) + # Disable tqdm progress bar to avoid CPU overhead / IO between steps + pipe.set_progress_bar_config(disable=True) + annotate_pipeline(pipe) return pipe diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py index 936d2c3804ab..b1cf3f4c9cb4 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py @@ -397,7 +397,9 @@ def _pack_latents(latents): @staticmethod # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids - def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: + def _unpack_latents_with_ids( + x: torch.Tensor, x_ids: torch.Tensor, height: int | None = None, width: int | None = None + ) -> list[torch.Tensor]: """ using position ids to scatter tokens into place """ @@ -407,8 +409,9 @@ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch h_ids = pos[:, 1].to(torch.int64) w_ids = pos[:, 2].to(torch.int64) - h = torch.max(h_ids) + 1 - w = torch.max(w_ids) + 1 + # Use provided height/width to avoid DtoH sync from torch.max().item() + h = height if height is not None else torch.max(h_ids) + 1 + w = width if width is not None else torch.max(w_ids) + 1 flat_ids = h_ids * w + w_ids @@ -895,7 +898,10 @@ def __call__( self._current_timestep = None - latents = self._unpack_latents_with_ids(latents, latent_ids) + # Pass pre-computed latent height/width to avoid DtoH sync from torch.max().item() + latent_height = 2 * (int(height) // (self.vae_scale_factor * 2)) + latent_width = 2 * (int(width) // (self.vae_scale_factor * 2)) + latents = self._unpack_latents_with_ids(latents, latent_ids, latent_height // 2, latent_width // 2) latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( From 96506c85d009b750c7c44cd70c689a13dfa0d111 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 27 Mar 2026 12:24:11 +0530 Subject: [PATCH 07/27] cache hooks --- src/diffusers/hooks/hooks.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 2d575b85427c..474cc4343cee 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -271,12 +271,31 @@ def _set_context(self, name: str | None = None) -> None: if hook._is_stateful: hook._set_context(self._module_ref, name) + for registry in self._get_child_registries(): + registry._set_context(name) + + def _get_child_registries(self) -> list["HookRegistry"]: + """Return registries of child modules, using a cached list when available. + + The cache is built on first call and reused for subsequent calls. This avoids the cost of walking the full + module tree via named_modules() on every _set_context call, which is significant for large models (e.g. ~2.7ms + per call on Flux2). + """ + if not hasattr(self, "_child_registries_cache"): + self._child_registries_cache = None + + if self._child_registries_cache is not None: + return self._child_registries_cache + + registries = [] for module_name, module in unwrap_module(self._module_ref).named_modules(): if module_name == "": continue module = unwrap_module(module) if hasattr(module, "_diffusers_hook"): - module._diffusers_hook._set_context(name) + registries.append(module._diffusers_hook) + self._child_registries_cache = registries + return registries def __repr__(self) -> str: registry_repr = "" From 6a23a771aa0d7d0a90b5ce43d1d9f3179af1245e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 27 Mar 2026 12:51:09 +0530 Subject: [PATCH 08/27] improve readme. --- examples/profiling/README.md | 195 ++++++++++++++++++++++ examples/profiling/profiling_pipelines.py | 182 ++++++++++++++++++++ examples/profiling/profiling_utils.py | 146 ++++++++++++++++ examples/profiling/run_profiling.sh | 46 +++++ 4 files changed, 569 insertions(+) create mode 100644 examples/profiling/README.md create mode 100644 examples/profiling/profiling_pipelines.py create mode 100644 examples/profiling/profiling_utils.py create mode 100755 examples/profiling/run_profiling.sh diff --git a/examples/profiling/README.md b/examples/profiling/README.md new file mode 100644 index 000000000000..5c444f48dfef --- /dev/null +++ b/examples/profiling/README.md @@ -0,0 +1,195 @@ +# Profiling Plan: Diffusers Pipeline Profiling with torch.profiler + +Education materials to strategically profile pipelines to potentially improve their +runtime with `torch.compile`. To set these pipelines up for success with `torch.compile`, +we often have to get rid of DtoH syncs, CPU overheads, kernel launch delays, and +graph breaks. In this context, profiling serves that purpose for us. + +Thanks to Claude Code for paircoding! We acknowledge the [Claude of OSS](https://claude.com/contact-sales/claude-for-oss) support provided to us. + +## Table of contents + +* [Context](#context) +* [Target pipelines](#target-pipelines) +* [Approach taken](#approach) +* [Verification](#verification) +* [Interpretation](#interpreting-traces-in-perfetto-ui) + +Jump to the "Verification" section to get started right away. + +## Context + +We want to uncover CPU overhead, CPU-GPU sync points, and other bottlenecks in popular diffusers pipelines — especially issues that become non-trivial under `torch.compile`. The approach is inspired by [flux-fast's run_benchmark.py](https://github.com/huggingface/flux-fast/blob/0a1dcc91658f0df14cd7fce862a5c8842784c6da/run_benchmark.py#L66-L85) which uses `torch.profiler` with method-level annotations, and motivated by issues like [diffusers#11696](https://github.com/huggingface/diffusers/pull/11696) (DtoH sync from scheduler `.item()` call). + +## Target Pipelines + +| Pipeline | Type | Checkpoint | Steps | +|----------|------|-----------|-------| +| `FluxPipeline` | text-to-image | `black-forest-labs/FLUX.1-dev` | 2 | +| `Flux2KleinPipeline` | text-to-image | `black-forest-labs/FLUX.2-klein-base-9B` | 2 | +| `WanPipeline` | text-to-video | `Wan-AI/Wan2.1-T2V-14B-Diffusers` | 2 | +| `LTX2Pipeline` | text-to-video | `Lightricks/LTX-2` | 2 | +| `QwenImagePipeline` | text-to-image | `Qwen/Qwen-Image` | 2 | + +> [!NOTE] +> We use realistic inference call hyperparameters that mimic how these pipelines will be actually used. This +> include using classifier-free guidance (where applicable), reasonable dimensions such 1024x1024, etc. +> But we keep the overall running time to a bare minimum (hence 2 `num_inference_steps`). + +## Approach + +Follow the flux-fast pattern: **annotate key pipeline methods** with `torch.profiler.record_function` wrappers, then run the pipeline under `torch.profiler.profile` and export a Chrome trace. + +### New Files + +``` +profiling/ + profiling_utils.py # Annotation helper + profiler setup + profiling_pipelines.py # CLI entry point with pipeline configs +``` + +### Step 1: `profiling_utils.py` — Annotation and Profiler Infrastructure + +**A) `annotate(func, name)` helper** (same pattern as flux-fast): + +```python +def annotate(func, name): + """Wrap a function with torch.profiler.record_function for trace annotation.""" + @functools.wraps(func) + def wrapper(*args, **kwargs): + with torch.profiler.record_function(name): + return func(*args, **kwargs) + return wrapper +``` + +**B) `annotate_pipeline(pipe)` function** — applies annotations to key methods on any pipeline: + +- `pipe.transformer.forward` → `"transformer_forward"` +- `pipe.vae.decode` → `"vae_decode"` (if present) +- `pipe.vae.encode` → `"vae_encode"` (if present) +- `pipe.scheduler.step` → `"scheduler_step"` +- `pipe.encode_prompt` → `"encode_prompt"` (if present, for full-pipeline profiling) + +This is non-invasive — it monkey-patches bound methods without modifying source. + +**C) `PipelineProfiler` class:** + +- `__init__(pipeline_config, output_dir, mode="eager"|"compile")` +- `setup_pipeline()` → loads from pretrained, optionally compiles transformer, calls `annotate_pipeline()` +- `run()`: + 1. Warm up with 1 unannotated run + 2. Profile 1 run with `torch.profiler.profile`: + - `activities=[CPU, CUDA]` + - `record_shapes=True` + - `profile_memory=True` + - `with_stack=True` + 3. Export Chrome trace JSON + 4. Print `key_averages()` summary table (sorted by CUDA time) to stdout + +### Step 2: `profiling_pipelines.py` — CLI with Pipeline Configs + +**Pipeline config registry** — each entry specifies: + +- `pipeline_cls`, `pretrained_model_name_or_path`, `torch_dtype` +- `call_kwargs` with pipeline-specific defaults: + +| Pipeline | Resolution | Frames | Steps | Extra | +|----------|-----------|--------|-------|-------| +| Flux | 1024x1024 | — | 2 | `guidance_scale=3.5` | +| Flux2Klein | 1024x1024 | — | 2 | `guidance_scale=3.5` | +| Wan | 480x832 | 81 | 2 | — | +| LTX2 | 768x512 | 121 | 2 | `guidance_scale=4.0` | +| QwenImage | 1024x1024 | — | 2 | `true_cfg_scale=4.0` | + +All configs use `output_type="latent"` by default (skip VAE decode for cleaner denoising-loop traces). + +**CLI flags:** + +- `--pipeline flux|flux2|wan|ltx2|qwenimage|all` +- `--mode eager|compile|both` +- `--output_dir profiling_results/` +- `--num_steps N` (override, default 4) +- `--full_decode` (switch output_type from `"latent"` to `"pil"` to include VAE) +- `--compile_mode default|reduce-overhead|max-autotune` +- `--compile_fullgraph` flag + +**Output:** `{output_dir}/{pipeline}_{mode}.json` Chrome trace + stdout summary. + +### Step 3: Known Sync Issues to Validate + +The profiling should surface these known/suspected issues: + +1. **Scheduler DtoH sync via `nonzero().item()`** — For Flux, this was fixed by adding `scheduler.set_begin_index(0)` before the denoising loop ([diffusers#11696](https://github.com/huggingface/diffusers/pull/11696)). Profiling should reveal whether similar sync points exist in other pipelines. + +2. **`modulate_index` tensor rebuilt every forward in `transformer_qwenimage.py`** (line 901-905) — Python list comprehension + `torch.tensor()` each step. Minor but visible in trace. + +3. **Any other `.item()`, `.cpu()`, `.numpy()` calls** in the denoising loop hot path — the profiler's `with_stack=True` will surface these as CPU stalls with Python stack traces. + +## Verification + +1. Run: `python profiling/profiling_pipelines.py --pipeline flux --mode eager --num_steps 2` +2. Verify `profiling_results/flux_eager.json` is produced +3. Open trace in [Perfetto UI](https://ui.perfetto.dev/) — confirm: + - `transformer_forward` and `scheduler_step` annotations visible + - CPU and CUDA timelines present + - Stack traces visible on CPU events +4. Run with `--mode compile` and compare trace for fewer/fused CUDA kernels + +You can also use the `run_profiling.sh` script to bulk launch runs for different pipelines. + +## Interpreting Traces in Perfetto UI + +Open the exported `.json` trace at [ui.perfetto.dev](https://ui.perfetto.dev/). The trace has two main rows: **CPU** (top) and **CUDA** (bottom). In Perfetto, the CPU row is typically labeled with the process/thread name (e.g., `python (PID)` or `MainThread`) and appears at the top. The CUDA row is labeled `GPU 0` (or similar) and appears below the CPU rows. + +**Navigation:** Use `W` to zoom in, `S` to zoom out, and `A`/`D` to pan left/right. You can also scroll to zoom and click-drag to pan. Use `Shift+scroll` to scroll vertically through rows. + +### What to look for + +**1. Gaps between CUDA kernels** + +Zoom into the CUDA row during the denoising loop. Ideally, GPU kernels should be back-to-back with no gaps. Gaps mean the GPU is idle waiting for the CPU to launch the next kernel. Common causes: +- Python overhead between ops (visible as CPU slices in the CPU row during the gap) +- DtoH sync (`.item()`, `.cpu()`) forcing the GPU to drain before the CPU can proceed + +**2. CPU stalls (DtoH syncs)** + +These appear on the **CPU row** (not the CUDA row) — they are CPU-side blocking calls that wait for the GPU to finish. Look for long slices labeled `cudaStreamSynchronize` or `cudaDeviceSynchronize`. To find them: zoom into the CPU row during a denoising step and look for unusually wide slices, or use Perfetto's search bar (press `/`) and type `cudaStreamSynchronize` to jump directly to matching events. Click on a slice — if `with_stack=True` was enabled, the bottom panel ("Current Selection") shows the Python stack trace pointing to the exact line causing the sync (e.g., a `.item()` call in the scheduler). + +**3. Annotated regions** + +Our `record_function` annotations (`transformer_forward`, `scheduler_step`, etc.) appear as labeled spans on the CPU row. This lets you quickly: +- Measure how long each phase takes (click a span to see duration) +- See if `scheduler_step` is disproportionately expensive relative to `transformer_forward` (it should be negligible) +- Spot unexpected CPU work between annotated regions + +**4. Eager vs compile comparison** + +Open both traces side by side (two Perfetto tabs). Key differences to look for: +- **Fewer, wider CUDA kernels** in compile mode (fused ops) vs many small kernels in eager +- **Smaller CPU gaps** between kernels in compile mode (less Python dispatch overhead) +- **CUDA kernel count per step**: to compare, zoom into a single `transformer_forward` span on the CUDA row and count the distinct kernel slices within it. In eager mode you'll typically see many narrow slices (one per op); in compile mode these fuse into fewer, wider slices. A quick way to estimate: select a time range covering one denoising step on the CUDA row — Perfetto shows the number of slices in the selection summary at the bottom. If compile mode shows a similar kernel count to eager, fusion isn't happening effectively (likely due to graph breaks). +- **Graph breaks**: if compile mode still shows many small kernels in a section, that section likely has a graph break — check `TORCH_LOGS="+dynamo"` output for details + +**5. Memory timeline** + +In Perfetto, look for the memory counter track (if `profile_memory=True`). Spikes during the denoising loop suggest unexpected allocations per step. Steady-state memory during denoising is expected — growing memory is not. + +**6. Kernel launch latency** + +Each CUDA kernel is launched from the CPU. The CPU-side launch calls (`cudaLaunchKernel`) appear as small slices on the **CPU row** — zoom in closely to a denoising step to see them. The corresponding GPU-side kernel executions appear on the **CUDA row** directly below. You can also use Perfetto's search bar (`/`) and type `cudaLaunchKernel` to find them. The time between the CPU dispatch and the GPU kernel starting should be minimal (single-digit microseconds). If you see consistent delays > 10-20us between launch and execution: +- The launch queue may be starved because of excessive Python work between ops +- There may be implicit syncs forcing serialization +- `torch.compile` should help here by batching launches — compare eager vs compile to confirm + +To inspect this: zoom into a single denoising step, select a CUDA kernel on the GPU row, and look at the corresponding CPU-side launch slice directly above it. The horizontal offset between them is the launch latency. In a healthy trace, CPU launch slices should be well ahead of GPU execution (the CPU is "feeding" the GPU faster than it can consume). + +### Quick checklist per pipeline + +| Question | Where to look | Healthy | Unhealthy | +|----------|--------------|---------|-----------| +| GPU staying busy? | CUDA row gaps | Back-to-back kernels | Frequent gaps > 100us | +| CPU blocking on GPU? | `cudaStreamSynchronize` slices | Rare/absent during denoise | Present every step | +| Scheduler overhead? | `scheduler_step` span duration | < 1% of step time | > 5% of step time | +| Compile effective? | CUDA kernel count per step | Fewer large kernels | Same as eager | +| Kernel launch latency? | CPU launch → GPU kernel offset | < 10us, CPU ahead of GPU | > 20us or CPU trailing GPU | +| Memory stable? | Memory counter track | Flat during denoise loop | Growing per step | diff --git a/examples/profiling/profiling_pipelines.py b/examples/profiling/profiling_pipelines.py new file mode 100644 index 000000000000..d19540014e11 --- /dev/null +++ b/examples/profiling/profiling_pipelines.py @@ -0,0 +1,182 @@ +""" +Profile diffusers pipelines with torch.profiler. + +Usage: + python profiling/profiling_pipelines.py --pipeline flux --mode eager + python profiling/profiling_pipelines.py --pipeline flux --mode compile + python profiling/profiling_pipelines.py --pipeline flux --mode both + python profiling/profiling_pipelines.py --pipeline all --mode eager + python profiling/profiling_pipelines.py --pipeline wan --mode eager --full_decode + python profiling/profiling_pipelines.py --pipeline flux --mode compile --num_steps 4 +""" + +import argparse +import copy +import logging + +import torch + +from profiling_utils import PipelineProfiler, PipelineProfilingConfig + + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") +logger = logging.getLogger(__name__) + +PROMPT = "A cat holding a sign that says hello world" + + +def build_registry(): + """Build the pipeline config registry. Imports are deferred to avoid loading all pipelines upfront.""" + from diffusers import FluxPipeline, Flux2KleinPipeline, WanPipeline, LTX2Pipeline, QwenImagePipeline + + return { + "flux": PipelineProfilingConfig( + name="flux", + pipeline_cls=FluxPipeline, + pipeline_init_kwargs={ + "pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev", + "torch_dtype": torch.bfloat16, + }, + pipeline_call_kwargs={ + "prompt": PROMPT, + "height": 1024, + "width": 1024, + "num_inference_steps": 4, + "guidance_scale": 3.5, + "output_type": "latent", + }, + ), + "flux2": PipelineProfilingConfig( + name="flux2", + pipeline_cls=Flux2KleinPipeline, + pipeline_init_kwargs={ + "pretrained_model_name_or_path": "black-forest-labs/FLUX.2-klein-base-9B", + "torch_dtype": torch.bfloat16, + }, + pipeline_call_kwargs={ + "prompt": PROMPT, + "height": 1024, + "width": 1024, + "num_inference_steps": 4, + "guidance_scale": 3.5, + "output_type": "latent", + }, + ), + "wan": PipelineProfilingConfig( + name="wan", + pipeline_cls=WanPipeline, + pipeline_init_kwargs={ + "pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers", + "torch_dtype": torch.bfloat16, + }, + pipeline_call_kwargs={ + "prompt": PROMPT, + "negative_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards", + "height": 480, + "width": 832, + "num_frames": 81, + "num_inference_steps": 4, + "output_type": "latent", + }, + ), + "ltx2": PipelineProfilingConfig( + name="ltx2", + pipeline_cls=LTX2Pipeline, + pipeline_init_kwargs={ + "pretrained_model_name_or_path": "Lightricks/LTX-2", + "torch_dtype": torch.bfloat16, + }, + pipeline_call_kwargs={ + "prompt": PROMPT, + "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted", + "height": 512, + "width": 768, + "num_frames": 121, + "num_inference_steps": 4, + "guidance_scale": 4.0, + "output_type": "latent", + }, + ), + "qwenimage": PipelineProfilingConfig( + name="qwenimage", + pipeline_cls=QwenImagePipeline, + pipeline_init_kwargs={ + "pretrained_model_name_or_path": "Qwen/Qwen-Image", + "torch_dtype": torch.bfloat16, + }, + pipeline_call_kwargs={ + "prompt": PROMPT, + "negative_prompt": " ", + "height": 1024, + "width": 1024, + "num_inference_steps": 4, + "true_cfg_scale": 4.0, + "output_type": "latent", + }, + ), + } + + +def main(): + parser = argparse.ArgumentParser(description="Profile diffusers pipelines with torch.profiler") + parser.add_argument( + "--pipeline", + choices=["flux", "flux2", "wan", "ltx2", "qwenimage", "all"], + required=True, + help="Which pipeline to profile", + ) + parser.add_argument( + "--mode", + choices=["eager", "compile", "both"], + default="eager", + help="Run in eager mode, compile mode, or both", + ) + parser.add_argument("--output_dir", default="profiling_results", help="Directory for trace output") + parser.add_argument("--num_steps", type=int, default=None, help="Override num_inference_steps") + parser.add_argument("--full_decode", action="store_true", help="Profile including VAE decode (output_type='pil')") + parser.add_argument( + "--compile_mode", + default="default", + choices=["default", "reduce-overhead", "max-autotune"], + help="torch.compile mode", + ) + parser.add_argument("--compile_fullgraph", action="store_true", help="Use fullgraph=True for torch.compile") + parser.add_argument( + "--compile_regional", + action="store_true", + help="Use compile_repeated_blocks() instead of full model compile", + ) + args = parser.parse_args() + + registry = build_registry() + + pipeline_names = list(registry.keys()) if args.pipeline == "all" else [args.pipeline] + modes = ["eager", "compile"] if args.mode == "both" else [args.mode] + + for pipeline_name in pipeline_names: + for mode in modes: + config = copy.deepcopy(registry[pipeline_name]) + + # Apply overrides + if args.num_steps is not None: + config.pipeline_call_kwargs["num_inference_steps"] = args.num_steps + if args.full_decode: + config.pipeline_call_kwargs["output_type"] = "pil" + if mode == "compile": + config.compile_kwargs = { + "fullgraph": args.compile_fullgraph, + "mode": args.compile_mode, + } + config.compile_regional = args.compile_regional + + logger.info(f"Profiling {pipeline_name} in {mode} mode...") + profiler = PipelineProfiler(config, args.output_dir) + try: + trace_file = profiler.run() + logger.info(f"Done: {trace_file}") + except Exception as e: + logger.error(f"Failed to profile {pipeline_name} ({mode}): {e}") + + +if __name__ == "__main__": + main() diff --git a/examples/profiling/profiling_utils.py b/examples/profiling/profiling_utils.py new file mode 100644 index 000000000000..cd30f912f938 --- /dev/null +++ b/examples/profiling/profiling_utils.py @@ -0,0 +1,146 @@ +import functools +import gc +import logging +import os +from dataclasses import dataclass, field +from typing import Any + +import torch +import torch.profiler + + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") +logger = logging.getLogger(__name__) + + +def annotate(func, name): + """Wrap a function with torch.profiler.record_function for trace annotation.""" + + @functools.wraps(func) + def wrapper(*args, **kwargs): + with torch.profiler.record_function(name): + return func(*args, **kwargs) + + return wrapper + + +def annotate_pipeline(pipe): + """Apply profiler annotations to key pipeline methods. + + Monkey-patches bound methods so they appear as named spans in the trace. + Non-invasive — no source modifications required. + """ + annotations = [ + ("transformer", "forward", "transformer_forward"), + ("vae", "decode", "vae_decode"), + ("vae", "encode", "vae_encode"), + ("scheduler", "step", "scheduler_step"), + ] + + # Annotate sub-component methods + for component_name, method_name, label in annotations: + component = getattr(pipe, component_name, None) + if component is None: + continue + method = getattr(component, method_name, None) + if method is None: + continue + setattr(component, method_name, annotate(method, label)) + + # Annotate pipeline-level methods + if hasattr(pipe, "encode_prompt"): + pipe.encode_prompt = annotate(pipe.encode_prompt, "encode_prompt") + + +def flush(): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + +@dataclass +class PipelineProfilingConfig: + name: str + pipeline_cls: Any + pipeline_init_kwargs: dict[str, Any] + pipeline_call_kwargs: dict[str, Any] + compile_kwargs: dict[str, Any] | None = field(default=None) + compile_regional: bool = False + + +class PipelineProfiler: + def __init__(self, config: PipelineProfilingConfig, output_dir: str = "profiling_results"): + self.config = config + self.output_dir = output_dir + os.makedirs(output_dir, exist_ok=True) + + def setup_pipeline(self): + """Load the pipeline from pretrained, optionally compile, and annotate.""" + logger.info(f"Loading pipeline: {self.config.name}") + pipe = self.config.pipeline_cls.from_pretrained(**self.config.pipeline_init_kwargs) + pipe.to("cuda") + + if self.config.compile_kwargs: + if self.config.compile_regional: + logger.info(f"Regional compilation (compile_repeated_blocks) with kwargs: {self.config.compile_kwargs}") + pipe.transformer.compile_repeated_blocks(**self.config.compile_kwargs) + else: + logger.info(f"Full compilation with kwargs: {self.config.compile_kwargs}") + pipe.transformer.compile(**self.config.compile_kwargs) + + # Disable tqdm progress bar to avoid CPU overhead / IO between steps + pipe.set_progress_bar_config(disable=True) + + annotate_pipeline(pipe) + return pipe + + def run(self): + """Execute the profiling run: warmup, then profile one pipeline call.""" + pipe = self.setup_pipeline() + flush() + + mode = "compile" if self.config.compile_kwargs else "eager" + trace_file = os.path.join(self.output_dir, f"{self.config.name}_{mode}.json") + + # Warmup (pipeline __call__ is already decorated with @torch.no_grad()) + logger.info("Running warmup...") + pipe(**self.config.pipeline_call_kwargs) + flush() + + # Profile + logger.info("Running profiled iteration...") + activities = [ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ] + with torch.profiler.profile( + activities=activities, + record_shapes=True, + profile_memory=True, + with_stack=True, + ) as prof: + with torch.profiler.record_function("pipeline_call"): + pipe(**self.config.pipeline_call_kwargs) + + # Export trace + prof.export_chrome_trace(trace_file) + logger.info(f"Chrome trace saved to: {trace_file}") + + # Print summary + print("\n" + "=" * 80) + print(f"Profile summary: {self.config.name} ({mode})") + print("=" * 80) + print( + prof.key_averages().table( + sort_by="cuda_time_total", + row_limit=20, + ) + ) + + # Cleanup + pipe.to("cpu") + del pipe + flush() + + return trace_file diff --git a/examples/profiling/run_profiling.sh b/examples/profiling/run_profiling.sh new file mode 100755 index 000000000000..a6be682e929d --- /dev/null +++ b/examples/profiling/run_profiling.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# Run profiling across all pipelines in eager and compile (regional) modes. +# +# Usage: +# bash profiling/run_profiling.sh +# bash profiling/run_profiling.sh --output_dir my_results + +set -euo pipefail + +OUTPUT_DIR="profiling_results" +while [[ $# -gt 0 ]]; do + case "$1" in + --output_dir) OUTPUT_DIR="$2"; shift 2 ;; + *) echo "Unknown arg: $1"; exit 1 ;; + esac +done +NUM_STEPS=2 +# PIPELINES=("flux" "flux2" "wan" "ltx2" "qwenimage") +PIPELINES=("flux2") +MODES=("eager" "compile") + +for pipeline in "${PIPELINES[@]}"; do + for mode in "${MODES[@]}"; do + echo "============================================================" + echo "Profiling: ${pipeline} | mode: ${mode}" + echo "============================================================" + + COMPILE_ARGS="" + if [ "$mode" = "compile" ]; then + COMPILE_ARGS="--compile_regional --compile_fullgraph --compile_mode default" + fi + + python profiling/profiling_pipelines.py \ + --pipeline "$pipeline" \ + --mode "$mode" \ + --output_dir "$OUTPUT_DIR" \ + --num_steps "$NUM_STEPS" \ + $COMPILE_ARGS + + echo "" + done +done + +echo "============================================================" +echo "All traces saved to: ${OUTPUT_DIR}/" +echo "============================================================" From bf5131fba9db0247df88820730e093bcac1fac37 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 27 Mar 2026 12:51:56 +0530 Subject: [PATCH 09/27] propagate deletion. --- examples/profiling/README.md | 4 + profiling/PROFILING_PLAN.md | 171 ----------------------------- profiling/profiling_pipelines.py | 182 ------------------------------- profiling/profiling_utils.py | 146 ------------------------- profiling/run_profiling.sh | 46 -------- 5 files changed, 4 insertions(+), 545 deletions(-) delete mode 100644 profiling/PROFILING_PLAN.md delete mode 100644 profiling/profiling_pipelines.py delete mode 100644 profiling/profiling_utils.py delete mode 100755 profiling/run_profiling.sh diff --git a/examples/profiling/README.md b/examples/profiling/README.md index 5c444f48dfef..8ef6d902337e 100644 --- a/examples/profiling/README.md +++ b/examples/profiling/README.md @@ -193,3 +193,7 @@ To inspect this: zoom into a single denoising step, select a CUDA kernel on the | Compile effective? | CUDA kernel count per step | Fewer large kernels | Same as eager | | Kernel launch latency? | CPU launch → GPU kernel offset | < 10us, CPU ahead of GPU | > 20us or CPU trailing GPU | | Memory stable? | Memory counter track | Flat during denoise loop | Growing per step | + +## Afterwards + +TODO \ No newline at end of file diff --git a/profiling/PROFILING_PLAN.md b/profiling/PROFILING_PLAN.md deleted file mode 100644 index 7b119dd2a53d..000000000000 --- a/profiling/PROFILING_PLAN.md +++ /dev/null @@ -1,171 +0,0 @@ -# Profiling Plan: Diffusers Pipeline Profiling with torch.profiler - -## Context - -We want to uncover CPU overhead, CPU-GPU sync points, and other bottlenecks in popular diffusers pipelines — especially issues that become non-trivial under `torch.compile`. The approach is inspired by [flux-fast's run_benchmark.py](https://github.com/huggingface/flux-fast/blob/0a1dcc91658f0df14cd7fce862a5c8842784c6da/run_benchmark.py#L66-L85) which uses `torch.profiler` with method-level annotations, and motivated by issues like [diffusers#11696](https://github.com/huggingface/diffusers/pull/11696) (DtoH sync from scheduler `.item()` call). - -## Target Pipelines - -| Pipeline | Type | Checkpoint | Steps | -|----------|------|-----------|-------| -| `FluxPipeline` | text-to-image | `black-forest-labs/FLUX.1-dev` | 4 | -| `Flux2KleinPipeline` | text-to-image | `black-forest-labs/FLUX.2-klein-base-9B` | 4 | -| `WanPipeline` | text-to-video | `Wan-AI/Wan2.1-T2V-14B-Diffusers` | 4 | -| `LTX2Pipeline` | text-to-video | `Lightricks/LTX-2` | 4 | -| `QwenImagePipeline` | text-to-image | `Qwen/Qwen-Image` | 4 | - -## Approach - -Follow the flux-fast pattern: **annotate key pipeline methods** with `torch.profiler.record_function` wrappers, then run the pipeline under `torch.profiler.profile` and export a Chrome trace. - -### New Files - -``` -profiling/ - profiling_utils.py # Annotation helper + profiler setup - profiling_pipelines.py # CLI entry point with pipeline configs -``` - -### Step 1: `profiling_utils.py` — Annotation and Profiler Infrastructure - -**A) `annotate(func, name)` helper** (same pattern as flux-fast): - -```python -def annotate(func, name): - """Wrap a function with torch.profiler.record_function for trace annotation.""" - @functools.wraps(func) - def wrapper(*args, **kwargs): - with torch.profiler.record_function(name): - return func(*args, **kwargs) - return wrapper -``` - -**B) `annotate_pipeline(pipe)` function** — applies annotations to key methods on any pipeline: - -- `pipe.transformer.forward` → `"transformer_forward"` -- `pipe.vae.decode` → `"vae_decode"` (if present) -- `pipe.vae.encode` → `"vae_encode"` (if present) -- `pipe.scheduler.step` → `"scheduler_step"` -- `pipe.encode_prompt` → `"encode_prompt"` (if present, for full-pipeline profiling) - -This is non-invasive — it monkey-patches bound methods without modifying source. - -**C) `PipelineProfiler` class:** - -- `__init__(pipeline_config, output_dir, mode="eager"|"compile")` -- `setup_pipeline()` → loads from pretrained, optionally compiles transformer, calls `annotate_pipeline()` -- `run()`: - 1. Warm up with 1 unannotated run - 2. Profile 1 run with `torch.profiler.profile`: - - `activities=[CPU, CUDA]` - - `record_shapes=True` - - `profile_memory=True` - - `with_stack=True` - 3. Export Chrome trace JSON - 4. Print `key_averages()` summary table (sorted by CUDA time) to stdout - -### Step 2: `profiling_pipelines.py` — CLI with Pipeline Configs - -**Pipeline config registry** — each entry specifies: - -- `pipeline_cls`, `pretrained_model_name_or_path`, `torch_dtype` -- `call_kwargs` with pipeline-specific defaults: - -| Pipeline | Resolution | Frames | Steps | Extra | -|----------|-----------|--------|-------|-------| -| Flux | 1024x1024 | — | 4 | `guidance_scale=3.5` | -| Flux2Klein | 1024x1024 | — | 4 | `guidance_scale=3.5` | -| Wan | 480x832 | 81 | 4 | — | -| LTX2 | 768x512 | 121 | 4 | `guidance_scale=4.0` | -| QwenImage | 1024x1024 | — | 4 | `true_cfg_scale=4.0` | - -All configs use `output_type="latent"` by default (skip VAE decode for cleaner denoising-loop traces). - -**CLI flags:** - -- `--pipeline flux|flux2|wan|ltx2|qwenimage|all` -- `--mode eager|compile|both` -- `--output_dir profiling_results/` -- `--num_steps N` (override, default 4) -- `--full_decode` (switch output_type from `"latent"` to `"pil"` to include VAE) -- `--compile_mode default|reduce-overhead|max-autotune` -- `--compile_fullgraph` flag - -**Output:** `{output_dir}/{pipeline}_{mode}.json` Chrome trace + stdout summary. - -### Step 3: Known Sync Issues to Validate - -The profiling should surface these known/suspected issues: - -1. **Scheduler DtoH sync via `nonzero().item()`** — For Flux, this was fixed by adding `scheduler.set_begin_index(0)` before the denoising loop ([diffusers#11696](https://github.com/huggingface/diffusers/pull/11696)). Profiling should reveal whether similar sync points exist in other pipelines. - -2. **`modulate_index` tensor rebuilt every forward in `transformer_qwenimage.py`** (line 901-905) — Python list comprehension + `torch.tensor()` each step. Minor but visible in trace. - -3. **Any other `.item()`, `.cpu()`, `.numpy()` calls** in the denoising loop hot path — the profiler's `with_stack=True` will surface these as CPU stalls with Python stack traces. - -## Verification - -1. Run: `python profiling/profiling_pipelines.py --pipeline flux --mode eager --num_steps 4` -2. Verify `profiling_results/flux_eager.json` is produced -3. Open trace in [Perfetto UI](https://ui.perfetto.dev/) — confirm: - - `transformer_forward` and `scheduler_step` annotations visible - - CPU and CUDA timelines present - - Stack traces visible on CPU events -4. Run with `--mode compile` and compare trace for fewer/fused CUDA kernels - -## Interpreting Traces in Perfetto UI - -Open the exported `.json` trace at [ui.perfetto.dev](https://ui.perfetto.dev/). The trace has two main rows: **CPU** (top) and **CUDA** (bottom). In Perfetto, the CPU row is typically labeled with the process/thread name (e.g., `python (PID)` or `MainThread`) and appears at the top. The CUDA row is labeled `GPU 0` (or similar) and appears below the CPU rows. - -**Navigation:** Use `W` to zoom in, `S` to zoom out, and `A`/`D` to pan left/right. You can also scroll to zoom and click-drag to pan. Use `Shift+scroll` to scroll vertically through rows. - -### What to look for - -**1. Gaps between CUDA kernels** - -Zoom into the CUDA row during the denoising loop. Ideally, GPU kernels should be back-to-back with no gaps. Gaps mean the GPU is idle waiting for the CPU to launch the next kernel. Common causes: -- Python overhead between ops (visible as CPU slices in the CPU row during the gap) -- DtoH sync (`.item()`, `.cpu()`) forcing the GPU to drain before the CPU can proceed - -**2. CPU stalls (DtoH syncs)** - -These appear on the **CPU row** (not the CUDA row) — they are CPU-side blocking calls that wait for the GPU to finish. Look for long slices labeled `cudaStreamSynchronize` or `cudaDeviceSynchronize`. To find them: zoom into the CPU row during a denoising step and look for unusually wide slices, or use Perfetto's search bar (press `/`) and type `cudaStreamSynchronize` to jump directly to matching events. Click on a slice — if `with_stack=True` was enabled, the bottom panel ("Current Selection") shows the Python stack trace pointing to the exact line causing the sync (e.g., a `.item()` call in the scheduler). - -**3. Annotated regions** - -Our `record_function` annotations (`transformer_forward`, `scheduler_step`, etc.) appear as labeled spans on the CPU row. This lets you quickly: -- Measure how long each phase takes (click a span to see duration) -- See if `scheduler_step` is disproportionately expensive relative to `transformer_forward` (it should be negligible) -- Spot unexpected CPU work between annotated regions - -**4. Eager vs compile comparison** - -Open both traces side by side (two Perfetto tabs). Key differences to look for: -- **Fewer, wider CUDA kernels** in compile mode (fused ops) vs many small kernels in eager -- **Smaller CPU gaps** between kernels in compile mode (less Python dispatch overhead) -- **CUDA kernel count per step**: to compare, zoom into a single `transformer_forward` span on the CUDA row and count the distinct kernel slices within it. In eager mode you'll typically see many narrow slices (one per op); in compile mode these fuse into fewer, wider slices. A quick way to estimate: select a time range covering one denoising step on the CUDA row — Perfetto shows the number of slices in the selection summary at the bottom. If compile mode shows a similar kernel count to eager, fusion isn't happening effectively (likely due to graph breaks). -- **Graph breaks**: if compile mode still shows many small kernels in a section, that section likely has a graph break — check `TORCH_LOGS="+dynamo"` output for details - -**5. Memory timeline** - -In Perfetto, look for the memory counter track (if `profile_memory=True`). Spikes during the denoising loop suggest unexpected allocations per step. Steady-state memory during denoising is expected — growing memory is not. - -**6. Kernel launch latency** - -Each CUDA kernel is launched from the CPU. The CPU-side launch calls (`cudaLaunchKernel`) appear as small slices on the **CPU row** — zoom in closely to a denoising step to see them. The corresponding GPU-side kernel executions appear on the **CUDA row** directly below. You can also use Perfetto's search bar (`/`) and type `cudaLaunchKernel` to find them. The time between the CPU dispatch and the GPU kernel starting should be minimal (single-digit microseconds). If you see consistent delays > 10-20us between launch and execution: -- The launch queue may be starved because of excessive Python work between ops -- There may be implicit syncs forcing serialization -- `torch.compile` should help here by batching launches — compare eager vs compile to confirm - -To inspect this: zoom into a single denoising step, select a CUDA kernel on the GPU row, and look at the corresponding CPU-side launch slice directly above it. The horizontal offset between them is the launch latency. In a healthy trace, CPU launch slices should be well ahead of GPU execution (the CPU is "feeding" the GPU faster than it can consume). - -### Quick checklist per pipeline - -| Question | Where to look | Healthy | Unhealthy | -|----------|--------------|---------|-----------| -| GPU staying busy? | CUDA row gaps | Back-to-back kernels | Frequent gaps > 100us | -| CPU blocking on GPU? | `cudaStreamSynchronize` slices | Rare/absent during denoise | Present every step | -| Scheduler overhead? | `scheduler_step` span duration | < 1% of step time | > 5% of step time | -| Compile effective? | CUDA kernel count per step | Fewer large kernels | Same as eager | -| Kernel launch latency? | CPU launch → GPU kernel offset | < 10us, CPU ahead of GPU | > 20us or CPU trailing GPU | -| Memory stable? | Memory counter track | Flat during denoise loop | Growing per step | diff --git a/profiling/profiling_pipelines.py b/profiling/profiling_pipelines.py deleted file mode 100644 index d19540014e11..000000000000 --- a/profiling/profiling_pipelines.py +++ /dev/null @@ -1,182 +0,0 @@ -""" -Profile diffusers pipelines with torch.profiler. - -Usage: - python profiling/profiling_pipelines.py --pipeline flux --mode eager - python profiling/profiling_pipelines.py --pipeline flux --mode compile - python profiling/profiling_pipelines.py --pipeline flux --mode both - python profiling/profiling_pipelines.py --pipeline all --mode eager - python profiling/profiling_pipelines.py --pipeline wan --mode eager --full_decode - python profiling/profiling_pipelines.py --pipeline flux --mode compile --num_steps 4 -""" - -import argparse -import copy -import logging - -import torch - -from profiling_utils import PipelineProfiler, PipelineProfilingConfig - - -logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") -logger = logging.getLogger(__name__) - -PROMPT = "A cat holding a sign that says hello world" - - -def build_registry(): - """Build the pipeline config registry. Imports are deferred to avoid loading all pipelines upfront.""" - from diffusers import FluxPipeline, Flux2KleinPipeline, WanPipeline, LTX2Pipeline, QwenImagePipeline - - return { - "flux": PipelineProfilingConfig( - name="flux", - pipeline_cls=FluxPipeline, - pipeline_init_kwargs={ - "pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev", - "torch_dtype": torch.bfloat16, - }, - pipeline_call_kwargs={ - "prompt": PROMPT, - "height": 1024, - "width": 1024, - "num_inference_steps": 4, - "guidance_scale": 3.5, - "output_type": "latent", - }, - ), - "flux2": PipelineProfilingConfig( - name="flux2", - pipeline_cls=Flux2KleinPipeline, - pipeline_init_kwargs={ - "pretrained_model_name_or_path": "black-forest-labs/FLUX.2-klein-base-9B", - "torch_dtype": torch.bfloat16, - }, - pipeline_call_kwargs={ - "prompt": PROMPT, - "height": 1024, - "width": 1024, - "num_inference_steps": 4, - "guidance_scale": 3.5, - "output_type": "latent", - }, - ), - "wan": PipelineProfilingConfig( - name="wan", - pipeline_cls=WanPipeline, - pipeline_init_kwargs={ - "pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers", - "torch_dtype": torch.bfloat16, - }, - pipeline_call_kwargs={ - "prompt": PROMPT, - "negative_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards", - "height": 480, - "width": 832, - "num_frames": 81, - "num_inference_steps": 4, - "output_type": "latent", - }, - ), - "ltx2": PipelineProfilingConfig( - name="ltx2", - pipeline_cls=LTX2Pipeline, - pipeline_init_kwargs={ - "pretrained_model_name_or_path": "Lightricks/LTX-2", - "torch_dtype": torch.bfloat16, - }, - pipeline_call_kwargs={ - "prompt": PROMPT, - "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted", - "height": 512, - "width": 768, - "num_frames": 121, - "num_inference_steps": 4, - "guidance_scale": 4.0, - "output_type": "latent", - }, - ), - "qwenimage": PipelineProfilingConfig( - name="qwenimage", - pipeline_cls=QwenImagePipeline, - pipeline_init_kwargs={ - "pretrained_model_name_or_path": "Qwen/Qwen-Image", - "torch_dtype": torch.bfloat16, - }, - pipeline_call_kwargs={ - "prompt": PROMPT, - "negative_prompt": " ", - "height": 1024, - "width": 1024, - "num_inference_steps": 4, - "true_cfg_scale": 4.0, - "output_type": "latent", - }, - ), - } - - -def main(): - parser = argparse.ArgumentParser(description="Profile diffusers pipelines with torch.profiler") - parser.add_argument( - "--pipeline", - choices=["flux", "flux2", "wan", "ltx2", "qwenimage", "all"], - required=True, - help="Which pipeline to profile", - ) - parser.add_argument( - "--mode", - choices=["eager", "compile", "both"], - default="eager", - help="Run in eager mode, compile mode, or both", - ) - parser.add_argument("--output_dir", default="profiling_results", help="Directory for trace output") - parser.add_argument("--num_steps", type=int, default=None, help="Override num_inference_steps") - parser.add_argument("--full_decode", action="store_true", help="Profile including VAE decode (output_type='pil')") - parser.add_argument( - "--compile_mode", - default="default", - choices=["default", "reduce-overhead", "max-autotune"], - help="torch.compile mode", - ) - parser.add_argument("--compile_fullgraph", action="store_true", help="Use fullgraph=True for torch.compile") - parser.add_argument( - "--compile_regional", - action="store_true", - help="Use compile_repeated_blocks() instead of full model compile", - ) - args = parser.parse_args() - - registry = build_registry() - - pipeline_names = list(registry.keys()) if args.pipeline == "all" else [args.pipeline] - modes = ["eager", "compile"] if args.mode == "both" else [args.mode] - - for pipeline_name in pipeline_names: - for mode in modes: - config = copy.deepcopy(registry[pipeline_name]) - - # Apply overrides - if args.num_steps is not None: - config.pipeline_call_kwargs["num_inference_steps"] = args.num_steps - if args.full_decode: - config.pipeline_call_kwargs["output_type"] = "pil" - if mode == "compile": - config.compile_kwargs = { - "fullgraph": args.compile_fullgraph, - "mode": args.compile_mode, - } - config.compile_regional = args.compile_regional - - logger.info(f"Profiling {pipeline_name} in {mode} mode...") - profiler = PipelineProfiler(config, args.output_dir) - try: - trace_file = profiler.run() - logger.info(f"Done: {trace_file}") - except Exception as e: - logger.error(f"Failed to profile {pipeline_name} ({mode}): {e}") - - -if __name__ == "__main__": - main() diff --git a/profiling/profiling_utils.py b/profiling/profiling_utils.py deleted file mode 100644 index cd30f912f938..000000000000 --- a/profiling/profiling_utils.py +++ /dev/null @@ -1,146 +0,0 @@ -import functools -import gc -import logging -import os -from dataclasses import dataclass, field -from typing import Any - -import torch -import torch.profiler - - -logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") -logger = logging.getLogger(__name__) - - -def annotate(func, name): - """Wrap a function with torch.profiler.record_function for trace annotation.""" - - @functools.wraps(func) - def wrapper(*args, **kwargs): - with torch.profiler.record_function(name): - return func(*args, **kwargs) - - return wrapper - - -def annotate_pipeline(pipe): - """Apply profiler annotations to key pipeline methods. - - Monkey-patches bound methods so they appear as named spans in the trace. - Non-invasive — no source modifications required. - """ - annotations = [ - ("transformer", "forward", "transformer_forward"), - ("vae", "decode", "vae_decode"), - ("vae", "encode", "vae_encode"), - ("scheduler", "step", "scheduler_step"), - ] - - # Annotate sub-component methods - for component_name, method_name, label in annotations: - component = getattr(pipe, component_name, None) - if component is None: - continue - method = getattr(component, method_name, None) - if method is None: - continue - setattr(component, method_name, annotate(method, label)) - - # Annotate pipeline-level methods - if hasattr(pipe, "encode_prompt"): - pipe.encode_prompt = annotate(pipe.encode_prompt, "encode_prompt") - - -def flush(): - gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_max_memory_allocated() - torch.cuda.reset_peak_memory_stats() - - -@dataclass -class PipelineProfilingConfig: - name: str - pipeline_cls: Any - pipeline_init_kwargs: dict[str, Any] - pipeline_call_kwargs: dict[str, Any] - compile_kwargs: dict[str, Any] | None = field(default=None) - compile_regional: bool = False - - -class PipelineProfiler: - def __init__(self, config: PipelineProfilingConfig, output_dir: str = "profiling_results"): - self.config = config - self.output_dir = output_dir - os.makedirs(output_dir, exist_ok=True) - - def setup_pipeline(self): - """Load the pipeline from pretrained, optionally compile, and annotate.""" - logger.info(f"Loading pipeline: {self.config.name}") - pipe = self.config.pipeline_cls.from_pretrained(**self.config.pipeline_init_kwargs) - pipe.to("cuda") - - if self.config.compile_kwargs: - if self.config.compile_regional: - logger.info(f"Regional compilation (compile_repeated_blocks) with kwargs: {self.config.compile_kwargs}") - pipe.transformer.compile_repeated_blocks(**self.config.compile_kwargs) - else: - logger.info(f"Full compilation with kwargs: {self.config.compile_kwargs}") - pipe.transformer.compile(**self.config.compile_kwargs) - - # Disable tqdm progress bar to avoid CPU overhead / IO between steps - pipe.set_progress_bar_config(disable=True) - - annotate_pipeline(pipe) - return pipe - - def run(self): - """Execute the profiling run: warmup, then profile one pipeline call.""" - pipe = self.setup_pipeline() - flush() - - mode = "compile" if self.config.compile_kwargs else "eager" - trace_file = os.path.join(self.output_dir, f"{self.config.name}_{mode}.json") - - # Warmup (pipeline __call__ is already decorated with @torch.no_grad()) - logger.info("Running warmup...") - pipe(**self.config.pipeline_call_kwargs) - flush() - - # Profile - logger.info("Running profiled iteration...") - activities = [ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ] - with torch.profiler.profile( - activities=activities, - record_shapes=True, - profile_memory=True, - with_stack=True, - ) as prof: - with torch.profiler.record_function("pipeline_call"): - pipe(**self.config.pipeline_call_kwargs) - - # Export trace - prof.export_chrome_trace(trace_file) - logger.info(f"Chrome trace saved to: {trace_file}") - - # Print summary - print("\n" + "=" * 80) - print(f"Profile summary: {self.config.name} ({mode})") - print("=" * 80) - print( - prof.key_averages().table( - sort_by="cuda_time_total", - row_limit=20, - ) - ) - - # Cleanup - pipe.to("cpu") - del pipe - flush() - - return trace_file diff --git a/profiling/run_profiling.sh b/profiling/run_profiling.sh deleted file mode 100755 index a6be682e929d..000000000000 --- a/profiling/run_profiling.sh +++ /dev/null @@ -1,46 +0,0 @@ -#!/bin/bash -# Run profiling across all pipelines in eager and compile (regional) modes. -# -# Usage: -# bash profiling/run_profiling.sh -# bash profiling/run_profiling.sh --output_dir my_results - -set -euo pipefail - -OUTPUT_DIR="profiling_results" -while [[ $# -gt 0 ]]; do - case "$1" in - --output_dir) OUTPUT_DIR="$2"; shift 2 ;; - *) echo "Unknown arg: $1"; exit 1 ;; - esac -done -NUM_STEPS=2 -# PIPELINES=("flux" "flux2" "wan" "ltx2" "qwenimage") -PIPELINES=("flux2") -MODES=("eager" "compile") - -for pipeline in "${PIPELINES[@]}"; do - for mode in "${MODES[@]}"; do - echo "============================================================" - echo "Profiling: ${pipeline} | mode: ${mode}" - echo "============================================================" - - COMPILE_ARGS="" - if [ "$mode" = "compile" ]; then - COMPILE_ARGS="--compile_regional --compile_fullgraph --compile_mode default" - fi - - python profiling/profiling_pipelines.py \ - --pipeline "$pipeline" \ - --mode "$mode" \ - --output_dir "$OUTPUT_DIR" \ - --num_steps "$NUM_STEPS" \ - $COMPILE_ARGS - - echo "" - done -done - -echo "============================================================" -echo "All traces saved to: ${OUTPUT_DIR}/" -echo "============================================================" From bfbaf079cd878df2b4765bf37a4994950be39bd9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 27 Mar 2026 13:39:49 +0530 Subject: [PATCH 10/27] up --- examples/profiling/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/profiling/README.md b/examples/profiling/README.md index 8ef6d902337e..358f42f14993 100644 --- a/examples/profiling/README.md +++ b/examples/profiling/README.md @@ -14,6 +14,7 @@ Thanks to Claude Code for paircoding! We acknowledge the [Claude of OSS](https:/ * [Approach taken](#approach) * [Verification](#verification) * [Interpretation](#interpreting-traces-in-perfetto-ui) +* [Taking profiling-guided steps for improvements](#afterwards) Jump to the "Verification" section to get started right away. From a410b4958c7b769d7383c81958743641dae13068 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 27 Mar 2026 16:29:45 +0530 Subject: [PATCH 11/27] up --- examples/profiling/README.md | 52 +++++++++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/examples/profiling/README.md b/examples/profiling/README.md index 358f42f14993..88070aa6bddf 100644 --- a/examples/profiling/README.md +++ b/examples/profiling/README.md @@ -197,4 +197,54 @@ To inspect this: zoom into a single denoising step, select a CUDA kernel on the ## Afterwards -TODO \ No newline at end of file +To keep the profiling iterations fast, we always used [regional compilation](https://pytorch.org/tutorials/recipes/regional_compilation.html). As one would expect the trace with compilation should show +fewer kernel launches than its eager counterpart: + +TODO: show traces + +_(The traces above were obtained with Flux2.)_ + +### Spotting gaps between launches + +Then a reasonable next step is to spot frequent gaps between kernel executions. In the compiled +case, we don't spot any on the surface. But if we zone in, some become apparent. + +TODO: show gaps in a compile trace + +So, we provided the profile trace (with compilation) to Claude, asked it to find the instances of +"cudaStreamSynchronize" and "cudaDeviceSynchronize", and to come up with some potential fixes. +Claude came back pretty strong: + +``` +Issue 1 — Gap between transformer forwards: +- Root cause: tqdm progress bar update() calls between steps add CPU overhead (I/O, time calculations) +- Fix: profiling/profiling_utils.py — added pipe.set_progress_bar_config(disable=True) during profiling setup. +This eliminates the tqdm overhead from the trace. (The remaining gap from scheduler step + Python dispatch is +inherent to eager-mode execution and should shrink significantly under torch.compile.) + +Issue 2 — cudaStreamSynchronize during last transformer forward: +- Root cause: _unpack_latents_with_ids() (called right after the denoising loop) computes h = torch.max(h_ids) + +1 and w = torch.max(w_ids) + 1 on GPU tensors, then uses them as shape args for torch.zeros((h * w, ch), ...). +This triggers an implicit .item() DtoH sync, blocking the CPU while the GPU is still finishing the last +transformer forward's kernels. +- Fix: Added height/width parameters to _unpack_latents_with_ids(), pre-computed from the known pixel dimensions +at the call site. +``` + +It still didn't eliminate the gaps as expected so, we fed that back to Claude and it spotted +something more crucial. TODO: caching context fix. + +With the fix applied, the improvements were visible: + +TODO: show before and after trace + +Before: + +- `_set_context` total: 21.6ms (8 calls) +- cache_context total: 21.7ms +- CPU gaps: 5,523us / 8,007us / 5,508us + +After: +- `_set_context` total: 0.0ms (8 calls) +- cache_context total: 0.1ms +- CPU gaps: 158us / 2,777us / 136us \ No newline at end of file From 35437a897e5dca2241bc3d5766f7e9f215665af8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 27 Mar 2026 17:01:37 +0530 Subject: [PATCH 12/27] wan fixes. --- examples/profiling/profiling_pipelines.py | 3 +-- examples/profiling/profiling_utils.py | 4 +++- examples/profiling/run_profiling.sh | 2 +- src/diffusers/pipelines/wan/pipeline_wan.py | 4 ++++ .../schedulers/scheduling_unipc_multistep.py | 16 ++++++++-------- 5 files changed, 17 insertions(+), 12 deletions(-) diff --git a/examples/profiling/profiling_pipelines.py b/examples/profiling/profiling_pipelines.py index d19540014e11..9916f1025e15 100644 --- a/examples/profiling/profiling_pipelines.py +++ b/examples/profiling/profiling_pipelines.py @@ -15,7 +15,6 @@ import logging import torch - from profiling_utils import PipelineProfiler, PipelineProfilingConfig @@ -27,7 +26,7 @@ def build_registry(): """Build the pipeline config registry. Imports are deferred to avoid loading all pipelines upfront.""" - from diffusers import FluxPipeline, Flux2KleinPipeline, WanPipeline, LTX2Pipeline, QwenImagePipeline + from diffusers import Flux2KleinPipeline, FluxPipeline, LTX2Pipeline, QwenImagePipeline, WanPipeline return { "flux": PipelineProfilingConfig( diff --git a/examples/profiling/profiling_utils.py b/examples/profiling/profiling_utils.py index cd30f912f938..f811a4c59b23 100644 --- a/examples/profiling/profiling_utils.py +++ b/examples/profiling/profiling_utils.py @@ -83,7 +83,9 @@ def setup_pipeline(self): if self.config.compile_kwargs: if self.config.compile_regional: - logger.info(f"Regional compilation (compile_repeated_blocks) with kwargs: {self.config.compile_kwargs}") + logger.info( + f"Regional compilation (compile_repeated_blocks) with kwargs: {self.config.compile_kwargs}" + ) pipe.transformer.compile_repeated_blocks(**self.config.compile_kwargs) else: logger.info(f"Full compilation with kwargs: {self.config.compile_kwargs}") diff --git a/examples/profiling/run_profiling.sh b/examples/profiling/run_profiling.sh index a6be682e929d..2d62ddd95046 100755 --- a/examples/profiling/run_profiling.sh +++ b/examples/profiling/run_profiling.sh @@ -16,7 +16,7 @@ while [[ $# -gt 0 ]]; do done NUM_STEPS=2 # PIPELINES=("flux" "flux2" "wan" "ltx2" "qwenimage") -PIPELINES=("flux2") +PIPELINES=("wan") MODES=("eager" "compile") for pipeline in "${PIPELINES[@]}"; do diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index d4edff01ad66..6cbe6d85de78 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -574,6 +574,10 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + if self.config.boundary_ratio is not None: boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps else: diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 71a5444491ed..21f81bc381b1 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -903,8 +903,8 @@ def multistep_uni_p_bh_update( rks.append(rk) D1s.append((mi - m0) / rk) - rks.append(1.0) - rks = torch.tensor(rks, device=device) + rks.append(torch.ones((), device=device)) + rks = torch.stack(rks) R = [] b = [] @@ -929,13 +929,13 @@ def multistep_uni_p_bh_update( h_phi_k = h_phi_k / hh - 1 / factorial_i R = torch.stack(R) - b = torch.tensor(b, device=device) + b = torch.stack(b) if len(b) > 0 else torch.tensor(b, device=device) if len(D1s) > 0: D1s = torch.stack(D1s, dim=1) # (B, K) # for order 2, we use a simplified version if order == 2: - rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + rhos_p = torch.ones(1, dtype=x.dtype, device=device) * 0.5 else: rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) else: @@ -1038,8 +1038,8 @@ def multistep_uni_c_bh_update( rks.append(rk) D1s.append((mi - m0) / rk) - rks.append(1.0) - rks = torch.tensor(rks, device=device) + rks.append(torch.ones((), device=device)) + rks = torch.stack(rks) R = [] b = [] @@ -1064,7 +1064,7 @@ def multistep_uni_c_bh_update( h_phi_k = h_phi_k / hh - 1 / factorial_i R = torch.stack(R) - b = torch.tensor(b, device=device) + b = torch.stack(b) if len(b) > 0 else torch.tensor(b, device=device) if len(D1s) > 0: D1s = torch.stack(D1s, dim=1) @@ -1073,7 +1073,7 @@ def multistep_uni_c_bh_update( # for order 1, we use a simplified version if order == 1: - rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + rhos_c = torch.ones(1, dtype=x.dtype, device=device) * 0.5 else: rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) From 142f417b668b1bad3d0d4a512af73707e730d92c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 27 Mar 2026 17:18:03 +0530 Subject: [PATCH 13/27] more --- examples/profiling/README.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/profiling/README.md b/examples/profiling/README.md index 88070aa6bddf..c84a7c350cf4 100644 --- a/examples/profiling/README.md +++ b/examples/profiling/README.md @@ -247,4 +247,11 @@ Before: After: - `_set_context` total: 0.0ms (8 calls) - cache_context total: 0.1ms -- CPU gaps: 158us / 2,777us / 136us \ No newline at end of file +- CPU gaps: 158us / 2,777us / 136us + +### Notes + +* As mentioned above, we profiled with regional compilation so it's possible that +there are still some gaps outside the compiled regions. A full compilation +will likely mitigate it. In case it doesn't, the above observations could +be useful to mitigate that. \ No newline at end of file From 9ba98a26423b28c48dac743dd3ef55cb8944b004 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 27 Mar 2026 17:27:11 +0530 Subject: [PATCH 14/27] up --- examples/profiling/README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/examples/profiling/README.md b/examples/profiling/README.md index c84a7c350cf4..29e211a8a490 100644 --- a/examples/profiling/README.md +++ b/examples/profiling/README.md @@ -249,6 +249,14 @@ After: - cache_context total: 0.1ms - CPU gaps: 158us / 2,777us / 136us +We also profiled the Wan model and uncovered problems related to CPU DtoH syncs. Below is an +overview. + +TODO: provide trace outputs and numbers + +> [!NOTE] +> The above-mentioned fixes are available in [this PR](TODO:link). + ### Notes * As mentioned above, we profiled with regional compilation so it's possible that From 12ba8be7205185c4d96ac3d648110fdc850f1b96 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 28 Mar 2026 09:54:32 +0530 Subject: [PATCH 15/27] add more traces. --- examples/profiling/README.md | 104 ++++++++++++++++++++++++++--------- 1 file changed, 77 insertions(+), 27 deletions(-) diff --git a/examples/profiling/README.md b/examples/profiling/README.md index 29e211a8a490..74d8887cb09d 100644 --- a/examples/profiling/README.md +++ b/examples/profiling/README.md @@ -43,10 +43,10 @@ Follow the flux-fast pattern: **annotate key pipeline methods** with `torch.prof ### New Files -``` -profiling/ - profiling_utils.py # Annotation helper + profiler setup - profiling_pipelines.py # CLI entry point with pipeline configs +```bash +profiling_utils.py # Annotation helper + profiler setup +profiling_pipelines.py # CLI entry point with pipeline configs +run_profiling.sh # Bulk launch runs for multiple pipelines ``` ### Step 1: `profiling_utils.py` — Annotation and Profiler Infrastructure @@ -198,22 +198,44 @@ To inspect this: zoom into a single denoising step, select a CUDA kernel on the ## Afterwards To keep the profiling iterations fast, we always used [regional compilation](https://pytorch.org/tutorials/recipes/regional_compilation.html). As one would expect the trace with compilation should show -fewer kernel launches than its eager counterpart: - -TODO: show traces - -_(The traces above were obtained with Flux2.)_ +fewer kernel launches than its eager counterpart. + +_(Unless otherwise specified, the traces below were obtained with **Flux2**.)_ + + + + + + +
+ Image 1
+ Without compile +
+ Image 2
+ With compile +
### Spotting gaps between launches Then a reasonable next step is to spot frequent gaps between kernel executions. In the compiled case, we don't spot any on the surface. But if we zone in, some become apparent. -TODO: show gaps in a compile trace - -So, we provided the profile trace (with compilation) to Claude, asked it to find the instances of + + + + + +
+ Image 1
+ Very small visible gaps in between compiled regions +
+ Image 2
+ Gaps become more visible when zoomed in +
+ +So, we provided the profile trace file (with compilation) to Claude, asked it to find the instances of "cudaStreamSynchronize" and "cudaDeviceSynchronize", and to come up with some potential fixes. -Claude came back pretty strong: +Claude came back with the following: ``` Issue 1 — Gap between transformer forwards: @@ -231,31 +253,59 @@ transformer forward's kernels. at the call site. ``` -It still didn't eliminate the gaps as expected so, we fed that back to Claude and it spotted -something more crucial. TODO: caching context fix. - -With the fix applied, the improvements were visible: +The changes looke reasonable based on our past experience. So, we asked Claude to apply these changes to [`pipeline_flux2_klein.py`](../../src/diffusers/pipelines/flux2/pipeline_flux2_klein.py). We then profiled +the updated pipeline. It still didn't eliminate the gaps as expected so, we fed that back to Claude and +it spotted something more crucial. -TODO: show before and after trace +Under the [`cache_context`](https://github.com/huggingface/diffusers/blob/f2be8bd6b3dc4035bd989dc467f15d86bf3c9c12/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py#L842) manager, there is a call to `_set_context()` upon +enters and exists. It calls `named_modules()` on the entire underlying model (in this case the Flux2 Klein DiT). +For large models, when they are invoked iteratively like our case, it adds to the latency because it involes traversing hundreds of submodules. -Before: +The fix was to build a list of hooked child registries once on the first call and cache it in `_child_registries_cache`. This way, the subsequent calls would return the cached list directly without +any traversal. With the fix applied, the improvements were visible. -- `_set_context` total: 21.6ms (8 calls) +Before: +- _set_context total: 21.6ms (8 calls) - cache_context total: 21.7ms - CPU gaps: 5,523us / 8,007us / 5,508us After: -- `_set_context` total: 0.0ms (8 calls) +- _set_context total: 0.0ms (8 calls) - cache_context total: 0.1ms -- CPU gaps: 158us / 2,777us / 136us +- CPU gaps: 158us / 2,777us / 136us -We also profiled the Wan model and uncovered problems related to CPU DtoH syncs. Below is an -overview. +> [!NOTE] +> The fixes mentioned above and below are available in [this PR](TODO:link). -TODO: provide trace outputs and numbers +### DtoH syncs -> [!NOTE] -> The above-mentioned fixes are available in [this PR](TODO:link). +We also profiled the **Wan** model and uncovered problems related to CPU DtoH syncs. Below is an +overview. + +First, there was a dynamo cache lookup delay making the GPU idle as reported [in this PR](https://github.com/huggingface/diffusers/pull/11696). So, the fix was to call `self.scheduler.set_begin_index(0)` before +the denoising loop. This tells the scheduler the starting index is 0, so `_init_step_index()` skips the `nonzero().item()` (which was causing the sync) path entirely. This fix eliminated the below ~2.3s GPU idle time completely: + +![GPU idle](https://huggingface.co/datasets/sayakpaul/torch-profiling-trace-diffusers/resolve/main/Wan/Screenshot%202026-03-27%20at%205.56.39%E2%80%AFPM.png) + +The UniPC scheduler (used in Wan) creates small constant tensors via `torch.tensor([0.5], dtype=x.dtype, device=device)` during `step()`. This triggers a "cudaMemcpyAsync + cudaStreamSynchronize" to copy +the value from CPU to GPU. The sync itself is normally fast (~6us), but it forces the CPU to wait +until all pending GPU kernels finish before proceeding. Under torch.compile, the GPU has many queued +kernels, so this tiny sync balloons to 2.3s. + +**Fix**: Replace with `torch.ones(1, dtype=x.dtype, device=device) * 0.5`. `torch.ones` allocates on GPU via "cudaMemsetAsync" (no sync), and `* 0.5` is a CUDA kernel launch (no sync). Same result, zero CPU-GPU synchronization. The duration of the scheduling step before and after this fix confirms this: + + + + + + +
+ Image 1
+ CPU<->GPU sync +
+ Image 2
+ Almost no sync +
### Notes From 43e16fba40a4bcf0b00f79535aca6785755410b5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 28 Mar 2026 10:08:35 +0530 Subject: [PATCH 16/27] up --- src/diffusers/pipelines/flux2/pipeline_flux2_klein.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py index b1cf3f4c9cb4..1f3b5c3c4fde 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py @@ -396,7 +396,6 @@ def _pack_latents(latents): return latents @staticmethod - # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids def _unpack_latents_with_ids( x: torch.Tensor, x_ids: torch.Tensor, height: int | None = None, width: int | None = None ) -> list[torch.Tensor]: From e26d5c6ee3ee0732ac00d899c3aebea0eca6b28a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 28 Mar 2026 15:18:30 +0530 Subject: [PATCH 17/27] better title --- examples/profiling/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/profiling/README.md b/examples/profiling/README.md index 74d8887cb09d..29944d3581a6 100644 --- a/examples/profiling/README.md +++ b/examples/profiling/README.md @@ -1,4 +1,4 @@ -# Profiling Plan: Diffusers Pipeline Profiling with torch.profiler +# Profiling a `DiffusionPipeline` with the PyTorch Profiler Education materials to strategically profile pipelines to potentially improve their runtime with `torch.compile`. To set these pipelines up for success with `torch.compile`, From 1131acd6e1ee94c120b791cb47bdcb5d7975f0d0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 29 Mar 2026 10:03:29 +0530 Subject: [PATCH 18/27] cuda graphs. --- examples/profiling/README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/profiling/README.md b/examples/profiling/README.md index 29944d3581a6..55be0d575ea8 100644 --- a/examples/profiling/README.md +++ b/examples/profiling/README.md @@ -312,4 +312,7 @@ kernels, so this tiny sync balloons to 2.3s. * As mentioned above, we profiled with regional compilation so it's possible that there are still some gaps outside the compiled regions. A full compilation will likely mitigate it. In case it doesn't, the above observations could -be useful to mitigate that. \ No newline at end of file +be useful to mitigate that. +* Use of CUDA Graphs can also help mitigate CPU overhead related issues. When +using "reduce-overhead" and "max-autotune" in `torch.compile` triggers the +use of CUDA Graphs. \ No newline at end of file From c642cd0e4f9623b750658803dd101008426153f0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 30 Mar 2026 09:02:49 +0530 Subject: [PATCH 19/27] up --- examples/profiling/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/profiling/README.md b/examples/profiling/README.md index 55be0d575ea8..86a99cb8cc19 100644 --- a/examples/profiling/README.md +++ b/examples/profiling/README.md @@ -275,7 +275,7 @@ After: - CPU gaps: 158us / 2,777us / 136us > [!NOTE] -> The fixes mentioned above and below are available in [this PR](TODO:link). +> The fixes mentioned above and below are available in [this PR](https://github.com/huggingface/diffusers/pull/13356). ### DtoH syncs From ed8241a394e78c056d9ad8a9f8b0953a63756f29 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 31 Mar 2026 09:46:45 +0530 Subject: [PATCH 20/27] Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- examples/profiling/README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/profiling/README.md b/examples/profiling/README.md index 86a99cb8cc19..3c84e87221a2 100644 --- a/examples/profiling/README.md +++ b/examples/profiling/README.md @@ -34,8 +34,8 @@ We want to uncover CPU overhead, CPU-GPU sync points, and other bottlenecks in p > [!NOTE] > We use realistic inference call hyperparameters that mimic how these pipelines will be actually used. This -> include using classifier-free guidance (where applicable), reasonable dimensions such 1024x1024, etc. -> But we keep the overall running time to a bare minimum (hence 2 `num_inference_steps`). +> includes using classifier-free guidance (where applicable), reasonable dimensions such 1024x1024, etc. +> But we keep the number of inference steps to a bare minimum. ## Approach @@ -197,7 +197,7 @@ To inspect this: zoom into a single denoising step, select a CUDA kernel on the ## Afterwards -To keep the profiling iterations fast, we always used [regional compilation](https://pytorch.org/tutorials/recipes/regional_compilation.html). As one would expect the trace with compilation should show +To keep the profiling iterations fast, we always use [regional compilation](https://pytorch.org/tutorials/recipes/regional_compilation.html). As one would expect the trace with compilation should show fewer kernel launches than its eager counterpart. _(Unless otherwise specified, the traces below were obtained with **Flux2**.)_ @@ -253,12 +253,12 @@ transformer forward's kernels. at the call site. ``` -The changes looke reasonable based on our past experience. So, we asked Claude to apply these changes to [`pipeline_flux2_klein.py`](../../src/diffusers/pipelines/flux2/pipeline_flux2_klein.py). We then profiled +The changes looked reasonable based on our past experience. So, we asked Claude to apply these changes to [`pipeline_flux2_klein.py`](../../src/diffusers/pipelines/flux2/pipeline_flux2_klein.py). We then profiled the updated pipeline. It still didn't eliminate the gaps as expected so, we fed that back to Claude and it spotted something more crucial. Under the [`cache_context`](https://github.com/huggingface/diffusers/blob/f2be8bd6b3dc4035bd989dc467f15d86bf3c9c12/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py#L842) manager, there is a call to `_set_context()` upon -enters and exists. It calls `named_modules()` on the entire underlying model (in this case the Flux2 Klein DiT). +enters and exits. It calls `named_modules()` on the entire underlying model (in this case the Flux2 Klein DiT). For large models, when they are invoked iteratively like our case, it adds to the latency because it involes traversing hundreds of submodules. The fix was to build a list of hooked child registries once on the first call and cache it in `_child_registries_cache`. This way, the subsequent calls would return the cached list directly without From 3ae7d9b4d79776c11e6643a1edc11c692faec2cd Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 31 Mar 2026 09:50:36 +0530 Subject: [PATCH 21/27] add torch.compile link. --- examples/profiling/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/profiling/README.md b/examples/profiling/README.md index 3c84e87221a2..1ab2ed91e6b9 100644 --- a/examples/profiling/README.md +++ b/examples/profiling/README.md @@ -315,4 +315,5 @@ will likely mitigate it. In case it doesn't, the above observations could be useful to mitigate that. * Use of CUDA Graphs can also help mitigate CPU overhead related issues. When using "reduce-overhead" and "max-autotune" in `torch.compile` triggers the -use of CUDA Graphs. \ No newline at end of file +use of CUDA Graphs. +* Diffusers' integration of `torch.compile` is documented [here](https://huggingface.co/docs/diffusers/main/en/optimization/fp16#torchcompile). \ No newline at end of file From bfb19afd1e4bd85f45730f04e7207bfd8b7fe451 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 31 Mar 2026 09:51:49 +0530 Subject: [PATCH 22/27] approach -> How the tooling works --- examples/profiling/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/profiling/README.md b/examples/profiling/README.md index 1ab2ed91e6b9..f0f968faf7f1 100644 --- a/examples/profiling/README.md +++ b/examples/profiling/README.md @@ -11,7 +11,7 @@ Thanks to Claude Code for paircoding! We acknowledge the [Claude of OSS](https:/ * [Context](#context) * [Target pipelines](#target-pipelines) -* [Approach taken](#approach) +* [How the tooling works](#how-the-tooling-works) * [Verification](#verification) * [Interpretation](#interpreting-traces-in-perfetto-ui) * [Taking profiling-guided steps for improvements](#afterwards) @@ -37,7 +37,7 @@ We want to uncover CPU overhead, CPU-GPU sync points, and other bottlenecks in p > includes using classifier-free guidance (where applicable), reasonable dimensions such 1024x1024, etc. > But we keep the number of inference steps to a bare minimum. -## Approach +## How the Tooling Works Follow the flux-fast pattern: **annotate key pipeline methods** with `torch.profiler.record_function` wrappers, then run the pipeline under `torch.profiler.profile` and export a Chrome trace. From 40a525e784aa4de98beadaa11334a447bf23795f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 31 Mar 2026 09:54:15 +0530 Subject: [PATCH 23/27] table --- examples/profiling/README.md | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/examples/profiling/README.md b/examples/profiling/README.md index f0f968faf7f1..3983b516869b 100644 --- a/examples/profiling/README.md +++ b/examples/profiling/README.md @@ -264,15 +264,11 @@ For large models, when they are invoked iteratively like our case, it adds to th The fix was to build a list of hooked child registries once on the first call and cache it in `_child_registries_cache`. This way, the subsequent calls would return the cached list directly without any traversal. With the fix applied, the improvements were visible. -Before: -- _set_context total: 21.6ms (8 calls) -- cache_context total: 21.7ms -- CPU gaps: 5,523us / 8,007us / 5,508us - -After: -- _set_context total: 0.0ms (8 calls) -- cache_context total: 0.1ms -- CPU gaps: 158us / 2,777us / 136us +| | Before | After | +|------------------------|------------------------------|-----------------------------| +| `_set_context` total | 21.6ms (8 calls) | 0.0ms (8 calls) | +| `cache_context` total | 21.7ms | 0.1ms | +| CPU gaps | 5,523us / 8,007us / 5,508us | 158us / 2,777us / 136us | > [!NOTE] > The fixes mentioned above and below are available in [this PR](https://github.com/huggingface/diffusers/pull/13356). From 6cf142902abe7b5f8cde5c8df8349252de106bd0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 31 Mar 2026 19:07:58 +0530 Subject: [PATCH 24/27] unavoidable gaps. --- examples/profiling/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/profiling/README.md b/examples/profiling/README.md index 3983b516869b..10a0dd196352 100644 --- a/examples/profiling/README.md +++ b/examples/profiling/README.md @@ -152,6 +152,8 @@ Zoom into the CUDA row during the denoising loop. Ideally, GPU kernels should be - Python overhead between ops (visible as CPU slices in the CPU row during the gap) - DtoH sync (`.item()`, `.cpu()`) forcing the GPU to drain before the CPU can proceed +No bubbles/gaps is ideal, but for small shapes (small model, small batch size, or both) some bubbles could be unavoidable. + **2. CPU stalls (DtoH syncs)** These appear on the **CPU row** (not the CUDA row) — they are CPU-side blocking calls that wait for the GPU to finish. Look for long slices labeled `cudaStreamSynchronize` or `cudaDeviceSynchronize`. To find them: zoom into the CPU row during a denoising step and look for unusually wide slices, or use Perfetto's search bar (press `/`) and type `cudaStreamSynchronize` to jump directly to matching events. Click on a slice — if `with_stack=True` was enabled, the bottom panel ("Current Selection") shows the Python stack trace pointing to the exact line causing the sync (e.g., a `.item()` call in the scheduler). From fb6afa6da6a885a0daca94be3b254f267d6b0045 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 31 Mar 2026 19:08:22 +0530 Subject: [PATCH 25/27] make important --- examples/profiling/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/profiling/README.md b/examples/profiling/README.md index 10a0dd196352..5d35a181197e 100644 --- a/examples/profiling/README.md +++ b/examples/profiling/README.md @@ -152,7 +152,8 @@ Zoom into the CUDA row during the denoising loop. Ideally, GPU kernels should be - Python overhead between ops (visible as CPU slices in the CPU row during the gap) - DtoH sync (`.item()`, `.cpu()`) forcing the GPU to drain before the CPU can proceed -No bubbles/gaps is ideal, but for small shapes (small model, small batch size, or both) some bubbles could be unavoidable. +> [!IMPORTANT] +> No bubbles/gaps is ideal, but for small shapes (small model, small batch size, or both) some bubbles could be unavoidable. **2. CPU stalls (DtoH syncs)** From 40c330a90d362d1e4385243e9f26cff691077620 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 31 Mar 2026 19:17:35 +0530 Subject: [PATCH 26/27] note on regional compilation --- examples/profiling/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/profiling/README.md b/examples/profiling/README.md index 5d35a181197e..d6333f7ec7c5 100644 --- a/examples/profiling/README.md +++ b/examples/profiling/README.md @@ -112,6 +112,7 @@ All configs use `output_type="latent"` by default (skip VAE decode for cleaner d - `--num_steps N` (override, default 4) - `--full_decode` (switch output_type from `"latent"` to `"pil"` to include VAE) - `--compile_mode default|reduce-overhead|max-autotune` +- `--compile_regional` flag (uses [regional compilation](https://pytorch.org/tutorials/recipes/regional_compilation.html) to compile only the transformer forward pass instead of the full pipeline — faster compile times, ideal for iterative profiling) - `--compile_fullgraph` flag **Output:** `{output_dir}/{pipeline}_{mode}.json` Chrome trace + stdout summary. From 131831ff20e6815075cdae923f4d5be85c9dd175 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 31 Mar 2026 20:56:20 +0530 Subject: [PATCH 27/27] Apply suggestions from code review Co-authored-by: Sayak Paul --- examples/profiling/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/profiling/README.md b/examples/profiling/README.md index d6333f7ec7c5..d87816a5a031 100644 --- a/examples/profiling/README.md +++ b/examples/profiling/README.md @@ -199,7 +199,7 @@ To inspect this: zoom into a single denoising step, select a CUDA kernel on the | Kernel launch latency? | CPU launch → GPU kernel offset | < 10us, CPU ahead of GPU | > 20us or CPU trailing GPU | | Memory stable? | Memory counter track | Flat during denoise loop | Growing per step | -## Afterwards +## What Profiling Revealed and Fixes To keep the profiling iterations fast, we always use [regional compilation](https://pytorch.org/tutorials/recipes/regional_compilation.html). As one would expect the trace with compilation should show fewer kernel launches than its eager counterpart.