Skip to content

feat: TorchTRT Annotation Layer for Cuda generated kernels#4199

Open
bowang007 wants to merge 2 commits intomainfrom
tta_cuda_plugin
Open

feat: TorchTRT Annotation Layer for Cuda generated kernels#4199
bowang007 wants to merge 2 commits intomainfrom
tta_cuda_plugin

Conversation

@bowang007
Copy link
Copy Markdown
Collaborator

@bowang007 bowang007 commented Apr 21, 2026

Description

This PR introduces torch_tensorrt.annotation, an experimental module for registering hand-written CUDA C++ kernels as both PyTorch custom ops (for eager execution) and TensorRT Quick Deployable Plugins with AOT support (for torch_tensorrt.compile).

Usage

  import torch, torch_tensorrt
  import torch_tensorrt.annotation as tta                                                                                                                           
   
  CU = """                                                                                                                                                          
  extern "C" __global__ void my_sigmoid(const float* x, int n, float* y) {
      int i = blockIdx.x * blockDim.x + threadIdx.x;
      if (i < n) y[i] = 1.0f / (1.0f + __expf(-x[i]));                                                                                                              
  }
  """                                                                                                                                                               
                  
  tta.auto_cuda_kernel_plugin(                                                                                                                                      
      "ann_ex::sigmoid",
      tta.KernelSpec(                                                                                                                                               
          kernel_source=CU, kernel_name="my_sigmoid",
          inputs=[tta.InputDecl("x")],                                                                                                                              
          outputs=[tta.OutputDecl("y", shape=tta.SameAs(0))],
          extras=[tta.Numel("x")],                                                                                                                                  
          geometry=tta.Elementwise(block=(256,), layout="flat"),                                                                                                    
      ),
  )   

After this call, torch.ops.ann_ex.sigmoid is available in eager and is embedded as a TensorRT plugin during torch_tensorrt.compile. The meta function, eager
launch, AOT implementation, and PyTorch schema are all derived from the KernelSpec.

API Surface

The module exposes two primary entry points, layered by declarativeness:

auto_cuda_kernel_plugin is the recommended default. The caller supplies a KernelSpec dataclass describing the kernel's inputs, outputs (with a shape relation such
as SameAs or ReduceDims), scalar extras (Numel, DimSize), and launch geometry (Elementwise or Reduction). The framework derives the meta function, eager CUDA
launch, TensorRT AOT implementation, and PyTorch schema. This path covers pointwise kernels (1-D flat or N-D grid launches), reductions (with optional keepdim),
multi-input kernels, and scalar (non-tensor) kernel arguments via ScalarInput.

manual_cuda_kernel_plugin is the lower-level alternative for kernels outside the declarative DSL — shape-changing outputs, multi-output kernels, or non-standard
launch geometries. The caller provides eager_fn and aot_fn directly; the decorator still registers the PyTorch op, TRT plugin, AOT implementation, and converter
in a single call.

A Custom(fn=...) geometry is also available for callers who want the declarative path's schema/meta derivation but need to hand-write the TRT KernelLaunchParams.

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@meta-cla meta-cla Bot added the cla signed label Apr 21, 2026
@github-actions github-actions Bot added component: tests Issues re: Tests component: api [Python] Issues re: Python API labels Apr 21, 2026
@github-actions github-actions Bot requested a review from lanluo-nvidia April 21, 2026 16:55
github-actions[bot]

This comment was marked as outdated.

@bowang007 bowang007 marked this pull request as draft April 21, 2026 16:56
@github-actions github-actions Bot added the component: build system Issues re: Build system label Apr 22, 2026
github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

@github-actions github-actions Bot added the documentation Improvements or additions to documentation label Apr 22, 2026
github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

@bowang007 bowang007 requested a review from narendasan April 22, 2026 18:06
@bowang007 bowang007 marked this pull request as ready for review April 22, 2026 18:09
return (0, torch.cuda.current_stream().cuda_stream)


def _eager_repeat2(x: torch.Tensor) -> torch.Tensor:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is like "launch_fn" we planned for other kernel providers. Is it possible to keep only this one and infer _aot_repeat2 from this?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BowenFu No, the aot_repeat func is required since it provides TensorRT symbolic launch metadata

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marking as resolved.

aot_fn=_aot_repeat2,
supports_dynamic_shapes=True,
)
def _repeat2_meta(x: torch.Tensor) -> torch.Tensor:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a dead fn?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BowenFu This is fakeTensor functions used for shape inference. PyTorch relies on to infer the output shape

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marking as resolved.

Copy link
Copy Markdown
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool I think this is getting really close. I think we just have a few naming things to make this more user friendly and I think we should let users provide PTX directly in addition to the cuda apis. Also did you add nvrtc as an optional dependency in the pyproject.toml) (maybe under an a extras called kernels)?

@@ -0,0 +1,156 @@
.. _torch_tensorrt_annotation_py:

torch_tensorrt.annotation
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be called torch_tensorrt.kernels

try:
import cuda.core # noqa: F401
except ImportError:
try:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example code just assume all these libraries are there. makes it easier to read

kernel_source=CU_SIGMOID,
kernel_name="my_sigmoid",
inputs=[tta.InputDecl("x")],
outputs=[tta.OutputDecl("y", shape=tta.SameAs(0))],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to use names of other known symbols (like this could be tta.SameAs("x"))

return params, extra


@tta.manual_cuda_kernel_plugin(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the distinction between manual and auto from a user perspective?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me it just seems like the function bodies are just manually configured here, why not just support additional kwargs in the same api?

def _torch_op_already_registered(op_name: str) -> bool:
"""Return True if ``op_name`` is already known to the torch dispatcher."""
try:
return bool(torch._C._jit_get_schemas_for_operator(op_name))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theres probably a more modern approach here (like some api in torch.library)

)


def custom_plugin(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is now a torch_tensorrt.annotations(kernels).custom_plugin and a torch_tensorrt.dynamo.conversion.plugins.custom_op. Why cant we just centralize on one?

Copy link
Copy Markdown
Collaborator

@narendasan narendasan May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or make it clear what is used for what by disambiguating the names

# Numel("x") pass x.numel() to the kernel as an int extra.
# Elementwise(flat) 1-D launch over the flattened output; any input rank works.

tta.auto_cuda_kernel_plugin(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can call this something like torch_tensorrt.kernels.cuda_kernel_op

aot_fn=_aot_repeat2,
supports_dynamic_shapes=True,
)
def _repeat2_meta(x: torch.Tensor) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the meta kernel the one that we decorate? to me the obvious thing to decorate is the jit_impl_fn

Copy link
Copy Markdown
Collaborator

@narendasan narendasan May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the manual api as a decorator is somewhat confusing, imo we either do the workflow that we already have with multiple decorators (https://docs.pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/nvrtc_aot_plugin.html) or we dont decorate anything and just have a function that takes all of kernel source, meta, jit and aot as argument

return
from torch_tensorrt.annotation._custom_plugin._nvrtc import compile_to_ptx

_ptx, device, kernel = compile_to_ptx(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also expose a torch_tensorrt.kernels.ptx_op, that just takes externally created valid ptx?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: build system Issues re: Build system component: tests Issues re: Tests documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants