diff --git a/cuda_core/cuda/core/_device.pyx b/cuda_core/cuda/core/_device.pyx index da6972f3727..451ca25ddaa 100644 --- a/cuda_core/cuda/core/_device.pyx +++ b/cuda_core/cuda/core/_device.pyx @@ -85,9 +85,12 @@ cdef class DeviceProperties: cdef inline int _get_cached_attribute(self, attr, default=0) except? -2: """Retrieve the attribute value, using cache if applicable.""" - if attr not in self._cache: - self._cache[attr] = self._get_attribute(attr, default) - return self._cache[attr] + cached = self._cache.get(attr) + if cached is not None: + return cached + cdef int value = self._get_attribute(attr, default) + self._cache[attr] = value # setdefault not needed for ints + return value @property def max_threads_per_block(self) -> int: @@ -1131,11 +1134,11 @@ class Device: def compute_capability(self) -> ComputeCapability: """Return a named tuple with 2 fields: major and minor.""" cdef DeviceProperties prop = self.properties - if "compute_capability" in prop._cache: - return prop._cache["compute_capability"] + cached = prop._cache.get("compute_capability") + if cached is not None: + return cached cc = ComputeCapability(prop.compute_capability_major, prop.compute_capability_minor) - prop._cache["compute_capability"] = cc - return cc + return prop._cache.setdefault("compute_capability", cc) @property def arch(self) -> str: diff --git a/cuda_core/cuda/core/_memory/_graph_memory_resource.pyi b/cuda_core/cuda/core/_memory/_graph_memory_resource.pyi index 4ff85eb5972..b34f968fdc9 100644 --- a/cuda_core/cuda/core/_memory/_graph_memory_resource.pyi +++ b/cuda_core/cuda/core/_memory/_graph_memory_resource.pyi @@ -2,8 +2,6 @@ from __future__ import annotations -from functools import cache - from cuda.core._device import Device from cuda.core._memory._buffer import Buffer, MemoryResource from cuda.core._stream import Stream @@ -113,7 +111,6 @@ class GraphMemoryResource(cyGraphMemoryResource): ... @classmethod - @cache def _create(cls, device_id: int) -> GraphMemoryResource: ... __all__ = ['GraphMemoryResource'] \ No newline at end of file diff --git a/cuda_core/cuda/core/_memory/_graph_memory_resource.pyx b/cuda_core/cuda/core/_memory/_graph_memory_resource.pyx index 479322ab017..e845a47b080 100644 --- a/cuda_core/cuda/core/_memory/_graph_memory_resource.pyx +++ b/cuda_core/cuda/core/_memory/_graph_memory_resource.pyx @@ -18,7 +18,6 @@ from cuda.core._resource_handles cimport ( from cuda.core._stream cimport Stream_accept, Stream from cuda.core._utils.cuda_utils cimport HANDLE_RETURN -from functools import cache from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -161,6 +160,8 @@ cdef class cyGraphMemoryResource(MemoryResource): return False +cdef dict _mem_resource_cache = {} + class GraphMemoryResource(cyGraphMemoryResource): """ A memory resource for memory related to graphs. @@ -185,9 +186,16 @@ class GraphMemoryResource(cyGraphMemoryResource): return cls._create(c_device_id) @classmethod - @cache def _create(cls, int device_id) -> GraphMemoryResource: - return cyGraphMemoryResource.__new__(cls, device_id) + # we use a dict currently, because functools.cache is currently less + # thread-safe see also: https://github.com/python/cpython/issues/150708 + res = _mem_resource_cache.get(device_id) + if res is not None: + return res + + # create new instance, but in case of a race may return another: + new = cyGraphMemoryResource.__new__(cls, device_id) + return _mem_resource_cache.setdefault(device_id, new) # Raise an exception if the given stream is capturing. diff --git a/cuda_core/cuda/core/_memoryview.pyx b/cuda_core/cuda/core/_memoryview.pyx index c65107ae273..5d97240360b 100644 --- a/cuda_core/cuda/core/_memoryview.pyx +++ b/cuda_core/cuda/core/_memoryview.pyx @@ -80,7 +80,7 @@ cdef inline bint _is_torch_tensor(object obj): cdef str mod = tp.__module__ or "" cdef bint result = mod.startswith("torch") and hasattr(obj, "data_ptr") \ and _torch_version_check() - _torch_type_cache[tp] = result + _torch_type_cache[tp] = result # setdefault not needed for bools return result diff --git a/cuda_core/cuda/core/_module.pyx b/cuda_core/cuda/core/_module.pyx index 5cb1b7f0059..2c6810a3718 100644 --- a/cuda_core/cuda/core/_module.pyx +++ b/cuda_core/cuda/core/_module.pyx @@ -83,7 +83,7 @@ cdef class KernelAttributes: cdef int result with nogil: HANDLE_RETURN(cydriver.cuKernelGetAttribute(&result, attribute, as_cu(self._h_kernel), device_id)) - self._cache[cache_key] = result + self._cache[cache_key] = result # setdefault not needed for ints return result def __getitem__(self, device: Device | int) -> KernelAttributes: diff --git a/cuda_core/cuda/core/graph/_graph_node.pyx b/cuda_core/cuda/core/graph/_graph_node.pyx index f627edf9bb2..53145dd5e2a 100644 --- a/cuda_core/cuda/core/graph/_graph_node.pyx +++ b/cuda_core/cuda/core/graph/_graph_node.pyx @@ -78,8 +78,7 @@ _node_registry: weakref.WeakValueDictionary[int, GraphNode] = weakref.WeakValueD cdef inline GraphNode _registered(GraphNode n): - _node_registry[n._h_node.get()] = n - return n + return _node_registry.setdefault(n._h_node.get(), n) cdef class GraphNode: