feat: TorchTRT Annotation Layer for Cuda generated kernels#4199
feat: TorchTRT Annotation Layer for Cuda generated kernels#4199
Conversation
d3a0651 to
abaaf96
Compare
b41c684 to
3b4dc2b
Compare
| return (0, torch.cuda.current_stream().cuda_stream) | ||
|
|
||
|
|
||
| def _eager_repeat2(x: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
This is like "launch_fn" we planned for other kernel providers. Is it possible to keep only this one and infer _aot_repeat2 from this?
There was a problem hiding this comment.
@BowenFu No, the aot_repeat func is required since it provides TensorRT symbolic launch metadata
| aot_fn=_aot_repeat2, | ||
| supports_dynamic_shapes=True, | ||
| ) | ||
| def _repeat2_meta(x: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
@BowenFu This is fakeTensor functions used for shape inference. PyTorch relies on to infer the output shape
narendasan
left a comment
There was a problem hiding this comment.
Cool I think this is getting really close. I think we just have a few naming things to make this more user friendly and I think we should let users provide PTX directly in addition to the cuda apis. Also did you add nvrtc as an optional dependency in the pyproject.toml) (maybe under an a extras called kernels)?
| @@ -0,0 +1,156 @@ | |||
| .. _torch_tensorrt_annotation_py: | |||
|
|
|||
| torch_tensorrt.annotation | |||
There was a problem hiding this comment.
I think this should be called torch_tensorrt.kernels
| try: | ||
| import cuda.core # noqa: F401 | ||
| except ImportError: | ||
| try: |
There was a problem hiding this comment.
For example code just assume all these libraries are there. makes it easier to read
| kernel_source=CU_SIGMOID, | ||
| kernel_name="my_sigmoid", | ||
| inputs=[tta.InputDecl("x")], | ||
| outputs=[tta.OutputDecl("y", shape=tta.SameAs(0))], |
There was a problem hiding this comment.
You should be able to use names of other known symbols (like this could be tta.SameAs("x"))
| return params, extra | ||
|
|
||
|
|
||
| @tta.manual_cuda_kernel_plugin( |
There was a problem hiding this comment.
What is the distinction between manual and auto from a user perspective?
There was a problem hiding this comment.
To me it just seems like the function bodies are just manually configured here, why not just support additional kwargs in the same api?
| def _torch_op_already_registered(op_name: str) -> bool: | ||
| """Return True if ``op_name`` is already known to the torch dispatcher.""" | ||
| try: | ||
| return bool(torch._C._jit_get_schemas_for_operator(op_name)) |
There was a problem hiding this comment.
Theres probably a more modern approach here (like some api in torch.library)
| ) | ||
|
|
||
|
|
||
| def custom_plugin( |
There was a problem hiding this comment.
there is now a torch_tensorrt.annotations(kernels).custom_plugin and a torch_tensorrt.dynamo.conversion.plugins.custom_op. Why cant we just centralize on one?
There was a problem hiding this comment.
Or make it clear what is used for what by disambiguating the names
| # Numel("x") pass x.numel() to the kernel as an int extra. | ||
| # Elementwise(flat) 1-D launch over the flattened output; any input rank works. | ||
|
|
||
| tta.auto_cuda_kernel_plugin( |
There was a problem hiding this comment.
maybe we can call this something like torch_tensorrt.kernels.cuda_kernel_op
| aot_fn=_aot_repeat2, | ||
| supports_dynamic_shapes=True, | ||
| ) | ||
| def _repeat2_meta(x: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Why is the meta kernel the one that we decorate? to me the obvious thing to decorate is the jit_impl_fn
There was a problem hiding this comment.
I think the manual api as a decorator is somewhat confusing, imo we either do the workflow that we already have with multiple decorators (https://docs.pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/nvrtc_aot_plugin.html) or we dont decorate anything and just have a function that takes all of kernel source, meta, jit and aot as argument
| return | ||
| from torch_tensorrt.annotation._custom_plugin._nvrtc import compile_to_ptx | ||
|
|
||
| _ptx, device, kernel = compile_to_ptx( |
There was a problem hiding this comment.
Could we also expose a torch_tensorrt.kernels.ptx_op, that just takes externally created valid ptx?
Description
This PR introduces torch_tensorrt.annotation, an experimental module for registering hand-written CUDA C++ kernels as both PyTorch custom ops (for eager execution) and TensorRT Quick Deployable Plugins with AOT support (for torch_tensorrt.compile).
Usage
After this call, torch.ops.ann_ex.sigmoid is available in eager and is embedded as a TensorRT plugin during torch_tensorrt.compile. The meta function, eager
launch, AOT implementation, and PyTorch schema are all derived from the KernelSpec.
API Surface
The module exposes two primary entry points, layered by declarativeness:
auto_cuda_kernel_plugin is the recommended default. The caller supplies a KernelSpec dataclass describing the kernel's inputs, outputs (with a shape relation such
as SameAs or ReduceDims), scalar extras (Numel, DimSize), and launch geometry (Elementwise or Reduction). The framework derives the meta function, eager CUDA
launch, TensorRT AOT implementation, and PyTorch schema. This path covers pointwise kernels (1-D flat or N-D grid launches), reductions (with optional keepdim),
multi-input kernels, and scalar (non-tensor) kernel arguments via ScalarInput.
manual_cuda_kernel_plugin is the lower-level alternative for kernels outside the declarative DSL — shape-changing outputs, multi-output kernels, or non-standard
launch geometries. The caller provides eager_fn and aot_fn directly; the decorator still registers the PyTorch op, TRT plugin, AOT implementation, and converter
in a single call.
A Custom(fn=...) geometry is also available for callers who want the declarative path's schema/meta derivation but need to hand-write the TRT KernelLaunchParams.
Type of change
Checklist: