Skip to content
318 changes: 318 additions & 0 deletions examples/profiling/README.md

Large diffs are not rendered by default.

181 changes: 181 additions & 0 deletions examples/profiling/profiling_pipelines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
"""
Profile diffusers pipelines with torch.profiler.

Usage:
python profiling/profiling_pipelines.py --pipeline flux --mode eager
python profiling/profiling_pipelines.py --pipeline flux --mode compile
python profiling/profiling_pipelines.py --pipeline flux --mode both
python profiling/profiling_pipelines.py --pipeline all --mode eager
python profiling/profiling_pipelines.py --pipeline wan --mode eager --full_decode
python profiling/profiling_pipelines.py --pipeline flux --mode compile --num_steps 4
"""

import argparse
import copy
import logging

import torch
from profiling_utils import PipelineProfiler, PipelineProfilingConfig


logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
logger = logging.getLogger(__name__)

PROMPT = "A cat holding a sign that says hello world"


def build_registry():
"""Build the pipeline config registry. Imports are deferred to avoid loading all pipelines upfront."""
from diffusers import Flux2KleinPipeline, FluxPipeline, LTX2Pipeline, QwenImagePipeline, WanPipeline

return {
"flux": PipelineProfilingConfig(
name="flux",
pipeline_cls=FluxPipeline,
pipeline_init_kwargs={
"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev",
"torch_dtype": torch.bfloat16,
},
pipeline_call_kwargs={
"prompt": PROMPT,
"height": 1024,
"width": 1024,
"num_inference_steps": 4,
"guidance_scale": 3.5,
"output_type": "latent",
},
),
"flux2": PipelineProfilingConfig(
name="flux2",
pipeline_cls=Flux2KleinPipeline,
pipeline_init_kwargs={
"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-klein-base-9B",
"torch_dtype": torch.bfloat16,
},
pipeline_call_kwargs={
"prompt": PROMPT,
"height": 1024,
"width": 1024,
"num_inference_steps": 4,
"guidance_scale": 3.5,
"output_type": "latent",
},
),
"wan": PipelineProfilingConfig(
name="wan",
pipeline_cls=WanPipeline,
pipeline_init_kwargs={
"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
"torch_dtype": torch.bfloat16,
},
pipeline_call_kwargs={
"prompt": PROMPT,
"negative_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards",
"height": 480,
"width": 832,
"num_frames": 81,
"num_inference_steps": 4,
"output_type": "latent",
},
),
"ltx2": PipelineProfilingConfig(
name="ltx2",
pipeline_cls=LTX2Pipeline,
pipeline_init_kwargs={
"pretrained_model_name_or_path": "Lightricks/LTX-2",
"torch_dtype": torch.bfloat16,
},
pipeline_call_kwargs={
"prompt": PROMPT,
"negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
"height": 512,
"width": 768,
"num_frames": 121,
"num_inference_steps": 4,
"guidance_scale": 4.0,
"output_type": "latent",
},
),
"qwenimage": PipelineProfilingConfig(
name="qwenimage",
pipeline_cls=QwenImagePipeline,
pipeline_init_kwargs={
"pretrained_model_name_or_path": "Qwen/Qwen-Image",
"torch_dtype": torch.bfloat16,
},
pipeline_call_kwargs={
"prompt": PROMPT,
"negative_prompt": " ",
"height": 1024,
"width": 1024,
"num_inference_steps": 4,
"true_cfg_scale": 4.0,
"output_type": "latent",
},
),
}


def main():
parser = argparse.ArgumentParser(description="Profile diffusers pipelines with torch.profiler")
parser.add_argument(
"--pipeline",
choices=["flux", "flux2", "wan", "ltx2", "qwenimage", "all"],
required=True,
help="Which pipeline to profile",
)
parser.add_argument(
"--mode",
choices=["eager", "compile", "both"],
default="eager",
help="Run in eager mode, compile mode, or both",
)
parser.add_argument("--output_dir", default="profiling_results", help="Directory for trace output")
parser.add_argument("--num_steps", type=int, default=None, help="Override num_inference_steps")
parser.add_argument("--full_decode", action="store_true", help="Profile including VAE decode (output_type='pil')")
parser.add_argument(
"--compile_mode",
default="default",
choices=["default", "reduce-overhead", "max-autotune"],
help="torch.compile mode",
)
parser.add_argument("--compile_fullgraph", action="store_true", help="Use fullgraph=True for torch.compile")
parser.add_argument(
"--compile_regional",
action="store_true",
help="Use compile_repeated_blocks() instead of full model compile",
)
args = parser.parse_args()

registry = build_registry()

pipeline_names = list(registry.keys()) if args.pipeline == "all" else [args.pipeline]
modes = ["eager", "compile"] if args.mode == "both" else [args.mode]

for pipeline_name in pipeline_names:
for mode in modes:
config = copy.deepcopy(registry[pipeline_name])

# Apply overrides
if args.num_steps is not None:
config.pipeline_call_kwargs["num_inference_steps"] = args.num_steps
if args.full_decode:
config.pipeline_call_kwargs["output_type"] = "pil"
if mode == "compile":
config.compile_kwargs = {
"fullgraph": args.compile_fullgraph,
"mode": args.compile_mode,
}
config.compile_regional = args.compile_regional

logger.info(f"Profiling {pipeline_name} in {mode} mode...")
profiler = PipelineProfiler(config, args.output_dir)
try:
trace_file = profiler.run()
logger.info(f"Done: {trace_file}")
except Exception as e:
logger.error(f"Failed to profile {pipeline_name} ({mode}): {e}")


if __name__ == "__main__":
main()
148 changes: 148 additions & 0 deletions examples/profiling/profiling_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import functools
import gc
import logging
import os
from dataclasses import dataclass, field
from typing import Any

import torch
import torch.profiler


logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
logger = logging.getLogger(__name__)


def annotate(func, name):
"""Wrap a function with torch.profiler.record_function for trace annotation."""

@functools.wraps(func)
def wrapper(*args, **kwargs):
with torch.profiler.record_function(name):
return func(*args, **kwargs)

return wrapper


def annotate_pipeline(pipe):
"""Apply profiler annotations to key pipeline methods.

Monkey-patches bound methods so they appear as named spans in the trace.
Non-invasive — no source modifications required.
"""
annotations = [
("transformer", "forward", "transformer_forward"),
("vae", "decode", "vae_decode"),
("vae", "encode", "vae_encode"),
("scheduler", "step", "scheduler_step"),
]

# Annotate sub-component methods
for component_name, method_name, label in annotations:
component = getattr(pipe, component_name, None)
if component is None:
continue
method = getattr(component, method_name, None)
if method is None:
continue
setattr(component, method_name, annotate(method, label))

# Annotate pipeline-level methods
if hasattr(pipe, "encode_prompt"):
pipe.encode_prompt = annotate(pipe.encode_prompt, "encode_prompt")


def flush():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()


@dataclass
class PipelineProfilingConfig:
name: str
pipeline_cls: Any
pipeline_init_kwargs: dict[str, Any]
pipeline_call_kwargs: dict[str, Any]
compile_kwargs: dict[str, Any] | None = field(default=None)
compile_regional: bool = False


class PipelineProfiler:
def __init__(self, config: PipelineProfilingConfig, output_dir: str = "profiling_results"):
self.config = config
self.output_dir = output_dir
os.makedirs(output_dir, exist_ok=True)

def setup_pipeline(self):
"""Load the pipeline from pretrained, optionally compile, and annotate."""
logger.info(f"Loading pipeline: {self.config.name}")
pipe = self.config.pipeline_cls.from_pretrained(**self.config.pipeline_init_kwargs)
pipe.to("cuda")

if self.config.compile_kwargs:
if self.config.compile_regional:
logger.info(
f"Regional compilation (compile_repeated_blocks) with kwargs: {self.config.compile_kwargs}"
)
pipe.transformer.compile_repeated_blocks(**self.config.compile_kwargs)
else:
logger.info(f"Full compilation with kwargs: {self.config.compile_kwargs}")
pipe.transformer.compile(**self.config.compile_kwargs)

# Disable tqdm progress bar to avoid CPU overhead / IO between steps
pipe.set_progress_bar_config(disable=True)

annotate_pipeline(pipe)
return pipe

def run(self):
"""Execute the profiling run: warmup, then profile one pipeline call."""
pipe = self.setup_pipeline()
flush()

mode = "compile" if self.config.compile_kwargs else "eager"
trace_file = os.path.join(self.output_dir, f"{self.config.name}_{mode}.json")

# Warmup (pipeline __call__ is already decorated with @torch.no_grad())
logger.info("Running warmup...")
pipe(**self.config.pipeline_call_kwargs)
flush()

# Profile
logger.info("Running profiled iteration...")
activities = [
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
with torch.profiler.profile(
activities=activities,
record_shapes=True,
profile_memory=True,
with_stack=True,
) as prof:
with torch.profiler.record_function("pipeline_call"):
pipe(**self.config.pipeline_call_kwargs)

# Export trace
prof.export_chrome_trace(trace_file)
logger.info(f"Chrome trace saved to: {trace_file}")

# Print summary
print("\n" + "=" * 80)
print(f"Profile summary: {self.config.name} ({mode})")
print("=" * 80)
print(
prof.key_averages().table(
sort_by="cuda_time_total",
row_limit=20,
)
)

# Cleanup
pipe.to("cpu")
del pipe
flush()

return trace_file
46 changes: 46 additions & 0 deletions examples/profiling/run_profiling.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/bin/bash
# Run profiling across all pipelines in eager and compile (regional) modes.
#
# Usage:
# bash profiling/run_profiling.sh
# bash profiling/run_profiling.sh --output_dir my_results

set -euo pipefail

OUTPUT_DIR="profiling_results"
while [[ $# -gt 0 ]]; do
case "$1" in
--output_dir) OUTPUT_DIR="$2"; shift 2 ;;
*) echo "Unknown arg: $1"; exit 1 ;;
esac
done
NUM_STEPS=2
# PIPELINES=("flux" "flux2" "wan" "ltx2" "qwenimage")
PIPELINES=("wan")
MODES=("eager" "compile")

for pipeline in "${PIPELINES[@]}"; do
for mode in "${MODES[@]}"; do
echo "============================================================"
echo "Profiling: ${pipeline} | mode: ${mode}"
echo "============================================================"

COMPILE_ARGS=""
if [ "$mode" = "compile" ]; then
COMPILE_ARGS="--compile_regional --compile_fullgraph --compile_mode default"
fi

python profiling/profiling_pipelines.py \
--pipeline "$pipeline" \
--mode "$mode" \
--output_dir "$OUTPUT_DIR" \
--num_steps "$NUM_STEPS" \
$COMPILE_ARGS

echo ""
done
done

echo "============================================================"
echo "All traces saved to: ${OUTPUT_DIR}/"
echo "============================================================"
Loading
Loading