diff --git a/cuda_core/cuda/core/_cpp/resource_handles.cpp b/cuda_core/cuda/core/_cpp/resource_handles.cpp index 6e82d734e35..8eba111f455 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.cpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.cpp @@ -69,6 +69,7 @@ decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel = nullptr; // Graph decltype(&cuGraphDestroy) p_cuGraphDestroy = nullptr; +decltype(&cuGraphExecDestroy) p_cuGraphExecDestroy = nullptr; // Linker decltype(&cuLinkDestroy) p_cuLinkDestroy = nullptr; @@ -1126,6 +1127,28 @@ GraphHandle create_graph_handle_ref(CUgraph graph, const GraphHandle& h_parent) return GraphHandle(box, &box->resource); } +// ============================================================================ +// Graph Exec Handles +// ============================================================================ + +namespace { +struct GraphExecBox { + CUgraphExec resource; +}; +} // namespace + +GraphExecHandle create_graph_exec_handle(CUgraphExec graph_exec) { + auto box = std::shared_ptr( + new GraphExecBox{graph_exec}, + [](const GraphExecBox* b) { + GILReleaseGuard gil; + p_cuGraphExecDestroy(b->resource); + delete b; + } + ); + return GraphExecHandle(box, &box->resource); +} + namespace { struct GraphNodeBox { mutable CUgraphNode resource; diff --git a/cuda_core/cuda/core/_cpp/resource_handles.hpp b/cuda_core/cuda/core/_cpp/resource_handles.hpp index 32a88f0b3cd..b7ad736e6f4 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.hpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.hpp @@ -100,6 +100,7 @@ extern decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel; // Graph extern decltype(&cuGraphDestroy) p_cuGraphDestroy; +extern decltype(&cuGraphExecDestroy) p_cuGraphExecDestroy; // Linker extern decltype(&cuLinkDestroy) p_cuLinkDestroy; @@ -164,6 +165,7 @@ using MemoryPoolHandle = std::shared_ptr; using LibraryHandle = std::shared_ptr; using KernelHandle = std::shared_ptr; using GraphHandle = std::shared_ptr; +using GraphExecHandle = std::shared_ptr; using GraphNodeHandle = std::shared_ptr; using GraphicsResourceHandle = std::shared_ptr; using NvrtcProgramHandle = std::shared_ptr; @@ -441,6 +443,14 @@ GraphHandle create_graph_handle(CUgraph graph); // but h_parent will be prevented from destruction while this handle exists. GraphHandle create_graph_handle_ref(CUgraph graph, const GraphHandle& h_parent); +// ============================================================================ +// Graph exec handle functions +// ============================================================================ + +// Wrap an externally-created CUgraphExec with RAII cleanup. +// When the last reference is released, cuGraphExecDestroy is called automatically. +GraphExecHandle create_graph_exec_handle(CUgraphExec graph_exec); + // ============================================================================ // Graph node handle functions // ============================================================================ @@ -571,6 +581,10 @@ inline CUgraph as_cu(const GraphHandle& h) noexcept { return h ? *h : nullptr; } +inline CUgraphExec as_cu(const GraphExecHandle& h) noexcept { + return h ? *h : nullptr; +} + inline CUgraphNode as_cu(const GraphNodeHandle& h) noexcept { return h ? *h : nullptr; } @@ -633,6 +647,10 @@ inline std::intptr_t as_intptr(const GraphHandle& h) noexcept { return reinterpret_cast(as_cu(h)); } +inline std::intptr_t as_intptr(const GraphExecHandle& h) noexcept { + return reinterpret_cast(as_cu(h)); +} + inline std::intptr_t as_intptr(const GraphNodeHandle& h) noexcept { return reinterpret_cast(as_cu(h)); } @@ -743,6 +761,10 @@ inline PyObject* as_py(const GraphHandle& h) noexcept { return detail::make_py("cuda.bindings.driver", "CUgraph", as_intptr(h)); } +inline PyObject* as_py(const GraphExecHandle& h) noexcept { + return detail::make_py("cuda.bindings.driver", "CUgraphExec", as_intptr(h)); +} + inline PyObject* as_py(const GraphNodeHandle& h) noexcept { if (!as_intptr(h)) { Py_RETURN_NONE; diff --git a/cuda_core/cuda/core/_device.pyx b/cuda_core/cuda/core/_device.pyx index 451ca25ddaa..a90f246d381 100644 --- a/cuda_core/cuda/core/_device.pyx +++ b/cuda_core/cuda/core/_device.pyx @@ -1454,7 +1454,7 @@ class Device: from cuda.core.graph._graph_builder import GraphBuilder self._check_context_initialized() - return GraphBuilder._init(stream=self.create_stream(), is_stream_owner=True) + return GraphBuilder._init(self.create_stream()) cdef inline int Device_ensure_cuda_initialized() except? -1: diff --git a/cuda_core/cuda/core/_resource_handles.pxd b/cuda_core/cuda/core/_resource_handles.pxd index 0ed3d6e5942..3f82eb22c15 100644 --- a/cuda_core/cuda/core/_resource_handles.pxd +++ b/cuda_core/cuda/core/_resource_handles.pxd @@ -28,6 +28,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": ctypedef shared_ptr[const cydriver.CUlibrary] LibraryHandle ctypedef shared_ptr[const cydriver.CUkernel] KernelHandle ctypedef shared_ptr[const cydriver.CUgraph] GraphHandle + ctypedef shared_ptr[const cydriver.CUgraphExec] GraphExecHandle ctypedef shared_ptr[const cydriver.CUgraphNode] GraphNodeHandle ctypedef shared_ptr[const cydriver.CUgraphicsResource] GraphicsResourceHandle ctypedef shared_ptr[const cynvrtc.nvrtcProgram] NvrtcProgramHandle @@ -54,6 +55,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": cydriver.CUlibrary as_cu(LibraryHandle h) noexcept nogil cydriver.CUkernel as_cu(KernelHandle h) noexcept nogil cydriver.CUgraph as_cu(GraphHandle h) noexcept nogil + cydriver.CUgraphExec as_cu(GraphExecHandle h) noexcept nogil cydriver.CUgraphNode as_cu(GraphNodeHandle h) noexcept nogil cydriver.CUgraphicsResource as_cu(GraphicsResourceHandle h) noexcept nogil cynvrtc.nvrtcProgram as_cu(NvrtcProgramHandle h) noexcept nogil @@ -71,6 +73,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": intptr_t as_intptr(LibraryHandle h) noexcept nogil intptr_t as_intptr(KernelHandle h) noexcept nogil intptr_t as_intptr(GraphHandle h) noexcept nogil + intptr_t as_intptr(GraphExecHandle h) noexcept nogil intptr_t as_intptr(GraphNodeHandle h) noexcept nogil intptr_t as_intptr(GraphicsResourceHandle h) noexcept nogil intptr_t as_intptr(NvrtcProgramHandle h) noexcept nogil @@ -89,6 +92,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": object as_py(LibraryHandle h) object as_py(KernelHandle h) object as_py(GraphHandle h) + object as_py(GraphExecHandle h) object as_py(GraphNodeHandle h) object as_py(GraphicsResourceHandle h) object as_py(NvrtcProgramHandle h) @@ -195,6 +199,9 @@ cdef LibraryHandle get_kernel_library(const KernelHandle& h) noexcept nogil cdef GraphHandle create_graph_handle(cydriver.CUgraph graph) except+ nogil cdef GraphHandle create_graph_handle_ref(cydriver.CUgraph graph, const GraphHandle& h_parent) except+ nogil +# Graph exec handles +cdef GraphExecHandle create_graph_exec_handle(cydriver.CUgraphExec graph_exec) except+ nogil + # Graph node handles cdef GraphNodeHandle create_graph_node_handle(cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil cdef GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept nogil diff --git a/cuda_core/cuda/core/_resource_handles.pyi b/cuda_core/cuda/core/_resource_handles.pyi index 490073c9fd1..8b2d9e75e18 100644 --- a/cuda_core/cuda/core/_resource_handles.pyi +++ b/cuda_core/cuda/core/_resource_handles.pyi @@ -13,6 +13,7 @@ DevicePtrHandle = shared_ptr LibraryHandle = shared_ptr KernelHandle = shared_ptr GraphHandle = shared_ptr +GraphExecHandle = shared_ptr GraphNodeHandle = shared_ptr GraphicsResourceHandle = shared_ptr NvrtcProgramHandle = shared_ptr diff --git a/cuda_core/cuda/core/_resource_handles.pyx b/cuda_core/cuda/core/_resource_handles.pyx index c87956f0c68..ba7d078a3fa 100644 --- a/cuda_core/cuda/core/_resource_handles.pyx +++ b/cuda_core/cuda/core/_resource_handles.pyx @@ -151,6 +151,10 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": GraphHandle create_graph_handle_ref "cuda_core::create_graph_handle_ref" ( cydriver.CUgraph graph, const GraphHandle& h_parent) except+ nogil + # Graph exec handles + GraphExecHandle create_graph_exec_handle "cuda_core::create_graph_exec_handle" ( + cydriver.CUgraphExec graph_exec) except+ nogil + # Graph node handles GraphNodeHandle create_graph_node_handle "cuda_core::create_graph_node_handle" ( cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil @@ -276,6 +280,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": # Graph void* p_cuGraphDestroy "reinterpret_cast(cuda_core::p_cuGraphDestroy)" + void* p_cuGraphExecDestroy "reinterpret_cast(cuda_core::p_cuGraphExecDestroy)" # Linker void* p_cuLinkDestroy "reinterpret_cast(cuda_core::p_cuLinkDestroy)" @@ -324,7 +329,7 @@ cdef void _init_driver_fn_pointers() noexcept: global p_cuMemFreeAsync, p_cuMemFree, p_cuMemFreeHost global p_cuMemPoolImportPointer global p_cuLibraryLoadFromFile, p_cuLibraryLoadData, p_cuLibraryUnload, p_cuLibraryGetKernel - global p_cuGraphDestroy + global p_cuGraphDestroy, p_cuGraphExecDestroy global p_cuLinkDestroy global p_cuGraphicsUnmapResources, p_cuGraphicsUnregisterResource global p_cuDevSmResourceSplit @@ -380,6 +385,7 @@ cdef void _init_driver_fn_pointers() noexcept: # Graph p_cuGraphDestroy = _get_driver_fn("cuGraphDestroy") + p_cuGraphExecDestroy = _get_driver_fn("cuGraphExecDestroy") # Linker p_cuLinkDestroy = _get_driver_fn("cuLinkDestroy") diff --git a/cuda_core/cuda/core/_stream.pyx b/cuda_core/cuda/core/_stream.pyx index 37088e9ace5..5212ec5c7de 100644 --- a/cuda_core/cuda/core/_stream.pyx +++ b/cuda_core/cuda/core/_stream.pyx @@ -414,7 +414,7 @@ cdef class Stream: """ from cuda.core.graph._graph_builder import GraphBuilder - return GraphBuilder._init(stream=self, is_stream_owner=False) + return GraphBuilder._init(self) LEGACY_DEFAULT_STREAM: Stream = Stream._legacy_default() diff --git a/cuda_core/cuda/core/graph/_graph_builder.pxd b/cuda_core/cuda/core/graph/_graph_builder.pxd new file mode 100644 index 00000000000..660ebe8ec7d --- /dev/null +++ b/cuda_core/cuda/core/graph/_graph_builder.pxd @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from cuda.bindings cimport cydriver + +from cuda.core._resource_handles cimport GraphExecHandle, GraphHandle, StreamHandle +from cuda.core._stream cimport Stream + + +cdef class GraphBuilder: + cdef: + GraphHandle _h_graph + StreamHandle _h_stream + int _kind + int _state + Stream _stream # cached to avoid reconstruction from _h_stream handle + object __weakref__ + + +cdef class Graph: + cdef: + GraphExecHandle _h_graph_exec + object __weakref__ + + @staticmethod + cdef Graph _init(cydriver.CUgraphExec graph_exec) diff --git a/cuda_core/cuda/core/graph/_graph_builder.pyi b/cuda_core/cuda/core/graph/_graph_builder.pyi index 6dbca20e60b..af1748ad86c 100644 --- a/cuda_core/cuda/core/graph/_graph_builder.pyi +++ b/cuda_core/cuda/core/graph/_graph_builder.pyi @@ -8,6 +8,8 @@ from cuda.core._stream import Stream from cuda.core._utils.cuda_utils import driver from cuda.core.graph._graph_definition import GraphCondition, GraphDefinition +_BuilderKind = int +_CaptureState = int @dataclass class GraphDebugPrintOptions: @@ -106,23 +108,19 @@ class GraphBuilder: """ - class _MembersNeededForFinalize: - __slots__ = ('conditional_graph', 'graph', 'is_join_required', 'is_stream_owner', 'stream') - - def __init__(self, graph_builder_obj: GraphBuilder, stream_obj: Stream | None, is_stream_owner: bool, conditional_graph, is_join_required: bool) -> None: - ... - - def close(self) -> None: - ... - __slots__ = ('__weakref__', '_building_ended', '_mnff') + def __init__(self): + ... - def __init__(self) -> None: + def __dealloc__(self): ... - @classmethod - def _init(cls, stream: Stream | None, is_stream_owner: bool, conditional_graph: object=None, is_join_required: bool=False) -> GraphBuilder: + @staticmethod + def _init(stream: Stream): ... + def close(self): + """Destroy the graph builder.""" + @property def stream(self) -> Stream: """Returns the stream associated with the graph builder.""" @@ -155,7 +153,7 @@ class GraphBuilder: def end_building(self) -> GraphBuilder: """Ends the building process.""" - def complete(self, options: GraphCompleteOptions | None=None) -> 'Graph': + def complete(self, options: GraphCompleteOptions | None=None) -> Graph: """Completes the graph builder and returns the built :obj:`~graph.Graph` object. Parameters @@ -245,9 +243,6 @@ class GraphBuilder: A condition variable for controlling conditional execution. """ - def _cond_with_params(self, node_params: object) -> tuple[GraphBuilder, ...]: - ... - def if_then(self, condition: GraphCondition) -> GraphBuilder: """Adds an if condition branch and returns a new graph builder for it. @@ -335,15 +330,7 @@ class GraphBuilder: """ - def close(self) -> None: - """Destroy the graph builder. - - Closes the associated stream if we own it. Borrowed stream - object will instead have their references released. - - """ - - def embed(self, child: GraphBuilder) -> None: + def embed(self, child: GraphBuilder): """Embed a previously-built :obj:`~graph.GraphBuilder` as a child node. Parameters @@ -392,21 +379,7 @@ class Graph: """ - class _MembersNeededForFinalize: - __slots__ = 'graph' - - def __init__(self, graph_obj: Graph, graph: driver.CUgraphExec) -> None: - ... - - def close(self) -> None: - ... - __slots__ = ('__weakref__', '_mnff') - - def __init__(self) -> None: - ... - - @classmethod - def _init(cls, graph: driver.CUgraphExec) -> Graph: + def __init__(self): ... def close(self) -> None: @@ -457,5 +430,5 @@ class Graph: """ __all__ = ['Graph', 'GraphBuilder', 'GraphCompleteOptions', 'GraphDebugPrintOptions'] -def _instantiate_graph(h_graph, options: GraphCompleteOptions | None=None) -> 'Graph': +def _instantiate_graph(h_graph, options: GraphCompleteOptions | None=None) -> Graph: ... \ No newline at end of file diff --git a/cuda_core/cuda/core/graph/_graph_builder.pyx b/cuda_core/cuda/core/graph/_graph_builder.pyx index 3dcc9431759..cea4a38ba24 100644 --- a/cuda_core/cuda/core/graph/_graph_builder.pyx +++ b/cuda_core/cuda/core/graph/_graph_builder.pyx @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -import weakref from dataclasses import dataclass from typing import TYPE_CHECKING @@ -12,7 +11,11 @@ from cuda.bindings cimport cydriver from cuda.core.graph._graph_definition cimport GraphCondition from cuda.core.graph._utils cimport _attach_host_callback_to_graph -from cuda.core._resource_handles cimport as_cu +from cuda.core._resource_handles cimport ( + GraphHandle, + as_cu, as_py, + create_graph_exec_handle, create_graph_handle, create_graph_handle_ref, +) from cuda.core._stream cimport Stream from cuda.core._utils.cuda_utils cimport HANDLE_RETURN from cuda.core._utils.version cimport cy_binding_version, cy_driver_version @@ -151,7 +154,8 @@ class GraphCompleteOptions: use_node_priority: bool = False -def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> "Graph": +def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> Graph: + cdef cydriver.CUgraphExec c_exec params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS() if options: flags = 0 @@ -166,7 +170,10 @@ def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY params.flags = flags - graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(h_graph, params))) + py_exec = handle_return(driver.cuGraphInstantiateWithParams(h_graph, params)) + # Check result_out before wrapping the exec: on a non-SUCCESS result the exec + # may be invalid, and Graph._init's RAII deleter would call cuGraphExecDestroy + # on it during the exception unwind below. if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR: raise RuntimeError( "Instantiation failed for an unexpected reason which is described in the return value of the function." @@ -186,10 +193,46 @@ def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> raise RuntimeError("One or more conditional handles are not associated with conditional builders.") elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS: raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}") - return graph + + c_exec = int(py_exec) + return Graph._init(c_exec) -class GraphBuilder: +# Distinguishes the three kinds of GraphBuilder, which differ in how they +# begin/end stream capture and whether they own the resulting CUgraph. +# Each kind progresses through _CaptureState as follows: +# +# PRIMARY: NOT_STARTED -> CAPTURING -> ENDED +# FORKED: CAPTURING (never transitions; joined and closed) +# CONDITIONAL_BODY: NOT_STARTED -> CAPTURING -> ENDED +# +cdef enum _BuilderKind: + # PRIMARY: The top-level builder created by Device or Stream. Owns the + # captured CUgraph via an owning GraphHandle. Progresses through all three + # capture states; responsible for ending capture if destroyed early. + PRIMARY = 0 + # FORKED: Created by split(). Captures on a private stream forked from the + # primary. Starts in CAPTURING state and never transitions; the user joins + # it back to the primary via join(), which closes the builder. Must NOT + # call cuStreamEndCapture (the driver requires all forked streams to be + # joined first). + FORKED = 1 + # CONDITIONAL_BODY: Created by if_then/if_else/switch/while_loop. Captures + # into a non-owned body graph via cuStreamBeginCaptureToGraph. The body + # graph's lifetime is tied to a parent graph. Progresses through all three + # capture states like PRIMARY. + CONDITIONAL_BODY = 2 + + +# Tracks the capture lifecycle of a GraphBuilder. +cdef enum _CaptureState: + CAPTURE_NOT_STARTED = 0 + CAPTURING = 1 + CAPTURE_ENDED = 2 # Finished, valid handle + CLOSED = 3 # No valid handle + + +cdef class GraphBuilder: """A graph under construction by stream capture. A graph groups a set of CUDA kernels and other CUDA operations together and executes @@ -202,63 +245,43 @@ class GraphBuilder: """ - class _MembersNeededForFinalize: - __slots__ = ("conditional_graph", "graph", "is_join_required", "is_stream_owner", "stream") - - def __init__(self, graph_builder_obj: GraphBuilder, stream_obj: Stream | None, is_stream_owner: bool, conditional_graph, is_join_required: bool) -> None: - self.stream = stream_obj - self.is_stream_owner = is_stream_owner - self.graph = None - self.conditional_graph = conditional_graph - self.is_join_required = is_join_required - weakref.finalize(graph_builder_obj, self.close) - - def close(self) -> None: - if self.stream: - if not self.is_join_required: - capture_status = handle_return(driver.cuStreamGetCaptureInfo(self.stream.handle))[0] - if capture_status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: - # Note how this condition only occures for the primary graph builder - # This is because calling cuStreamEndCapture streams that were split off of the primary - # would error out with CUDA_ERROR_STREAM_CAPTURE_UNJOINED. - # Therefore, it is currently a requirement that users join all split graph builders - # before a graph builder can be clearly destroyed. - handle_return(driver.cuStreamEndCapture(self.stream.handle)) - if self.is_stream_owner: - self.stream.close() - self.stream = None - if self.graph: - handle_return(driver.cuGraphDestroy(self.graph)) - self.graph = None - self.conditional_graph = None - - __slots__ = ("__weakref__", "_building_ended", "_mnff") - - def __init__(self) -> None: + def __init__(self): raise NotImplementedError( - "directly creating a Graph object can be ambiguous. Please either " + "directly creating a GraphBuilder object can be ambiguous. Please either " "call Device.create_graph_builder() or stream.create_graph_builder()" ) - @classmethod - def _init(cls, stream: Stream | None, is_stream_owner: bool, conditional_graph: object = None, is_join_required: bool = False) -> GraphBuilder: - self = cls.__new__(cls) - self._mnff = GraphBuilder._MembersNeededForFinalize( - self, stream, is_stream_owner, conditional_graph, is_join_required - ) + def __dealloc__(self): + GB_end_capture_if_needed(self, False) - self._building_ended = False + @staticmethod + def _init(Stream stream): + cdef GraphBuilder self = GraphBuilder.__new__(GraphBuilder) + # _h_graph set by begin_building + self._h_stream = stream._h_stream + self._kind = PRIMARY + self._state = CAPTURE_NOT_STARTED + self._stream = stream return self + def close(self): + """Destroy the graph builder.""" + with nogil: + GB_end_capture_if_needed(self, True) + self._h_graph.reset() + self._h_stream.reset() + self._state = CLOSED + self._stream = None + @property def stream(self) -> Stream: """Returns the stream associated with the graph builder.""" - return self._mnff.stream + return self._stream @property def is_join_required(self) -> bool: """Returns True if this graph builder must be joined before building is ended.""" - return self._mnff.is_join_required + return self._kind == FORKED def begin_building(self, mode: str | None = "relaxed") -> GraphBuilder: """Begins the building process. @@ -276,64 +299,79 @@ class GraphBuilder: Default set to use relaxed. """ - if self._building_ended: - raise RuntimeError("Cannot resume building after building has ended.") - if mode not in ("global", "thread_local", "relaxed"): - raise ValueError(f"Unsupported build mode: {mode}") + GB_check_open(self) + if self._state != CAPTURE_NOT_STARTED: + if self._state == CAPTURING: + raise RuntimeError("Graph builder is already building.") + else: + raise RuntimeError("Cannot resume building after building has ended.") + cdef cydriver.CUstreamCaptureMode c_mode if mode == "global": - capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_GLOBAL + c_mode = cydriver.CU_STREAM_CAPTURE_MODE_GLOBAL elif mode == "thread_local": - capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_THREAD_LOCAL + c_mode = cydriver.CU_STREAM_CAPTURE_MODE_THREAD_LOCAL elif mode == "relaxed": - capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_RELAXED + c_mode = cydriver.CU_STREAM_CAPTURE_MODE_RELAXED else: raise ValueError(f"Unsupported build mode: {mode}") - if self._mnff.conditional_graph: - handle_return( - driver.cuStreamBeginCaptureToGraph( - self._mnff.stream.handle, - self._mnff.conditional_graph, - None, # dependencies - None, # dependencyData - 0, # numDependencies - capture_mode, - ) - ) + cdef cydriver.CUstream c_stream = as_cu(self._h_stream) + cdef cydriver.CUgraph c_graph + cdef cydriver.CUstreamCaptureStatus c_status + if self._kind == CONDITIONAL_BODY: + c_graph = as_cu(self._h_graph) + with nogil: + HANDLE_RETURN(cydriver.cuStreamBeginCaptureToGraph( + c_stream, c_graph, NULL, NULL, 0, c_mode)) + self._state = CAPTURING else: - handle_return(driver.cuStreamBeginCapture(self._mnff.stream.handle, capture_mode)) + with nogil: + HANDLE_RETURN(cydriver.cuStreamBeginCapture(c_stream, c_mode)) + # Capture is active now; set CAPTURING before the calls below so a + # failure in _get_capture_info/create_graph_handle still lets + # cleanup end the capture rather than leaving the stream poisoned. + self._state = CAPTURING + with nogil: + # The driver rejects a NULL captureStatus_out, so pass a + # stack-local even though we only want the graph handle. + _get_capture_info(c_stream, &c_status, &c_graph) + self._h_graph = create_graph_handle(c_graph) return self @property def is_building(self) -> bool: """Returns True if the graph builder is currently building.""" - capture_status = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle))[0] - if capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: + GB_check_open(self) + cdef cydriver.CUstream c_stream = as_cu(self._h_stream) + cdef cydriver.CUstreamCaptureStatus status + with nogil: + _get_capture_info(c_stream, &status, NULL) + if status == cydriver.CU_STREAM_CAPTURE_STATUS_NONE: return False - elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: + elif status == cydriver.CU_STREAM_CAPTURE_STATUS_ACTIVE: return True - elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_INVALIDATED: + elif status == cydriver.CU_STREAM_CAPTURE_STATUS_INVALIDATED: raise RuntimeError( "Build process encountered an error and has been invalidated. Build process must now be ended." ) else: - raise NotImplementedError(f"Unsupported capture status type received: {capture_status}") + raise NotImplementedError(f"Unsupported capture status type received: {status}") def end_building(self) -> GraphBuilder: """Ends the building process.""" + GB_check_open(self) if not self.is_building: raise RuntimeError("Graph builder is not building.") - if self._mnff.conditional_graph: - self._mnff.conditional_graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) - else: - self._mnff.graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) + cdef cydriver.CUstream c_stream = as_cu(self._h_stream) + with nogil: + HANDLE_RETURN(cydriver.cuStreamEndCapture(c_stream, NULL)) # TODO: Resolving https://github.com/NVIDIA/cuda-python/issues/617 would allow us to # resume the build process after the first call to end_building() - self._building_ended = True + self._state = CAPTURE_ENDED return self - def complete(self, options: GraphCompleteOptions | None = None) -> "Graph": + def complete(self, options: GraphCompleteOptions | None = None) -> Graph: """Completes the graph builder and returns the built :obj:`~graph.Graph` object. Parameters @@ -347,10 +385,11 @@ class GraphBuilder: The newly built graph. """ - if not self._building_ended: + GB_check_open(self) + if self._state != CAPTURE_ENDED: raise RuntimeError("Graph has not finished building.") - return _instantiate_graph(self._mnff.graph, options) + return _instantiate_graph(as_py(self._h_graph), options) def debug_dot_print(self, path: str, options: GraphDebugPrintOptions | None = None) -> None: """Generates a DOT debug file for the graph builder. @@ -363,12 +402,15 @@ class GraphBuilder: Customizable dataclass for the debug print options. """ - if not self._building_ended: + GB_check_open(self) + if self._state != CAPTURE_ENDED: raise RuntimeError("Graph has not finished building.") - flags = options._to_flags() if options else 0 - cdef bytes path_bytes = path.encode('utf-8') - cdef const char* c_path = path_bytes - handle_return(driver.cuGraphDebugDotPrint(self._mnff.graph, c_path, flags)) + cdef unsigned int c_flags = options._to_flags() if options else 0 + cdef cydriver.CUgraph c_graph = as_cu(self._h_graph) + cdef bytes b_path = path.encode('utf-8') + cdef const char* c_path = b_path + with nogil: + HANDLE_RETURN(cydriver.cuGraphDebugDotPrint(c_graph, c_path, c_flags)) def split(self, count: int) -> tuple[GraphBuilder, ...]: """Splits the original graph builder into multiple graph builders. @@ -390,15 +432,16 @@ class GraphBuilder: """ if count < 2: raise ValueError(f"Invalid split count: expecting >= 2, got {count}") + GB_check_open(self) + if self._state != CAPTURING: + raise RuntimeError("Graph builder must be building before it can be split.") - event = self._mnff.stream.record() + event = self._stream.record() result = [self] for i in range(count - 1): - stream = self._mnff.stream.device.create_stream() + stream = self._stream.device.create_stream() stream.wait(event) - result.append( - GraphBuilder._init(stream=stream, is_stream_owner=True, conditional_graph=None, is_join_required=True) - ) + result.append(GB_init_forked(stream, self._h_graph)) event.close() return tuple(result) @@ -443,10 +486,11 @@ class GraphBuilder: def __cuda_stream__(self) -> tuple[int, int]: """Return an instance of a __cuda_stream__ protocol.""" + GB_check_open(self) return self.stream.__cuda_stream__() def _get_conditional_context(self) -> driver.CUcontext: - return self._mnff.stream.context.handle + return self._stream.context.handle def create_condition(self, default_value: int | None = None) -> GraphCondition: """Create a condition variable for use with conditional nodes. @@ -467,6 +511,7 @@ class GraphBuilder: GraphCondition A condition variable for controlling conditional execution. """ + GB_check_open(self) if cy_driver_version() < (12, 3, 0): raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional handles") if cy_binding_version() < (12, 3, 0): @@ -477,7 +522,7 @@ class GraphBuilder: default_value = 0 flags = 0 - status, _, graph, *_, _ = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle)) + status, _, graph, *_, _ = handle_return(driver.cuStreamGetCaptureInfo(self._stream.handle)) if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: raise RuntimeError("Cannot create a condition when graph is not being built") @@ -486,42 +531,6 @@ class GraphBuilder: ) return GraphCondition._from_handle(int(raw_handle)) - def _cond_with_params(self, node_params: object) -> tuple[GraphBuilder, ...]: - # Get current capture info to ensure we're in a valid state - status, _, graph, *deps_info, num_dependencies = handle_return( - driver.cuStreamGetCaptureInfo(self._mnff.stream.handle) - ) - if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: - raise RuntimeError("Cannot add conditional node when not actively capturing") - - # Add the conditional node to the graph - deps_info_update = [ - [handle_return(driver.cuGraphAddNode(graph, *deps_info, num_dependencies, node_params))] - ] + [None] * (len(deps_info) - 1) - - # Update the stream's capture dependencies - handle_return( - driver.cuStreamUpdateCaptureDependencies( - self._mnff.stream.handle, - *deps_info_update, # dependencies, edgeData - 1, # numDependencies - driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, - ) - ) - - # Create new graph builders for each condition - return tuple( - [ - GraphBuilder._init( - stream=self._mnff.stream.device.create_stream(), - is_stream_owner=True, - conditional_graph=node_params.conditional.phGraph_out[i], - is_join_required=False, - ) - for i in range(node_params.conditional.size) - ] - ) - def if_then(self, condition: GraphCondition) -> GraphBuilder: """Adds an if condition branch and returns a new graph builder for it. @@ -542,6 +551,7 @@ class GraphBuilder: The newly created conditional graph builder. """ + GB_check_open(self) if cy_driver_version() < (12, 3, 0): raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if") if cy_binding_version() < (12, 3, 0): @@ -556,7 +566,7 @@ class GraphBuilder: node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF node_params.conditional.size = 1 node_params.conditional.ctx = self._get_conditional_context() - return self._cond_with_params(node_params)[0] + return GB_cond_with_params(self, node_params)[0] def if_else(self, condition: GraphCondition) -> tuple[GraphBuilder, GraphBuilder]: """Adds an if-else condition branch and returns new graph builders for both branches. @@ -578,6 +588,7 @@ class GraphBuilder: A tuple of two new graph builders, one for the if branch and one for the else branch. """ + GB_check_open(self) if cy_driver_version() < (12, 8, 0): raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if-else") if cy_binding_version() < (12, 8, 0): @@ -592,7 +603,7 @@ class GraphBuilder: node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF node_params.conditional.size = 2 node_params.conditional.ctx = self._get_conditional_context() - return self._cond_with_params(node_params) + return GB_cond_with_params(self, node_params) def switch(self, condition: GraphCondition, count: int) -> tuple[GraphBuilder, ...]: """Adds a switch condition branch and returns new graph builders for all cases. @@ -617,6 +628,7 @@ class GraphBuilder: A tuple of new graph builders, one for each branch. """ + GB_check_open(self) if cy_driver_version() < (12, 8, 0): raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional switch") if cy_binding_version() < (12, 8, 0): @@ -631,7 +643,7 @@ class GraphBuilder: node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_SWITCH node_params.conditional.size = count node_params.conditional.ctx = self._get_conditional_context() - return self._cond_with_params(node_params) + return GB_cond_with_params(self, node_params) def while_loop(self, condition: GraphCondition) -> GraphBuilder: """Adds a while loop and returns a new graph builder for it. @@ -653,6 +665,7 @@ class GraphBuilder: The newly created while loop graph builder. """ + GB_check_open(self) if cy_driver_version() < (12, 3, 0): raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional while loop") if cy_binding_version() < (12, 3, 0): @@ -667,18 +680,9 @@ class GraphBuilder: node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE node_params.conditional.size = 1 node_params.conditional.ctx = self._get_conditional_context() - return self._cond_with_params(node_params)[0] - - def close(self) -> None: - """Destroy the graph builder. - - Closes the associated stream if we own it. Borrowed stream - object will instead have their references released. - - """ - self._mnff.close() + return GB_cond_with_params(self, node_params)[0] - def embed(self, child: GraphBuilder) -> None: + def embed(self, GraphBuilder child): """Embed a previously-built :obj:`~graph.GraphBuilder` as a child node. Parameters @@ -686,13 +690,14 @@ class GraphBuilder: child : :obj:`~graph.GraphBuilder` The child graph builder. Must have finished building. """ - if not child._building_ended: + GB_check_open(self) + if child._state != CAPTURE_ENDED: raise ValueError("Child graph has not finished building.") if not self.is_building: raise ValueError("Parent graph is not being built.") - stream_handle = self._mnff.stream.handle + stream_handle = self._stream.handle _, _, graph_out, *deps_info_out, num_dependencies_out = handle_return( driver.cuStreamGetCaptureInfo(stream_handle) ) @@ -704,7 +709,7 @@ class GraphBuilder: [ handle_return( driver.cuGraphAddChildGraphNode( - graph_out, *deps_info_trimmed, num_dependencies_out, child._mnff.graph + graph_out, *deps_info_trimmed, num_dependencies_out, as_py(child._h_graph) ) ) ] @@ -746,18 +751,14 @@ class GraphBuilder: pointer (caller manages lifetime). If bytes-like, the data is copied and its lifetime is tied to the graph. """ - cdef Stream stream = self._mnff.stream + GB_check_open(self) + cdef Stream stream = self._stream cdef cydriver.CUstream c_stream = as_cu(stream._h_stream) cdef cydriver.CUstreamCaptureStatus capture_status cdef cydriver.CUgraph c_graph = NULL with nogil: - IF CUDA_CORE_BUILD_MAJOR >= 13: - HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( - c_stream, &capture_status, NULL, &c_graph, NULL, NULL, NULL)) - ELSE: - HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( - c_stream, &capture_status, NULL, &c_graph, NULL, NULL)) + _get_capture_info(c_stream, &capture_status, &c_graph) if capture_status != cydriver.CU_STREAM_CAPTURE_STATUS_ACTIVE: raise RuntimeError("Cannot add callback when graph is not being built") @@ -770,7 +771,114 @@ class GraphBuilder: HANDLE_RETURN(cydriver.cuLaunchHostFunc(c_stream, c_fn, c_user_data)) -class Graph: +cdef inline int GB_check_open(GraphBuilder gb) except -1: + """Reject operations on a builder that has been closed. + + A CLOSED builder has reset its stream and graph handles, so any method + that dereferences them would read a null handle (or, for the cached + Stream, a None typed as cdef Stream). Guarding here yields a clear error + instead. + """ + if gb._state == CLOSED: + raise RuntimeError("Graph builder has been closed.") + return 0 + + +cdef inline int GB_end_capture_if_needed(GraphBuilder gb, bint check_status) except -1 nogil: + """End an in-progress capture if this builder owns it. + + Only a CAPTURING PRIMARY or CONDITIONAL_BODY builder owns the live + capture. A FORKED builder must not call cuStreamEndCapture: the driver + requires forked streams to be joined first. + + A NULL phGraph ends the capture and discards the graph; the driver + guards every write to phGraph (cuapiStreamEndCaptureCommon). + + check_status=True checks the driver return (close()); False ignores it + (__dealloc__). + """ + if gb._h_stream and gb._state == CAPTURING and gb._kind != FORKED: + if check_status: + HANDLE_RETURN(cydriver.cuStreamEndCapture(as_cu(gb._h_stream), NULL)) + else: + cydriver.cuStreamEndCapture(as_cu(gb._h_stream), NULL) + return 0 + + +cdef inline GraphBuilder GB_init_forked(Stream stream, GraphHandle h_primary_graph): + cdef GraphBuilder gb = GraphBuilder.__new__(GraphBuilder) + # A FORKED builder captures into the primary's CUgraph. It holds the + # primary's GraphHandle so conditional bodies created on it (via + # GB_init_conditional -> create_graph_handle_ref(cond_graph, parent._h_graph)) + # have a valid parent handle to pin. + gb._h_graph = h_primary_graph + gb._h_stream = stream._h_stream + gb._kind = FORKED + gb._state = CAPTURING + gb._stream = stream + return gb + + +cdef inline GraphBuilder GB_init_conditional(Stream stream, cydriver.CUgraph cond_graph, GraphBuilder parent): + cdef GraphBuilder gb = GraphBuilder.__new__(GraphBuilder) + gb._h_graph = create_graph_handle_ref(cond_graph, parent._h_graph) + gb._h_stream = stream._h_stream + gb._kind = CONDITIONAL_BODY + gb._state = CAPTURE_NOT_STARTED + gb._stream = stream + return gb + + +cdef inline int _get_capture_info( + cydriver.CUstream stream, + cydriver.CUstreamCaptureStatus* status, + cydriver.CUgraph* graph) except?-1 nogil: + """Thin wrapper around ``cuStreamGetCaptureInfo`` that papers over the + CUDA 12 vs 13 signature change. + + ``status`` must be non-NULL: the driver rejects ``captureStatus_out=NULL`` + with ``CUDA_ERROR_INVALID_VALUE``. ``graph`` may be NULL when the caller + does not need the graph handle. + """ + IF CUDA_CORE_BUILD_MAJOR >= 13: + return HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( + stream, status, NULL, graph, NULL, NULL, NULL)) + ELSE: + return HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( + stream, status, NULL, graph, NULL, NULL)) + + +cdef inline tuple GB_cond_with_params(GraphBuilder gb, node_params): + status, _, graph, *deps_info, num_dependencies = handle_return( + driver.cuStreamGetCaptureInfo(gb._stream.handle) + ) + if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: + raise RuntimeError("Cannot add conditional node when not actively capturing") + + deps_info_update = [ + [handle_return(driver.cuGraphAddNode(graph, *deps_info, num_dependencies, node_params))] + ] + [None] * (len(deps_info) - 1) + + handle_return( + driver.cuStreamUpdateCaptureDependencies( + gb._stream.handle, + *deps_info_update, # dependencies, edgeData + 1, # numDependencies + driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, + ) + ) + + return tuple( + GB_init_conditional( + gb._stream.device.create_stream(), + int(node_params.conditional.phGraph_out[i]), + gb, + ) + for i in range(node_params.conditional.size) + ) + + +cdef class Graph: """An executable graph. A graph groups a set of CUDA kernels and other CUDA operations together and executes @@ -781,32 +889,18 @@ class Graph: """ - class _MembersNeededForFinalize: - __slots__ = "graph" - - def __init__(self, graph_obj: Graph, graph: driver.CUgraphExec) -> None: - self.graph = graph - weakref.finalize(graph_obj, self.close) - - def close(self) -> None: - if self.graph: - handle_return(driver.cuGraphExecDestroy(self.graph)) - self.graph = None - - __slots__ = ("__weakref__", "_mnff") - - def __init__(self) -> None: + def __init__(self): raise RuntimeError("directly constructing a Graph instance is not supported") - @classmethod - def _init(cls, graph: driver.CUgraphExec) -> Graph: - self = cls.__new__(cls) - self._mnff = Graph._MembersNeededForFinalize(self, graph) + @staticmethod + cdef Graph _init(cydriver.CUgraphExec graph_exec): + cdef Graph self = Graph.__new__(Graph) + self._h_graph_exec = create_graph_exec_handle(graph_exec) return self def close(self) -> None: """Destroy the graph.""" - self._mnff.close() + self._h_graph_exec.reset() @property def handle(self) -> driver.CUgraphExec: @@ -818,7 +912,7 @@ class Graph: handle, call ``int()`` on the returned object. """ - return self._mnff.graph + return as_py(self._h_graph_exec) def update(self, source: "GraphBuilder | GraphDefinition") -> None: """Update the graph using a new graph definition. @@ -835,12 +929,14 @@ class Graph: from cuda.core.graph import GraphDefinition cdef cydriver.CUgraph cu_graph - cdef cydriver.CUgraphExec cu_exec = int(self._mnff.graph) + cdef cydriver.CUgraphExec cu_exec = as_cu(self._h_graph_exec) if isinstance(source, GraphBuilder): - if not source._building_ended: + if (source)._state == CLOSED: + raise ValueError("Source graph builder has been closed.") + if (source)._state != CAPTURE_ENDED: raise ValueError("Graph has not finished building.") - cu_graph = int(source._mnff.graph) + cu_graph = as_cu((source)._h_graph) elif isinstance(source, GraphDefinition): cu_graph = int(source.handle) else: @@ -866,7 +962,10 @@ class Graph: The stream in which to upload the graph """ - handle_return(driver.cuGraphUpload(self._mnff.graph, stream.handle)) + cdef cydriver.CUgraphExec c_exec = as_cu(self._h_graph_exec) + cdef cydriver.CUstream c_stream = int(stream.handle) + with nogil: + HANDLE_RETURN(cydriver.cuGraphUpload(c_exec, c_stream)) def launch(self, stream: Stream) -> None: """Launches the graph in a stream. @@ -877,4 +976,7 @@ class Graph: The stream in which to launch the graph. """ - handle_return(driver.cuGraphLaunch(self._mnff.graph, stream.handle)) + cdef cydriver.CUgraphExec c_exec = as_cu(self._h_graph_exec) + cdef cydriver.CUstream c_stream = int(stream.handle) + with nogil: + HANDLE_RETURN(cydriver.cuGraphLaunch(c_exec, c_stream)) diff --git a/cuda_core/tests/graph/test_graph_builder.py b/cuda_core/tests/graph/test_graph_builder.py index e0e3fd9a51c..18dfe21cc12 100644 --- a/cuda_core/tests/graph/test_graph_builder.py +++ b/cuda_core/tests/graph/test_graph_builder.py @@ -167,6 +167,86 @@ def test_graph_capture_errors(init_cuda): gb.end_building().complete() +def test_graph_begin_building_twice(init_cuda): + """Calling begin_building() while already capturing is a clear error.""" + gb = Device().create_graph_builder() + gb.begin_building() + with pytest.raises(RuntimeError, match="^Graph builder is already building."): + gb.begin_building() + gb.end_building() + + +def test_graph_split_requires_building(init_cuda): + """A builder must be capturing before it can be split.""" + gb = Device().create_graph_builder() + with pytest.raises(RuntimeError, match="^Graph builder must be building before it can be split."): + gb.split(2) + + +def test_graph_complete_after_close_forked(init_cuda): + """complete() on a forked builder closed via join() must not deref a null handle.""" + mod = compile_common_kernels() + empty_kernel = mod.get_kernel("empty_kernel") + + gb = Device().create_graph_builder().begin_building() + launch(gb, LaunchConfig(grid=1, block=1), empty_kernel) + left, right = gb.split(2) + launch(left, LaunchConfig(grid=1, block=1), empty_kernel) + launch(right, LaunchConfig(grid=1, block=1), empty_kernel) + + # join() closes the non-root builder (right); it must now be rejected, not crash. + GraphBuilder.join(left, right) + with pytest.raises(RuntimeError, match="^Graph builder has been closed."): + right.complete() + + +def test_graph_update_after_source_close(init_cuda): + """Graph.update() with a closed source builder must raise, not deref a null handle.""" + mod = compile_common_kernels() + empty_kernel = mod.get_kernel("empty_kernel") + + gb = Device().create_graph_builder().begin_building() + launch(gb, LaunchConfig(grid=1, block=1), empty_kernel) + graph = gb.end_building().complete() + + source = Device().create_graph_builder().begin_building() + launch(source, LaunchConfig(grid=1, block=1), empty_kernel) + source.end_building() + source.close() + + with pytest.raises(ValueError, match="^Source graph builder has been closed."): + graph.update(source) + + +def test_graph_gc_mid_capture(init_cuda): + """Dropping a builder mid-capture ends the orphaned capture so the stream stays usable.""" + import gc + + mod = compile_common_kernels() + empty_kernel = mod.get_kernel("empty_kernel") + + stream = Device().create_stream() + gb = stream.create_graph_builder().begin_building() + launch(gb, LaunchConfig(grid=1, block=1), empty_kernel) + + # Drop the builder without end_building()/close(); __dealloc__ must end the capture. + del gb + gc.collect() + + # If the capture were left active, the stream would be poisoned for new work. + launch(stream, LaunchConfig(grid=1, block=1), empty_kernel) + stream.sync() + stream.close() + + +def test_graph_embed_non_builder(init_cuda): + """embed() rejects a non-GraphBuilder argument with a TypeError.""" + gb = Device().create_graph_builder().begin_building() + with pytest.raises(TypeError): + gb.embed(object()) + gb.end_building() + + def test_graph_capture_callback_python(init_cuda): results = [] @@ -260,6 +340,21 @@ def test_graph_child_graph(init_cuda): b.close() +def test_graph_close_is_idempotent(init_cuda): + """Re-entrant close must not double-destroy the graph exec (Glasswing V18.1).""" + mod = compile_common_kernels() + empty_kernel = mod.get_kernel("empty_kernel") + + gb = Device().create_graph_builder().begin_building() + launch(gb, LaunchConfig(grid=1, block=1), empty_kernel) + graph = gb.end_building().complete() + gb.close() + + graph.close() + graph.close() + assert int(graph.handle) == 0 + + def test_graph_stream_lifetime(init_cuda): mod = compile_common_kernels() empty_kernel = mod.get_kernel("empty_kernel") diff --git a/cuda_core/tests/graph/test_graph_builder_conditional.py b/cuda_core/tests/graph/test_graph_builder_conditional.py index 69956cf0f21..ff9b40a7016 100644 --- a/cuda_core/tests/graph/test_graph_builder_conditional.py +++ b/cuda_core/tests/graph/test_graph_builder_conditional.py @@ -290,3 +290,97 @@ def test_graph_conditional_while(init_cuda, condition_value): # Close the memory resource now because the garbage collected might # de-allocate it during the next graph builder process b.close() + + +@requires_module(np, "2.1") +def test_graph_conditional_on_forked_builder(init_cuda): + """A conditional created on a forked builder keeps its body graph's parent + handle pinned to the owning primary graph.""" + mod = compile_conditional_kernels(int) + add_one = mod.get_kernel("add_one") + set_handle = mod.get_kernel("set_handle") + + launch_stream = Device().create_stream() + mr = LegacyPinnedMemoryResource() + b = mr.allocate(4) + arr = np.from_dlpack(b).view(np.int32) + arr[0] = 0 + + gb = Device().create_graph_builder().begin_building() + launch(gb, LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data) + + # Fork, then create the conditional on the forked builder (not the primary). + left, right = gb.split(2) + try: + condition = right.create_condition() + except RuntimeError as e: + with pytest.raises(RuntimeError, match="^(Driver|Binding) version"): + raise e + right.end_building() + GraphBuilder.join(left, right).end_building() + b.close() + pytest.skip("Driver does not support conditional handle") + launch(right, LaunchConfig(grid=1, block=1), set_handle, condition, 1) + gb_if = right.if_then(condition).begin_building() + launch(gb_if, LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data) + gb_if.end_building() + + gb = GraphBuilder.join(left, right) + graph = gb.end_building().complete() + + arr[0] = 0 + graph.launch(launch_stream) + launch_stream.sync() + # add_one on primary (1) + add_one inside the taken if-branch (1) + assert arr[0] == 2 + + b.close() + + +@requires_module(np, "2.1") +def test_graph_conditional_nested(init_cuda): + """A conditional nested inside another conditional body exercises the + multi-level body -> outer-body -> primary keep-alive chain.""" + mod = compile_conditional_kernels(int) + add_one = mod.get_kernel("add_one") + set_handle = mod.get_kernel("set_handle") + + launch_stream = Device().create_stream() + mr = LegacyPinnedMemoryResource() + b = mr.allocate(4) + arr = np.from_dlpack(b).view(np.int32) + arr[0] = 0 + + gb = Device().create_graph_builder().begin_building() + + try: + outer_condition = gb.create_condition() + except RuntimeError as e: + with pytest.raises(RuntimeError, match="^(Driver|Binding) version"): + raise e + gb.end_building() + b.close() + pytest.skip("Driver does not support conditional handle") + launch(gb, LaunchConfig(grid=1, block=1), set_handle, outer_condition, 1) + + # Outer if-branch + gb_outer = gb.if_then(outer_condition).begin_building() + launch(gb_outer, LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data) + + # Inner if-branch, created inside the outer body + inner_condition = gb_outer.create_condition() + launch(gb_outer, LaunchConfig(grid=1, block=1), set_handle, inner_condition, 1) + gb_inner = gb_outer.if_then(inner_condition).begin_building() + launch(gb_inner, LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data) + gb_inner.end_building() + gb_outer.end_building() + + graph = gb.end_building().complete() + + arr[0] = 0 + graph.launch(launch_stream) + launch_stream.sync() + # add_one in outer body (1) + add_one in inner body (1) + assert arr[0] == 2 + + b.close()