Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions cuda_core/cuda/core/_device.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,12 @@ cdef class DeviceProperties:

cdef inline int _get_cached_attribute(self, attr, default=0) except? -2:
"""Retrieve the attribute value, using cache if applicable."""
if attr not in self._cache:
self._cache[attr] = self._get_attribute(attr, default)
return self._cache[attr]
cached = self._cache.get(attr)
if cached is not None:
return cached
cdef int value = self._get_attribute(attr, default)
self._cache[attr] = value # setdefault not needed for ints
return value

@property
def max_threads_per_block(self) -> int:
Expand Down Expand Up @@ -1131,11 +1134,11 @@ class Device:
def compute_capability(self) -> ComputeCapability:
"""Return a named tuple with 2 fields: major and minor."""
cdef DeviceProperties prop = self.properties
if "compute_capability" in prop._cache:
return prop._cache["compute_capability"]
cached = prop._cache.get("compute_capability")
if cached is not None:
return cached
cc = ComputeCapability(prop.compute_capability_major, prop.compute_capability_minor)
prop._cache["compute_capability"] = cc
return cc
return prop._cache.setdefault("compute_capability", cc)

@property
def arch(self) -> str:
Expand Down
3 changes: 0 additions & 3 deletions cuda_core/cuda/core/_memory/_graph_memory_resource.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from __future__ import annotations

from functools import cache

from cuda.core._device import Device
from cuda.core._memory._buffer import Buffer, MemoryResource
from cuda.core._stream import Stream
Expand Down Expand Up @@ -113,7 +111,6 @@ class GraphMemoryResource(cyGraphMemoryResource):
...

@classmethod
@cache
def _create(cls, device_id: int) -> GraphMemoryResource:
...
__all__ = ['GraphMemoryResource']
14 changes: 11 additions & 3 deletions cuda_core/cuda/core/_memory/_graph_memory_resource.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ from cuda.core._resource_handles cimport (
from cuda.core._stream cimport Stream_accept, Stream
from cuda.core._utils.cuda_utils cimport HANDLE_RETURN

from functools import cache
from typing import TYPE_CHECKING

if TYPE_CHECKING:
Expand Down Expand Up @@ -161,6 +160,8 @@ cdef class cyGraphMemoryResource(MemoryResource):
return False


cdef dict _mem_resource_cache = {}

class GraphMemoryResource(cyGraphMemoryResource):
"""
A memory resource for memory related to graphs.
Expand All @@ -185,9 +186,16 @@ class GraphMemoryResource(cyGraphMemoryResource):
return cls._create(c_device_id)

@classmethod
@cache
def _create(cls, int device_id) -> GraphMemoryResource:
return cyGraphMemoryResource.__new__(cls, device_id)
# we use a dict currently, because functools.cache is currently less
# thread-safe see also: https://github.com/python/cpython/issues/150708
res = _mem_resource_cache.get(device_id)
if res is not None:
return res

# create new instance, but in case of a race may return another:
new = cyGraphMemoryResource.__new__(cls, device_id)
return _mem_resource_cache.setdefault(device_id, new)


# Raise an exception if the given stream is capturing.
Expand Down
2 changes: 1 addition & 1 deletion cuda_core/cuda/core/_memoryview.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ cdef inline bint _is_torch_tensor(object obj):
cdef str mod = tp.__module__ or ""
cdef bint result = mod.startswith("torch") and hasattr(obj, "data_ptr") \
and _torch_version_check()
_torch_type_cache[tp] = result
_torch_type_cache[tp] = result # setdefault not needed for bools
return result


Expand Down
2 changes: 1 addition & 1 deletion cuda_core/cuda/core/_module.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ cdef class KernelAttributes:
cdef int result
with nogil:
HANDLE_RETURN(cydriver.cuKernelGetAttribute(&result, attribute, as_cu(self._h_kernel), device_id))
self._cache[cache_key] = result
self._cache[cache_key] = result # setdefault not needed for ints
return result

def __getitem__(self, device: Device | int) -> KernelAttributes:
Expand Down
3 changes: 1 addition & 2 deletions cuda_core/cuda/core/graph/_graph_node.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ _node_registry: weakref.WeakValueDictionary[int, GraphNode] = weakref.WeakValueD


cdef inline GraphNode _registered(GraphNode n):
_node_registry[<uintptr_t>n._h_node.get()] = n
return n
return _node_registry.setdefault(<uintptr_t>n._h_node.get(), n)


cdef class GraphNode:
Expand Down
Loading