Skip to content

fix: torch-TRT runtime cache attribute + standard-TRT fast refit regression#4225

Merged
zewenli98 merged 3 commits into
pytorch:mainfrom
tp5uiuc:fix/torch-trt-refit-and-runtime-cache-cleanup
May 4, 2026
Merged

fix: torch-TRT runtime cache attribute + standard-TRT fast refit regression#4225
zewenli98 merged 3 commits into
pytorch:mainfrom
tp5uiuc:fix/torch-trt-refit-and-runtime-cache-cleanup

Conversation

@tp5uiuc
Copy link
Copy Markdown
Collaborator

@tp5uiuc tp5uiuc commented Apr 30, 2026

Summary

Two related fixes to issues introduced by recent torch-TensorRT runtime-cache and refit work, both surfaced by the L2 dynamo compile tests.

  • fix(runtime): preserve runtime_cache/runtime_config on rehydration pathsPythonTorchTensorRTModule.__getstate__ pops runtime_config and runtime_cache (they hold non-picklable native handles), but __setstate__ never restored them. setup_engine() only re-creates them inside _setup_runtime_config(), which is gated by ENABLED_FEATURES.tensorrt_rtx. On standard TRT the gate is false, so on every unpickle / deepcopy the attributes simply do not exist — __del__ -> _save_runtime_cache() then reads self.runtime_cache and Python emits PytestUnraisableExceptionWarning across the refit and weight-stripped-engine test suites. Re-init both fields to None in __setstate__ and _load_from_state_dict before setup_engine(). No behavior change on TRT-RTX.
  • fix(refit): scope unset-weights strict check to TRT-RTX only — the strict cross-check unset_weights = {w for w in weight_list if w not in mapping} added previously broke fast refit on standard TRT for any engine with CONSTANT layers intentionally absent from the mapping (e.g. batch-norm eps constants baked at build time). The pre-existing warn-and-continue branch makes that absence the contract; the strict check made it an error. On a resnet18 disk-engine-cache hit, the path now asserts with "0 missing, 20 unset" and _pretraced_backend silently falls back to GraphModule (visible as XPASS on test_dynamo_compile_with_default_disk_engine_cache and test_torch_compile_with_default_disk_engine_cache). The strict check exists only to guard the TRT-RTX case where each weight lives in its own independent wtsEngine and get_missing_weights() can under-report; on standard TRT, get_missing_weights() is authoritative because connected weight engines surface any unset weight transitively. Gate the assertion behind ENABLED_FEATURES.tensorrt_rtx.
  • chore(refit): drop unused CPU_DEVICE import — pre-existing F401 surfaced by ruff on the touched file.

Test plan

Verified end-to-end on a fresh standard-TRT (non-RTX) install of nightly torch_tensorrt on Linux + CUDA 13.0:

  • Fix 1: pytest tests/py/dynamo/models/test_model_refit.py::test_refit_one_engine_with_weightmap test_refit_one_engine_python_runtime_with_weightmap test_complex_buffer_with_real_param_refit — baseline reproduces 2× AttributeError: ...runtime_cache; with patch, 0 occurrences.
  • Fix 2: pytest tests/py/dynamo/models/test_engine_cache.py::TestEngineCache::test_dynamo_compile_with_default_disk_engine_cache test_torch_compile_with_default_disk_engine_cache — baseline reproduces 1× AssertionError: Fast refitting failed due to incomplete mapping (0 missing, 20 unset) plus Returning GraphModule forward instead silent fallback; with patch, 0 of either, TRT compilation actually succeeds on the cache-hit path.
  • CI sweep on Windows L2 dynamo compile job to confirm the original PytestUnraisableExceptionWarning groups disappear.

🤖 Generated with Claude Code

tp5uiuc added 2 commits April 30, 2026 14:32
PythonTorchTensorRTModule.__getstate__ pops runtime_config and runtime_cache
(they hold non-picklable native handles), but __setstate__ never restored
them. setup_engine() only re-creates them inside _setup_runtime_config(),
which is gated by ENABLED_FEATURES.tensorrt_rtx. On standard TRT the gate is
false, so on every unpickle / deepcopy the attributes simply do not exist.
__del__ -> _save_runtime_cache() then reads self.runtime_cache, nn.Module's
__getattr__ raises AttributeError, and Python emits a
PytestUnraisableExceptionWarning across the refit and weight-stripped-engine
test suites.

Re-initialize both fields to None inside __setstate__ before calling
setup_engine(). Mirror the same init in _load_from_state_dict so the method
is self-contained even though __init__ usually runs first on that path.

Also annotate the deliberate "import tensorrt as trt" placement (after
torch_tensorrt, so the tensorrt_rtx alias resolves) with an isort:skip
marker, mirroring the convention already used in _refit.py and the
weight-stripped-engine test.

No behavior change on TRT-RTX, where setup_engine() proceeds to populate
the real handles via _setup_runtime_config() as before.

Signed-off-by: tejaswinp <tejaswinp@nvidia.com>
Pre-existing F401 surfaced by ruff when other changes touch this file.

Signed-off-by: tejaswinp <tejaswinp@nvidia.com>
@meta-cla meta-cla Bot added the cla signed label Apr 30, 2026
@github-actions github-actions Bot added component: core Issues re: The core compiler component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Apr 30, 2026
@github-actions github-actions Bot requested a review from narendasan April 30, 2026 21:41
The strict check added in pytorch#4198 — comparing weights actually set against the
full engine weight list — broke fast refit on standard TRT for any engine
that contains CONSTANT layers intentionally absent from the mapping (e.g.
batch-norm eps constants, which are baked in at build time and not
expected to be refit). The pre-existing warn-and-continue branch makes
that absence the contract; the strict check made it an error.

Concretely, on standard TRT compiling resnet18 through torch.compile + the
disk engine cache, the cache-hit path now asserts with "0 missing, 20
unset". _pretraced_backend swallows the assertion and falls back to the
plain GraphModule, silently disabling TRT compilation for the test.

The strict check exists for the TRT-RTX case where each weight lives in
its own independent wtsEngine and get_missing_weights() can under-report.
On standard TRT, get_missing_weights() is authoritative because connected
weight engines surface any unset weight transitively, so the additional
unset-weights cross-check is unnecessary and actively wrong.

Gate the unset-weights assertion behind ENABLED_FEATURES.tensorrt_rtx to
restore the standard-TRT contract while keeping the TRT-RTX safety net.
Also rewrite both fast-refit assertion messages to surface counts vs total
plus example unset weights for diagnostic purposes.

Signed-off-by: tejaswinp <tejaswinp@nvidia.com>
@tp5uiuc tp5uiuc force-pushed the fix/torch-trt-refit-and-runtime-cache-cleanup branch from ec2d418 to dc978a0 Compare April 30, 2026 21:58
Copy link
Copy Markdown
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zewenli98 zewenli98 merged commit e85844a into pytorch:main May 4, 2026
78 of 82 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: runtime

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants