feat: a generic run hf tool that works for a number of model classes#4223
Open
narendasan wants to merge 2 commits intomainfrom
Open
feat: a generic run hf tool that works for a number of model classes#4223narendasan wants to merge 2 commits intomainfrom
narendasan wants to merge 2 commits intomainfrom
Conversation
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/bench.py 2026-04-29 16:28:51.984749+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/bench.py 2026-04-29 16:29:15.476456+00:00
@@ -1,16 +1,16 @@
"""
Shared benchmarking harness for all HF model strategies.
"""
+
from __future__ import annotations
import timeit
from typing import Callable, Sequence
import numpy as np
import torch
-
WARMUP_ITERS = 5
def warmup_and_time(
--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/metrics.py 2026-04-29 16:28:51.984749+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/metrics.py 2026-04-29 16:29:15.561992+00:00
@@ -3,10 +3,11 @@
- Encoder / classifier : samples/s (throughput) + latency
- LLM : tokens/s
- Diffusion : images/s (one full denoising pass)
- Audio : real-time factor (audio_duration / inference_time)
"""
+
from __future__ import annotations
import json
--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/dist.py 2026-04-29 16:28:51.984749+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/dist.py 2026-04-29 16:29:15.564318+00:00
@@ -12,10 +12,11 @@
MASTER_PORT – rendezvous port
Single-process (no torchtrtrun) defaults: WORLD_SIZE=1 (or unset), so
all helpers here become no-ops.
"""
+
from __future__ import annotations
import datetime
import logging
import os
@@ -44,11 +45,13 @@
def is_master() -> bool:
return rank() == 0
-def init_distributed(timeout_hours: int = 2) -> Optional["torch.distributed.ProcessGroup"]:
+def init_distributed(
+ timeout_hours: int = 2,
+) -> Optional["torch.distributed.ProcessGroup"]:
"""
Initialize torch.distributed if WORLD_SIZE > 1.
Long timeout (default 2 h) so TRT engine builds don't trigger the NCCL
watchdog when one rank takes longer than another to build.
@@ -79,17 +82,16 @@
# Wire NCCL into TRT (sets up the symbol resolution path so TRT's
# collectives talk to the same libnccl.so torchtrtrun preloaded).
try:
from torch_tensorrt.distributed import setup_nccl_for_torch_tensorrt
+
setup_nccl_for_torch_tensorrt()
except Exception as e:
logger.warning(f"setup_nccl_for_torch_tensorrt failed: {e}")
- logger.info(
- f"[rank {rank()}/{world_size()}] dist init OK on cuda:{local_rank()}"
- )
+ logger.info(f"[rank {rank()}/{world_size()}] dist init OK on cuda:{local_rank()}")
return dist.group.WORLD
def build_device_mesh(mesh_dim_names: tuple[str, ...] = ("tp",)):
"""
@@ -112,7 +114,8 @@
def barrier() -> None:
"""torch.distributed barrier — no-op when not distributed."""
if not is_distributed():
return
import torch.distributed as dist
+
if dist.is_initialized():
dist.barrier()
--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/accuracy.py 2026-04-29 16:28:51.984749+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/accuracy.py 2026-04-29 16:29:15.582221+00:00
@@ -8,21 +8,22 @@
Default tolerances are tuned for FP16 (atol=1e-2, rtol=1e-2,
cos_sim_min=0.99). Override via --accuracy-atol / --accuracy-rtol /
--accuracy-cos-sim-min when comparing tighter precisions or models
known to have larger accumulated error.
"""
+
from __future__ import annotations
from typing import Iterable
import torch
import torch.utils._pytree as pytree
-
# --------------------------------------------------------------------------- #
# Output flattening
# --------------------------------------------------------------------------- #
+
def _flatten_to_tensors(out) -> list[torch.Tensor]:
"""
Flatten an arbitrary HF model output (ModelOutput dataclass, dict,
tuple, list, or single tensor) into a list of leaf tensors using
@@ -33,10 +34,11 @@
# --------------------------------------------------------------------------- #
# Per-tensor metrics
# --------------------------------------------------------------------------- #
+
def _cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float:
a = a.detach().to(torch.float32).flatten()
b = b.detach().to(torch.float32).flatten()
if a.numel() == 0 or b.numel() == 0:
@@ -87,10 +89,11 @@
# --------------------------------------------------------------------------- #
# Compare two outputs (each a tensor / dict / dataclass / tuple)
# --------------------------------------------------------------------------- #
+
def compare_outputs(
pt_out,
trt_out,
*,
atol: float = 1e-2,
@@ -99,22 +102,28 @@
) -> list[dict]:
pt_leaves = _flatten_to_tensors(pt_out)
trt_leaves = _flatten_to_tensors(trt_out)
if len(pt_leaves) != len(trt_leaves):
- return [{
- "name": "<output-count-mismatch>",
- "shape_pt": f"{len(pt_leaves)} tensors",
- "shape_trt": f"{len(trt_leaves)} tensors",
- "shape_match": False,
- "cos_sim": float("nan"),
- "max_abs": float("nan"),
- "mean_abs": float("nan"),
- "allclose": False,
- }]
-
- names = list(output_names) if output_names else [f"out[{i}]" for i in range(len(pt_leaves))]
+ return [
+ {
+ "name": "<output-count-mismatch>",
+ "shape_pt": f"{len(pt_leaves)} tensors",
+ "shape_trt": f"{len(trt_leaves)} tensors",
+ "shape_match": False,
+ "cos_sim": float("nan"),
+ "max_abs": float("nan"),
+ "mean_abs": float("nan"),
+ "allclose": False,
+ }
+ ]
+
+ names = (
+ list(output_names)
+ if output_names
+ else [f"out[{i}]" for i in range(len(pt_leaves))]
+ )
if len(names) < len(pt_leaves):
names += [f"out[{i}]" for i in range(len(names), len(pt_leaves))]
rows: list[dict] = []
for name, pt, trt in zip(names, pt_leaves, trt_leaves):
@@ -125,10 +134,11 @@
# --------------------------------------------------------------------------- #
# Reporting
# --------------------------------------------------------------------------- #
+
def overall_pass(
rows: list[dict],
*,
cos_sim_min: float = 0.99,
@@ -169,11 +179,13 @@
print(f"\n{'=' * 70}")
print(f" {title}")
print(f"{'=' * 70}")
cols = ("name", "shape_pt", "cos_sim", "max_abs", "mean_abs", "allclose")
- widths = {c: max(len(c), max(len(_fmt(r.get(c, ""), c)) for r in rows)) for c in cols}
+ widths = {
+ c: max(len(c), max(len(_fmt(r.get(c, ""), c)) for r in rows)) for c in cols
+ }
header = " ".join(c.ljust(widths[c]) for c in cols)
print(header)
print("-" * len(header))
for r in rows:
print(" ".join(_fmt(r.get(c, ""), c).ljust(widths[c]) for c in cols))
--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/compile.py 2026-04-29 16:28:51.984749+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/compile.py 2026-04-29 16:29:15.613228+00:00
@@ -22,17 +22,17 @@
)
`enabled_precisions` is deprecated when `use_explicit_typing=True` and is
NEVER set by these helpers.
"""
+
from __future__ import annotations
from typing import Iterable, Optional
import torch
import torch_tensorrt
-
# --------------------------------------------------------------------------- #
# Precision plumbing
# --------------------------------------------------------------------------- #
@@ -87,10 +87,11 @@
# --------------------------------------------------------------------------- #
# Export helper (with fallback for guard violations)
# --------------------------------------------------------------------------- #
+
def safe_export(
module: torch.nn.Module,
args: tuple = (),
kwargs: Optional[dict] = None,
@@ -110,11 +111,13 @@
kwargs=kwargs,
dynamic_shapes=dynamic_shapes,
strict=False,
)
except Exception as e:
- print(f"[compile] torch.export.export failed ({e}); retrying with deferred guards.")
+ print(
+ f"[compile] torch.export.export failed ({e}); retrying with deferred guards."
+ )
return torch.export._trace._export(
module,
args=args,
kwargs=kwargs,
dynamic_shapes=dynamic_shapes,
@@ -124,10 +127,11 @@
# --------------------------------------------------------------------------- #
# Compile wrapper
# --------------------------------------------------------------------------- #
+
def _build_trt_kwargs(
precision: str,
autocast: bool,
min_block_size: int,
@@ -179,13 +183,19 @@
- C++ runtime (use_python_runtime=False)
- engine caching ON
- offload_module_to_cpu OFF (opt-in for memory-constrained models)
"""
kw = _build_trt_kwargs(
- precision, autocast, min_block_size, debug,
- offload_module_to_cpu, cache_built_engines, reuse_cached_engines,
- engine_cache_dir, extra,
+ precision,
+ autocast,
+ min_block_size,
+ debug,
+ offload_module_to_cpu,
+ cache_built_engines,
+ reuse_cached_engines,
+ engine_cache_dir,
+ extra,
)
return torch_tensorrt.dynamo.compile(ep, inputs=list(inputs), **kw)
def maybe_save_exported_program(cfg, ep, *, log_prefix: str) -> None:
@@ -225,11 +235,13 @@
# torchscript needs real tensors for jit.trace; other formats accept
# torch_tensorrt.Input specs.
if fmt == "torchscript":
use_args = example_arg_inputs if example_arg_inputs is not None else arg_inputs
- use_kwargs = example_kwarg_inputs if example_kwarg_inputs is not None else kwarg_inputs
+ use_kwargs = (
+ example_kwarg_inputs if example_kwarg_inputs is not None else kwarg_inputs
+ )
else:
use_args = arg_inputs
use_kwargs = kwarg_inputs
save_kwargs = {"output_format": fmt}
@@ -308,11 +320,14 @@
single TRT engine end-to-end (no PyTorch fallback partitions). If
your model has unsupported ops, this call will fail; use the
`compile_with_trt` + `torch_tensorrt.save` path instead.
"""
kw = _build_trt_kwargs(
- precision, autocast, min_block_size, debug,
+ precision,
+ autocast,
+ min_block_size,
+ debug,
offload_module_to_cpu,
# Engine caching is irrelevant for serialization output.
cache_built_engines=False,
reuse_cached_engines=False,
engine_cache_dir=engine_cache_dir,
--- /home/runner/work/TensorRT/TensorRT/tools/hf/run_hf.py 2026-04-29 16:28:51.984749+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/run_hf.py 2026-04-29 16:29:15.643944+00:00
@@ -64,10 +64,11 @@
Path B (--autocast) : Model stays FP32; TRT compiler casts via
enable_autocast=True, autocast_low_precision_type=<dtype>.
`enabled_precisions` is deprecated under use_explicit_typing and is never set.
"""
+
from __future__ import annotations
import sys
from pathlib import Path
@@ -92,37 +93,45 @@
def _build_strategy(family: str, cfg: RunConfig):
if family == "llm":
if is_distributed():
from strategies.llm_tp import LLMTPStrategy
+
return LLMTPStrategy(cfg)
from strategies.llm import LLMStrategy
+
return LLMStrategy(cfg)
if family == "encoder":
from strategies.encoder import EncoderStrategy
+
return EncoderStrategy(cfg)
if family == "seq2seq":
from strategies.seq2seq import SeqToSeqStrategy
+
return SeqToSeqStrategy(cfg)
if family == "diffusion":
from strategies.diffusion import DiffusionStrategy
+
return DiffusionStrategy(cfg)
if family == "audio":
from strategies.audio import AudioStrategy
+
return AudioStrategy(cfg)
if family == "multimodal":
from strategies.multimodal import MultimodalStrategy
+
return MultimodalStrategy(cfg)
raise ValueError(f"No strategy implemented for family '{family}'")
def main() -> None:
args: CLIArgs = tyro.cli(CLIArgs)
if not args.debug:
import logging
import torch_tensorrt
+
torch_tensorrt.logging.set_level(logging.WARNING)
# Initialize torch.distributed if launched under torchtrtrun
# (WORLD_SIZE > 1). No-op for single-process runs.
init_distributed()
@@ -159,10 +168,11 @@
if args.benchmark:
rows = strategy.benchmark()
if args.json_out and rows:
from common.metrics import dump_json
+
dump_json(rows, args.json_out)
if __name__ == "__main__":
main()
--- /home/runner/work/TensorRT/TensorRT/tools/hf/detect.py 2026-04-29 16:28:51.984749+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/detect.py 2026-04-29 16:29:15.678084+00:00
@@ -10,48 +10,92 @@
"audio" – speech / ASR (Whisper, wav2vec2, ...)
Returns the family string. Raises ValueError if the family cannot be
determined and no --task override was given.
"""
+
from __future__ import annotations
from typing import Optional
-
# Maps HuggingFace model_type strings to strategy family names.
# Keep sorted within each group for readability.
_LLM_TYPES = {
- "bloom", "codegen", "falcon", "gemma", "gemma2", "gemma3",
- "gpt2", "gpt_bigcode", "gpt_neo", "gpt_neox", "gptj",
- "llama", "mistral", "mixtral", "mpt", "opt",
- "phi", "phi3", "qwen2", "qwen2_moe", "starcoder2",
+ "bloom",
+ "codegen",
+ "falcon",
+ "gemma",
+ "gemma2",
+ "gemma3",
+ "gpt2",
+ "gpt_bigcode",
+ "gpt_neo",
+ "gpt_neox",
+ "gptj",
+ "llama",
+ "mistral",
+ "mixtral",
+ "mpt",
+ "opt",
+ "phi",
+ "phi3",
+ "qwen2",
+ "qwen2_moe",
+ "starcoder2",
"stablelm",
}
_ENCODER_TYPES = {
- "albert", "bert", "camembert", "convnext", "deberta",
- "deberta-v2", "distilbert", "electra", "mobilenet_v2",
- "resnet", "roberta", "swin", "vit", "xlm", "xlm-roberta",
+ "albert",
+ "bert",
+ "camembert",
+ "convnext",
+ "deberta",
+ "deberta-v2",
+ "distilbert",
+ "electra",
+ "mobilenet_v2",
+ "resnet",
+ "roberta",
+ "swin",
+ "vit",
+ "xlm",
+ "xlm-roberta",
"efficientnet",
}
_SEQ2SEQ_TYPES = {
- "bart", "longt5", "mt5", "mbart", "pegasus", "t5",
+ "bart",
+ "longt5",
+ "mt5",
+ "mbart",
+ "pegasus",
+ "t5",
}
_DIFFUSION_TYPES = {
- "flux", "stable-diffusion", "stable_diffusion",
- "stable-diffusion-xl", "stable_diffusion_xl",
+ "flux",
+ "stable-diffusion",
+ "stable_diffusion",
+ "stable-diffusion-xl",
+ "stable_diffusion_xl",
"unet-2d-condition",
}
_AUDIO_TYPES = {
- "hubert", "wav2vec2", "wavlm", "whisper",
+ "hubert",
+ "wav2vec2",
+ "wavlm",
+ "whisper",
}
_MULTIMODAL_TYPES = {
- "blip", "blip-2", "clip", "clipseg", "siglip",
+ "blip",
+ "blip-2",
+ "clip",
+ "clipseg",
+ "siglip",
}
def detect_family(model_id: str, task_override: Optional[str] = None) -> str:
"""
@@ -72,19 +116,21 @@
# instead of config.json with a model_type field — fall back to that.
model_type = ""
last_error: Optional[Exception] = None
try:
from transformers import AutoConfig
+
cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=False)
model_type = getattr(cfg, "model_type", "").lower()
except Exception as e:
last_error = e
if not model_type:
try:
from huggingface_hub import hf_hub_download
import json
+
path = hf_hub_download(model_id, "model_index.json")
with open(path) as f:
idx = json.load(f)
cls_name = idx.get("_class_name", "").lower()
if "pipeline" in cls_name or any(
@@ -96,11 +142,14 @@
if not model_type:
msg = f"Could not determine family for '{model_id}'."
# Detect gated / authentication errors and surface them clearly.
err_str = str(last_error) if last_error else ""
- if any(k in err_str for k in ("Repository Not Found", "401", "gated", "authenticated", "access")):
+ if any(
+ k in err_str
+ for k in ("Repository Not Found", "401", "gated", "authenticated", "access")
+ ):
msg += (
"\n\nThis model may be gated or private. Run "
"`hf auth login` (or `huggingface-cli login`) and accept the "
"model's license on https://huggingface.co/" + model_id
)
--- /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/base.py 2026-04-29 16:28:51.985193+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/base.py 2026-04-29 16:29:15.754327+00:00
@@ -5,10 +5,11 @@
load() – download & prepare the model (called once, on CPU or GPU)
compile() – export + TRT compile (or torch.compile fast path)
benchmark() – warmup + timed loop + return list[dict] rows for metrics.py
generate() – optional: run a single sample-mode forward (text, image, audio)
"""
+
from __future__ import annotations
import abc
import dataclasses
from typing import Literal, Optional
@@ -166,11 +167,12 @@
pt_out = self._run_pt() # type: ignore[attr-defined]
trt_out = self._run_trt() # type: ignore[attr-defined]
rows = compare_outputs(
- pt_out, trt_out,
+ pt_out,
+ trt_out,
atol=self.cfg.accuracy_atol,
rtol=self.cfg.accuracy_rtol,
)
print_accuracy_table(
rows,
--- /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/audio.py 2026-04-29 16:28:51.985193+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/audio.py 2026-04-29 16:29:15.793592+00:00
@@ -8,10 +8,11 @@
For Whisper the encoder is the bulk of the compute on long audio, so this
still delivers most of the available speedup.
Benchmarking metric: real-time factor (RTF = inference_time / audio_duration).
"""
+
from __future__ import annotations
from typing import Optional
import torch
@@ -89,11 +90,13 @@
print("[audio] Compiled encoder swapped into self._model.model.encoder")
def _compile_torch_compile(self, encoder: torch.nn.Module) -> torch.nn.Module:
print("[audio] torch.compile encoder (backend='torch_tensorrt') ...")
options = compile_kwargs(self.cfg.precision, self.cfg.autocast)
- options.update({"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug})
+ options.update(
+ {"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug}
+ )
compiled = torch.compile(encoder, backend="torch_tensorrt", options=options)
dummy = torch.randn(*self._mel_shape, dtype=self._model_dtype).to(self._device)
with torch.no_grad():
_ = compiled(dummy)
@@ -102,11 +105,13 @@
def _compile_export(self, encoder: torch.nn.Module) -> torch.nn.Module:
dummy = torch.randn(*self._mel_shape, dtype=self._model_dtype).to(self._device)
print("[audio] torch.export.export encoder ...")
- ep = safe_export(encoder, args=(dummy,), dynamic_shapes=({0: torch.export.Dim.AUTO},))
+ ep = safe_export(
+ encoder, args=(dummy,), dynamic_shapes=({0: torch.export.Dim.AUTO},)
+ )
maybe_save_exported_program(self.cfg, ep, log_prefix="[audio]")
print("[audio] torch_tensorrt.dynamo.compile encoder ...")
trt_inputs = [torch_tensorrt.Input(shape=dummy.shape, dtype=dummy.dtype)]
compiled = compile_with_trt(
@@ -120,29 +125,40 @@
engine_cache_dir=self.cfg.engine_cache_dir,
)
torch.cuda.synchronize()
maybe_save_trt_module(
- self.cfg, compiled,
+ self.cfg,
+ compiled,
arg_inputs=trt_inputs,
log_prefix="[audio]",
)
maybe_save_trt_engine(self.cfg, ep, trt_inputs, log_prefix="[audio]")
return compiled
# ---------------------------------------------------------------------- #
def benchmark(self) -> list[dict]:
- assert self._model is not None and self._mel_shape is not None, "Call load() first"
+ assert (
+ self._model is not None and self._mel_shape is not None
+ ), "Call load() first"
rows: list[dict] = []
dummy = torch.randn(*self._mel_shape, dtype=self._model_dtype).to(self._device)
audio_duration_s = self.cfg.audio_duration_s
- encoder_for_bench = self._trt_encoder if self._trt_encoder is not None else self._model.model.encoder
- backend = f"torch_tensorrt[{self.cfg.mode}]" if self._trt_encoder is not None else "pytorch"
+ encoder_for_bench = (
+ self._trt_encoder
+ if self._trt_encoder is not None
+ else self._model.model.encoder
+ )
+ backend = (
+ f"torch_tensorrt[{self.cfg.mode}]"
+ if self._trt_encoder is not None
+ else "pytorch"
+ )
def _run_encoder():
with torch.no_grad():
return encoder_for_bench(dummy)
@@ -180,11 +196,13 @@
def _run_trt(self):
# Reuse the same dummy from _run_pt so PT vs TRT compare on identical input.
dummy = getattr(self, "_last_dummy", None)
if dummy is None:
- dummy = torch.randn(*self._mel_shape, dtype=self._model_dtype).to(self._device)
+ dummy = torch.randn(*self._mel_shape, dtype=self._model_dtype).to(
+ self._device
+ )
with torch.no_grad():
return self._trt_encoder(dummy)
def generate(self) -> None:
"""Transcribe a silence clip end-to-end (TRT encoder + PyTorch decoder)."""
--- /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/encoder.py 2026-04-29 16:28:51.985193+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/encoder.py 2026-04-29 16:29:15.896539+00:00
@@ -2,10 +2,11 @@
Encoder strategy: BERT, RoBERTa, ViT, ResNet, EfficientNet, etc.
Whole-model export. These are the easiest models to compile because they
have no KV cache, no generation loop, and minimal control flow.
"""
+
from __future__ import annotations
from typing import Optional
import torch
@@ -25,11 +26,16 @@
from common.metrics import print_table, report_latency
from strategies.base import ModelStrategy, RunConfig
# HF model_type values that use vision inputs (pixel_values) rather than input_ids.
_VISION_TYPES = {
- "convnext", "efficientnet", "mobilenet_v2", "resnet", "swin", "vit",
+ "convnext",
+ "efficientnet",
+ "mobilenet_v2",
+ "resnet",
+ "swin",
+ "vit",
}
def _is_vision_model(cfg) -> bool:
return getattr(cfg, "model_type", "").lower() in _VISION_TYPES
@@ -47,11 +53,13 @@
nc = getattr(model_cfg, "num_channels", 3)
return (torch.randn(batch_size, nc, h, w, dtype=dtype).to(device),)
seq_len = 128
vocab_size = getattr(model_cfg, "vocab_size", 30522)
- input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.int64).to(device)
+ input_ids = torch.randint(
+ 0, vocab_size, (batch_size, seq_len), dtype=torch.int64
+ ).to(device)
attention_mask = torch.ones_like(input_ids)
return (input_ids, attention_mask)
class EncoderStrategy(ModelStrategy):
@@ -88,11 +96,13 @@
self._model_cfg,
self.cfg.batch_size,
self._model_dtype,
self._device,
)
- print(f"[encoder] Model loaded. Input shapes: {[t.shape for t in self._dummy_inputs]}")
+ print(
+ f"[encoder] Model loaded. Input shapes: {[t.shape for t in self._dummy_inputs]}"
+ )
# ---------------------------------------------------------------------- #
def compile(self) -> None:
assert self._model is not None, "Call load() before compile()"
@@ -103,11 +113,13 @@
self._compile_export()
def _compile_torch_compile(self) -> None:
print("[encoder] torch.compile(backend='torch_tensorrt') ...")
options = compile_kwargs(self.cfg.precision, self.cfg.autocast)
- options.update({"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug})
+ options.update(
+ {"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug}
+ )
self._trt_model = torch.compile(
self._model,
backend="torch_tensorrt",
options=options,
@@ -131,13 +143,11 @@
print("[encoder] torch.export.export ...")
ep = safe_export(self._model, args=dummy, dynamic_shapes=dyn)
maybe_save_exported_program(self.cfg, ep, log_prefix="[encoder]")
print("[encoder] torch_tensorrt.dynamo.compile ...")
- trt_inputs = [
- torch_tensorrt.Input(shape=t.shape, dtype=t.dtype) for t in dummy
- ]
+ trt_inputs = [torch_tensorrt.Input(shape=t.shape, dtype=t.dtype) for t in dummy]
self._trt_model = compile_with_trt(
ep,
inputs=trt_inputs,
precision=self.cfg.precision,
autocast=self.cfg.autocast,
@@ -148,11 +158,12 @@
)
torch.cuda.synchronize()
print("[encoder] TRT compile done.")
maybe_save_trt_module(
- self.cfg, self._trt_model,
+ self.cfg,
+ self._trt_model,
arg_inputs=trt_inputs,
example_arg_inputs=list(dummy),
log_prefix="[encoder]",
)
maybe_save_trt_engine(self.cfg, ep, trt_inputs, log_prefix="[encoder]")
@@ -163,11 +174,13 @@
assert self._dummy_inputs is not None, "Call load() first"
rows: list[dict] = []
with torch.no_grad():
pt_t = warmup_and_time(
- lambda *a: self._model(*a), self._dummy_inputs, iterations=self.cfg.iterations
+ lambda *a: self._model(*a),
+ self._dummy_inputs,
+ iterations=self.cfg.iterations,
)
rows.append(
report_latency(
compute_stats(pt_t, self.cfg.batch_size),
backend="pytorch",
@@ -177,11 +190,13 @@
)
if self._trt_model is not None:
with torch.no_grad():
trt_t = warmup_and_time(
- lambda *a: self._trt_model(*a), self._dummy_inputs, iterations=self.cfg.iterations
+ lambda *a: self._trt_model(*a),
+ self._dummy_inputs,
+ iterations=self.cfg.iterations,
)
rows.append(
report_latency(
compute_stats(trt_t, self.cfg.batch_size),
backend=f"torch_tensorrt[{self.cfg.mode}]",
--- /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/diffusion.py 2026-04-29 16:28:51.985193+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/diffusion.py 2026-04-29 16:29:15.944213+00:00
@@ -12,10 +12,11 @@
Rather than hard-code dummy inputs per architecture, we run one short
inference pass and capture the actual backbone-call args via a forward
pre-hook. Those captured tensors are then used as the export args.
"""
+
from __future__ import annotations
import timeit
from typing import Optional
@@ -40,10 +41,11 @@
Insert `.<suffix>` before the extension of base_path.
foo.trt → foo.<suffix>.trt
foo → foo.<suffix>
"""
import os
+
root, ext = os.path.splitext(base_path)
return f"{root}.{suffix}{ext}" if ext else f"{root}.{suffix}"
def _get_backbone(pipe) -> tuple[torch.nn.Module, str]:
@@ -144,11 +146,13 @@
if self.cfg.save_trt_engine and self.cfg.mode == "export":
self._save_companion_engines()
def _compile_torch_compile(self, backbone: torch.nn.Module) -> torch.nn.Module:
options = compile_kwargs(self.cfg.precision, self.cfg.autocast)
- options.update({"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug})
+ options.update(
+ {"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug}
+ )
compiled = torch.compile(backbone, backend="torch_tensorrt", options=options)
# Trigger compilation by running a real pipeline step.
with torch.no_grad():
@@ -162,11 +166,13 @@
args, kwargs = _capture_backbone_inputs(
self._pipe,
prompt="a photo of an astronaut",
num_steps=1,
)
- print(f"[diffusion] Captured {len(args)} positional + {len(kwargs)} keyword args.")
+ print(
+ f"[diffusion] Captured {len(args)} positional + {len(kwargs)} keyword args."
+ )
# Stash for accuracy comparison: keep the PT backbone reference and
# the captured call args/kwargs so we can run both pre- and post-swap.
self._pt_backbone = backbone
self._captured_args = args
@@ -196,19 +202,19 @@
# entry for each: torch_tensorrt.Input for tensors, the original
# value for everything else.
trt_kwarg_inputs = {
k: (
torch_tensorrt.Input(shape=v.shape, dtype=v.dtype)
- if isinstance(v, torch.Tensor) else v
+ if isinstance(v, torch.Tensor)
+ else v
)
for k, v in kwargs.items()
}
# For the compile() call we still need a flat list, but only
# tensor-typed entries; compile()'s `inputs=` arg is positional.
flat_trt_inputs = trt_arg_inputs + [
- v for v in trt_kwarg_inputs.values()
- if isinstance(v, torch_tensorrt.Input)
+ v for v in trt_kwarg_inputs.values() if isinstance(v, torch_tensorrt.Input)
]
compiled = compile_with_trt(
ep,
inputs=flat_trt_inputs,
@@ -221,12 +227,14 @@
)
torch.cuda.synchronize()
if self.cfg.save_engine:
from common.compile import maybe_save_trt_module
+
maybe_save_trt_module(
- self.cfg, compiled,
+ self.cfg,
+ compiled,
arg_inputs=flat_trt_inputs,
log_prefix="[diffusion]",
)
# Save the backbone engine to <base>.<attr>.trt; companion engines
@@ -246,11 +254,13 @@
kwarg_inputs=trt_kwarg_inputs or None,
log_prefix=f"[diffusion:{backbone_attr}]",
)
except Exception as e:
print(f"[diffusion:{backbone_attr}] FAILED: {e}")
- print(f"[diffusion:{backbone_attr}] (continuing with companion engines)")
+ print(
+ f"[diffusion:{backbone_attr}] (continuing with companion engines)"
+ )
return compiled
# ---------------------------------------------------------------------- #
# Multi-engine save: text_encoder + backbone + vae_decoder
@@ -313,13 +323,16 @@
0, vocab, (self.cfg.batch_size, max_len), dtype=torch.int64
).to(self._device)
print("[diffusion:text_encoder] torch.export.export ...")
ep = safe_export(text_encoder, args=(input_ids,))
- trt_inputs = [torch_tensorrt.Input(shape=input_ids.shape, dtype=input_ids.dtype)]
+ trt_inputs = [
+ torch_tensorrt.Input(shape=input_ids.shape, dtype=input_ids.dtype)
+ ]
self._serialize_to_path(
- ep, path,
+ ep,
+ path,
arg_inputs=trt_inputs,
log_prefix="[diffusion:text_encoder]",
)
def _serialize_vae_decoder(self, path: str) -> None:
@@ -341,19 +354,23 @@
spatial_factor = 2 ** (len(block_out_channels) - 1)
h_lat = self.cfg.image_size // spatial_factor
w_lat = self.cfg.image_size // spatial_factor
latents = torch.randn(
- self.cfg.batch_size, latent_channels, h_lat, w_lat,
+ self.cfg.batch_size,
+ latent_channels,
+ h_lat,
+ w_lat,
dtype=self._model_dtype,
).to(self._device)
print("[diffusion:vae_decoder] torch.export.export ...")
ep = safe_export(decoder, args=(latents,))
trt_inputs = [torch_tensorrt.Input(shape=latents.shape, dtype=latents.dtype)]
self._serialize_to_path(
- ep, path,
+ ep,
+ path,
arg_inputs=trt_inputs,
log_prefix="[diffusion:vae_decoder]",
)
# ---------------------------------------------------------------------- #
--- /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/llm_tp.py 2026-04-29 16:28:51.985193+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/llm_tp.py 2026-04-29 16:29:15.970101+00:00
@@ -20,10 +20,11 @@
Launch
------
torchtrtrun --nproc_per_node=2 run_hf.py --model meta-llama/Llama-3.2-1B-Instruct
"""
+
from __future__ import annotations
import logging
import sys
from contextlib import nullcontext
@@ -66,19 +67,21 @@
)
plan: dict[str, object] = {}
n_layers = model.config.num_hidden_layers
for i in range(n_layers):
- plan.update({
- f"model.layers.{i}.self_attn.q_proj": ColwiseParallel(),
- f"model.layers.{i}.self_attn.k_proj": ColwiseParallel(),
- f"model.layers.{i}.self_attn.v_proj": ColwiseParallel(),
- f"model.layers.{i}.self_attn.o_proj": RowwiseParallel(),
- f"model.layers.{i}.mlp.gate_proj": ColwiseParallel(),
- f"model.layers.{i}.mlp.up_proj": ColwiseParallel(),
- f"model.layers.{i}.mlp.down_proj": RowwiseParallel(),
- })
+ plan.update(
+ {
+ f"model.layers.{i}.self_attn.q_proj": ColwiseParallel(),
+ f"model.layers.{i}.self_attn.k_proj": ColwiseParallel(),
+ f"model.layers.{i}.self_attn.v_proj": ColwiseParallel(),
+ f"model.layers.{i}.self_attn.o_proj": RowwiseParallel(),
+ f"model.layers.{i}.mlp.gate_proj": ColwiseParallel(),
+ f"model.layers.{i}.mlp.up_proj": ColwiseParallel(),
+ f"model.layers.{i}.mlp.down_proj": RowwiseParallel(),
+ }
+ )
return plan
def _patch_attention_head_counts(model, ws: int) -> None:
"""
@@ -119,11 +122,13 @@
# Build a 1-D TP device mesh and shard the model.
self._mesh = build_device_mesh(("tp",))
ws = world_size()
from torch.distributed.tensor.parallel import parallelize_module
- if not hasattr(self._model, "model") or not hasattr(self._model.model, "layers"):
+ if not hasattr(self._model, "model") or not hasattr(
+ self._model.model, "layers"
+ ):
raise NotImplementedError(
f"TP plan only supports Llama-family layouts "
f"(model.model.layers.{{i}}.self_attn / mlp). Got: "
f"{type(self._model).__name__}. GPT-2 / OPT / NeoX support TBD."
)
@@ -158,18 +163,20 @@
"""
if is_master():
print("[llm_tp] torch.compile(backend='torch_tensorrt', dynamic=True) ...")
opts = compile_kwargs(self.cfg.precision, self.cfg.autocast)
- opts.update({
- "min_block_size": self.cfg.min_block_size,
- "debug": self.cfg.debug,
- "device": self._device,
- "disable_tf32": True,
- "use_python_runtime": False,
- "assume_dynamic_shape_support": True,
- })
+ opts.update(
+ {
+ "min_block_size": self.cfg.min_block_size,
+ "debug": self.cfg.debug,
+ "device": self._device,
+ "disable_tf32": True,
+ "use_python_runtime": False,
+ "assume_dynamic_shape_support": True,
+ }
+ )
with torch_tensorrt.logging.debug() if self.cfg.debug else nullcontext():
self._trt_model = torch.compile(
self._model,
backend="torch_tensorrt",
@@ -264,11 +271,15 @@
def generate(self) -> None:
"""One greedy decode + print on rank 0."""
from utils import generate
- ids = self._tokenizer(self.cfg.prompt, return_tensors="pt")["input_ids"].to(self._device)
+ ids = self._tokenizer(self.cfg.prompt, return_tensors="pt")["input_ids"].to(
+ self._device
+ )
max_out = ids.shape[1] + self.cfg.num_tokens
- out = generate(self._trt_model, ids.clone(), max_out, self._tokenizer.eos_token_id)
+ out = generate(
+ self._trt_model, ids.clone(), max_out, self._tokenizer.eos_token_id
+ )
if is_master():
text = self._tokenizer.decode(out[0], skip_special_tokens=True)
print(f"[llm_tp] Output: {text!r}")
--- /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/llm.py 2026-04-29 16:28:51.985193+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/llm.py 2026-04-29 16:29:15.999905+00:00
@@ -9,10 +9,11 @@
Generation note: a TRT-compiled GraphModule does NOT have HuggingFace's
.generate() method. We implement a small greedy decoder that calls the
compiled module step-by-step on (input_ids, position_ids).
"""
+
from __future__ import annotations
import sys
import timeit
from pathlib import Path
@@ -81,10 +82,11 @@
)
# Optional: register SDPA converter from tools/llm/torchtrt_ext
try:
from torchtrt_ext import register_sdpa
+
register_sdpa.enable_sdpa_converter(self.cfg.model, self._model.config)
except Exception:
pass
self._tokenizer = AutoTokenizer.from_pretrained(self.cfg.model)
@@ -106,11 +108,13 @@
self._compile_export()
def _compile_torch_compile(self) -> None:
print("[llm] torch.compile(backend='torch_tensorrt') ...")
options = compile_kwargs(self.cfg.precision, self.cfg.autocast)
- options.update({"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug})
+ options.update(
+ {"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug}
+ )
self._trt_model = torch.compile(
self._model,
backend="torch_tensorrt",
options=options,
@@ -127,20 +131,24 @@
# pass that rewrites the exported graph to take KV cache tensors as
# extra inputs and return updated caches as extra outputs. This must
# happen BEFORE compile() so the pass sees the graph.
if self.cfg.cache == "static_v1":
import static_cache_v1 # noqa: F401 (registers lowering pass)
+
print("[llm] Static KV cache v1 lowering pass registered.")
elif self.cfg.cache == "static_v2":
import static_cache_v2 # noqa: F401
+
print("[llm] Static KV cache v2 lowering pass registered.")
print(f"[llm] Exporting (max_seq_len={max_seq_len}) ...")
ep = _export_llm(self._model, self._input_ids, max_seq_len)
maybe_save_exported_program(self.cfg, ep, log_prefix="[llm]")
- position_ids = torch.arange(self._input_ids.shape[1]).unsqueeze(0).to(self._device)
+ position_ids = (
+ torch.arange(self._input_ids.shape[1]).unsqueeze(0).to(self._device)
+ )
print("[llm] torch_tensorrt.dynamo.compile ...")
self._trt_model = compile_with_trt(
ep,
inputs=[self._input_ids, position_ids],
@@ -155,23 +163,30 @@
print("[llm] TRT compile done.")
# The EP was exported with input_ids as a positional arg and
# position_ids as a keyword arg, so the serialize path needs the
# arg/kwarg split explicitly.
- arg_inputs = [torch_tensorrt.Input(shape=self._input_ids.shape, dtype=torch.int64)]
+ arg_inputs = [
+ torch_tensorrt.Input(shape=self._input_ids.shape, dtype=torch.int64)
+ ]
kwarg_inputs = {
- "position_ids": torch_tensorrt.Input(shape=position_ids.shape, dtype=torch.int64),
+ "position_ids": torch_tensorrt.Input(
+ shape=position_ids.shape, dtype=torch.int64
+ ),
}
maybe_save_trt_module(
- self.cfg, self._trt_model,
+ self.cfg,
+ self._trt_model,
arg_inputs=arg_inputs,
kwarg_inputs=kwarg_inputs,
log_prefix="[llm]",
)
maybe_save_trt_engine(
- self.cfg, ep, arg_inputs,
+ self.cfg,
+ ep,
+ arg_inputs,
kwarg_inputs=kwarg_inputs,
log_prefix="[llm]",
)
# ---------------------------------------------------------------------- #
@@ -186,11 +201,13 @@
return self._benchmark_generation()
return self._benchmark_prefill()
def _benchmark_prefill(self) -> list[dict]:
assert self._input_ids is not None, "Call load() first"
- position_ids = torch.arange(self._input_ids.shape[1]).unsqueeze(0).to(self._device)
+ position_ids = (
+ torch.arange(self._input_ids.shape[1]).unsqueeze(0).to(self._device)
+ )
rows: list[dict] = []
tokens_per_iter = self.cfg.batch_size * self.cfg.isl
def _run_pt():
with torch.no_grad():
@@ -205,10 +222,11 @@
precision=self.cfg.precision,
)
)
if self._trt_model is not None:
+
def _run_trt():
with torch.no_grad():
return self._trt_model(self._input_ids, position_ids)
trt_t = warmup_and_time(_run_trt, (), iterations=self.cfg.iterations)
@@ -286,26 +304,32 @@
return rows
# ---------------------------------------------------------------------- #
def _run_pt(self):
- position_ids = torch.arange(self._input_ids.shape[1]).unsqueeze(0).to(self._device)
+ position_ids = (
+ torch.arange(self._input_ids.shape[1]).unsqueeze(0).to(self._device)
+ )
with torch.no_grad():
return self._model(self._input_ids, position_ids=position_ids)
def _run_trt(self):
- position_ids = torch.arange(self._input_ids.shape[1]).unsqueeze(0).to(self._device)
+ position_ids = (
+ torch.arange(self._input_ids.shape[1]).unsqueeze(0).to(self._device)
+ )
with torch.no_grad():
return self._trt_model(self._input_ids, position_ids)
def generate(self) -> None:
"""
Greedy decode against the original PyTorch model. TRT-compiled
GraphModules don't have HF's .generate(); a proper TRT generation
loop requires KV cache support — see tools/llm/utils.py for that.
"""
- assert self._tokenizer is not None and self._model is not None, "Call load() first"
+ assert (
+ self._tokenizer is not None and self._model is not None
+ ), "Call load() first"
inputs = self._tokenizer(self.cfg.prompt, return_tensors="pt").to(self._device)
input_ids = inputs["input_ids"]
print(f"[llm] Prompt: {self.cfg.prompt!r}")
--- /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/multimodal.py 2026-04-29 16:28:51.985193+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/multimodal.py 2026-04-29 16:29:16.075216+00:00
@@ -10,10 +10,11 @@
CLIPVisionModel : vision encoder only (also handled here for convenience)
CLIPTextModel : text encoder only
Benchmarking metric: latency + throughput of the full forward.
"""
+
from __future__ import annotations
from typing import Optional
import torch
@@ -31,11 +32,13 @@
)
from common.metrics import print_table, report_latency
from strategies.base import ModelStrategy, RunConfig
-def _build_clip_kwargs(model_cfg, batch_size: int, dtype: torch.dtype, device: str) -> dict:
+def _build_clip_kwargs(
+ model_cfg, batch_size: int, dtype: torch.dtype, device: str
+) -> dict:
"""
Build {input_ids, attention_mask, pixel_values} for a CLIP-style model.
Returned as kwargs to avoid positional-argument-order issues across model variants.
"""
text_cfg = getattr(model_cfg, "text_config", model_cfg)
@@ -49,13 +52,19 @@
else:
h = w = int(image_size)
num_channels = getattr(vision_cfg, "num_channels", 3)
return {
- "input_ids": torch.randint(0, vocab, (batch_size, seq_len), dtype=torch.int64).to(device),
- "attention_mask": torch.ones((batch_size, seq_len), dtype=torch.int64).to(device),
- "pixel_values": torch.randn(batch_size, num_channels, h, w, dtype=dtype).to(device),
+ "input_ids": torch.randint(
+ 0, vocab, (batch_size, seq_len), dtype=torch.int64
+ ).to(device),
+ "attention_mask": torch.ones((batch_size, seq_len), dtype=torch.int64).to(
+ device
+ ),
+ "pixel_values": torch.randn(batch_size, num_channels, h, w, dtype=dtype).to(
+ device
+ ),
}
class MultimodalStrategy(ModelStrategy):
def __init__(self, cfg: RunConfig):
@@ -103,13 +112,17 @@
self._compile_export()
def _compile_torch_compile(self) -> None:
print("[multimodal] torch.compile(backend='torch_tensorrt') ...")
options = compile_kwargs(self.cfg.precision, self.cfg.autocast)
- options.update({"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug})
-
- self._trt_model = torch.compile(self._model, backend="torch_tensorrt", options=options)
+ options.update(
+ {"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug}
+ )
+
+ self._trt_model = torch.compile(
+ self._model, backend="torch_tensorrt", options=options
+ )
with torch.no_grad():
_ = self._trt_model(**self._dummy_kwargs)
torch.cuda.synchronize()
def _compile_export(self) -> None:
@@ -140,17 +153,20 @@
)
torch.cuda.synchronize()
# Serializer is strict about arg/kwarg split; CLIP's EP has only kwargs.
maybe_save_trt_module(
- self.cfg, self._trt_model,
+ self.cfg,
+ self._trt_model,
arg_inputs=[],
kwarg_inputs=trt_kwarg_inputs,
log_prefix="[multimodal]",
)
maybe_save_trt_engine(
- self.cfg, ep, [],
+ self.cfg,
+ ep,
+ [],
kwarg_inputs=trt_kwarg_inputs,
log_prefix="[multimodal]",
)
# ---------------------------------------------------------------------- #
@@ -175,10 +191,11 @@
if self._trt_model is not None:
# The exported TRT module takes positional inputs in the order
# they were passed at export time.
trt_args = tuple(self._dummy_kwargs.values())
+
def _run_trt():
with torch.no_grad():
return self._trt_model(*trt_args)
trt_t = warmup_and_time(_run_trt, (), iterations=self.cfg.iterations)
--- /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/seq2seq.py 2026-04-29 16:28:51.985193+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/seq2seq.py 2026-04-29 16:29:16.166075+00:00
@@ -9,10 +9,11 @@
.generate() automatically uses the TRT encoder + PyTorch decoder.
Benchmarking metric: encoder latency on a fixed-length input + qualitative
.generate() output via the strategy's generate() method.
"""
+
from __future__ import annotations
from typing import Optional
import torch
@@ -61,54 +62,64 @@
self._tokenizer = AutoTokenizer.from_pretrained(self.cfg.model)
# Build encoder dummy inputs. T5 encoder takes (input_ids, attention_mask).
seq_len = self.cfg.isl
vocab = self._model.config.vocab_size
- input_ids = torch.randint(0, vocab, (self.cfg.batch_size, seq_len), dtype=torch.int64).to(self._device)
+ input_ids = torch.randint(
+ 0, vocab, (self.cfg.batch_size, seq_len), dtype=torch.int64
+ ).to(self._device)
attention_mask = torch.ones_like(input_ids)
self._dummy_inputs = (input_ids, attention_mask)
print(f"[seq2seq] Model loaded. Encoder input shape: {input_ids.shape}")
# ---------------------------------------------------------------------- #
def compile(self) -> None:
assert self._model is not None, "Call load() before compile()"
# T5: model.encoder. BART/Pegasus/mBART: model.model.encoder.
- if hasattr(self._model, "encoder") and not isinstance(
- getattr(self._model, "encoder", None), type(None)
- ) and hasattr(self._model.encoder, "forward"):
+ if (
+ hasattr(self._model, "encoder")
+ and not isinstance(getattr(self._model, "encoder", None), type(None))
+ and hasattr(self._model.encoder, "forward")
+ ):
encoder_owner = self._model
elif hasattr(self._model, "model") and hasattr(self._model.model, "encoder"):
encoder_owner = self._model.model
else:
raise AttributeError(
f"Could not find encoder submodule on {type(self._model).__name__}"
)
encoder = encoder_owner.encoder
self._encoder_owner = encoder_owner # remember for swap-back
- self._pt_encoder = encoder # keep PT reference for accuracy
+ self._pt_encoder = encoder # keep PT reference for accuracy
if self.cfg.mode == "compile":
self._trt_encoder = self._compile_torch_compile(encoder)
else:
self._trt_encoder = self._compile_export(encoder)
# Preserve attributes the parent model accesses on the encoder.
for attr_name in ("config", "dtype", "main_input_name"):
- if hasattr(encoder, attr_name) and not hasattr(self._trt_encoder, attr_name):
+ if hasattr(encoder, attr_name) and not hasattr(
+ self._trt_encoder, attr_name
+ ):
try:
setattr(self._trt_encoder, attr_name, getattr(encoder, attr_name))
except Exception:
pass
self._encoder_owner.encoder = self._trt_encoder
- print(f"[seq2seq] Compiled encoder swapped into {type(self._encoder_owner).__name__}.encoder")
+ print(
+ f"[seq2seq] Compiled encoder swapped into {type(self._encoder_owner).__name__}.encoder"
+ )
def _compile_torch_compile(self, encoder: torch.nn.Module) -> torch.nn.Module:
print("[seq2seq] torch.compile encoder (backend='torch_tensorrt') ...")
options = compile_kwargs(self.cfg.precision, self.cfg.autocast)
- options.update({"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug})
+ options.update(
+ {"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug}
+ )
compiled = torch.compile(encoder, backend="torch_tensorrt", options=options)
with torch.no_grad():
_ = compiled(*self._dummy_inputs)
torch.cuda.synchronize()
@@ -118,17 +129,22 @@
print("[seq2seq] torch.export.export encoder ...")
# Dynamic batch + seq_len.
ep = safe_export(
encoder,
args=self._dummy_inputs,
- dynamic_shapes=({0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO},
- {0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO}),
+ dynamic_shapes=(
+ {0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO},
+ {0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO},
+ ),
)
maybe_save_exported_program(self.cfg, ep, log_prefix="[seq2seq]")
print("[seq2seq] torch_tensorrt.dynamo.compile encoder ...")
- trt_inputs = [torch_tensorrt.Input(shape=t.shape, dtype=t.dtype) for t in self._dummy_inputs]
+ trt_inputs = [
+ torch_tensorrt.Input(shape=t.shape, dtype=t.dtype)
+ for t in self._dummy_inputs
+ ]
compiled = compile_with_trt(
ep,
inputs=trt_inputs,
precision=self.cfg.precision,
autocast=self.cfg.autocast,
@@ -138,11 +154,12 @@
engine_cache_dir=self.cfg.engine_cache_dir,
)
torch.cuda.synchronize()
maybe_save_trt_module(
- self.cfg, compiled,
+ self.cfg,
+ compiled,
arg_inputs=trt_inputs,
log_prefix="[seq2seq]",
)
maybe_save_trt_engine(self.cfg, ep, trt_inputs, log_prefix="[seq2seq]")
@@ -151,12 +168,20 @@
# ---------------------------------------------------------------------- #
def benchmark(self) -> list[dict]:
assert self._dummy_inputs is not None, "Call load() first"
rows: list[dict] = []
- encoder = self._trt_encoder if self._trt_encoder is not None else self._encoder_owner.encoder
- backend = f"torch_tensorrt[{self.cfg.mode}]" if self._trt_encoder is not None else "pytorch"
+ encoder = (
+ self._trt_encoder
+ if self._trt_encoder is not None
+ else self._encoder_owner.encoder
+ )
+ backend = (
+ f"torch_tensorrt[{self.cfg.mode}]"
+ if self._trt_encoder is not None
+ else "pytorch"
+ )
def _run():
with torch.no_grad():
return encoder(*self._dummy_inputs)
55e5086 to
c5ef4cd
Compare
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/bench.py 2026-04-29 19:07:00.389656+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/bench.py 2026-04-29 19:07:24.465515+00:00
@@ -1,16 +1,16 @@
"""
Shared benchmarking harness for all HF model strategies.
"""
+
from __future__ import annotations
import timeit
from typing import Callable, Sequence
import numpy as np
import torch
-
WARMUP_ITERS = 5
def warmup_and_time(
--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/dist.py 2026-04-29 19:07:00.389656+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/dist.py 2026-04-29 19:07:24.511046+00:00
@@ -12,10 +12,11 @@
MASTER_PORT – rendezvous port
Single-process (no torchtrtrun) defaults: WORLD_SIZE=1 (or unset), so
all helpers here become no-ops.
"""
+
from __future__ import annotations
import datetime
import logging
import os
@@ -44,11 +45,13 @@
def is_master() -> bool:
return rank() == 0
-def init_distributed(timeout_hours: int = 2) -> Optional["torch.distributed.ProcessGroup"]:
+def init_distributed(
+ timeout_hours: int = 2,
+) -> Optional["torch.distributed.ProcessGroup"]:
"""
Initialize torch.distributed if WORLD_SIZE > 1.
Long timeout (default 2 h) so TRT engine builds don't trigger the NCCL
watchdog when one rank takes longer than another to build.
@@ -79,17 +82,16 @@
# Wire NCCL into TRT (sets up the symbol resolution path so TRT's
# collectives talk to the same libnccl.so torchtrtrun preloaded).
try:
from torch_tensorrt.distributed import setup_nccl_for_torch_tensorrt
+
setup_nccl_for_torch_tensorrt()
except Exception as e:
logger.warning(f"setup_nccl_for_torch_tensorrt failed: {e}")
- logger.info(
- f"[rank {rank()}/{world_size()}] dist init OK on cuda:{local_rank()}"
- )
+ logger.info(f"[rank {rank()}/{world_size()}] dist init OK on cuda:{local_rank()}")
return dist.group.WORLD
def build_device_mesh(mesh_dim_names: tuple[str, ...] = ("tp",)):
"""
@@ -112,7 +114,8 @@
def barrier() -> None:
"""torch.distributed barrier — no-op when not distributed."""
if not is_distributed():
return
import torch.distributed as dist
+
if dist.is_initialized():
dist.barrier()
--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/compile.py 2026-04-29 19:07:00.389656+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/compile.py 2026-04-29 19:07:24.605988+00:00
@@ -22,17 +22,17 @@
)
`enabled_precisions` is deprecated when `use_explicit_typing=True` and is
NEVER set by these helpers.
"""
+
from __future__ import annotations
from typing import Iterable, Optional
import torch
import torch_tensorrt
-
# --------------------------------------------------------------------------- #
# Precision plumbing
# --------------------------------------------------------------------------- #
@@ -87,10 +87,11 @@
# --------------------------------------------------------------------------- #
# Export helper (with fallback for guard violations)
# --------------------------------------------------------------------------- #
+
def safe_export(
module: torch.nn.Module,
args: tuple = (),
kwargs: Optional[dict] = None,
@@ -110,11 +111,13 @@
kwargs=kwargs,
dynamic_shapes=dynamic_shapes,
strict=False,
)
except Exception as e:
- print(f"[compile] torch.export.export failed ({e}); retrying with deferred guards.")
+ print(
+ f"[compile] torch.export.export failed ({e}); retrying with deferred guards."
+ )
return torch.export._trace._export(
module,
args=args,
kwargs=kwargs,
dynamic_shapes=dynamic_shapes,
@@ -124,10 +127,11 @@
# --------------------------------------------------------------------------- #
# Compile wrapper
# --------------------------------------------------------------------------- #
+
def _build_trt_kwargs(
precision: str,
autocast: bool,
min_block_size: int,
@@ -179,13 +183,19 @@
- C++ runtime (use_python_runtime=False)
- engine caching ON
- offload_module_to_cpu OFF (opt-in for memory-constrained models)
"""
kw = _build_trt_kwargs(
- precision, autocast, min_block_size, debug,
- offload_module_to_cpu, cache_built_engines, reuse_cached_engines,
- engine_cache_dir, extra,
+ precision,
+ autocast,
+ min_block_size,
+ debug,
+ offload_module_to_cpu,
+ cache_built_engines,
+ reuse_cached_engines,
+ engine_cache_dir,
+ extra,
)
return torch_tensorrt.dynamo.compile(ep, inputs=list(inputs), **kw)
def maybe_save_exported_program(cfg, ep, *, log_prefix: str) -> None:
@@ -225,11 +235,13 @@
# torchscript needs real tensors for jit.trace; other formats accept
# torch_tensorrt.Input specs.
if fmt == "torchscript":
use_args = example_arg_inputs if example_arg_inputs is not None else arg_inputs
- use_kwargs = example_kwarg_inputs if example_kwarg_inputs is not None else kwarg_inputs
+ use_kwargs = (
+ example_kwarg_inputs if example_kwarg_inputs is not None else kwarg_inputs
+ )
else:
use_args = arg_inputs
use_kwargs = kwarg_inputs
save_kwargs = {"output_format": fmt}
@@ -308,11 +320,14 @@
single TRT engine end-to-end (no PyTorch fallback partitions). If
your model has unsupported ops, this call will fail; use the
`compile_with_trt` + `torch_tensorrt.save` path instead.
"""
kw = _build_trt_kwargs(
- precision, autocast, min_block_size, debug,
+ precision,
+ autocast,
+ min_block_size,
+ debug,
offload_module_to_cpu,
# Engine caching is irrelevant for serialization output.
cache_built_engines=False,
reuse_cached_engines=False,
engine_cache_dir=engine_cache_dir,
--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/accuracy.py 2026-04-29 19:07:00.389656+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/accuracy.py 2026-04-29 19:07:24.631701+00:00
@@ -8,21 +8,22 @@
Default tolerances are tuned for FP16 (atol=1e-2, rtol=1e-2,
cos_sim_min=0.99). Override via --accuracy-atol / --accuracy-rtol /
--accuracy-cos-sim-min when comparing tighter precisions or models
known to have larger accumulated error.
"""
+
from __future__ import annotations
from typing import Iterable
import torch
import torch.utils._pytree as pytree
-
# --------------------------------------------------------------------------- #
# Output flattening
# --------------------------------------------------------------------------- #
+
def _flatten_to_tensors(out) -> list[torch.Tensor]:
"""
Flatten an arbitrary HF model output (ModelOutput dataclass, dict,
tuple, list, or single tensor) into a list of leaf tensors using
@@ -33,10 +34,11 @@
# --------------------------------------------------------------------------- #
# Per-tensor metrics
# --------------------------------------------------------------------------- #
+
def _cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float:
a = a.detach().to(torch.float32).flatten()
b = b.detach().to(torch.float32).flatten()
if a.numel() == 0 or b.numel() == 0:
@@ -87,10 +89,11 @@
# --------------------------------------------------------------------------- #
# Compare two outputs (each a tensor / dict / dataclass / tuple)
# --------------------------------------------------------------------------- #
+
def compare_outputs(
pt_out,
trt_out,
*,
atol: float = 1e-2,
@@ -99,22 +102,28 @@
) -> list[dict]:
pt_leaves = _flatten_to_tensors(pt_out)
trt_leaves = _flatten_to_tensors(trt_out)
if len(pt_leaves) != len(trt_leaves):
- return [{
- "name": "<output-count-mismatch>",
- "shape_pt": f"{len(pt_leaves)} tensors",
- "shape_trt": f"{len(trt_leaves)} tensors",
- "shape_match": False,
- "cos_sim": float("nan"),
- "max_abs": float("nan"),
- "mean_abs": float("nan"),
- "allclose": False,
- }]
-
- names = list(output_names) if output_names else [f"out[{i}]" for i in range(len(pt_leaves))]
+ return [
+ {
+ "name": "<output-count-mismatch>",
+ "shape_pt": f"{len(pt_leaves)} tensors",
+ "shape_trt": f"{len(trt_leaves)} tensors",
+ "shape_match": False,
+ "cos_sim": float("nan"),
+ "max_abs": float("nan"),
+ "mean_abs": float("nan"),
+ "allclose": False,
+ }
+ ]
+
+ names = (
+ list(output_names)
+ if output_names
+ else [f"out[{i}]" for i in range(len(pt_leaves))]
+ )
if len(names) < len(pt_leaves):
names += [f"out[{i}]" for i in range(len(names), len(pt_leaves))]
rows: list[dict] = []
for name, pt, trt in zip(names, pt_leaves, trt_leaves):
@@ -125,10 +134,11 @@
# --------------------------------------------------------------------------- #
# Reporting
# --------------------------------------------------------------------------- #
+
def overall_pass(
rows: list[dict],
*,
cos_sim_min: float = 0.99,
@@ -169,11 +179,13 @@
print(f"\n{'=' * 70}")
print(f" {title}")
print(f"{'=' * 70}")
cols = ("name", "shape_pt", "cos_sim", "max_abs", "mean_abs", "allclose")
- widths = {c: max(len(c), max(len(_fmt(r.get(c, ""), c)) for r in rows)) for c in cols}
+ widths = {
+ c: max(len(c), max(len(_fmt(r.get(c, ""), c)) for r in rows)) for c in cols
+ }
header = " ".join(c.ljust(widths[c]) for c in cols)
print(header)
print("-" * len(header))
for r in rows:
print(" ".join(_fmt(r.get(c, ""), c).ljust(widths[c]) for c in cols))There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/bench.py 2026-04-29 23:52:22.708681+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/bench.py 2026-04-29 23:52:45.045685+00:00
@@ -1,16 +1,16 @@
"""
Shared benchmarking harness for all HF model strategies.
"""
+
from __future__ import annotations
import timeit
from typing import Callable, Sequence
import numpy as np
import torch
-
WARMUP_ITERS = 5
def warmup_and_time(
--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/accuracy.py 2026-04-29 23:52:22.708681+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/accuracy.py 2026-04-29 23:52:45.149558+00:00
@@ -8,21 +8,22 @@
Default tolerances are tuned for FP16 (atol=1e-2, rtol=1e-2,
cos_sim_min=0.99). Override via --accuracy-atol / --accuracy-rtol /
--accuracy-cos-sim-min when comparing tighter precisions or models
known to have larger accumulated error.
"""
+
from __future__ import annotations
from typing import Iterable
import torch
import torch.utils._pytree as pytree
-
# --------------------------------------------------------------------------- #
# Output flattening
# --------------------------------------------------------------------------- #
+
def _flatten_to_tensors(out) -> list[torch.Tensor]:
"""
Flatten an arbitrary HF model output (ModelOutput dataclass, dict,
tuple, list, or single tensor) into a list of leaf tensors using
@@ -33,10 +34,11 @@
# --------------------------------------------------------------------------- #
# Per-tensor metrics
# --------------------------------------------------------------------------- #
+
def _cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float:
a = a.detach().to(torch.float32).flatten()
b = b.detach().to(torch.float32).flatten()
if a.numel() == 0 or b.numel() == 0:
@@ -87,10 +89,11 @@
# --------------------------------------------------------------------------- #
# Compare two outputs (each a tensor / dict / dataclass / tuple)
# --------------------------------------------------------------------------- #
+
def compare_outputs(
pt_out,
trt_out,
*,
atol: float = 1e-2,
@@ -99,22 +102,28 @@
) -> list[dict]:
pt_leaves = _flatten_to_tensors(pt_out)
trt_leaves = _flatten_to_tensors(trt_out)
if len(pt_leaves) != len(trt_leaves):
- return [{
- "name": "<output-count-mismatch>",
- "shape_pt": f"{len(pt_leaves)} tensors",
- "shape_trt": f"{len(trt_leaves)} tensors",
- "shape_match": False,
- "cos_sim": float("nan"),
- "max_abs": float("nan"),
- "mean_abs": float("nan"),
- "allclose": False,
- }]
-
- names = list(output_names) if output_names else [f"out[{i}]" for i in range(len(pt_leaves))]
+ return [
+ {
+ "name": "<output-count-mismatch>",
+ "shape_pt": f"{len(pt_leaves)} tensors",
+ "shape_trt": f"{len(trt_leaves)} tensors",
+ "shape_match": False,
+ "cos_sim": float("nan"),
+ "max_abs": float("nan"),
+ "mean_abs": float("nan"),
+ "allclose": False,
+ }
+ ]
+
+ names = (
+ list(output_names)
+ if output_names
+ else [f"out[{i}]" for i in range(len(pt_leaves))]
+ )
if len(names) < len(pt_leaves):
names += [f"out[{i}]" for i in range(len(names), len(pt_leaves))]
rows: list[dict] = []
for name, pt, trt in zip(names, pt_leaves, trt_leaves):
@@ -125,10 +134,11 @@
# --------------------------------------------------------------------------- #
# Reporting
# --------------------------------------------------------------------------- #
+
def overall_pass(
rows: list[dict],
*,
cos_sim_min: float = 0.99,
@@ -169,11 +179,13 @@
print(f"\n{'=' * 70}")
print(f" {title}")
print(f"{'=' * 70}")
cols = ("name", "shape_pt", "cos_sim", "max_abs", "mean_abs", "allclose")
- widths = {c: max(len(c), max(len(_fmt(r.get(c, ""), c)) for r in rows)) for c in cols}
+ widths = {
+ c: max(len(c), max(len(_fmt(r.get(c, ""), c)) for r in rows)) for c in cols
+ }
header = " ".join(c.ljust(widths[c]) for c in cols)
print(header)
print("-" * len(header))
for r in rows:
print(" ".join(_fmt(r.get(c, ""), c).ljust(widths[c]) for c in cols))
--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/dist.py 2026-04-29 23:52:22.708681+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/dist.py 2026-04-29 23:52:45.157749+00:00
@@ -12,10 +12,11 @@
MASTER_PORT – rendezvous port
Single-process (no torchtrtrun) defaults: WORLD_SIZE=1 (or unset), so
all helpers here become no-ops.
"""
+
from __future__ import annotations
import datetime
import logging
import os
@@ -44,11 +45,13 @@
def is_master() -> bool:
return rank() == 0
-def init_distributed(timeout_hours: int = 2) -> Optional["torch.distributed.ProcessGroup"]:
+def init_distributed(
+ timeout_hours: int = 2,
+) -> Optional["torch.distributed.ProcessGroup"]:
"""
Initialize torch.distributed if WORLD_SIZE > 1.
Long timeout (default 2 h) so TRT engine builds don't trigger the NCCL
watchdog when one rank takes longer than another to build.
@@ -79,17 +82,16 @@
# Wire NCCL into TRT (sets up the symbol resolution path so TRT's
# collectives talk to the same libnccl.so torchtrtrun preloaded).
try:
from torch_tensorrt.distributed import setup_nccl_for_torch_tensorrt
+
setup_nccl_for_torch_tensorrt()
except Exception as e:
logger.warning(f"setup_nccl_for_torch_tensorrt failed: {e}")
- logger.info(
- f"[rank {rank()}/{world_size()}] dist init OK on cuda:{local_rank()}"
- )
+ logger.info(f"[rank {rank()}/{world_size()}] dist init OK on cuda:{local_rank()}")
return dist.group.WORLD
def build_device_mesh(mesh_dim_names: tuple[str, ...] = ("tp",)):
"""
@@ -112,7 +114,8 @@
def barrier() -> None:
"""torch.distributed barrier — no-op when not distributed."""
if not is_distributed():
return
import torch.distributed as dist
+
if dist.is_initialized():
dist.barrier()
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
An quickly generated tool that lets us run a number of model classes from huggingface OOB similar to run LLM.
Fixes #4220
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: