diff --git a/cuda_core/build_hooks.py b/cuda_core/build_hooks.py index 444da18eb13..f4fb4af01f7 100644 --- a/cuda_core/build_hooks.py +++ b/cuda_core/build_hooks.py @@ -186,6 +186,9 @@ def get_sources(mod_name): # On Windows, _tensor_bridge.pyx needs a stub import library so the MSVC # linker can resolve the AOTI symbols (they live in torch_cpu.dll at # runtime). We generate the .lib from a .def file at build time. + # Note: aoti_torch_get_current_cuda_stream lives in torch_cuda.dll and + # is resolved lazily at runtime (not via the stub lib) — see + # _tensor_bridge.pyx. _aoti_extra_link_args = [] if sys.platform == "win32": _def_file = os.path.join("cuda", "core", "_include", "aoti_shim.def") diff --git a/cuda_core/cuda/core/_include/aoti_shim.def b/cuda_core/cuda/core/_include/aoti_shim.def index 5cc6897e815..e21097bd25e 100644 --- a/cuda_core/cuda/core/_include/aoti_shim.def +++ b/cuda_core/cuda/core/_include/aoti_shim.def @@ -34,4 +34,3 @@ EXPORTS aoti_torch_get_device_index aoti_torch_device_type_cpu aoti_torch_device_type_cuda - aoti_torch_get_current_cuda_stream diff --git a/cuda_core/cuda/core/_include/aoti_shim.h b/cuda_core/cuda/core/_include/aoti_shim.h index 809bdb1a2a6..464d27de46c 100644 --- a/cuda_core/cuda/core/_include/aoti_shim.h +++ b/cuda_core/cuda/core/_include/aoti_shim.h @@ -52,10 +52,13 @@ typedef struct AtenTensorOpaque* AtenTensorHandle; /* * IMPORTANT: Keep the AOTI_SHIM_API declaration list below in sync with - * aoti_shim.def. On Windows, build_hooks.py turns that .def file into the + * aoti_shim.def. On Windows, build_hooks.py turns that .def file into the * stub import library that MSVC needs to link _tensor_bridge without making - * PyTorch a build-time dependency. If you add, remove, or rename an imported - * AOTI symbol here, update aoti_shim.def in the same change. + * PyTorch a build-time dependency. If you add, remove, or rename an + * imported AOTI symbol here, update aoti_shim.def in the same change. + * + * Exception: aoti_torch_get_current_cuda_stream lives in torch_cuda (not + * torch_cpu) and is resolved lazily at runtime — see _tensor_bridge.pyx. */ /* ---- tensor metadata --------------------------------------------------- */ @@ -105,10 +108,11 @@ AOTI_SHIM_API AOTITorchError aoti_torch_get_device_index( AOTI_SHIM_API int32_t aoti_torch_device_type_cpu(void); AOTI_SHIM_API int32_t aoti_torch_device_type_cuda(void); -/* ---- stream -------------------------------------------------------------- */ - -AOTI_SHIM_API AOTITorchError aoti_torch_get_current_cuda_stream( - int32_t device_index, void** ret_stream); +/* ---- stream -------------------------------------------------------------- + * aoti_torch_get_current_cuda_stream is NOT declared here — it lives in + * torch_cuda (not torch_cpu) and is resolved at runtime. See the inline + * C helper _resolve_cuda_stream_fn() in _tensor_bridge.pyx. + * ---------------------------------------------------------------------- */ #ifdef __cplusplus } /* extern "C" */ diff --git a/cuda_core/cuda/core/_tensor_bridge.pyx b/cuda_core/cuda/core/_tensor_bridge.pyx index 388ca738dbb..07eec56537b 100644 --- a/cuda_core/cuda/core/_tensor_bridge.pyx +++ b/cuda_core/cuda/core/_tensor_bridge.pyx @@ -103,8 +103,38 @@ cdef extern from "_include/aoti_shim.h": int32_t aoti_torch_device_type_cpu() int32_t aoti_torch_device_type_cuda() - # stream - AOTITorchError aoti_torch_get_current_cuda_stream(int32_t, void**) + # Note: aoti_torch_get_current_cuda_stream is NOT declared here because + # it lives in torch_cuda.dll (not torch_cpu.dll). It is resolved lazily + # at runtime via dlsym / GetProcAddress — see _resolve_cuda_stream_fn(). + +# Runtime resolution for aoti_torch_get_current_cuda_stream. +# This symbol lives in torch_cuda.dll (Windows) / libtorch_cuda.so (Linux), +# NOT in torch_cpu. We resolve it lazily on first use so that the module +# can be imported even with CPU-only PyTorch. +ctypedef AOTITorchError (*_get_cuda_stream_fn_t)(int32_t, void**) nogil + +cdef extern from *: + """ + #ifdef _WIN32 + #include + static void* _resolve_cuda_stream_fn(void) { + HMODULE h = LoadLibraryA("torch_cuda.dll"); + if (!h) return NULL; + return (void*)GetProcAddress(h, "aoti_torch_get_current_cuda_stream"); + } + #else + #include + #ifndef RTLD_DEFAULT + #define RTLD_DEFAULT ((void*)0) + #endif + static void* _resolve_cuda_stream_fn(void) { + return dlsym(RTLD_DEFAULT, "aoti_torch_get_current_cuda_stream"); + } + #endif + """ + void* _resolve_cuda_stream_fn() nogil + +cdef _get_cuda_stream_fn_t _cached_get_cuda_stream = NULL import numpy import sys @@ -274,10 +304,17 @@ cpdef int sync_torch_stream(int32_t device_index, the consumer stream wait on it. This is a no-op if both streams are the same. """ + global _cached_get_cuda_stream cdef void* producer_s cdef EventHandle h_event - check_aoti(aoti_torch_get_current_cuda_stream(device_index, &producer_s), + if _cached_get_cuda_stream == NULL: + _cached_get_cuda_stream = <_get_cuda_stream_fn_t>_resolve_cuda_stream_fn() + if _cached_get_cuda_stream == NULL: + raise RuntimeError( + "Cannot resolve aoti_torch_get_current_cuda_stream from " + "torch_cuda — is CUDA-enabled PyTorch installed?") + check_aoti(_cached_get_cuda_stream(device_index, &producer_s), b"aoti_torch_get_current_cuda_stream") if producer_s != consumer_s: with nogil: