Skip to content

feat(runtime): support binding TRTEngine execution to an external CUDA stream#4232

Open
shoumikhin wants to merge 9 commits intopytorch:mainfrom
shoumikhin:green-context-external-stream-upstream
Open

feat(runtime): support binding TRTEngine execution to an external CUDA stream#4232
shoumikhin wants to merge 9 commits intopytorch:mainfrom
shoumikhin:green-context-external-stream-upstream

Conversation

@shoumikhin
Copy link
Copy Markdown

@shoumikhin shoumikhin commented May 1, 2026

Summary

Adds opt-in support for binding torch-tensorrt's TRT engine execution to externally-managed CUDA streams — typically streams created via cuGreenCtxStreamCreate for SM partitioning via CUDA Green Contexts (CUDA 12.4+). The motivating workload is edge / on-device multi-tenant inference on Jetson-class hardware where a vision encoder + policy net + diffusion head all share one process and need disjoint SM partitions to avoid time-slicing.

Currently, core/runtime/execute_engine.cpp lazily pulls a stream from torch's global stream pool on first execute. That pool is bound to the primary CUDA context, so even when a caller sets a green-context-bound stream as current (via c10::cuda::CUDAStreamGuard / torch.cuda.stream(...)), the TRT engine bypasses it and uses a primary-context pool stream — defeating any SM partitioning the caller set up.

Pure additive: no behavior change for callers that don't opt in.

Two complementary mechanisms

The PR ships two ways to bind a stream, sized to two different deployment shapes:

(1) Per-engine binding — for Python / dynamo / output_format="exported_program"

Reach the TRTEngine torchbind through the wrapping nn.Module's named_modules() and bind a stream per engine. This is the canonical multi-engine SM-partitioning case where one compiled model contains several TRT subgraphs that should each run on a distinct green context.

New C++ API on TRTEngine (exposed via torchbind):

void TRTEngine::set_external_stream(int64_t stream_handle);  // reinterpret_cast<int64_t>(cudaStream_t)
void TRTEngine::clear_external_stream();
int64_t TRTEngine::get_external_stream() const;
bool TRTEngine::is_external_stream_set() const;

Reachable from Python and external C++ via torch.classes.tensorrt.Engine.

New Python facade with RAII context-manager semantics:

import torch_tensorrt
from torch_tensorrt.runtime import set_external_stream, clear_external_stream

# Single stream bound to every TRT submodule
with set_external_stream(model, my_stream):
    out = model(x)            # restored on exit

# Per-engine binding (the canonical green-context case)
with set_external_stream(model, {
    "_run_on_acc_0": vision_encoder_stream,    # SM partition A
    "_run_on_acc_1": policy_net_stream,        # SM partition B
    "_run_on_acc_2": diffusion_head_stream,    # SM partition C
}):
    out = model(x)

set_external_stream walks named_modules() recursively, so deeply nested TRT submodules (e.g. HF blocks under wrapper GraphModules) are reachable. Submodule names are dotted paths, validated up front so a bad value cannot leave a partially-bound module. The setter validates the stream's device-affinity against the engine's target device (via cuStreamGetCtx + cuCtxGetDevice) and rejects the legacy / per-thread magic stream IDs; the binding is applied atomically across multiple engines (any per-engine failure rolls back successfully-applied bindings before re-raising).

(2) Process-wide stream passthrough — for AOTI / .pt2 C++ deployments

When the model is exported with output_format="aot_inductor" and consumed in pure C++ via torch::inductor::AOTIModelPackageLoader, the live TRTEngine torchbind instances live inside OSSProxyExecutor::custom_objs_private with no public PyTorch accessor. Re-parsing the .pt2 only yields independent IValue copies that the running .so never invokes, so the per-engine API in (1) is unreachable.

The fix: a process-wide opt-in flag that makes execute_engine honor the caller's current CUDA stream instead of the lazy pool stream. Users wrap loader.run(...) in a CUDAStreamGuard and the engine inherits it.

New globals (C++ + Python):

namespace torch_tensorrt::core::runtime {
  bool get_engine_stream_passthrough();
  void set_engine_stream_passthrough(bool);
}
torch_tensorrt.runtime.set_engine_stream_passthrough(True)
torch_tensorrt.runtime.get_engine_stream_passthrough()

C++ usage after merge:

#include <ATen/cuda/CUDAStream.h>
#include <c10/cuda/CUDAGuard.h>
#include "torch/csrc/inductor/aoti_package/model_package_loader.h"

torch::inductor::AOTIModelPackageLoader loader("model.pt2");

// Process-wide opt-in — set once, applies to every loaded engine.
torch_tensorrt::core::runtime::set_engine_stream_passthrough(true);

// Carve out an SM partition with a Green Context and create a stream on it.
CUgreenCtx green_ctx;
CUstream raw_green_stream;
// ... cuDevSmResourceSplitByCount + cuGreenCtxCreate + cuGreenCtxStreamCreate ...

auto stream = c10::cuda::getStreamFromExternal(raw_green_stream, /*device=*/0);
{
  c10::cuda::CUDAStreamGuard guard(stream);
  auto out = loader.run(inputs);   // TRT engine inherits the guarded stream
}

Python usage after merge (also valid for AOTI-loaded models via torch._inductor.aoti_load_package):

import torch
import torch_tensorrt

torch_tensorrt.runtime.set_engine_stream_passthrough(True)

green_stream = torch.cuda.Stream(device=0)   # or wrap a green-ctx CUstream
with torch.cuda.stream(green_stream):
    out = aoti_model(x)

Precedence

When multiple sources are configured, the resolver picks in this order, every call (so set / clear take effect immediately without recreating the engine):

  1. Per-engine external_stream (set via TRTEngine::set_external_stream)
  2. Process-wide ENGINE_STREAM_PASSTHROUGH → caller's current CUDA stream
  3. Existing pool fallback (getStreamFromPool) — unchanged default behavior

Mutual exclusion with CUDA Graphs

Both mechanisms are mutually exclusive with CUDA Graphs. The check fires at bind time (set_external_stream / set_engine_stream_passthrough(true) throw if cudagraphs are currently enabled) and again at execute time as defense-in-depth (covers cudagraphs being enabled after the binding):

CUDA Graphs are not supported when an external stream is set on the engine.
Disable cudagraphs or call clear_external_stream() first.

CUDA Graphs are not supported while engine-stream passthrough is enabled.
Disable cudagraphs or call set_engine_stream_passthrough(False) first.

The setter and clearer also invalidate any captured graph (cudagraph.reset()) so a subsequent recapture happens cleanly and never replays against a stale stream identity.

Multi-GPU correctness fix folded in

TRTEngine::engine_stream and TRTEngine::caller_stream are now pinned to the engine's actual device_info.id in the constructor body. The in-class initializers at TRTEngine.h:211-212 default to device 0 (no device arg). Without this fix, the lazy pool re-acquire in execute_engine checked engine_stream == getDefaultCUDAStream(current_device_id) — always false on cuda:N for N>0 — so the engine ran on cuda:0's default stream regardless of the input device. Pre-existing bug; fixed here while we were in the area.

Files changed

File Change
core/runtime/TRTEngine.{h,cpp} per-engine setter / clearer / getter, external_stream + engine_stream_is_external fields, cudagraph invalidation in setter & clearer, multi-GPU default-stream init in ctor
core/runtime/execute_engine.cpp stream-resolve sites in both lambdas (regular + output-allocator paths) with per-engine + passthrough + pool-fallback precedence; cudagraph mutual-exclusion guards for both
core/runtime/runtime.{h,cpp} ENGINE_STREAM_PASSTHROUGH global + get_/set_engine_stream_passthrough() accessors
core/runtime/register_jit_hooks.cpp torchbind exposure for the three per-engine methods + the two passthrough globals
py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py Python runtime parity
py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py passthrough on the C++-backed runtime
py/torch_tensorrt/runtime/_external_stream.py top-level facade + context manager + passthrough toggles
py/torch_tensorrt/runtime/__init__.py re-export
tests/py/dynamo/runtime/test_006_external_stream.py both runtimes (swap, clear, per-engine binding, cudagraph guard, validation, passthrough routing, passthrough+cudagraph mutex)

Test plan

  • pytest tests/py/dynamo/runtime/test_006_external_stream.py — covers both PythonTorchTensorRTModule and TorchTensorRTModule runtime classes, including the new passthrough tests.
  • GPU runtime test on H100 / Hopper: register a green-context-bound stream, run a small TRT engine, verify via nsys profile that kernel launches are confined to the green context's SM partition.
  • GPU runtime test on Jetson Blackwell (sm_110).
  • AOTI C++ test: set_engine_stream_passthrough(true) + wrap AOTIModelPackageLoader::run() with a CUDAStreamGuard on a green-context stream, verify SM-partitioned execution in nsys.

Out of scope (future PRs)

  • Upstream PyTorch PR to add AOTIModelPackageLoader::get_custom_objs() so AOTI users can also use the per-engine API (when they want different streams per submodule inside one .pt2). The passthrough flag in this PR is the interim mechanism while that lands and reaches stable.
  • torch_tensorrt::aoti::TRTAOTILoader C++ wrapper behind a TORCH_TRT_HAVE_AOTI_CUSTOM_OBJS CMake probe — depends on the upstream PR.
  • NCCL + green context interaction. Distributed (NCCL collectives) on green-context-partitioned streams is not validated; the existing requires_native_multidevice path may need follow-up if a user combines both.

@meta-cla meta-cla Bot added the cla signed label May 1, 2026
@github-actions github-actions Bot added component: tests Issues re: Tests component: core Issues re: The core compiler component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels May 1, 2026
@github-actions github-actions Bot requested a review from narendasan May 1, 2026 15:46
…A stream

Adds opt-in support for binding torch-tensorrt's TRT engine execution to
externally-managed CUDA streams -- typically streams created via
`cuGreenCtxStreamCreate` for SM partitioning via CUDA Green Contexts (cuda
12.4+).

Currently, the runtime in `core/runtime/execute_engine.cpp` lazily pulls a
stream from torch's global stream pool on first execute. That pool is bound
to the primary CUDA context, so even when a caller sets a green-context-bound
stream as current, the TRT engine bypasses it and uses a primary-context pool
stream -- defeating any SM partitioning the caller set up.

Pure additive: no behavior change for callers that don't opt in.

This change adds two complementary mechanisms:

(1) Per-engine binding (Python / dynamo / Exported Program path):
    - C++ API on `TRTEngine` (exposed via torchbind):
        void set_external_stream(int64_t stream_handle);
        void clear_external_stream();
        int64_t get_external_stream() const;
      The handle is `reinterpret_cast<int64_t>(cudaStream_t)`. Reachable from
      Python and external C++ via `torch.classes.tensorrt.Engine`.
    - Python facade in `torch_tensorrt.runtime.set_external_stream(module, ...)`
      with optional per-engine binding via `Dict[submodule_name, StreamLike]`
      and RAII context-manager semantics. Walks `named_modules()` so deeply
      nested TRT submodules (e.g. HF blocks under wrapper GraphModules) are
      reachable.

(2) Process-wide stream passthrough (AOTI / .pt2 C++ path):
    - New global flag `ENGINE_STREAM_PASSTHROUGH` and accessors:
        bool get_engine_stream_passthrough();
        void set_engine_stream_passthrough(bool);
      When enabled, `execute_engine` honors the caller's *current* CUDA stream
      (`c10::cuda::getCurrentCUDAStream`) instead of acquiring a pool stream.
      This unblocks `output_format="aot_inductor"` users whose `TRTEngine`
      torchbind constants live inside `OSSProxyExecutor::custom_objs_`
      (private, no public PyTorch accessor) and so are unreachable for the
      per-engine API. Users wrap `loader.run(...)` in a `CUDAStreamGuard`
      bound to e.g. a Green Context stream and the engine inherits it.
    - Python facade: `torch_tensorrt.runtime.set_engine_stream_passthrough(bool)`
      / `get_engine_stream_passthrough()`.

Mutual exclusion with CUDA Graphs is enforced for both mechanisms (throws at
execute time). Setter and clearer also invalidate any captured graph so a
subsequent recapture happens cleanly (avoids replaying against a stale stream
identity).

Multi-GPU correctness: `engine_stream` and `caller_stream` are now pinned to
the engine's actual `device_info.id` in the constructor body (the in-class
initializers default to device 0; without this, the lazy pool re-acquire in
`execute_engine` skipped firing on `cuda:N` for `N>0` because the
`engine_stream == getDefaultCUDAStream(current_device_id)` check was always
false).

Same code path serves both the C++ AOTI runtime (model.so dispatch into
`execute_engine.cpp` via the C-shim) and the dynamo Python runtime
(`PythonTorchTensorRTModule`). Per-engine binding lets callers map distinct
green contexts to distinct TRT subgraphs in one compiled model. The
process-wide passthrough is the alternative for callers who can't reach the
engines individually (AOTI's private custom_objs_ map being the canonical
case).

Files changed:
- core/runtime/TRTEngine.{h,cpp}                                 setter / clearer / getter, `external_stream` and `engine_stream_is_external` fields, cudagraph invalidation in setter & clearer, multi-GPU default-stream init in ctor
- core/runtime/execute_engine.cpp                                stream-resolve sites in both lambdas (regular + output-allocator paths) with per-engine + passthrough + pool fallback precedence; cudagraph mutual-exclusion guards
- core/runtime/runtime.{h,cpp}                                   `ENGINE_STREAM_PASSTHROUGH` global + accessors
- core/runtime/register_jit_hooks.cpp                            torchbind exposure for all three per-engine methods + the two passthrough globals
- py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py Python runtime parity
- py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py       passthrough on the C++-backed runtime
- py/torch_tensorrt/runtime/_external_stream.py                  top-level facade + context manager + passthrough toggles
- py/torch_tensorrt/runtime/__init__.py                          re-export
- tests/py/dynamo/runtime/test_006_external_stream.py            covers both runtimes (swap, clear, per-engine binding, cudagraph guard, validation, passthrough routing, passthrough+cudagraph mutex)

Test plan:
- pytest tests/py/dynamo/runtime/test_006_external_stream.py -- both
  PythonTorchTensorRTModule and TorchTensorRTModule runtime classes.
- GPU runtime test on H100 / Hopper: register a green-context-bound stream,
  run a small TRT engine, verify via nsys profile that kernel launches are
  confined to the green context's SM partition.
- GPU runtime test on Jetson Thor (Blackwell): same as above with sm_110.
- AOTI C++ test: `set_engine_stream_passthrough(True)`, wrap
  `AOTIModelPackageLoader::run()` with a `CUDAStreamGuard` on a green-context
  stream, verify SM-partitioned execution.

Open item (deliberately not in this commit, can land separately):
- Device-affinity validation in `set_external_stream`. The current sanity
  check (`cudaStreamGetFlags`) confirms the handle is real but does not
  validate the stream's device against `device_info.id`. A multi-GPU caller
  could silently bind a wrong-device stream. Clean fix uses `cuStreamGetCtx`
  + `cuCtxGetDevice` (driver API) or `cudaStreamGetDevice` (CUDA 12.5+).
@shoumikhin shoumikhin force-pushed the green-context-external-stream-upstream branch from bfa0fea to a0434e4 Compare May 1, 2026 16:51
@narendasan narendasan requested a review from cehongwang May 1, 2026 17:02
@shoumikhin
Copy link
Copy Markdown
Author

shoumikhin commented May 1, 2026

Long-term plan: upstream PyTorch PR

Opened pytorch/pytorch#182149 to add AOTIModelPackageLoader::get_custom_objs(). Once it lands and reaches a tagged PyTorch release, AOTI / .pt2 C++ users can reach the live TRTEngine torchbind instances inside the loaded .so and use the per-engine set_external_stream API directly (no need for the process-wide set_engine_stream_passthrough flag).

torch::inductor::AOTIModelPackageLoader loader("model.pt2");
for (auto& [name, ivalue] : loader.get_custom_objs()) {
  if (auto e = ivalue.toCustomClass<torch_tensorrt::TRTEngine>()) {
    e->set_external_stream(reinterpret_cast<int64_t>(my_green_stream));
  }
}
loader.run(inputs);

This PR (#4232) ships set_engine_stream_passthrough as the interim mechanism so edge / on-device users are unblocked today. A follow-up torch_tensorrt::aoti::TRTAOTILoader wrapper (gated on a CMake TORCH_TRT_HAVE_AOTI_CUSTOM_OBJS probe) will provide the clean per-engine API once upstream lands.

shoumikhin and others added 6 commits May 1, 2026 11:23
Bundle 1 (must-fix from reviewers):

1. Device-affinity validation in set_external_stream
   - cuStreamGetCtx + cuCtxPushCurrent + cuCtxGetDevice resolves the stream's
     device and asserts it matches engine.device_info.id. Catches the silent
     cross-device launch (cuda:1 stream bound to cuda:0 engine) before any
     enqueueV3, where the failure would otherwise surface as a confusing
     CUDA error far from the bind site.

2. Reject magic stream values
   - cudaStreamLegacy / cudaStreamPerThread are now explicitly rejected.
     Binding them latches engine_stream_is_external onto a non-isolated
     stream that defeats the whole point of the API.

3. Atomic rollback on partial multi-engine bind (Python facade)
   - The set_external_stream loop now records each successful application
     and reverses them on any failure, so an engine's per-handle validation
     throwing midway through a Dict-shaped binding can no longer leave
     earlier engines in a half-bound state.

4. Re-entrancy / deadlock fix
   - mu is now std::recursive_mutex everywhere on TRTEngine. Allows TRT
     plugin -> Python -> set_external_stream re-entry on the same thread
     without self-deadlock. Zero downside for the non-reentrant path.

5. Cudagraph mutual-exclusion check moved to set time
   - set_external_stream now asserts CUDAGRAPHS_MODE == STANDARD up front
     instead of waiting until next execute. Faster failure, clearer call
     site, no wasted input migration etc. before the throw. The execute-
     time guard remains as defense-in-depth (covers cudagraphs being
     enabled AFTER an external stream is bound).

6. is_external_stream_set() companion accessor
   - Avoids the ambiguous get_external_stream() == 0 sentinel pattern.
     ABI-safe, cheap, exposed via torchbind.

7. Error message typo fix
   - 'wraps a non-null CUDA stream is required' -> 'must wrap a non-null
     CUDA stream'.

Defer to follow-up: Python torch.cuda.default_stream(self.device) one-char
fix, additional tests (green-context smoke, restore-non-zero-prior,
serialize round-trip), passthrough relocation, NCCL+external_stream
LOG_WARNING.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: runtime component: tests Issues re: Tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant