Skip to content

feat: a generic run hf tool that works for a number of model classes#4223

Open
narendasan wants to merge 2 commits intomainfrom
narendasan/push-prvysnyvnylw
Open

feat: a generic run hf tool that works for a number of model classes#4223
narendasan wants to merge 2 commits intomainfrom
narendasan/push-prvysnyvnylw

Conversation

@narendasan
Copy link
Copy Markdown
Collaborator

Description

An quickly generated tool that lets us run a number of model classes from huggingface OOB similar to run LLM.

Fixes #4220

Type of change

Please delete options that are not relevant and/or add your own.

  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/bench.py	2026-04-29 16:28:51.984749+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/bench.py	2026-04-29 16:29:15.476456+00:00
@@ -1,16 +1,16 @@
"""
Shared benchmarking harness for all HF model strategies.
"""
+
from __future__ import annotations

import timeit
from typing import Callable, Sequence

import numpy as np
import torch
-

WARMUP_ITERS = 5


def warmup_and_time(
--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/metrics.py	2026-04-29 16:28:51.984749+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/metrics.py	2026-04-29 16:29:15.561992+00:00
@@ -3,10 +3,11 @@
  - Encoder / classifier : samples/s (throughput) + latency
  - LLM                  : tokens/s
  - Diffusion            : images/s (one full denoising pass)
  - Audio                : real-time factor (audio_duration / inference_time)
"""
+
from __future__ import annotations

import json


--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/dist.py	2026-04-29 16:28:51.984749+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/dist.py	2026-04-29 16:29:15.564318+00:00
@@ -12,10 +12,11 @@
  MASTER_PORT  – rendezvous port

Single-process (no torchtrtrun) defaults: WORLD_SIZE=1 (or unset), so
all helpers here become no-ops.
"""
+
from __future__ import annotations

import datetime
import logging
import os
@@ -44,11 +45,13 @@

def is_master() -> bool:
    return rank() == 0


-def init_distributed(timeout_hours: int = 2) -> Optional["torch.distributed.ProcessGroup"]:
+def init_distributed(
+    timeout_hours: int = 2,
+) -> Optional["torch.distributed.ProcessGroup"]:
    """
    Initialize torch.distributed if WORLD_SIZE > 1.

    Long timeout (default 2 h) so TRT engine builds don't trigger the NCCL
    watchdog when one rank takes longer than another to build.
@@ -79,17 +82,16 @@

    # Wire NCCL into TRT (sets up the symbol resolution path so TRT's
    # collectives talk to the same libnccl.so torchtrtrun preloaded).
    try:
        from torch_tensorrt.distributed import setup_nccl_for_torch_tensorrt
+
        setup_nccl_for_torch_tensorrt()
    except Exception as e:
        logger.warning(f"setup_nccl_for_torch_tensorrt failed: {e}")

-    logger.info(
-        f"[rank {rank()}/{world_size()}] dist init OK on cuda:{local_rank()}"
-    )
+    logger.info(f"[rank {rank()}/{world_size()}] dist init OK on cuda:{local_rank()}")
    return dist.group.WORLD


def build_device_mesh(mesh_dim_names: tuple[str, ...] = ("tp",)):
    """
@@ -112,7 +114,8 @@
def barrier() -> None:
    """torch.distributed barrier — no-op when not distributed."""
    if not is_distributed():
        return
    import torch.distributed as dist
+
    if dist.is_initialized():
        dist.barrier()
--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/accuracy.py	2026-04-29 16:28:51.984749+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/accuracy.py	2026-04-29 16:29:15.582221+00:00
@@ -8,21 +8,22 @@
Default tolerances are tuned for FP16 (atol=1e-2, rtol=1e-2,
cos_sim_min=0.99).  Override via --accuracy-atol / --accuracy-rtol /
--accuracy-cos-sim-min when comparing tighter precisions or models
known to have larger accumulated error.
"""
+
from __future__ import annotations

from typing import Iterable

import torch
import torch.utils._pytree as pytree

-
# --------------------------------------------------------------------------- #
# Output flattening
# --------------------------------------------------------------------------- #
+

def _flatten_to_tensors(out) -> list[torch.Tensor]:
    """
    Flatten an arbitrary HF model output (ModelOutput dataclass, dict,
    tuple, list, or single tensor) into a list of leaf tensors using
@@ -33,10 +34,11 @@


# --------------------------------------------------------------------------- #
# Per-tensor metrics
# --------------------------------------------------------------------------- #
+

def _cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float:
    a = a.detach().to(torch.float32).flatten()
    b = b.detach().to(torch.float32).flatten()
    if a.numel() == 0 or b.numel() == 0:
@@ -87,10 +89,11 @@

# --------------------------------------------------------------------------- #
# Compare two outputs (each a tensor / dict / dataclass / tuple)
# --------------------------------------------------------------------------- #

+
def compare_outputs(
    pt_out,
    trt_out,
    *,
    atol: float = 1e-2,
@@ -99,22 +102,28 @@
) -> list[dict]:
    pt_leaves = _flatten_to_tensors(pt_out)
    trt_leaves = _flatten_to_tensors(trt_out)

    if len(pt_leaves) != len(trt_leaves):
-        return [{
-            "name": "<output-count-mismatch>",
-            "shape_pt": f"{len(pt_leaves)} tensors",
-            "shape_trt": f"{len(trt_leaves)} tensors",
-            "shape_match": False,
-            "cos_sim": float("nan"),
-            "max_abs": float("nan"),
-            "mean_abs": float("nan"),
-            "allclose": False,
-        }]
-
-    names = list(output_names) if output_names else [f"out[{i}]" for i in range(len(pt_leaves))]
+        return [
+            {
+                "name": "<output-count-mismatch>",
+                "shape_pt": f"{len(pt_leaves)} tensors",
+                "shape_trt": f"{len(trt_leaves)} tensors",
+                "shape_match": False,
+                "cos_sim": float("nan"),
+                "max_abs": float("nan"),
+                "mean_abs": float("nan"),
+                "allclose": False,
+            }
+        ]
+
+    names = (
+        list(output_names)
+        if output_names
+        else [f"out[{i}]" for i in range(len(pt_leaves))]
+    )
    if len(names) < len(pt_leaves):
        names += [f"out[{i}]" for i in range(len(names), len(pt_leaves))]

    rows: list[dict] = []
    for name, pt, trt in zip(names, pt_leaves, trt_leaves):
@@ -125,10 +134,11 @@


# --------------------------------------------------------------------------- #
# Reporting
# --------------------------------------------------------------------------- #
+

def overall_pass(
    rows: list[dict],
    *,
    cos_sim_min: float = 0.99,
@@ -169,11 +179,13 @@
        print(f"\n{'=' * 70}")
        print(f"  {title}")
        print(f"{'=' * 70}")

    cols = ("name", "shape_pt", "cos_sim", "max_abs", "mean_abs", "allclose")
-    widths = {c: max(len(c), max(len(_fmt(r.get(c, ""), c)) for r in rows)) for c in cols}
+    widths = {
+        c: max(len(c), max(len(_fmt(r.get(c, ""), c)) for r in rows)) for c in cols
+    }
    header = "  ".join(c.ljust(widths[c]) for c in cols)
    print(header)
    print("-" * len(header))
    for r in rows:
        print("  ".join(_fmt(r.get(c, ""), c).ljust(widths[c]) for c in cols))
--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/compile.py	2026-04-29 16:28:51.984749+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/compile.py	2026-04-29 16:29:15.613228+00:00
@@ -22,17 +22,17 @@
    )

`enabled_precisions` is deprecated when `use_explicit_typing=True` and is
NEVER set by these helpers.
"""
+
from __future__ import annotations

from typing import Iterable, Optional

import torch
import torch_tensorrt
-

# --------------------------------------------------------------------------- #
# Precision plumbing
# --------------------------------------------------------------------------- #

@@ -87,10 +87,11 @@


# --------------------------------------------------------------------------- #
# Export helper (with fallback for guard violations)
# --------------------------------------------------------------------------- #
+

def safe_export(
    module: torch.nn.Module,
    args: tuple = (),
    kwargs: Optional[dict] = None,
@@ -110,11 +111,13 @@
                kwargs=kwargs,
                dynamic_shapes=dynamic_shapes,
                strict=False,
            )
        except Exception as e:
-            print(f"[compile] torch.export.export failed ({e}); retrying with deferred guards.")
+            print(
+                f"[compile] torch.export.export failed ({e}); retrying with deferred guards."
+            )
            return torch.export._trace._export(
                module,
                args=args,
                kwargs=kwargs,
                dynamic_shapes=dynamic_shapes,
@@ -124,10 +127,11 @@


# --------------------------------------------------------------------------- #
# Compile wrapper
# --------------------------------------------------------------------------- #
+

def _build_trt_kwargs(
    precision: str,
    autocast: bool,
    min_block_size: int,
@@ -179,13 +183,19 @@
      - C++ runtime (use_python_runtime=False)
      - engine caching ON
      - offload_module_to_cpu OFF (opt-in for memory-constrained models)
    """
    kw = _build_trt_kwargs(
-        precision, autocast, min_block_size, debug,
-        offload_module_to_cpu, cache_built_engines, reuse_cached_engines,
-        engine_cache_dir, extra,
+        precision,
+        autocast,
+        min_block_size,
+        debug,
+        offload_module_to_cpu,
+        cache_built_engines,
+        reuse_cached_engines,
+        engine_cache_dir,
+        extra,
    )
    return torch_tensorrt.dynamo.compile(ep, inputs=list(inputs), **kw)


def maybe_save_exported_program(cfg, ep, *, log_prefix: str) -> None:
@@ -225,11 +235,13 @@

    # torchscript needs real tensors for jit.trace; other formats accept
    # torch_tensorrt.Input specs.
    if fmt == "torchscript":
        use_args = example_arg_inputs if example_arg_inputs is not None else arg_inputs
-        use_kwargs = example_kwarg_inputs if example_kwarg_inputs is not None else kwarg_inputs
+        use_kwargs = (
+            example_kwarg_inputs if example_kwarg_inputs is not None else kwarg_inputs
+        )
    else:
        use_args = arg_inputs
        use_kwargs = kwarg_inputs

    save_kwargs = {"output_format": fmt}
@@ -308,11 +320,14 @@
    single TRT engine end-to-end (no PyTorch fallback partitions).  If
    your model has unsupported ops, this call will fail; use the
    `compile_with_trt` + `torch_tensorrt.save` path instead.
    """
    kw = _build_trt_kwargs(
-        precision, autocast, min_block_size, debug,
+        precision,
+        autocast,
+        min_block_size,
+        debug,
        offload_module_to_cpu,
        # Engine caching is irrelevant for serialization output.
        cache_built_engines=False,
        reuse_cached_engines=False,
        engine_cache_dir=engine_cache_dir,
--- /home/runner/work/TensorRT/TensorRT/tools/hf/run_hf.py	2026-04-29 16:28:51.984749+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/run_hf.py	2026-04-29 16:29:15.643944+00:00
@@ -64,10 +64,11 @@
  Path B (--autocast) : Model stays FP32; TRT compiler casts via
                        enable_autocast=True, autocast_low_precision_type=<dtype>.

`enabled_precisions` is deprecated under use_explicit_typing and is never set.
"""
+
from __future__ import annotations

import sys
from pathlib import Path

@@ -92,37 +93,45 @@

def _build_strategy(family: str, cfg: RunConfig):
    if family == "llm":
        if is_distributed():
            from strategies.llm_tp import LLMTPStrategy
+
            return LLMTPStrategy(cfg)
        from strategies.llm import LLMStrategy
+
        return LLMStrategy(cfg)
    if family == "encoder":
        from strategies.encoder import EncoderStrategy
+
        return EncoderStrategy(cfg)
    if family == "seq2seq":
        from strategies.seq2seq import SeqToSeqStrategy
+
        return SeqToSeqStrategy(cfg)
    if family == "diffusion":
        from strategies.diffusion import DiffusionStrategy
+
        return DiffusionStrategy(cfg)
    if family == "audio":
        from strategies.audio import AudioStrategy
+
        return AudioStrategy(cfg)
    if family == "multimodal":
        from strategies.multimodal import MultimodalStrategy
+
        return MultimodalStrategy(cfg)
    raise ValueError(f"No strategy implemented for family '{family}'")


def main() -> None:
    args: CLIArgs = tyro.cli(CLIArgs)

    if not args.debug:
        import logging
        import torch_tensorrt
+
        torch_tensorrt.logging.set_level(logging.WARNING)

    # Initialize torch.distributed if launched under torchtrtrun
    # (WORLD_SIZE > 1).  No-op for single-process runs.
    init_distributed()
@@ -159,10 +168,11 @@
    if args.benchmark:
        rows = strategy.benchmark()

    if args.json_out and rows:
        from common.metrics import dump_json
+
        dump_json(rows, args.json_out)


if __name__ == "__main__":
    main()
--- /home/runner/work/TensorRT/TensorRT/tools/hf/detect.py	2026-04-29 16:28:51.984749+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/detect.py	2026-04-29 16:29:15.678084+00:00
@@ -10,48 +10,92 @@
  "audio"     – speech / ASR (Whisper, wav2vec2, ...)

Returns the family string.  Raises ValueError if the family cannot be
determined and no --task override was given.
"""
+
from __future__ import annotations

from typing import Optional
-

# Maps HuggingFace model_type strings to strategy family names.
# Keep sorted within each group for readability.
_LLM_TYPES = {
-    "bloom", "codegen", "falcon", "gemma", "gemma2", "gemma3",
-    "gpt2", "gpt_bigcode", "gpt_neo", "gpt_neox", "gptj",
-    "llama", "mistral", "mixtral", "mpt", "opt",
-    "phi", "phi3", "qwen2", "qwen2_moe", "starcoder2",
+    "bloom",
+    "codegen",
+    "falcon",
+    "gemma",
+    "gemma2",
+    "gemma3",
+    "gpt2",
+    "gpt_bigcode",
+    "gpt_neo",
+    "gpt_neox",
+    "gptj",
+    "llama",
+    "mistral",
+    "mixtral",
+    "mpt",
+    "opt",
+    "phi",
+    "phi3",
+    "qwen2",
+    "qwen2_moe",
+    "starcoder2",
    "stablelm",
}

_ENCODER_TYPES = {
-    "albert", "bert", "camembert", "convnext", "deberta",
-    "deberta-v2", "distilbert", "electra", "mobilenet_v2",
-    "resnet", "roberta", "swin", "vit", "xlm", "xlm-roberta",
+    "albert",
+    "bert",
+    "camembert",
+    "convnext",
+    "deberta",
+    "deberta-v2",
+    "distilbert",
+    "electra",
+    "mobilenet_v2",
+    "resnet",
+    "roberta",
+    "swin",
+    "vit",
+    "xlm",
+    "xlm-roberta",
    "efficientnet",
}

_SEQ2SEQ_TYPES = {
-    "bart", "longt5", "mt5", "mbart", "pegasus", "t5",
+    "bart",
+    "longt5",
+    "mt5",
+    "mbart",
+    "pegasus",
+    "t5",
}

_DIFFUSION_TYPES = {
-    "flux", "stable-diffusion", "stable_diffusion",
-    "stable-diffusion-xl", "stable_diffusion_xl",
+    "flux",
+    "stable-diffusion",
+    "stable_diffusion",
+    "stable-diffusion-xl",
+    "stable_diffusion_xl",
    "unet-2d-condition",
}

_AUDIO_TYPES = {
-    "hubert", "wav2vec2", "wavlm", "whisper",
+    "hubert",
+    "wav2vec2",
+    "wavlm",
+    "whisper",
}

_MULTIMODAL_TYPES = {
-    "blip", "blip-2", "clip", "clipseg", "siglip",
+    "blip",
+    "blip-2",
+    "clip",
+    "clipseg",
+    "siglip",
}


def detect_family(model_id: str, task_override: Optional[str] = None) -> str:
    """
@@ -72,19 +116,21 @@
    # instead of config.json with a model_type field — fall back to that.
    model_type = ""
    last_error: Optional[Exception] = None
    try:
        from transformers import AutoConfig
+
        cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=False)
        model_type = getattr(cfg, "model_type", "").lower()
    except Exception as e:
        last_error = e

    if not model_type:
        try:
            from huggingface_hub import hf_hub_download
            import json
+
            path = hf_hub_download(model_id, "model_index.json")
            with open(path) as f:
                idx = json.load(f)
            cls_name = idx.get("_class_name", "").lower()
            if "pipeline" in cls_name or any(
@@ -96,11 +142,14 @@

    if not model_type:
        msg = f"Could not determine family for '{model_id}'."
        # Detect gated / authentication errors and surface them clearly.
        err_str = str(last_error) if last_error else ""
-        if any(k in err_str for k in ("Repository Not Found", "401", "gated", "authenticated", "access")):
+        if any(
+            k in err_str
+            for k in ("Repository Not Found", "401", "gated", "authenticated", "access")
+        ):
            msg += (
                "\n\nThis model may be gated or private.  Run "
                "`hf auth login` (or `huggingface-cli login`) and accept the "
                "model's license on https://huggingface.co/" + model_id
            )
--- /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/base.py	2026-04-29 16:28:51.985193+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/base.py	2026-04-29 16:29:15.754327+00:00
@@ -5,10 +5,11 @@
  load()      – download & prepare the model (called once, on CPU or GPU)
  compile()   – export + TRT compile (or torch.compile fast path)
  benchmark() – warmup + timed loop + return list[dict] rows for metrics.py
  generate()  – optional: run a single sample-mode forward (text, image, audio)
"""
+
from __future__ import annotations

import abc
import dataclasses
from typing import Literal, Optional
@@ -166,11 +167,12 @@

        pt_out = self._run_pt()  # type: ignore[attr-defined]
        trt_out = self._run_trt()  # type: ignore[attr-defined]

        rows = compare_outputs(
-            pt_out, trt_out,
+            pt_out,
+            trt_out,
            atol=self.cfg.accuracy_atol,
            rtol=self.cfg.accuracy_rtol,
        )
        print_accuracy_table(
            rows,
--- /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/audio.py	2026-04-29 16:28:51.985193+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/audio.py	2026-04-29 16:29:15.793592+00:00
@@ -8,10 +8,11 @@
For Whisper the encoder is the bulk of the compute on long audio, so this
still delivers most of the available speedup.

Benchmarking metric: real-time factor (RTF = inference_time / audio_duration).
"""
+
from __future__ import annotations

from typing import Optional

import torch
@@ -89,11 +90,13 @@
        print("[audio] Compiled encoder swapped into self._model.model.encoder")

    def _compile_torch_compile(self, encoder: torch.nn.Module) -> torch.nn.Module:
        print("[audio] torch.compile encoder (backend='torch_tensorrt') ...")
        options = compile_kwargs(self.cfg.precision, self.cfg.autocast)
-        options.update({"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug})
+        options.update(
+            {"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug}
+        )

        compiled = torch.compile(encoder, backend="torch_tensorrt", options=options)
        dummy = torch.randn(*self._mel_shape, dtype=self._model_dtype).to(self._device)
        with torch.no_grad():
            _ = compiled(dummy)
@@ -102,11 +105,13 @@

    def _compile_export(self, encoder: torch.nn.Module) -> torch.nn.Module:
        dummy = torch.randn(*self._mel_shape, dtype=self._model_dtype).to(self._device)

        print("[audio] torch.export.export encoder ...")
-        ep = safe_export(encoder, args=(dummy,), dynamic_shapes=({0: torch.export.Dim.AUTO},))
+        ep = safe_export(
+            encoder, args=(dummy,), dynamic_shapes=({0: torch.export.Dim.AUTO},)
+        )
        maybe_save_exported_program(self.cfg, ep, log_prefix="[audio]")

        print("[audio] torch_tensorrt.dynamo.compile encoder ...")
        trt_inputs = [torch_tensorrt.Input(shape=dummy.shape, dtype=dummy.dtype)]
        compiled = compile_with_trt(
@@ -120,29 +125,40 @@
            engine_cache_dir=self.cfg.engine_cache_dir,
        )
        torch.cuda.synchronize()

        maybe_save_trt_module(
-            self.cfg, compiled,
+            self.cfg,
+            compiled,
            arg_inputs=trt_inputs,
            log_prefix="[audio]",
        )
        maybe_save_trt_engine(self.cfg, ep, trt_inputs, log_prefix="[audio]")

        return compiled

    # ---------------------------------------------------------------------- #

    def benchmark(self) -> list[dict]:
-        assert self._model is not None and self._mel_shape is not None, "Call load() first"
+        assert (
+            self._model is not None and self._mel_shape is not None
+        ), "Call load() first"
        rows: list[dict] = []

        dummy = torch.randn(*self._mel_shape, dtype=self._model_dtype).to(self._device)
        audio_duration_s = self.cfg.audio_duration_s

-        encoder_for_bench = self._trt_encoder if self._trt_encoder is not None else self._model.model.encoder
-        backend = f"torch_tensorrt[{self.cfg.mode}]" if self._trt_encoder is not None else "pytorch"
+        encoder_for_bench = (
+            self._trt_encoder
+            if self._trt_encoder is not None
+            else self._model.model.encoder
+        )
+        backend = (
+            f"torch_tensorrt[{self.cfg.mode}]"
+            if self._trt_encoder is not None
+            else "pytorch"
+        )

        def _run_encoder():
            with torch.no_grad():
                return encoder_for_bench(dummy)

@@ -180,11 +196,13 @@

    def _run_trt(self):
        # Reuse the same dummy from _run_pt so PT vs TRT compare on identical input.
        dummy = getattr(self, "_last_dummy", None)
        if dummy is None:
-            dummy = torch.randn(*self._mel_shape, dtype=self._model_dtype).to(self._device)
+            dummy = torch.randn(*self._mel_shape, dtype=self._model_dtype).to(
+                self._device
+            )
        with torch.no_grad():
            return self._trt_encoder(dummy)

    def generate(self) -> None:
        """Transcribe a silence clip end-to-end (TRT encoder + PyTorch decoder)."""
--- /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/encoder.py	2026-04-29 16:28:51.985193+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/encoder.py	2026-04-29 16:29:15.896539+00:00
@@ -2,10 +2,11 @@
Encoder strategy: BERT, RoBERTa, ViT, ResNet, EfficientNet, etc.

Whole-model export.  These are the easiest models to compile because they
have no KV cache, no generation loop, and minimal control flow.
"""
+
from __future__ import annotations

from typing import Optional

import torch
@@ -25,11 +26,16 @@
from common.metrics import print_table, report_latency
from strategies.base import ModelStrategy, RunConfig

# HF model_type values that use vision inputs (pixel_values) rather than input_ids.
_VISION_TYPES = {
-    "convnext", "efficientnet", "mobilenet_v2", "resnet", "swin", "vit",
+    "convnext",
+    "efficientnet",
+    "mobilenet_v2",
+    "resnet",
+    "swin",
+    "vit",
}


def _is_vision_model(cfg) -> bool:
    return getattr(cfg, "model_type", "").lower() in _VISION_TYPES
@@ -47,11 +53,13 @@
        nc = getattr(model_cfg, "num_channels", 3)
        return (torch.randn(batch_size, nc, h, w, dtype=dtype).to(device),)

    seq_len = 128
    vocab_size = getattr(model_cfg, "vocab_size", 30522)
-    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.int64).to(device)
+    input_ids = torch.randint(
+        0, vocab_size, (batch_size, seq_len), dtype=torch.int64
+    ).to(device)
    attention_mask = torch.ones_like(input_ids)
    return (input_ids, attention_mask)


class EncoderStrategy(ModelStrategy):
@@ -88,11 +96,13 @@
            self._model_cfg,
            self.cfg.batch_size,
            self._model_dtype,
            self._device,
        )
-        print(f"[encoder] Model loaded.  Input shapes: {[t.shape for t in self._dummy_inputs]}")
+        print(
+            f"[encoder] Model loaded.  Input shapes: {[t.shape for t in self._dummy_inputs]}"
+        )

    # ---------------------------------------------------------------------- #

    def compile(self) -> None:
        assert self._model is not None, "Call load() before compile()"
@@ -103,11 +113,13 @@
            self._compile_export()

    def _compile_torch_compile(self) -> None:
        print("[encoder] torch.compile(backend='torch_tensorrt') ...")
        options = compile_kwargs(self.cfg.precision, self.cfg.autocast)
-        options.update({"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug})
+        options.update(
+            {"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug}
+        )

        self._trt_model = torch.compile(
            self._model,
            backend="torch_tensorrt",
            options=options,
@@ -131,13 +143,11 @@
        print("[encoder] torch.export.export ...")
        ep = safe_export(self._model, args=dummy, dynamic_shapes=dyn)
        maybe_save_exported_program(self.cfg, ep, log_prefix="[encoder]")

        print("[encoder] torch_tensorrt.dynamo.compile ...")
-        trt_inputs = [
-            torch_tensorrt.Input(shape=t.shape, dtype=t.dtype) for t in dummy
-        ]
+        trt_inputs = [torch_tensorrt.Input(shape=t.shape, dtype=t.dtype) for t in dummy]
        self._trt_model = compile_with_trt(
            ep,
            inputs=trt_inputs,
            precision=self.cfg.precision,
            autocast=self.cfg.autocast,
@@ -148,11 +158,12 @@
        )
        torch.cuda.synchronize()
        print("[encoder] TRT compile done.")

        maybe_save_trt_module(
-            self.cfg, self._trt_model,
+            self.cfg,
+            self._trt_model,
            arg_inputs=trt_inputs,
            example_arg_inputs=list(dummy),
            log_prefix="[encoder]",
        )
        maybe_save_trt_engine(self.cfg, ep, trt_inputs, log_prefix="[encoder]")
@@ -163,11 +174,13 @@
        assert self._dummy_inputs is not None, "Call load() first"
        rows: list[dict] = []

        with torch.no_grad():
            pt_t = warmup_and_time(
-                lambda *a: self._model(*a), self._dummy_inputs, iterations=self.cfg.iterations
+                lambda *a: self._model(*a),
+                self._dummy_inputs,
+                iterations=self.cfg.iterations,
            )
        rows.append(
            report_latency(
                compute_stats(pt_t, self.cfg.batch_size),
                backend="pytorch",
@@ -177,11 +190,13 @@
        )

        if self._trt_model is not None:
            with torch.no_grad():
                trt_t = warmup_and_time(
-                    lambda *a: self._trt_model(*a), self._dummy_inputs, iterations=self.cfg.iterations
+                    lambda *a: self._trt_model(*a),
+                    self._dummy_inputs,
+                    iterations=self.cfg.iterations,
                )
            rows.append(
                report_latency(
                    compute_stats(trt_t, self.cfg.batch_size),
                    backend=f"torch_tensorrt[{self.cfg.mode}]",
--- /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/diffusion.py	2026-04-29 16:28:51.985193+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/diffusion.py	2026-04-29 16:29:15.944213+00:00
@@ -12,10 +12,11 @@

Rather than hard-code dummy inputs per architecture, we run one short
inference pass and capture the actual backbone-call args via a forward
pre-hook.  Those captured tensors are then used as the export args.
"""
+
from __future__ import annotations

import timeit
from typing import Optional

@@ -40,10 +41,11 @@
    Insert `.<suffix>` before the extension of base_path.
      foo.trt → foo.<suffix>.trt
      foo     → foo.<suffix>
    """
    import os
+
    root, ext = os.path.splitext(base_path)
    return f"{root}.{suffix}{ext}" if ext else f"{root}.{suffix}"


def _get_backbone(pipe) -> tuple[torch.nn.Module, str]:
@@ -144,11 +146,13 @@
        if self.cfg.save_trt_engine and self.cfg.mode == "export":
            self._save_companion_engines()

    def _compile_torch_compile(self, backbone: torch.nn.Module) -> torch.nn.Module:
        options = compile_kwargs(self.cfg.precision, self.cfg.autocast)
-        options.update({"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug})
+        options.update(
+            {"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug}
+        )

        compiled = torch.compile(backbone, backend="torch_tensorrt", options=options)

        # Trigger compilation by running a real pipeline step.
        with torch.no_grad():
@@ -162,11 +166,13 @@
        args, kwargs = _capture_backbone_inputs(
            self._pipe,
            prompt="a photo of an astronaut",
            num_steps=1,
        )
-        print(f"[diffusion] Captured {len(args)} positional + {len(kwargs)} keyword args.")
+        print(
+            f"[diffusion] Captured {len(args)} positional + {len(kwargs)} keyword args."
+        )

        # Stash for accuracy comparison: keep the PT backbone reference and
        # the captured call args/kwargs so we can run both pre- and post-swap.
        self._pt_backbone = backbone
        self._captured_args = args
@@ -196,19 +202,19 @@
        # entry for each: torch_tensorrt.Input for tensors, the original
        # value for everything else.
        trt_kwarg_inputs = {
            k: (
                torch_tensorrt.Input(shape=v.shape, dtype=v.dtype)
-                if isinstance(v, torch.Tensor) else v
+                if isinstance(v, torch.Tensor)
+                else v
            )
            for k, v in kwargs.items()
        }
        # For the compile() call we still need a flat list, but only
        # tensor-typed entries; compile()'s `inputs=` arg is positional.
        flat_trt_inputs = trt_arg_inputs + [
-            v for v in trt_kwarg_inputs.values()
-            if isinstance(v, torch_tensorrt.Input)
+            v for v in trt_kwarg_inputs.values() if isinstance(v, torch_tensorrt.Input)
        ]

        compiled = compile_with_trt(
            ep,
            inputs=flat_trt_inputs,
@@ -221,12 +227,14 @@
        )
        torch.cuda.synchronize()

        if self.cfg.save_engine:
            from common.compile import maybe_save_trt_module
+
            maybe_save_trt_module(
-                self.cfg, compiled,
+                self.cfg,
+                compiled,
                arg_inputs=flat_trt_inputs,
                log_prefix="[diffusion]",
            )

        # Save the backbone engine to <base>.<attr>.trt; companion engines
@@ -246,11 +254,13 @@
                    kwarg_inputs=trt_kwarg_inputs or None,
                    log_prefix=f"[diffusion:{backbone_attr}]",
                )
            except Exception as e:
                print(f"[diffusion:{backbone_attr}] FAILED: {e}")
-                print(f"[diffusion:{backbone_attr}] (continuing with companion engines)")
+                print(
+                    f"[diffusion:{backbone_attr}] (continuing with companion engines)"
+                )

        return compiled

    # ---------------------------------------------------------------------- #
    # Multi-engine save: text_encoder + backbone + vae_decoder
@@ -313,13 +323,16 @@
            0, vocab, (self.cfg.batch_size, max_len), dtype=torch.int64
        ).to(self._device)

        print("[diffusion:text_encoder] torch.export.export ...")
        ep = safe_export(text_encoder, args=(input_ids,))
-        trt_inputs = [torch_tensorrt.Input(shape=input_ids.shape, dtype=input_ids.dtype)]
+        trt_inputs = [
+            torch_tensorrt.Input(shape=input_ids.shape, dtype=input_ids.dtype)
+        ]
        self._serialize_to_path(
-            ep, path,
+            ep,
+            path,
            arg_inputs=trt_inputs,
            log_prefix="[diffusion:text_encoder]",
        )

    def _serialize_vae_decoder(self, path: str) -> None:
@@ -341,19 +354,23 @@
        spatial_factor = 2 ** (len(block_out_channels) - 1)
        h_lat = self.cfg.image_size // spatial_factor
        w_lat = self.cfg.image_size // spatial_factor

        latents = torch.randn(
-            self.cfg.batch_size, latent_channels, h_lat, w_lat,
+            self.cfg.batch_size,
+            latent_channels,
+            h_lat,
+            w_lat,
            dtype=self._model_dtype,
        ).to(self._device)

        print("[diffusion:vae_decoder] torch.export.export ...")
        ep = safe_export(decoder, args=(latents,))
        trt_inputs = [torch_tensorrt.Input(shape=latents.shape, dtype=latents.dtype)]
        self._serialize_to_path(
-            ep, path,
+            ep,
+            path,
            arg_inputs=trt_inputs,
            log_prefix="[diffusion:vae_decoder]",
        )

    # ---------------------------------------------------------------------- #
--- /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/llm_tp.py	2026-04-29 16:28:51.985193+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/llm_tp.py	2026-04-29 16:29:15.970101+00:00
@@ -20,10 +20,11 @@

Launch
------
  torchtrtrun --nproc_per_node=2 run_hf.py --model meta-llama/Llama-3.2-1B-Instruct
"""
+
from __future__ import annotations

import logging
import sys
from contextlib import nullcontext
@@ -66,19 +67,21 @@
    )

    plan: dict[str, object] = {}
    n_layers = model.config.num_hidden_layers
    for i in range(n_layers):
-        plan.update({
-            f"model.layers.{i}.self_attn.q_proj": ColwiseParallel(),
-            f"model.layers.{i}.self_attn.k_proj": ColwiseParallel(),
-            f"model.layers.{i}.self_attn.v_proj": ColwiseParallel(),
-            f"model.layers.{i}.self_attn.o_proj": RowwiseParallel(),
-            f"model.layers.{i}.mlp.gate_proj":   ColwiseParallel(),
-            f"model.layers.{i}.mlp.up_proj":     ColwiseParallel(),
-            f"model.layers.{i}.mlp.down_proj":   RowwiseParallel(),
-        })
+        plan.update(
+            {
+                f"model.layers.{i}.self_attn.q_proj": ColwiseParallel(),
+                f"model.layers.{i}.self_attn.k_proj": ColwiseParallel(),
+                f"model.layers.{i}.self_attn.v_proj": ColwiseParallel(),
+                f"model.layers.{i}.self_attn.o_proj": RowwiseParallel(),
+                f"model.layers.{i}.mlp.gate_proj": ColwiseParallel(),
+                f"model.layers.{i}.mlp.up_proj": ColwiseParallel(),
+                f"model.layers.{i}.mlp.down_proj": RowwiseParallel(),
+            }
+        )
    return plan


def _patch_attention_head_counts(model, ws: int) -> None:
    """
@@ -119,11 +122,13 @@
        # Build a 1-D TP device mesh and shard the model.
        self._mesh = build_device_mesh(("tp",))
        ws = world_size()
        from torch.distributed.tensor.parallel import parallelize_module

-        if not hasattr(self._model, "model") or not hasattr(self._model.model, "layers"):
+        if not hasattr(self._model, "model") or not hasattr(
+            self._model.model, "layers"
+        ):
            raise NotImplementedError(
                f"TP plan only supports Llama-family layouts "
                f"(model.model.layers.{{i}}.self_attn / mlp).  Got: "
                f"{type(self._model).__name__}.  GPT-2 / OPT / NeoX support TBD."
            )
@@ -158,18 +163,20 @@
        """
        if is_master():
            print("[llm_tp] torch.compile(backend='torch_tensorrt', dynamic=True) ...")

        opts = compile_kwargs(self.cfg.precision, self.cfg.autocast)
-        opts.update({
-            "min_block_size": self.cfg.min_block_size,
-            "debug": self.cfg.debug,
-            "device": self._device,
-            "disable_tf32": True,
-            "use_python_runtime": False,
-            "assume_dynamic_shape_support": True,
-        })
+        opts.update(
+            {
+                "min_block_size": self.cfg.min_block_size,
+                "debug": self.cfg.debug,
+                "device": self._device,
+                "disable_tf32": True,
+                "use_python_runtime": False,
+                "assume_dynamic_shape_support": True,
+            }
+        )

        with torch_tensorrt.logging.debug() if self.cfg.debug else nullcontext():
            self._trt_model = torch.compile(
                self._model,
                backend="torch_tensorrt",
@@ -264,11 +271,15 @@

    def generate(self) -> None:
        """One greedy decode + print on rank 0."""
        from utils import generate

-        ids = self._tokenizer(self.cfg.prompt, return_tensors="pt")["input_ids"].to(self._device)
+        ids = self._tokenizer(self.cfg.prompt, return_tensors="pt")["input_ids"].to(
+            self._device
+        )
        max_out = ids.shape[1] + self.cfg.num_tokens
-        out = generate(self._trt_model, ids.clone(), max_out, self._tokenizer.eos_token_id)
+        out = generate(
+            self._trt_model, ids.clone(), max_out, self._tokenizer.eos_token_id
+        )
        if is_master():
            text = self._tokenizer.decode(out[0], skip_special_tokens=True)
            print(f"[llm_tp] Output: {text!r}")
--- /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/llm.py	2026-04-29 16:28:51.985193+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/llm.py	2026-04-29 16:29:15.999905+00:00
@@ -9,10 +9,11 @@

Generation note: a TRT-compiled GraphModule does NOT have HuggingFace's
.generate() method.  We implement a small greedy decoder that calls the
compiled module step-by-step on (input_ids, position_ids).
"""
+
from __future__ import annotations

import sys
import timeit
from pathlib import Path
@@ -81,10 +82,11 @@
        )

        # Optional: register SDPA converter from tools/llm/torchtrt_ext
        try:
            from torchtrt_ext import register_sdpa
+
            register_sdpa.enable_sdpa_converter(self.cfg.model, self._model.config)
        except Exception:
            pass

        self._tokenizer = AutoTokenizer.from_pretrained(self.cfg.model)
@@ -106,11 +108,13 @@
            self._compile_export()

    def _compile_torch_compile(self) -> None:
        print("[llm] torch.compile(backend='torch_tensorrt') ...")
        options = compile_kwargs(self.cfg.precision, self.cfg.autocast)
-        options.update({"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug})
+        options.update(
+            {"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug}
+        )

        self._trt_model = torch.compile(
            self._model,
            backend="torch_tensorrt",
            options=options,
@@ -127,20 +131,24 @@
        # pass that rewrites the exported graph to take KV cache tensors as
        # extra inputs and return updated caches as extra outputs.  This must
        # happen BEFORE compile() so the pass sees the graph.
        if self.cfg.cache == "static_v1":
            import static_cache_v1  # noqa: F401  (registers lowering pass)
+
            print("[llm] Static KV cache v1 lowering pass registered.")
        elif self.cfg.cache == "static_v2":
            import static_cache_v2  # noqa: F401
+
            print("[llm] Static KV cache v2 lowering pass registered.")

        print(f"[llm] Exporting (max_seq_len={max_seq_len}) ...")
        ep = _export_llm(self._model, self._input_ids, max_seq_len)
        maybe_save_exported_program(self.cfg, ep, log_prefix="[llm]")

-        position_ids = torch.arange(self._input_ids.shape[1]).unsqueeze(0).to(self._device)
+        position_ids = (
+            torch.arange(self._input_ids.shape[1]).unsqueeze(0).to(self._device)
+        )

        print("[llm] torch_tensorrt.dynamo.compile ...")
        self._trt_model = compile_with_trt(
            ep,
            inputs=[self._input_ids, position_ids],
@@ -155,23 +163,30 @@
        print("[llm] TRT compile done.")

        # The EP was exported with input_ids as a positional arg and
        # position_ids as a keyword arg, so the serialize path needs the
        # arg/kwarg split explicitly.
-        arg_inputs = [torch_tensorrt.Input(shape=self._input_ids.shape, dtype=torch.int64)]
+        arg_inputs = [
+            torch_tensorrt.Input(shape=self._input_ids.shape, dtype=torch.int64)
+        ]
        kwarg_inputs = {
-            "position_ids": torch_tensorrt.Input(shape=position_ids.shape, dtype=torch.int64),
+            "position_ids": torch_tensorrt.Input(
+                shape=position_ids.shape, dtype=torch.int64
+            ),
        }

        maybe_save_trt_module(
-            self.cfg, self._trt_model,
+            self.cfg,
+            self._trt_model,
            arg_inputs=arg_inputs,
            kwarg_inputs=kwarg_inputs,
            log_prefix="[llm]",
        )
        maybe_save_trt_engine(
-            self.cfg, ep, arg_inputs,
+            self.cfg,
+            ep,
+            arg_inputs,
            kwarg_inputs=kwarg_inputs,
            log_prefix="[llm]",
        )

    # ---------------------------------------------------------------------- #
@@ -186,11 +201,13 @@
            return self._benchmark_generation()
        return self._benchmark_prefill()

    def _benchmark_prefill(self) -> list[dict]:
        assert self._input_ids is not None, "Call load() first"
-        position_ids = torch.arange(self._input_ids.shape[1]).unsqueeze(0).to(self._device)
+        position_ids = (
+            torch.arange(self._input_ids.shape[1]).unsqueeze(0).to(self._device)
+        )
        rows: list[dict] = []
        tokens_per_iter = self.cfg.batch_size * self.cfg.isl

        def _run_pt():
            with torch.no_grad():
@@ -205,10 +222,11 @@
                precision=self.cfg.precision,
            )
        )

        if self._trt_model is not None:
+
            def _run_trt():
                with torch.no_grad():
                    return self._trt_model(self._input_ids, position_ids)

            trt_t = warmup_and_time(_run_trt, (), iterations=self.cfg.iterations)
@@ -286,26 +304,32 @@
        return rows

    # ---------------------------------------------------------------------- #

    def _run_pt(self):
-        position_ids = torch.arange(self._input_ids.shape[1]).unsqueeze(0).to(self._device)
+        position_ids = (
+            torch.arange(self._input_ids.shape[1]).unsqueeze(0).to(self._device)
+        )
        with torch.no_grad():
            return self._model(self._input_ids, position_ids=position_ids)

    def _run_trt(self):
-        position_ids = torch.arange(self._input_ids.shape[1]).unsqueeze(0).to(self._device)
+        position_ids = (
+            torch.arange(self._input_ids.shape[1]).unsqueeze(0).to(self._device)
+        )
        with torch.no_grad():
            return self._trt_model(self._input_ids, position_ids)

    def generate(self) -> None:
        """
        Greedy decode against the original PyTorch model.  TRT-compiled
        GraphModules don't have HF's .generate(); a proper TRT generation
        loop requires KV cache support — see tools/llm/utils.py for that.
        """
-        assert self._tokenizer is not None and self._model is not None, "Call load() first"
+        assert (
+            self._tokenizer is not None and self._model is not None
+        ), "Call load() first"

        inputs = self._tokenizer(self.cfg.prompt, return_tensors="pt").to(self._device)
        input_ids = inputs["input_ids"]

        print(f"[llm] Prompt: {self.cfg.prompt!r}")
--- /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/multimodal.py	2026-04-29 16:28:51.985193+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/multimodal.py	2026-04-29 16:29:16.075216+00:00
@@ -10,10 +10,11 @@
  CLIPVisionModel  : vision encoder only (also handled here for convenience)
  CLIPTextModel    : text encoder only

Benchmarking metric: latency + throughput of the full forward.
"""
+
from __future__ import annotations

from typing import Optional

import torch
@@ -31,11 +32,13 @@
)
from common.metrics import print_table, report_latency
from strategies.base import ModelStrategy, RunConfig


-def _build_clip_kwargs(model_cfg, batch_size: int, dtype: torch.dtype, device: str) -> dict:
+def _build_clip_kwargs(
+    model_cfg, batch_size: int, dtype: torch.dtype, device: str
+) -> dict:
    """
    Build {input_ids, attention_mask, pixel_values} for a CLIP-style model.
    Returned as kwargs to avoid positional-argument-order issues across model variants.
    """
    text_cfg = getattr(model_cfg, "text_config", model_cfg)
@@ -49,13 +52,19 @@
    else:
        h = w = int(image_size)
    num_channels = getattr(vision_cfg, "num_channels", 3)

    return {
-        "input_ids": torch.randint(0, vocab, (batch_size, seq_len), dtype=torch.int64).to(device),
-        "attention_mask": torch.ones((batch_size, seq_len), dtype=torch.int64).to(device),
-        "pixel_values": torch.randn(batch_size, num_channels, h, w, dtype=dtype).to(device),
+        "input_ids": torch.randint(
+            0, vocab, (batch_size, seq_len), dtype=torch.int64
+        ).to(device),
+        "attention_mask": torch.ones((batch_size, seq_len), dtype=torch.int64).to(
+            device
+        ),
+        "pixel_values": torch.randn(batch_size, num_channels, h, w, dtype=dtype).to(
+            device
+        ),
    }


class MultimodalStrategy(ModelStrategy):
    def __init__(self, cfg: RunConfig):
@@ -103,13 +112,17 @@
            self._compile_export()

    def _compile_torch_compile(self) -> None:
        print("[multimodal] torch.compile(backend='torch_tensorrt') ...")
        options = compile_kwargs(self.cfg.precision, self.cfg.autocast)
-        options.update({"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug})
-
-        self._trt_model = torch.compile(self._model, backend="torch_tensorrt", options=options)
+        options.update(
+            {"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug}
+        )
+
+        self._trt_model = torch.compile(
+            self._model, backend="torch_tensorrt", options=options
+        )
        with torch.no_grad():
            _ = self._trt_model(**self._dummy_kwargs)
        torch.cuda.synchronize()

    def _compile_export(self) -> None:
@@ -140,17 +153,20 @@
        )
        torch.cuda.synchronize()

        # Serializer is strict about arg/kwarg split; CLIP's EP has only kwargs.
        maybe_save_trt_module(
-            self.cfg, self._trt_model,
+            self.cfg,
+            self._trt_model,
            arg_inputs=[],
            kwarg_inputs=trt_kwarg_inputs,
            log_prefix="[multimodal]",
        )
        maybe_save_trt_engine(
-            self.cfg, ep, [],
+            self.cfg,
+            ep,
+            [],
            kwarg_inputs=trt_kwarg_inputs,
            log_prefix="[multimodal]",
        )

    # ---------------------------------------------------------------------- #
@@ -175,10 +191,11 @@

        if self._trt_model is not None:
            # The exported TRT module takes positional inputs in the order
            # they were passed at export time.
            trt_args = tuple(self._dummy_kwargs.values())
+
            def _run_trt():
                with torch.no_grad():
                    return self._trt_model(*trt_args)

            trt_t = warmup_and_time(_run_trt, (), iterations=self.cfg.iterations)
--- /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/seq2seq.py	2026-04-29 16:28:51.985193+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/strategies/seq2seq.py	2026-04-29 16:29:16.166075+00:00
@@ -9,10 +9,11 @@
.generate() automatically uses the TRT encoder + PyTorch decoder.

Benchmarking metric: encoder latency on a fixed-length input + qualitative
.generate() output via the strategy's generate() method.
"""
+
from __future__ import annotations

from typing import Optional

import torch
@@ -61,54 +62,64 @@
        self._tokenizer = AutoTokenizer.from_pretrained(self.cfg.model)

        # Build encoder dummy inputs.  T5 encoder takes (input_ids, attention_mask).
        seq_len = self.cfg.isl
        vocab = self._model.config.vocab_size
-        input_ids = torch.randint(0, vocab, (self.cfg.batch_size, seq_len), dtype=torch.int64).to(self._device)
+        input_ids = torch.randint(
+            0, vocab, (self.cfg.batch_size, seq_len), dtype=torch.int64
+        ).to(self._device)
        attention_mask = torch.ones_like(input_ids)
        self._dummy_inputs = (input_ids, attention_mask)
        print(f"[seq2seq] Model loaded.  Encoder input shape: {input_ids.shape}")

    # ---------------------------------------------------------------------- #

    def compile(self) -> None:
        assert self._model is not None, "Call load() before compile()"
        # T5: model.encoder.  BART/Pegasus/mBART: model.model.encoder.
-        if hasattr(self._model, "encoder") and not isinstance(
-            getattr(self._model, "encoder", None), type(None)
-        ) and hasattr(self._model.encoder, "forward"):
+        if (
+            hasattr(self._model, "encoder")
+            and not isinstance(getattr(self._model, "encoder", None), type(None))
+            and hasattr(self._model.encoder, "forward")
+        ):
            encoder_owner = self._model
        elif hasattr(self._model, "model") and hasattr(self._model.model, "encoder"):
            encoder_owner = self._model.model
        else:
            raise AttributeError(
                f"Could not find encoder submodule on {type(self._model).__name__}"
            )
        encoder = encoder_owner.encoder
        self._encoder_owner = encoder_owner  # remember for swap-back
-        self._pt_encoder = encoder            # keep PT reference for accuracy
+        self._pt_encoder = encoder  # keep PT reference for accuracy

        if self.cfg.mode == "compile":
            self._trt_encoder = self._compile_torch_compile(encoder)
        else:
            self._trt_encoder = self._compile_export(encoder)

        # Preserve attributes the parent model accesses on the encoder.
        for attr_name in ("config", "dtype", "main_input_name"):
-            if hasattr(encoder, attr_name) and not hasattr(self._trt_encoder, attr_name):
+            if hasattr(encoder, attr_name) and not hasattr(
+                self._trt_encoder, attr_name
+            ):
                try:
                    setattr(self._trt_encoder, attr_name, getattr(encoder, attr_name))
                except Exception:
                    pass

        self._encoder_owner.encoder = self._trt_encoder
-        print(f"[seq2seq] Compiled encoder swapped into {type(self._encoder_owner).__name__}.encoder")
+        print(
+            f"[seq2seq] Compiled encoder swapped into {type(self._encoder_owner).__name__}.encoder"
+        )

    def _compile_torch_compile(self, encoder: torch.nn.Module) -> torch.nn.Module:
        print("[seq2seq] torch.compile encoder (backend='torch_tensorrt') ...")
        options = compile_kwargs(self.cfg.precision, self.cfg.autocast)
-        options.update({"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug})
+        options.update(
+            {"min_block_size": self.cfg.min_block_size, "debug": self.cfg.debug}
+        )

        compiled = torch.compile(encoder, backend="torch_tensorrt", options=options)
        with torch.no_grad():
            _ = compiled(*self._dummy_inputs)
        torch.cuda.synchronize()
@@ -118,17 +129,22 @@
        print("[seq2seq] torch.export.export encoder ...")
        # Dynamic batch + seq_len.
        ep = safe_export(
            encoder,
            args=self._dummy_inputs,
-            dynamic_shapes=({0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO},
-                            {0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO}),
+            dynamic_shapes=(
+                {0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO},
+                {0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO},
+            ),
        )
        maybe_save_exported_program(self.cfg, ep, log_prefix="[seq2seq]")

        print("[seq2seq] torch_tensorrt.dynamo.compile encoder ...")
-        trt_inputs = [torch_tensorrt.Input(shape=t.shape, dtype=t.dtype) for t in self._dummy_inputs]
+        trt_inputs = [
+            torch_tensorrt.Input(shape=t.shape, dtype=t.dtype)
+            for t in self._dummy_inputs
+        ]
        compiled = compile_with_trt(
            ep,
            inputs=trt_inputs,
            precision=self.cfg.precision,
            autocast=self.cfg.autocast,
@@ -138,11 +154,12 @@
            engine_cache_dir=self.cfg.engine_cache_dir,
        )
        torch.cuda.synchronize()

        maybe_save_trt_module(
-            self.cfg, compiled,
+            self.cfg,
+            compiled,
            arg_inputs=trt_inputs,
            log_prefix="[seq2seq]",
        )
        maybe_save_trt_engine(self.cfg, ep, trt_inputs, log_prefix="[seq2seq]")

@@ -151,12 +168,20 @@
    # ---------------------------------------------------------------------- #

    def benchmark(self) -> list[dict]:
        assert self._dummy_inputs is not None, "Call load() first"
        rows: list[dict] = []
-        encoder = self._trt_encoder if self._trt_encoder is not None else self._encoder_owner.encoder
-        backend = f"torch_tensorrt[{self.cfg.mode}]" if self._trt_encoder is not None else "pytorch"
+        encoder = (
+            self._trt_encoder
+            if self._trt_encoder is not None
+            else self._encoder_owner.encoder
+        )
+        backend = (
+            f"torch_tensorrt[{self.cfg.mode}]"
+            if self._trt_encoder is not None
+            else "pytorch"
+        )

        def _run():
            with torch.no_grad():
                return encoder(*self._dummy_inputs)

@narendasan narendasan force-pushed the narendasan/push-prvysnyvnylw branch from 55e5086 to c5ef4cd Compare April 29, 2026 19:06
Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/bench.py	2026-04-29 19:07:00.389656+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/bench.py	2026-04-29 19:07:24.465515+00:00
@@ -1,16 +1,16 @@
"""
Shared benchmarking harness for all HF model strategies.
"""
+
from __future__ import annotations

import timeit
from typing import Callable, Sequence

import numpy as np
import torch
-

WARMUP_ITERS = 5


def warmup_and_time(
--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/dist.py	2026-04-29 19:07:00.389656+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/dist.py	2026-04-29 19:07:24.511046+00:00
@@ -12,10 +12,11 @@
  MASTER_PORT  – rendezvous port

Single-process (no torchtrtrun) defaults: WORLD_SIZE=1 (or unset), so
all helpers here become no-ops.
"""
+
from __future__ import annotations

import datetime
import logging
import os
@@ -44,11 +45,13 @@

def is_master() -> bool:
    return rank() == 0


-def init_distributed(timeout_hours: int = 2) -> Optional["torch.distributed.ProcessGroup"]:
+def init_distributed(
+    timeout_hours: int = 2,
+) -> Optional["torch.distributed.ProcessGroup"]:
    """
    Initialize torch.distributed if WORLD_SIZE > 1.

    Long timeout (default 2 h) so TRT engine builds don't trigger the NCCL
    watchdog when one rank takes longer than another to build.
@@ -79,17 +82,16 @@

    # Wire NCCL into TRT (sets up the symbol resolution path so TRT's
    # collectives talk to the same libnccl.so torchtrtrun preloaded).
    try:
        from torch_tensorrt.distributed import setup_nccl_for_torch_tensorrt
+
        setup_nccl_for_torch_tensorrt()
    except Exception as e:
        logger.warning(f"setup_nccl_for_torch_tensorrt failed: {e}")

-    logger.info(
-        f"[rank {rank()}/{world_size()}] dist init OK on cuda:{local_rank()}"
-    )
+    logger.info(f"[rank {rank()}/{world_size()}] dist init OK on cuda:{local_rank()}")
    return dist.group.WORLD


def build_device_mesh(mesh_dim_names: tuple[str, ...] = ("tp",)):
    """
@@ -112,7 +114,8 @@
def barrier() -> None:
    """torch.distributed barrier — no-op when not distributed."""
    if not is_distributed():
        return
    import torch.distributed as dist
+
    if dist.is_initialized():
        dist.barrier()
--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/compile.py	2026-04-29 19:07:00.389656+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/compile.py	2026-04-29 19:07:24.605988+00:00
@@ -22,17 +22,17 @@
    )

`enabled_precisions` is deprecated when `use_explicit_typing=True` and is
NEVER set by these helpers.
"""
+
from __future__ import annotations

from typing import Iterable, Optional

import torch
import torch_tensorrt
-

# --------------------------------------------------------------------------- #
# Precision plumbing
# --------------------------------------------------------------------------- #

@@ -87,10 +87,11 @@


# --------------------------------------------------------------------------- #
# Export helper (with fallback for guard violations)
# --------------------------------------------------------------------------- #
+

def safe_export(
    module: torch.nn.Module,
    args: tuple = (),
    kwargs: Optional[dict] = None,
@@ -110,11 +111,13 @@
                kwargs=kwargs,
                dynamic_shapes=dynamic_shapes,
                strict=False,
            )
        except Exception as e:
-            print(f"[compile] torch.export.export failed ({e}); retrying with deferred guards.")
+            print(
+                f"[compile] torch.export.export failed ({e}); retrying with deferred guards."
+            )
            return torch.export._trace._export(
                module,
                args=args,
                kwargs=kwargs,
                dynamic_shapes=dynamic_shapes,
@@ -124,10 +127,11 @@


# --------------------------------------------------------------------------- #
# Compile wrapper
# --------------------------------------------------------------------------- #
+

def _build_trt_kwargs(
    precision: str,
    autocast: bool,
    min_block_size: int,
@@ -179,13 +183,19 @@
      - C++ runtime (use_python_runtime=False)
      - engine caching ON
      - offload_module_to_cpu OFF (opt-in for memory-constrained models)
    """
    kw = _build_trt_kwargs(
-        precision, autocast, min_block_size, debug,
-        offload_module_to_cpu, cache_built_engines, reuse_cached_engines,
-        engine_cache_dir, extra,
+        precision,
+        autocast,
+        min_block_size,
+        debug,
+        offload_module_to_cpu,
+        cache_built_engines,
+        reuse_cached_engines,
+        engine_cache_dir,
+        extra,
    )
    return torch_tensorrt.dynamo.compile(ep, inputs=list(inputs), **kw)


def maybe_save_exported_program(cfg, ep, *, log_prefix: str) -> None:
@@ -225,11 +235,13 @@

    # torchscript needs real tensors for jit.trace; other formats accept
    # torch_tensorrt.Input specs.
    if fmt == "torchscript":
        use_args = example_arg_inputs if example_arg_inputs is not None else arg_inputs
-        use_kwargs = example_kwarg_inputs if example_kwarg_inputs is not None else kwarg_inputs
+        use_kwargs = (
+            example_kwarg_inputs if example_kwarg_inputs is not None else kwarg_inputs
+        )
    else:
        use_args = arg_inputs
        use_kwargs = kwarg_inputs

    save_kwargs = {"output_format": fmt}
@@ -308,11 +320,14 @@
    single TRT engine end-to-end (no PyTorch fallback partitions).  If
    your model has unsupported ops, this call will fail; use the
    `compile_with_trt` + `torch_tensorrt.save` path instead.
    """
    kw = _build_trt_kwargs(
-        precision, autocast, min_block_size, debug,
+        precision,
+        autocast,
+        min_block_size,
+        debug,
        offload_module_to_cpu,
        # Engine caching is irrelevant for serialization output.
        cache_built_engines=False,
        reuse_cached_engines=False,
        engine_cache_dir=engine_cache_dir,
--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/accuracy.py	2026-04-29 19:07:00.389656+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/accuracy.py	2026-04-29 19:07:24.631701+00:00
@@ -8,21 +8,22 @@
Default tolerances are tuned for FP16 (atol=1e-2, rtol=1e-2,
cos_sim_min=0.99).  Override via --accuracy-atol / --accuracy-rtol /
--accuracy-cos-sim-min when comparing tighter precisions or models
known to have larger accumulated error.
"""
+
from __future__ import annotations

from typing import Iterable

import torch
import torch.utils._pytree as pytree

-
# --------------------------------------------------------------------------- #
# Output flattening
# --------------------------------------------------------------------------- #
+

def _flatten_to_tensors(out) -> list[torch.Tensor]:
    """
    Flatten an arbitrary HF model output (ModelOutput dataclass, dict,
    tuple, list, or single tensor) into a list of leaf tensors using
@@ -33,10 +34,11 @@


# --------------------------------------------------------------------------- #
# Per-tensor metrics
# --------------------------------------------------------------------------- #
+

def _cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float:
    a = a.detach().to(torch.float32).flatten()
    b = b.detach().to(torch.float32).flatten()
    if a.numel() == 0 or b.numel() == 0:
@@ -87,10 +89,11 @@

# --------------------------------------------------------------------------- #
# Compare two outputs (each a tensor / dict / dataclass / tuple)
# --------------------------------------------------------------------------- #

+
def compare_outputs(
    pt_out,
    trt_out,
    *,
    atol: float = 1e-2,
@@ -99,22 +102,28 @@
) -> list[dict]:
    pt_leaves = _flatten_to_tensors(pt_out)
    trt_leaves = _flatten_to_tensors(trt_out)

    if len(pt_leaves) != len(trt_leaves):
-        return [{
-            "name": "<output-count-mismatch>",
-            "shape_pt": f"{len(pt_leaves)} tensors",
-            "shape_trt": f"{len(trt_leaves)} tensors",
-            "shape_match": False,
-            "cos_sim": float("nan"),
-            "max_abs": float("nan"),
-            "mean_abs": float("nan"),
-            "allclose": False,
-        }]
-
-    names = list(output_names) if output_names else [f"out[{i}]" for i in range(len(pt_leaves))]
+        return [
+            {
+                "name": "<output-count-mismatch>",
+                "shape_pt": f"{len(pt_leaves)} tensors",
+                "shape_trt": f"{len(trt_leaves)} tensors",
+                "shape_match": False,
+                "cos_sim": float("nan"),
+                "max_abs": float("nan"),
+                "mean_abs": float("nan"),
+                "allclose": False,
+            }
+        ]
+
+    names = (
+        list(output_names)
+        if output_names
+        else [f"out[{i}]" for i in range(len(pt_leaves))]
+    )
    if len(names) < len(pt_leaves):
        names += [f"out[{i}]" for i in range(len(names), len(pt_leaves))]

    rows: list[dict] = []
    for name, pt, trt in zip(names, pt_leaves, trt_leaves):
@@ -125,10 +134,11 @@


# --------------------------------------------------------------------------- #
# Reporting
# --------------------------------------------------------------------------- #
+

def overall_pass(
    rows: list[dict],
    *,
    cos_sim_min: float = 0.99,
@@ -169,11 +179,13 @@
        print(f"\n{'=' * 70}")
        print(f"  {title}")
        print(f"{'=' * 70}")

    cols = ("name", "shape_pt", "cos_sim", "max_abs", "mean_abs", "allclose")
-    widths = {c: max(len(c), max(len(_fmt(r.get(c, ""), c)) for r in rows)) for c in cols}
+    widths = {
+        c: max(len(c), max(len(_fmt(r.get(c, ""), c)) for r in rows)) for c in cols
+    }
    header = "  ".join(c.ljust(widths[c]) for c in cols)
    print(header)
    print("-" * len(header))
    for r in rows:
        print("  ".join(_fmt(r.get(c, ""), c).ljust(widths[c]) for c in cols))

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/bench.py	2026-04-29 23:52:22.708681+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/bench.py	2026-04-29 23:52:45.045685+00:00
@@ -1,16 +1,16 @@
"""
Shared benchmarking harness for all HF model strategies.
"""
+
from __future__ import annotations

import timeit
from typing import Callable, Sequence

import numpy as np
import torch
-

WARMUP_ITERS = 5


def warmup_and_time(
--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/accuracy.py	2026-04-29 23:52:22.708681+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/accuracy.py	2026-04-29 23:52:45.149558+00:00
@@ -8,21 +8,22 @@
Default tolerances are tuned for FP16 (atol=1e-2, rtol=1e-2,
cos_sim_min=0.99).  Override via --accuracy-atol / --accuracy-rtol /
--accuracy-cos-sim-min when comparing tighter precisions or models
known to have larger accumulated error.
"""
+
from __future__ import annotations

from typing import Iterable

import torch
import torch.utils._pytree as pytree

-
# --------------------------------------------------------------------------- #
# Output flattening
# --------------------------------------------------------------------------- #
+

def _flatten_to_tensors(out) -> list[torch.Tensor]:
    """
    Flatten an arbitrary HF model output (ModelOutput dataclass, dict,
    tuple, list, or single tensor) into a list of leaf tensors using
@@ -33,10 +34,11 @@


# --------------------------------------------------------------------------- #
# Per-tensor metrics
# --------------------------------------------------------------------------- #
+

def _cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float:
    a = a.detach().to(torch.float32).flatten()
    b = b.detach().to(torch.float32).flatten()
    if a.numel() == 0 or b.numel() == 0:
@@ -87,10 +89,11 @@

# --------------------------------------------------------------------------- #
# Compare two outputs (each a tensor / dict / dataclass / tuple)
# --------------------------------------------------------------------------- #

+
def compare_outputs(
    pt_out,
    trt_out,
    *,
    atol: float = 1e-2,
@@ -99,22 +102,28 @@
) -> list[dict]:
    pt_leaves = _flatten_to_tensors(pt_out)
    trt_leaves = _flatten_to_tensors(trt_out)

    if len(pt_leaves) != len(trt_leaves):
-        return [{
-            "name": "<output-count-mismatch>",
-            "shape_pt": f"{len(pt_leaves)} tensors",
-            "shape_trt": f"{len(trt_leaves)} tensors",
-            "shape_match": False,
-            "cos_sim": float("nan"),
-            "max_abs": float("nan"),
-            "mean_abs": float("nan"),
-            "allclose": False,
-        }]
-
-    names = list(output_names) if output_names else [f"out[{i}]" for i in range(len(pt_leaves))]
+        return [
+            {
+                "name": "<output-count-mismatch>",
+                "shape_pt": f"{len(pt_leaves)} tensors",
+                "shape_trt": f"{len(trt_leaves)} tensors",
+                "shape_match": False,
+                "cos_sim": float("nan"),
+                "max_abs": float("nan"),
+                "mean_abs": float("nan"),
+                "allclose": False,
+            }
+        ]
+
+    names = (
+        list(output_names)
+        if output_names
+        else [f"out[{i}]" for i in range(len(pt_leaves))]
+    )
    if len(names) < len(pt_leaves):
        names += [f"out[{i}]" for i in range(len(names), len(pt_leaves))]

    rows: list[dict] = []
    for name, pt, trt in zip(names, pt_leaves, trt_leaves):
@@ -125,10 +134,11 @@


# --------------------------------------------------------------------------- #
# Reporting
# --------------------------------------------------------------------------- #
+

def overall_pass(
    rows: list[dict],
    *,
    cos_sim_min: float = 0.99,
@@ -169,11 +179,13 @@
        print(f"\n{'=' * 70}")
        print(f"  {title}")
        print(f"{'=' * 70}")

    cols = ("name", "shape_pt", "cos_sim", "max_abs", "mean_abs", "allclose")
-    widths = {c: max(len(c), max(len(_fmt(r.get(c, ""), c)) for r in rows)) for c in cols}
+    widths = {
+        c: max(len(c), max(len(_fmt(r.get(c, ""), c)) for r in rows)) for c in cols
+    }
    header = "  ".join(c.ljust(widths[c]) for c in cols)
    print(header)
    print("-" * len(header))
    for r in rows:
        print("  ".join(_fmt(r.get(c, ""), c).ljust(widths[c]) for c in cols))
--- /home/runner/work/TensorRT/TensorRT/tools/hf/common/dist.py	2026-04-29 23:52:22.708681+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/hf/common/dist.py	2026-04-29 23:52:45.157749+00:00
@@ -12,10 +12,11 @@
  MASTER_PORT  – rendezvous port

Single-process (no torchtrtrun) defaults: WORLD_SIZE=1 (or unset), so
all helpers here become no-ops.
"""
+
from __future__ import annotations

import datetime
import logging
import os
@@ -44,11 +45,13 @@

def is_master() -> bool:
    return rank() == 0


-def init_distributed(timeout_hours: int = 2) -> Optional["torch.distributed.ProcessGroup"]:
+def init_distributed(
+    timeout_hours: int = 2,
+) -> Optional["torch.distributed.ProcessGroup"]:
    """
    Initialize torch.distributed if WORLD_SIZE > 1.

    Long timeout (default 2 h) so TRT engine builds don't trigger the NCCL
    watchdog when one rank takes longer than another to build.
@@ -79,17 +82,16 @@

    # Wire NCCL into TRT (sets up the symbol resolution path so TRT's
    # collectives talk to the same libnccl.so torchtrtrun preloaded).
    try:
        from torch_tensorrt.distributed import setup_nccl_for_torch_tensorrt
+
        setup_nccl_for_torch_tensorrt()
    except Exception as e:
        logger.warning(f"setup_nccl_for_torch_tensorrt failed: {e}")

-    logger.info(
-        f"[rank {rank()}/{world_size()}] dist init OK on cuda:{local_rank()}"
-    )
+    logger.info(f"[rank {rank()}/{world_size()}] dist init OK on cuda:{local_rank()}")
    return dist.group.WORLD


def build_device_mesh(mesh_dim_names: tuple[str, ...] = ("tp",)):
    """
@@ -112,7 +114,8 @@
def barrier() -> None:
    """torch.distributed barrier — no-op when not distributed."""
    if not is_distributed():
        return
    import torch.distributed as dist
+
    if dist.is_initialized():
        dist.barrier()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

✨[Feature] Multi-Framework Runner Support in tools

1 participant