From 0b63c0010259412f183b0d0d8eb1ce899a927918 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Tue, 28 Apr 2026 09:44:15 -0700 Subject: [PATCH 1/4] cuda.core: convert GraphBuilder to cdef class with explicit state machine Refactor GraphBuilder from a Python class using _MembersNeededForFinalize to a cdef class with explicit _BuilderKind (PRIMARY/FORKED/CONDITIONAL_BODY) and _CaptureState (NOT_STARTED/CAPTURING/ENDED) tracking. Cleanup moves into __dealloc__/close, and the builder now uses GraphHandle/StreamHandle from _resource_handles instead of holding raw driver objects. Drop the is_stream_owner flag now that StreamHandle owns the lifetime. End-capture paths in __dealloc__ and close guard on _h_stream so cleanup is safe even if _init* fails before completing assignment. Made-with: Cursor --- cuda_core/cuda/core/_device.pyx | 5 +- cuda_core/cuda/core/_stream.pyx | 7 +- cuda_core/cuda/core/graph/_graph_builder.pxd | 19 ++ cuda_core/cuda/core/graph/_graph_builder.pyx | 331 +++++++++++-------- 4 files changed, 208 insertions(+), 154 deletions(-) create mode 100644 cuda_core/cuda/core/graph/_graph_builder.pxd diff --git a/cuda_core/cuda/core/_device.pyx b/cuda_core/cuda/core/_device.pyx index c0d7f09ee4..d9776a72e8 100644 --- a/cuda_core/cuda/core/_device.pyx +++ b/cuda_core/cuda/core/_device.pyx @@ -14,6 +14,7 @@ import threading from cuda.core._context cimport Context from cuda.core._context import ContextOptions from cuda.core._event cimport Event as cyEvent +from cuda.core.graph._graph_builder cimport GraphBuilder from cuda.core._event import Event, EventOptions from cuda.core._memory._buffer cimport Buffer, MemoryResource from cuda.core._resource_handles cimport ( @@ -1370,10 +1371,8 @@ class Device: Newly created graph builder object. """ - from cuda.core.graph._graph_builder import GraphBuilder - self._check_context_initialized() - return GraphBuilder._init(stream=self.create_stream(), is_stream_owner=True) + return GraphBuilder._init(self.create_stream()) cdef inline int Device_ensure_cuda_initialized() except? -1: diff --git a/cuda_core/cuda/core/_stream.pyx b/cuda_core/cuda/core/_stream.pyx index fdb617f032..c7b1312c17 100644 --- a/cuda_core/cuda/core/_stream.pyx +++ b/cuda_core/cuda/core/_stream.pyx @@ -10,6 +10,7 @@ from libc.stdlib cimport strtol, getenv from cuda.bindings cimport cydriver from cuda.core._event cimport Event as cyEvent +from cuda.core.graph._graph_builder cimport GraphBuilder from cuda.core._utils.cuda_utils cimport ( check_or_create_options, HANDLE_RETURN, @@ -371,9 +372,7 @@ cdef class Stream: Newly created graph builder object. """ - from cuda.core.graph._graph_builder import GraphBuilder - - return GraphBuilder._init(stream=self, is_stream_owner=False) + return GraphBuilder._init(self) # c-only python objects, not public @@ -474,8 +473,6 @@ cdef cydriver.CUstream _handle_from_stream_protocol(obj) except*: # Helper for API functions that accept either Stream or GraphBuilder. Performs # needed checks and returns the relevant stream. cdef Stream Stream_accept(arg, bint allow_stream_protocol=False): - from cuda.core.graph._graph_builder import GraphBuilder - if isinstance(arg, Stream): return (arg) elif isinstance(arg, GraphBuilder): diff --git a/cuda_core/cuda/core/graph/_graph_builder.pxd b/cuda_core/cuda/core/graph/_graph_builder.pxd new file mode 100644 index 0000000000..e224f3a510 --- /dev/null +++ b/cuda_core/cuda/core/graph/_graph_builder.pxd @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from cuda.core._resource_handles cimport GraphHandle, StreamHandle +from cuda.core._stream cimport Stream + + +cdef class GraphBuilder: + cdef: + GraphHandle _h_graph + StreamHandle _h_stream + int _kind + int _state + Stream _stream # cached to avoid reconstruction from _h_stream handle + object __weakref__ + + @staticmethod + cdef GraphBuilder _init(Stream stream) diff --git a/cuda_core/cuda/core/graph/_graph_builder.pyx b/cuda_core/cuda/core/graph/_graph_builder.pyx index 526c95e04a..61629dcc2e 100644 --- a/cuda_core/cuda/core/graph/_graph_builder.pyx +++ b/cuda_core/cuda/core/graph/_graph_builder.pyx @@ -11,7 +11,10 @@ from cuda.bindings cimport cydriver from cuda.core.graph._graph_definition cimport GraphCondition from cuda.core.graph._utils cimport _attach_host_callback_to_graph -from cuda.core._resource_handles cimport as_cu +from cuda.core._resource_handles cimport ( + GraphHandle, StreamHandle, as_cu, as_py, + create_graph_handle, create_graph_handle_ref, +) from cuda.core._stream cimport Stream from cuda.core._utils.cuda_utils cimport HANDLE_RETURN from cuda.core._utils.version cimport cy_binding_version, cy_driver_version @@ -185,7 +188,40 @@ def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> return graph -class GraphBuilder: +# Distinguishes the three kinds of GraphBuilder, which differ in how they +# begin/end stream capture and whether they own the resulting CUgraph. +# Each kind progresses through _CaptureState as follows: +# +# PRIMARY: NOT_STARTED -> CAPTURING -> ENDED +# FORKED: CAPTURING (never transitions; joined and closed) +# CONDITIONAL_BODY: NOT_STARTED -> CAPTURING -> ENDED +# +cdef enum _BuilderKind: + # PRIMARY: The top-level builder created by Device or Stream. Owns the + # captured CUgraph via an owning GraphHandle. Progresses through all three + # capture states; responsible for ending capture if destroyed early. + PRIMARY = 0 + # FORKED: Created by split(). Captures on a private stream forked from the + # primary. Starts in CAPTURING state and never transitions; the user joins + # it back to the primary via join(), which closes the builder. Must NOT + # call cuStreamEndCapture (the driver requires all forked streams to be + # joined first). + FORKED = 1 + # CONDITIONAL_BODY: Created by if_then/if_else/switch/while_loop. Captures + # into a non-owned body graph via cuStreamBeginCaptureToGraph. The body + # graph's lifetime is tied to a parent graph. Progresses through all three + # capture states like PRIMARY. + CONDITIONAL_BODY = 2 + + +# Tracks the capture lifecycle of a GraphBuilder. +cdef enum _CaptureState: + CAPTURE_NOT_STARTED = 0 + CAPTURING = 1 + CAPTURE_ENDED = 2 + + +cdef class GraphBuilder: """A graph under construction by stream capture. A graph groups a set of CUDA kernels and other CUDA operations together and executes @@ -198,63 +234,48 @@ class GraphBuilder: """ - class _MembersNeededForFinalize: - __slots__ = ("conditional_graph", "graph", "is_join_required", "is_stream_owner", "stream") - - def __init__(self, graph_builder_obj, stream_obj, is_stream_owner, conditional_graph, is_join_required): - self.stream = stream_obj - self.is_stream_owner = is_stream_owner - self.graph = None - self.conditional_graph = conditional_graph - self.is_join_required = is_join_required - weakref.finalize(graph_builder_obj, self.close) - - def close(self): - if self.stream: - if not self.is_join_required: - capture_status = handle_return(driver.cuStreamGetCaptureInfo(self.stream.handle))[0] - if capture_status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: - # Note how this condition only occures for the primary graph builder - # This is because calling cuStreamEndCapture streams that were split off of the primary - # would error out with CUDA_ERROR_STREAM_CAPTURE_UNJOINED. - # Therefore, it is currently a requirement that users join all split graph builders - # before a graph builder can be clearly destroyed. - handle_return(driver.cuStreamEndCapture(self.stream.handle)) - if self.is_stream_owner: - self.stream.close() - self.stream = None - if self.graph: - handle_return(driver.cuGraphDestroy(self.graph)) - self.graph = None - self.conditional_graph = None - - __slots__ = ("__weakref__", "_building_ended", "_mnff") - def __init__(self): raise NotImplementedError( "directly creating a Graph object can be ambiguous. Please either " "call Device.create_graph_builder() or stream.create_graph_builder()" ) - @classmethod - def _init(cls, stream, is_stream_owner, conditional_graph=None, is_join_required=False): - self = cls.__new__(cls) - self._mnff = GraphBuilder._MembersNeededForFinalize( - self, stream, is_stream_owner, conditional_graph, is_join_required - ) + def __dealloc__(self): + # Note: _stream could be set to None by cyclic-GC tp_clear before + # __dealloc__, but _h_stream is guaranteed to be valid. + if self._h_stream and self._state == CAPTURING and self._kind != FORKED: + with nogil: + cydriver.cuStreamEndCapture(as_cu(self._h_stream), NULL) - self._building_ended = False + @staticmethod + cdef GraphBuilder _init(Stream stream): + cdef GraphBuilder self = GraphBuilder.__new__(GraphBuilder) + # _h_graph set by begin_building + self._h_stream = stream._h_stream + self._kind = PRIMARY + self._state = CAPTURE_NOT_STARTED + self._stream = stream return self + def close(self): + """Destroy the graph builder.""" + if self._h_stream and self._state == CAPTURING and self._kind != FORKED: + with nogil: + HANDLE_RETURN(cydriver.cuStreamEndCapture(as_cu(self._h_stream), NULL)) + self._h_graph.reset() + self._h_stream.reset() + self._state = CAPTURE_ENDED + self._stream = None + @property def stream(self) -> Stream: """Returns the stream associated with the graph builder.""" - return self._mnff.stream + return self._stream @property def is_join_required(self) -> bool: """Returns True if this graph builder must be joined before building is ended.""" - return self._mnff.is_join_required + return self._kind == FORKED def begin_building(self, mode="relaxed") -> GraphBuilder: """Begins the building process. @@ -272,61 +293,65 @@ class GraphBuilder: Default set to use relaxed. """ - if self._building_ended: - raise RuntimeError("Cannot resume building after building has ended.") - if mode not in ("global", "thread_local", "relaxed"): - raise ValueError(f"Unsupported build mode: {mode}") + if self._state != CAPTURE_NOT_STARTED: + if self._state == CAPTURING: + raise RuntimeError("Graph builder is already building.") + else: + raise RuntimeError("Cannot resume building after building has ended.") + cdef cydriver.CUstreamCaptureMode c_mode if mode == "global": - capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_GLOBAL + c_mode = cydriver.CU_STREAM_CAPTURE_MODE_GLOBAL elif mode == "thread_local": - capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_THREAD_LOCAL + c_mode = cydriver.CU_STREAM_CAPTURE_MODE_THREAD_LOCAL elif mode == "relaxed": - capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_RELAXED + c_mode = cydriver.CU_STREAM_CAPTURE_MODE_RELAXED else: raise ValueError(f"Unsupported build mode: {mode}") - if self._mnff.conditional_graph: - handle_return( - driver.cuStreamBeginCaptureToGraph( - self._mnff.stream.handle, - self._mnff.conditional_graph, - None, # dependencies - None, # dependencyData - 0, # numDependencies - capture_mode, - ) - ) + cdef cydriver.CUstream c_stream = as_cu(self._h_stream) + cdef cydriver.CUgraph c_graph + if self._kind == CONDITIONAL_BODY: + c_graph = as_cu(self._h_graph) + with nogil: + HANDLE_RETURN(cydriver.cuStreamBeginCaptureToGraph( + c_stream, c_graph, NULL, NULL, 0, c_mode)) else: - handle_return(driver.cuStreamBeginCapture(self._mnff.stream.handle, capture_mode)) + with nogil: + HANDLE_RETURN(cydriver.cuStreamBeginCapture(c_stream, c_mode)) + _get_capture_info(c_stream, NULL, &c_graph) + self._h_graph = create_graph_handle(c_graph) + self._state = CAPTURING return self @property def is_building(self) -> bool: """Returns True if the graph builder is currently building.""" - capture_status = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle))[0] - if capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: + cdef cydriver.CUstream c_stream = as_cu(self._h_stream) + cdef cydriver.CUstreamCaptureStatus status + with nogil: + _get_capture_info(c_stream, &status, NULL) + if status == cydriver.CU_STREAM_CAPTURE_STATUS_NONE: return False - elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: + elif status == cydriver.CU_STREAM_CAPTURE_STATUS_ACTIVE: return True - elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_INVALIDATED: + elif status == cydriver.CU_STREAM_CAPTURE_STATUS_INVALIDATED: raise RuntimeError( "Build process encountered an error and has been invalidated. Build process must now be ended." ) else: - raise NotImplementedError(f"Unsupported capture status type received: {capture_status}") + raise NotImplementedError(f"Unsupported capture status type received: {status}") def end_building(self) -> GraphBuilder: """Ends the building process.""" if not self.is_building: raise RuntimeError("Graph builder is not building.") - if self._mnff.conditional_graph: - self._mnff.conditional_graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) - else: - self._mnff.graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) + cdef cydriver.CUstream c_stream = as_cu(self._h_stream) + with nogil: + HANDLE_RETURN(cydriver.cuStreamEndCapture(c_stream, NULL)) # TODO: Resolving https://github.com/NVIDIA/cuda-python/issues/617 would allow us to # resume the build process after the first call to end_building() - self._building_ended = True + self._state = CAPTURE_ENDED return self def complete(self, options: GraphCompleteOptions | None = None) -> "Graph": @@ -343,10 +368,10 @@ class GraphBuilder: The newly built graph. """ - if not self._building_ended: + if self._state != CAPTURE_ENDED: raise RuntimeError("Graph has not finished building.") - return _instantiate_graph(self._mnff.graph, options) + return _instantiate_graph(as_py(self._h_graph), options) def debug_dot_print(self, path, options: GraphDebugPrintOptions | None = None): """Generates a DOT debug file for the graph builder. @@ -359,10 +384,14 @@ class GraphBuilder: Customizable dataclass for the debug print options. """ - if not self._building_ended: + if self._state != CAPTURE_ENDED: raise RuntimeError("Graph has not finished building.") - flags = options._to_flags() if options else 0 - handle_return(driver.cuGraphDebugDotPrint(self._mnff.graph, path, flags)) + cdef unsigned int c_flags = options._to_flags() if options else 0 + cdef cydriver.CUgraph c_graph = as_cu(self._h_graph) + cdef bytes b_path = path.encode() if isinstance(path, str) else path + cdef const char* c_path = b_path + with nogil: + HANDLE_RETURN(cydriver.cuGraphDebugDotPrint(c_graph, c_path, c_flags)) def split(self, count: int) -> tuple[GraphBuilder, ...]: """Splits the original graph builder into multiple graph builders. @@ -385,14 +414,12 @@ class GraphBuilder: if count < 2: raise ValueError(f"Invalid split count: expecting >= 2, got {count}") - event = self._mnff.stream.record() + event = self._stream.record() result = [self] for i in range(count - 1): - stream = self._mnff.stream.device.create_stream() + stream = self._stream.device.create_stream() stream.wait(event) - result.append( - GraphBuilder._init(stream=stream, is_stream_owner=True, conditional_graph=None, is_join_required=True) - ) + result.append(_init_forked(stream)) event.close() return tuple(result) @@ -440,7 +467,7 @@ class GraphBuilder: return self.stream.__cuda_stream__() def _get_conditional_context(self) -> driver.CUcontext: - return self._mnff.stream.context.handle + return self._stream.context.handle def create_condition(self, default_value=None) -> GraphCondition: """Create a condition variable for use with conditional nodes. @@ -471,7 +498,7 @@ class GraphBuilder: default_value = 0 flags = 0 - status, _, graph, *_, _ = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle)) + status, _, graph, *_, _ = handle_return(driver.cuStreamGetCaptureInfo(self._stream.handle)) if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: raise RuntimeError("Cannot create a condition when graph is not being built") @@ -480,42 +507,6 @@ class GraphBuilder: ) return GraphCondition._from_handle(int(raw_handle)) - def _cond_with_params(self, node_params) -> tuple: - # Get current capture info to ensure we're in a valid state - status, _, graph, *deps_info, num_dependencies = handle_return( - driver.cuStreamGetCaptureInfo(self._mnff.stream.handle) - ) - if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: - raise RuntimeError("Cannot add conditional node when not actively capturing") - - # Add the conditional node to the graph - deps_info_update = [ - [handle_return(driver.cuGraphAddNode(graph, *deps_info, num_dependencies, node_params))] - ] + [None] * (len(deps_info) - 1) - - # Update the stream's capture dependencies - handle_return( - driver.cuStreamUpdateCaptureDependencies( - self._mnff.stream.handle, - *deps_info_update, # dependencies, edgeData - 1, # numDependencies - driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, - ) - ) - - # Create new graph builders for each condition - return tuple( - [ - GraphBuilder._init( - stream=self._mnff.stream.device.create_stream(), - is_stream_owner=True, - conditional_graph=node_params.conditional.phGraph_out[i], - is_join_required=False, - ) - for i in range(node_params.conditional.size) - ] - ) - def if_then(self, condition: GraphCondition) -> GraphBuilder: """Adds an if condition branch and returns a new graph builder for it. @@ -550,7 +541,7 @@ class GraphBuilder: node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF node_params.conditional.size = 1 node_params.conditional.ctx = self._get_conditional_context() - return self._cond_with_params(node_params)[0] + return _cond_with_params(self, node_params)[0] def if_else(self, condition: GraphCondition) -> tuple[GraphBuilder, GraphBuilder]: """Adds an if-else condition branch and returns new graph builders for both branches. @@ -586,7 +577,7 @@ class GraphBuilder: node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF node_params.conditional.size = 2 node_params.conditional.ctx = self._get_conditional_context() - return self._cond_with_params(node_params) + return _cond_with_params(self, node_params) def switch(self, condition: GraphCondition, count: int) -> tuple[GraphBuilder, ...]: """Adds a switch condition branch and returns new graph builders for all cases. @@ -625,7 +616,7 @@ class GraphBuilder: node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_SWITCH node_params.conditional.size = count node_params.conditional.ctx = self._get_conditional_context() - return self._cond_with_params(node_params) + return _cond_with_params(self, node_params) def while_loop(self, condition: GraphCondition) -> GraphBuilder: """Adds a while loop and returns a new graph builder for it. @@ -661,18 +652,9 @@ class GraphBuilder: node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE node_params.conditional.size = 1 node_params.conditional.ctx = self._get_conditional_context() - return self._cond_with_params(node_params)[0] - - def close(self): - """Destroy the graph builder. + return _cond_with_params(self, node_params)[0] - Closes the associated stream if we own it. Borrowed stream - object will instead have their references released. - - """ - self._mnff.close() - - def embed(self, child: GraphBuilder): + def embed(self, GraphBuilder child): """Embed a previously-built :obj:`~graph.GraphBuilder` as a child node. Parameters @@ -680,13 +662,13 @@ class GraphBuilder: child : :obj:`~graph.GraphBuilder` The child graph builder. Must have finished building. """ - if not child._building_ended: + if child._state != CAPTURE_ENDED: raise ValueError("Child graph has not finished building.") if not self.is_building: raise ValueError("Parent graph is not being built.") - stream_handle = self._mnff.stream.handle + stream_handle = self._stream.handle _, _, graph_out, *deps_info_out, num_dependencies_out = handle_return( driver.cuStreamGetCaptureInfo(stream_handle) ) @@ -698,7 +680,7 @@ class GraphBuilder: [ handle_return( driver.cuGraphAddChildGraphNode( - graph_out, *deps_info_trimmed, num_dependencies_out, child._mnff.graph + graph_out, *deps_info_trimmed, num_dependencies_out, as_py(child._h_graph) ) ) ] @@ -740,18 +722,13 @@ class GraphBuilder: pointer (caller manages lifetime). If bytes-like, the data is copied and its lifetime is tied to the graph. """ - cdef Stream stream = self._mnff.stream + cdef Stream stream = self._stream cdef cydriver.CUstream c_stream = as_cu(stream._h_stream) cdef cydriver.CUstreamCaptureStatus capture_status cdef cydriver.CUgraph c_graph = NULL with nogil: - IF CUDA_CORE_BUILD_MAJOR >= 13: - HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( - c_stream, &capture_status, NULL, &c_graph, NULL, NULL, NULL)) - ELSE: - HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( - c_stream, &capture_status, NULL, &c_graph, NULL, NULL)) + _get_capture_info(c_stream, &capture_status, &c_graph) if capture_status != cydriver.CU_STREAM_CAPTURE_STATUS_ACTIVE: raise RuntimeError("Cannot add callback when graph is not being built") @@ -764,6 +741,68 @@ class GraphBuilder: HANDLE_RETURN(cydriver.cuLaunchHostFunc(c_stream, c_fn, c_user_data)) +cdef inline GraphBuilder _init_forked(Stream stream): + cdef GraphBuilder gb = GraphBuilder.__new__(GraphBuilder) + # _h_graph not used for FORKED builders. Captures to primary graph. + gb._h_stream = stream._h_stream + gb._kind = FORKED + gb._state = CAPTURING + gb._stream = stream + return gb + + +cdef inline GraphBuilder _init_conditional(Stream stream, cydriver.CUgraph cond_graph, GraphBuilder parent): + cdef GraphBuilder gb = GraphBuilder.__new__(GraphBuilder) + gb._h_graph = create_graph_handle_ref(cond_graph, parent._h_graph) + gb._h_stream = stream._h_stream + gb._kind = CONDITIONAL_BODY + gb._state = CAPTURE_NOT_STARTED + gb._stream = stream + return gb + + +cdef inline int _get_capture_info( + cydriver.CUstream stream, + cydriver.CUstreamCaptureStatus* status, + cydriver.CUgraph* graph) except?-1 nogil: + IF CUDA_CORE_BUILD_MAJOR >= 13: + return HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( + stream, status, NULL, graph, NULL, NULL, NULL)) + ELSE: + return HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( + stream, status, NULL, graph, NULL, NULL)) + + +cdef inline tuple _cond_with_params(GraphBuilder gb, node_params): + status, _, graph, *deps_info, num_dependencies = handle_return( + driver.cuStreamGetCaptureInfo(gb._stream.handle) + ) + if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: + raise RuntimeError("Cannot add conditional node when not actively capturing") + + deps_info_update = [ + [handle_return(driver.cuGraphAddNode(graph, *deps_info, num_dependencies, node_params))] + ] + [None] * (len(deps_info) - 1) + + handle_return( + driver.cuStreamUpdateCaptureDependencies( + gb._stream.handle, + *deps_info_update, # dependencies, edgeData + 1, # numDependencies + driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, + ) + ) + + return tuple( + _init_conditional( + gb._stream.device.create_stream(), + int(node_params.conditional.phGraph_out[i]), + gb, + ) + for i in range(node_params.conditional.size) + ) + + class Graph: """An executable graph. @@ -832,9 +871,9 @@ class Graph: cdef cydriver.CUgraphExec cu_exec = int(self._mnff.graph) if isinstance(source, GraphBuilder): - if not source._building_ended: + if (source)._state != CAPTURE_ENDED: raise ValueError("Graph has not finished building.") - cu_graph = int(source._mnff.graph) + cu_graph = as_cu((source)._h_graph) elif isinstance(source, GraphDefinition): cu_graph = int(source.handle) else: From 035b09d89740e4b21f496c68b3158eafcd152ccc Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Fri, 1 May 2026 15:17:19 -0700 Subject: [PATCH 2/4] cuda.core: convert Graph to cdef class with GraphExecHandle Add a GraphExecHandle to the resource-handle layer (parallel to GraphHandle) wrapping CUgraphExec with RAII cleanup via cuGraphExecDestroy on shared_ptr release. Convert Graph from a Python class using _MembersNeededForFinalize to a cdef class holding a typed _h_graph_exec attribute, dropping the weakref.finalize machinery. update/upload/launch move to nogil cydriver paths consistent with the GraphBuilder rewrite. Also drop quoted forward-reference annotations on create_graph_builder and _instantiate_graph/complete now that GraphBuilder is cimported in _device.pyx and _stream.pyx and Cython accepts the in-module forward reference to Graph. Clears the related "Strings should no longer be used for type declarations" warnings. Made-with: Cursor --- cuda_core/cuda/core/_cpp/resource_handles.cpp | 23 ++++++++ cuda_core/cuda/core/_cpp/resource_handles.hpp | 22 ++++++++ cuda_core/cuda/core/_device.pyx | 2 +- cuda_core/cuda/core/_resource_handles.pxd | 7 +++ cuda_core/cuda/core/_resource_handles.pyx | 7 +++ cuda_core/cuda/core/_stream.pyx | 2 +- cuda_core/cuda/core/graph/_graph_builder.pxd | 13 ++++- cuda_core/cuda/core/graph/_graph_builder.pyx | 54 +++++++++---------- 8 files changed, 97 insertions(+), 33 deletions(-) diff --git a/cuda_core/cuda/core/_cpp/resource_handles.cpp b/cuda_core/cuda/core/_cpp/resource_handles.cpp index 5eb4716b98..029ac46d66 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.cpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.cpp @@ -63,6 +63,7 @@ decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel = nullptr; // Graph decltype(&cuGraphDestroy) p_cuGraphDestroy = nullptr; +decltype(&cuGraphExecDestroy) p_cuGraphExecDestroy = nullptr; // Linker decltype(&cuLinkDestroy) p_cuLinkDestroy = nullptr; @@ -952,6 +953,28 @@ GraphHandle create_graph_handle_ref(CUgraph graph, const GraphHandle& h_parent) return GraphHandle(box, &box->resource); } +// ============================================================================ +// Graph Exec Handles +// ============================================================================ + +namespace { +struct GraphExecBox { + CUgraphExec resource; +}; +} // namespace + +GraphExecHandle create_graph_exec_handle(CUgraphExec graph_exec) { + auto box = std::shared_ptr( + new GraphExecBox{graph_exec}, + [](const GraphExecBox* b) { + GILReleaseGuard gil; + p_cuGraphExecDestroy(b->resource); + delete b; + } + ); + return GraphExecHandle(box, &box->resource); +} + namespace { struct GraphNodeBox { mutable CUgraphNode resource; diff --git a/cuda_core/cuda/core/_cpp/resource_handles.hpp b/cuda_core/cuda/core/_cpp/resource_handles.hpp index 2e6ebb6271..14bd2a0bc4 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.hpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.hpp @@ -94,6 +94,7 @@ extern decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel; // Graph extern decltype(&cuGraphDestroy) p_cuGraphDestroy; +extern decltype(&cuGraphExecDestroy) p_cuGraphExecDestroy; // Linker extern decltype(&cuLinkDestroy) p_cuLinkDestroy; @@ -148,6 +149,7 @@ using MemoryPoolHandle = std::shared_ptr; using LibraryHandle = std::shared_ptr; using KernelHandle = std::shared_ptr; using GraphHandle = std::shared_ptr; +using GraphExecHandle = std::shared_ptr; using GraphNodeHandle = std::shared_ptr; using GraphicsResourceHandle = std::shared_ptr; using NvrtcProgramHandle = std::shared_ptr; @@ -403,6 +405,14 @@ GraphHandle create_graph_handle(CUgraph graph); // but h_parent will be prevented from destruction while this handle exists. GraphHandle create_graph_handle_ref(CUgraph graph, const GraphHandle& h_parent); +// ============================================================================ +// Graph exec handle functions +// ============================================================================ + +// Wrap an externally-created CUgraphExec with RAII cleanup. +// When the last reference is released, cuGraphExecDestroy is called automatically. +GraphExecHandle create_graph_exec_handle(CUgraphExec graph_exec); + // ============================================================================ // Graph node handle functions // ============================================================================ @@ -529,6 +539,10 @@ inline CUgraph as_cu(const GraphHandle& h) noexcept { return h ? *h : nullptr; } +inline CUgraphExec as_cu(const GraphExecHandle& h) noexcept { + return h ? *h : nullptr; +} + inline CUgraphNode as_cu(const GraphNodeHandle& h) noexcept { return h ? *h : nullptr; } @@ -587,6 +601,10 @@ inline std::intptr_t as_intptr(const GraphHandle& h) noexcept { return reinterpret_cast(as_cu(h)); } +inline std::intptr_t as_intptr(const GraphExecHandle& h) noexcept { + return reinterpret_cast(as_cu(h)); +} + inline std::intptr_t as_intptr(const GraphNodeHandle& h) noexcept { return reinterpret_cast(as_cu(h)); } @@ -677,6 +695,10 @@ inline PyObject* as_py(const GraphHandle& h) noexcept { return detail::make_py("cuda.bindings.driver", "CUgraph", as_intptr(h)); } +inline PyObject* as_py(const GraphExecHandle& h) noexcept { + return detail::make_py("cuda.bindings.driver", "CUgraphExec", as_intptr(h)); +} + inline PyObject* as_py(const GraphNodeHandle& h) noexcept { if (!as_intptr(h)) { Py_RETURN_NONE; diff --git a/cuda_core/cuda/core/_device.pyx b/cuda_core/cuda/core/_device.pyx index d9776a72e8..269816b025 100644 --- a/cuda_core/cuda/core/_device.pyx +++ b/cuda_core/cuda/core/_device.pyx @@ -1362,7 +1362,7 @@ class Device: self._check_context_initialized() handle_return(runtime.cudaDeviceSynchronize()) - def create_graph_builder(self) -> "GraphBuilder": + def create_graph_builder(self) -> GraphBuilder: """Create a new :obj:`~graph.GraphBuilder` object. Returns diff --git a/cuda_core/cuda/core/_resource_handles.pxd b/cuda_core/cuda/core/_resource_handles.pxd index 0d7d20e574..a059465403 100644 --- a/cuda_core/cuda/core/_resource_handles.pxd +++ b/cuda_core/cuda/core/_resource_handles.pxd @@ -27,6 +27,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": ctypedef shared_ptr[const cydriver.CUlibrary] LibraryHandle ctypedef shared_ptr[const cydriver.CUkernel] KernelHandle ctypedef shared_ptr[const cydriver.CUgraph] GraphHandle + ctypedef shared_ptr[const cydriver.CUgraphExec] GraphExecHandle ctypedef shared_ptr[const cydriver.CUgraphNode] GraphNodeHandle ctypedef shared_ptr[const cydriver.CUgraphicsResource] GraphicsResourceHandle ctypedef shared_ptr[const cynvrtc.nvrtcProgram] NvrtcProgramHandle @@ -52,6 +53,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": cydriver.CUlibrary as_cu(LibraryHandle h) noexcept nogil cydriver.CUkernel as_cu(KernelHandle h) noexcept nogil cydriver.CUgraph as_cu(GraphHandle h) noexcept nogil + cydriver.CUgraphExec as_cu(GraphExecHandle h) noexcept nogil cydriver.CUgraphNode as_cu(GraphNodeHandle h) noexcept nogil cydriver.CUgraphicsResource as_cu(GraphicsResourceHandle h) noexcept nogil cynvrtc.nvrtcProgram as_cu(NvrtcProgramHandle h) noexcept nogil @@ -68,6 +70,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": intptr_t as_intptr(LibraryHandle h) noexcept nogil intptr_t as_intptr(KernelHandle h) noexcept nogil intptr_t as_intptr(GraphHandle h) noexcept nogil + intptr_t as_intptr(GraphExecHandle h) noexcept nogil intptr_t as_intptr(GraphNodeHandle h) noexcept nogil intptr_t as_intptr(GraphicsResourceHandle h) noexcept nogil intptr_t as_intptr(NvrtcProgramHandle h) noexcept nogil @@ -85,6 +88,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": object as_py(LibraryHandle h) object as_py(KernelHandle h) object as_py(GraphHandle h) + object as_py(GraphExecHandle h) object as_py(GraphNodeHandle h) object as_py(GraphicsResourceHandle h) object as_py(NvrtcProgramHandle h) @@ -183,6 +187,9 @@ cdef LibraryHandle get_kernel_library(const KernelHandle& h) noexcept nogil cdef GraphHandle create_graph_handle(cydriver.CUgraph graph) except+ nogil cdef GraphHandle create_graph_handle_ref(cydriver.CUgraph graph, const GraphHandle& h_parent) except+ nogil +# Graph exec handles +cdef GraphExecHandle create_graph_exec_handle(cydriver.CUgraphExec graph_exec) except+ nogil + # Graph node handles cdef GraphNodeHandle create_graph_node_handle(cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil cdef GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept nogil diff --git a/cuda_core/cuda/core/_resource_handles.pyx b/cuda_core/cuda/core/_resource_handles.pyx index d30993cc5e..2291b1ec20 100644 --- a/cuda_core/cuda/core/_resource_handles.pyx +++ b/cuda_core/cuda/core/_resource_handles.pyx @@ -27,6 +27,7 @@ from ._resource_handles cimport ( LibraryHandle, KernelHandle, GraphHandle, + GraphExecHandle, GraphicsResourceHandle, NvrtcProgramHandle, NvvmProgramHandle, @@ -154,6 +155,10 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": GraphHandle create_graph_handle_ref "cuda_core::create_graph_handle_ref" ( cydriver.CUgraph graph, const GraphHandle& h_parent) except+ nogil + # Graph exec handles + GraphExecHandle create_graph_exec_handle "cuda_core::create_graph_exec_handle" ( + cydriver.CUgraphExec graph_exec) except+ nogil + # Graph node handles GraphNodeHandle create_graph_node_handle "cuda_core::create_graph_node_handle" ( cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil @@ -265,6 +270,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": # Graph void* p_cuGraphDestroy "reinterpret_cast(cuda_core::p_cuGraphDestroy)" + void* p_cuGraphExecDestroy "reinterpret_cast(cuda_core::p_cuGraphExecDestroy)" # Linker void* p_cuLinkDestroy "reinterpret_cast(cuda_core::p_cuLinkDestroy)" @@ -334,6 +340,7 @@ p_cuLibraryGetKernel = _get_driver_fn("cuLibraryGetKernel") # Graph p_cuGraphDestroy = _get_driver_fn("cuGraphDestroy") +p_cuGraphExecDestroy = _get_driver_fn("cuGraphExecDestroy") # Linker p_cuLinkDestroy = _get_driver_fn("cuLinkDestroy") diff --git a/cuda_core/cuda/core/_stream.pyx b/cuda_core/cuda/core/_stream.pyx index c7b1312c17..caf56ee136 100644 --- a/cuda_core/cuda/core/_stream.pyx +++ b/cuda_core/cuda/core/_stream.pyx @@ -361,7 +361,7 @@ cdef class Stream: return Stream._init(obj=_stream_holder()) - def create_graph_builder(self) -> "GraphBuilder": + def create_graph_builder(self) -> GraphBuilder: """Create a new :obj:`~graph.GraphBuilder` object. The new graph builder will be associated with this stream. diff --git a/cuda_core/cuda/core/graph/_graph_builder.pxd b/cuda_core/cuda/core/graph/_graph_builder.pxd index e224f3a510..c33a7d63c1 100644 --- a/cuda_core/cuda/core/graph/_graph_builder.pxd +++ b/cuda_core/cuda/core/graph/_graph_builder.pxd @@ -2,7 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from cuda.core._resource_handles cimport GraphHandle, StreamHandle +from cuda.bindings cimport cydriver + +from cuda.core._resource_handles cimport GraphExecHandle, GraphHandle, StreamHandle from cuda.core._stream cimport Stream @@ -17,3 +19,12 @@ cdef class GraphBuilder: @staticmethod cdef GraphBuilder _init(Stream stream) + + +cdef class Graph: + cdef: + GraphExecHandle _h_graph_exec + object __weakref__ + + @staticmethod + cdef Graph _init(cydriver.CUgraphExec graph_exec) diff --git a/cuda_core/cuda/core/graph/_graph_builder.pyx b/cuda_core/cuda/core/graph/_graph_builder.pyx index 61629dcc2e..f02a0409d8 100644 --- a/cuda_core/cuda/core/graph/_graph_builder.pyx +++ b/cuda_core/cuda/core/graph/_graph_builder.pyx @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -import weakref from dataclasses import dataclass from libc.stdint cimport intptr_t @@ -12,8 +11,8 @@ from cuda.bindings cimport cydriver from cuda.core.graph._graph_definition cimport GraphCondition from cuda.core.graph._utils cimport _attach_host_callback_to_graph from cuda.core._resource_handles cimport ( - GraphHandle, StreamHandle, as_cu, as_py, - create_graph_handle, create_graph_handle_ref, + GraphExecHandle, GraphHandle, StreamHandle, as_cu, as_py, + create_graph_exec_handle, create_graph_handle, create_graph_handle_ref, ) from cuda.core._stream cimport Stream from cuda.core._utils.cuda_utils cimport HANDLE_RETURN @@ -150,7 +149,8 @@ class GraphCompleteOptions: use_node_priority: bool = False -def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> "Graph": +def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> Graph: + cdef cydriver.CUgraphExec c_exec params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS() if options: flags = 0 @@ -165,7 +165,9 @@ def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY params.flags = flags - graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(h_graph, params))) + py_exec = handle_return(driver.cuGraphInstantiateWithParams(h_graph, params)) + c_exec = int(py_exec) + graph = Graph._init(c_exec) if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR: raise RuntimeError( "Instantiation failed for an unexpected reason which is described in the return value of the function." @@ -354,7 +356,7 @@ cdef class GraphBuilder: self._state = CAPTURE_ENDED return self - def complete(self, options: GraphCompleteOptions | None = None) -> "Graph": + def complete(self, options: GraphCompleteOptions | None = None) -> Graph: """Completes the graph builder and returns the built :obj:`~graph.Graph` object. Parameters @@ -803,7 +805,7 @@ cdef inline tuple _cond_with_params(GraphBuilder gb, node_params): ) -class Graph: +cdef class Graph: """An executable graph. A graph groups a set of CUDA kernels and other CUDA operations together and executes @@ -814,32 +816,18 @@ class Graph: """ - class _MembersNeededForFinalize: - __slots__ = "graph" - - def __init__(self, graph_obj, graph): - self.graph = graph - weakref.finalize(graph_obj, self.close) - - def close(self): - if self.graph: - handle_return(driver.cuGraphExecDestroy(self.graph)) - self.graph = None - - __slots__ = ("__weakref__", "_mnff") - def __init__(self): raise RuntimeError("directly constructing a Graph instance is not supported") - @classmethod - def _init(cls, graph): - self = cls.__new__(cls) - self._mnff = Graph._MembersNeededForFinalize(self, graph) + @staticmethod + cdef Graph _init(cydriver.CUgraphExec graph_exec): + cdef Graph self = Graph.__new__(Graph) + self._h_graph_exec = create_graph_exec_handle(graph_exec) return self def close(self): """Destroy the graph.""" - self._mnff.close() + self._h_graph_exec.reset() @property def handle(self) -> driver.CUgraphExec: @@ -851,7 +839,7 @@ class Graph: handle, call ``int()`` on the returned object. """ - return self._mnff.graph + return as_py(self._h_graph_exec) def update(self, source: "GraphBuilder | GraphDefinition") -> None: """Update the graph using a new graph definition. @@ -868,7 +856,7 @@ class Graph: from cuda.core.graph import GraphDefinition cdef cydriver.CUgraph cu_graph - cdef cydriver.CUgraphExec cu_exec = int(self._mnff.graph) + cdef cydriver.CUgraphExec cu_exec = as_cu(self._h_graph_exec) if isinstance(source, GraphBuilder): if (source)._state != CAPTURE_ENDED: @@ -899,7 +887,10 @@ class Graph: The stream in which to upload the graph """ - handle_return(driver.cuGraphUpload(self._mnff.graph, stream.handle)) + cdef cydriver.CUgraphExec c_exec = as_cu(self._h_graph_exec) + cdef cydriver.CUstream c_stream = int(stream.handle) + with nogil: + HANDLE_RETURN(cydriver.cuGraphUpload(c_exec, c_stream)) def launch(self, stream: Stream): """Launches the graph in a stream. @@ -910,4 +901,7 @@ class Graph: The stream in which to launch the graph """ - handle_return(driver.cuGraphLaunch(self._mnff.graph, stream.handle)) + cdef cydriver.CUgraphExec c_exec = as_cu(self._h_graph_exec) + cdef cydriver.CUstream c_stream = int(stream.handle) + with nogil: + HANDLE_RETURN(cydriver.cuGraphLaunch(c_exec, c_stream)) From ae974af76bc04d95de7882650785aa83334da705 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Fri, 1 May 2026 15:19:59 -0700 Subject: [PATCH 3/4] fix(cuda.core): drop unused handle cimports flagged by cython-lint The cdef-class member declarations live in the .pxd, so the .pyx does not need to re-cimport GraphExecHandle, GraphHandle, or StreamHandle. Made-with: Cursor --- cuda_core/cuda/core/graph/_graph_builder.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuda_core/cuda/core/graph/_graph_builder.pyx b/cuda_core/cuda/core/graph/_graph_builder.pyx index f02a0409d8..f27fbb4a06 100644 --- a/cuda_core/cuda/core/graph/_graph_builder.pyx +++ b/cuda_core/cuda/core/graph/_graph_builder.pyx @@ -11,7 +11,7 @@ from cuda.bindings cimport cydriver from cuda.core.graph._graph_definition cimport GraphCondition from cuda.core.graph._utils cimport _attach_host_callback_to_graph from cuda.core._resource_handles cimport ( - GraphExecHandle, GraphHandle, StreamHandle, as_cu, as_py, + as_cu, as_py, create_graph_exec_handle, create_graph_handle, create_graph_handle_ref, ) from cuda.core._stream cimport Stream From 4e71a1e845cf5049d2f9492ecde45306d6fe30a2 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Fri, 1 May 2026 15:41:18 -0700 Subject: [PATCH 4/4] fix(cuda.core): break _stream/_device <-> graph._graph_builder import cycle cimport-ing GraphBuilder at the top of _stream.pyx and _device.pyx made Cython emit a Python-level import of cuda.core.graph._graph_builder during _stream module init. That triggered the chain graph -> _graph_node -> _kernel_arg_handler -> _memory._buffer -> _device, which then re-entered the still-initializing _stream module via "from cuda.core._stream import IsStreamT", failing with ImportError: cannot import name IsStreamT. Restore the original lazy "import GraphBuilder" inside create_graph_builder (Stream and Device) and Stream_accept. The return annotations stay as bare names; "from __future__ import annotations" in both files defers their evaluation, so they need not resolve at function-definition time. Made-with: Cursor --- cuda_core/cuda/core/_device.pyx | 3 ++- cuda_core/cuda/core/_stream.pyx | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/cuda_core/cuda/core/_device.pyx b/cuda_core/cuda/core/_device.pyx index 269816b025..c4fba83006 100644 --- a/cuda_core/cuda/core/_device.pyx +++ b/cuda_core/cuda/core/_device.pyx @@ -14,7 +14,6 @@ import threading from cuda.core._context cimport Context from cuda.core._context import ContextOptions from cuda.core._event cimport Event as cyEvent -from cuda.core.graph._graph_builder cimport GraphBuilder from cuda.core._event import Event, EventOptions from cuda.core._memory._buffer cimport Buffer, MemoryResource from cuda.core._resource_handles cimport ( @@ -1371,6 +1370,8 @@ class Device: Newly created graph builder object. """ + from cuda.core.graph._graph_builder import GraphBuilder + self._check_context_initialized() return GraphBuilder._init(self.create_stream()) diff --git a/cuda_core/cuda/core/_stream.pyx b/cuda_core/cuda/core/_stream.pyx index caf56ee136..a2bf0e025c 100644 --- a/cuda_core/cuda/core/_stream.pyx +++ b/cuda_core/cuda/core/_stream.pyx @@ -10,7 +10,6 @@ from libc.stdlib cimport strtol, getenv from cuda.bindings cimport cydriver from cuda.core._event cimport Event as cyEvent -from cuda.core.graph._graph_builder cimport GraphBuilder from cuda.core._utils.cuda_utils cimport ( check_or_create_options, HANDLE_RETURN, @@ -372,6 +371,8 @@ cdef class Stream: Newly created graph builder object. """ + from cuda.core.graph._graph_builder import GraphBuilder + return GraphBuilder._init(self) @@ -473,6 +474,8 @@ cdef cydriver.CUstream _handle_from_stream_protocol(obj) except*: # Helper for API functions that accept either Stream or GraphBuilder. Performs # needed checks and returns the relevant stream. cdef Stream Stream_accept(arg, bint allow_stream_protocol=False): + from cuda.core.graph._graph_builder import GraphBuilder + if isinstance(arg, Stream): return (arg) elif isinstance(arg, GraphBuilder):