diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 40121b6f0fe..f8d2a2731ed 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,7 +32,7 @@ repos: language: python additional_dependencies: - https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl - exclude: '(.*pixi\.lock)|(\.git_archival\.txt)|(.*\.patch$)' + exclude: '(.*pixi\.lock)|(\.git_archival\.txt)|(.*\.patch$)|(^cuda_core/cuda/core/_vendored/)' args: ["--fix"] - id: no-markdown-in-docs-source @@ -111,6 +111,7 @@ repos: alias: mypy-cuda-core name: mypy-cuda-core files: ^cuda_core/cuda/.*\.(py|pyi)$ + exclude: ^cuda_core/cuda/core/_vendored/ pass_filenames: false args: [--config-file=cuda_core/pyproject.toml, cuda_core/cuda/core] additional_dependencies: diff --git a/cuda_core/AGENTS.md b/cuda_core/AGENTS.md index 2e7391b2b84..ca49d4e9199 100644 --- a/cuda_core/AGENTS.md +++ b/cuda_core/AGENTS.md @@ -145,3 +145,97 @@ so that they are documented but don't appear in the main index. ### API stability Reviews should point out where existing public APIs are broken. + +### Deprecation and API lifecycle + +`cuda.core` follows SemVer (see `docs/source/: + +- **New APIs** may be added at any time (`x.Y.0`). +- **Breaking removals** only happen in **major releases** (`X.0.0`). +- Per the support policy, a deprecation notice must be present for **at least + one minor release** before the API is actually removed. +- Changes should be notated in the code and also in the release notes in the + "Deprecated APIs" section. + +**Annotating a new API** — Use the `versionadded` decorator from the vendored +`cuda.core._vendored.deprecated.sphinx` module: + +```python + +from cuda.core._vendored.deprecated.sphinx import versionadded + +@versionadded("1.2.0") +def new_feature(...): + """Short description. + """ +``` + +Alternatively, if the vagaries of how we implement functions in Cython does not +allow this, you can add the reST `versionadded` directive directly: + +```python +def new_feature(...): + """Short description. + + .. versionadded:: 1.2.0 + """ +``` + +**Annotating a changed API** — Use the `versionchanged` decorator from the +vendored `cuda.core._vendored.deprecated.sphinx` module: + +```python + +from cuda.core._vendored.deprecated.sphinx import versionchanged + +@versionchanged("1.2.0", reason="The old version was broken because...") +def new_feature(...): + """Short description. + """ +``` + +Alternatively, if the vagaries of how we implement functions in Cython does not +allow this, you can add the reST `versionchanged` directive directly: + +```python +def new_feature(...): + """Short description. + + .. versionchanged:: 1.2.0 + The old version was broken because... + """ +``` + +**Deprecating an existing API** — use the `@deprecated` decorator from the +vendored `cuda.core._vendored.deprecated.sphinx` module and add a +`.. deprecated::` directive in the docstring. The decorator emits a +`DeprecationWarning` at call time; the docstring directive surfaces it in the +generated docs. + +```python +from cuda.core._vendored.deprecated.sphinx import deprecated + +@deprecated(version="1.2.0", reason="Use `new_feature` instead.") +def old_feature(...): + """Short description. + """ +``` + +Rules to follow when deprecating: + +- The `version=` argument must be the **first** in which the + deprecation appears, not the release in which removal is planned. +- The `reason=` string must name the replacement (if one exists) so users + know what to migrate to. +- Keep the old implementation fully functional — do not change its behavior, + only add the decorator. +- The deprecated API must remain in the codebase for **at least one full minor + release cycle** before it can be removed in a subsequent major release. + +**Removing a deprecated API** — removals land in a **major release**. Before +removing, verify that the deprecation has been present since at least the +previous minor release. Remove the decorator, the implementation, and any +`__all__` entry; update `api.rst` and the release notes accordingly. + +At some point in the future, we will provide automation for removal of +deprecated APIs. diff --git a/cuda_core/cuda/core/_vendored/__init__.py b/cuda_core/cuda/core/_vendored/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/cuda_core/cuda/core/_vendored/deprecated/__init__.py b/cuda_core/cuda/core/_vendored/deprecated/__init__.py new file mode 100644 index 00000000000..20a317a94e2 --- /dev/null +++ b/cuda_core/cuda/core/_vendored/deprecated/__init__.py @@ -0,0 +1,6 @@ +# Vendored from the Deprecated package (https://pypi.org/project/Deprecated/), +# version 1.3.1, (c) Laurent LAPORTE, MIT License. +# Modified to remove the dependency on the `wrapt` package. + +from cuda.core._vendored.deprecated.classic import deprecated +from cuda.core._vendored.deprecated.params import deprecated_params diff --git a/cuda_core/cuda/core/_vendored/deprecated/classic.py b/cuda_core/cuda/core/_vendored/deprecated/classic.py new file mode 100644 index 00000000000..965a10d6a15 --- /dev/null +++ b/cuda_core/cuda/core/_vendored/deprecated/classic.py @@ -0,0 +1,111 @@ +# Vendored from the Deprecated package (https://pypi.org/project/Deprecated/), +# version 1.3.1, (c) Laurent LAPORTE, MIT License. +# Modified to remove the dependency on the `wrapt` package. + +import functools +import inspect +import warnings + +# stacklevel=2 points past the wrapper to the actual call site +_routine_stacklevel = 2 +_class_stacklevel = 2 + +string_types = (bytes, str) + + +class ClassicAdapter: + """ + Classic adapter -- *for advanced usage only* + + This adapter is used to get the deprecation message according to the wrapped + object type: class, function, standard method, static method, or class method. + + This is the base class of the :class:`~deprecated.sphinx.SphinxAdapter` class + which is used to update the wrapped object docstring. + """ + + def __init__(self, reason="", version="", action=None, category=DeprecationWarning, extra_stacklevel=0): + self.reason = reason or "" + self.version = version or "" + self.action = action + self.category = category + self.extra_stacklevel = extra_stacklevel + + def get_deprecated_msg(self, wrapped, instance): + if instance is None: + if inspect.isclass(wrapped): + fmt = "Call to deprecated class {name}." + else: + fmt = "Call to deprecated function (or staticmethod) {name}." + else: + if inspect.isclass(instance): + fmt = "Call to deprecated class method {name}." + else: + fmt = "Call to deprecated method {name}." + if self.reason: + fmt += " ({reason})" + if self.version: + fmt += " -- Deprecated since version {version}." + return fmt.format(name=wrapped.__name__, reason=self.reason or "", version=self.version or "") + + def __call__(self, wrapped): + if inspect.isclass(wrapped): + old_new1 = wrapped.__new__ + + def wrapped_cls(cls, *args, **kwargs): + msg = self.get_deprecated_msg(wrapped, None) + stacklevel = _class_stacklevel + self.extra_stacklevel + if self.action: + with warnings.catch_warnings(): + warnings.simplefilter(self.action, self.category) + warnings.warn(msg, category=self.category, stacklevel=stacklevel) + else: + warnings.warn(msg, category=self.category, stacklevel=stacklevel) + if old_new1 is object.__new__: + return old_new1(cls) + return old_new1(cls, *args, **kwargs) + + wrapped.__new__ = staticmethod(wrapped_cls) + return wrapped + + elif inspect.isroutine(wrapped): + adapter = self + + @functools.wraps(wrapped) + def wrapper(*args, **kwargs): + msg = adapter.get_deprecated_msg(wrapped, None) + stacklevel = _routine_stacklevel + adapter.extra_stacklevel + if adapter.action: + with warnings.catch_warnings(): + warnings.simplefilter(adapter.action, adapter.category) + warnings.warn(msg, category=adapter.category, stacklevel=stacklevel) + else: + warnings.warn(msg, category=adapter.category, stacklevel=stacklevel) + return wrapped(*args, **kwargs) + + return wrapper + + else: + raise TypeError(repr(type(wrapped))) + + +def deprecated(*args, **kwargs): + """ + Decorator which can be used to mark functions as deprecated. + + It will result in a warning being emitted when the function is used. + """ + if args and isinstance(args[0], string_types): + kwargs["reason"] = args[0] + args = args[1:] + + if args and not callable(args[0]): + raise TypeError(repr(type(args[0]))) + + if args: + adapter_cls = kwargs.pop("adapter_cls", ClassicAdapter) + adapter = adapter_cls(**kwargs) + wrapped = args[0] + return adapter(wrapped) + + return functools.partial(deprecated, **kwargs) diff --git a/cuda_core/cuda/core/_vendored/deprecated/params.py b/cuda_core/cuda/core/_vendored/deprecated/params.py new file mode 100644 index 00000000000..6584d86df4d --- /dev/null +++ b/cuda_core/cuda/core/_vendored/deprecated/params.py @@ -0,0 +1,52 @@ +# Vendored from the Deprecated package (https://pypi.org/project/Deprecated/), +# version 1.3.1, (c) Laurent LAPORTE, MIT License. +# Modified to remove the dependency on the `wrapt` package. + +import collections +import functools +import inspect +import warnings + + +class DeprecatedParams: + """ + Decorator for functions where one or more parameters are deprecated. + """ + + def __init__(self, param, reason="", category=DeprecationWarning): + self.messages = {} + self.category = category + self.populate_messages(param, reason=reason) + + def populate_messages(self, param, reason=""): + if isinstance(param, dict): + self.messages.update(param) + elif isinstance(param, str): + fmt = "'{param}' parameter is deprecated" + reason = reason or fmt.format(param=param) + self.messages[param] = reason + else: + raise TypeError(param) + + def check_params(self, signature, *args, **kwargs): + binding = signature.bind(*args, **kwargs) + bound = collections.OrderedDict(binding.arguments, **binding.kwargs) + return [param for param in bound if param in self.messages] + + def warn_messages(self, messages): + for message in messages: + warnings.warn(message, category=self.category, stacklevel=3) + + def __call__(self, f): + signature = inspect.signature(f) + + @functools.wraps(f) + def wrapper(*args, **kwargs): + invalid_params = self.check_params(signature, *args, **kwargs) + self.warn_messages([self.messages[param] for param in invalid_params]) + return f(*args, **kwargs) + + return wrapper + + +deprecated_params = DeprecatedParams diff --git a/cuda_core/cuda/core/_vendored/deprecated/sphinx.py b/cuda_core/cuda/core/_vendored/deprecated/sphinx.py new file mode 100644 index 00000000000..557154c5474 --- /dev/null +++ b/cuda_core/cuda/core/_vendored/deprecated/sphinx.py @@ -0,0 +1,113 @@ +# Vendored from the Deprecated package (https://pypi.org/project/Deprecated/), +# version 1.3.1, (c) Laurent LAPORTE, MIT License. +# Modified to remove the dependency on the `wrapt` package. + +import re +import textwrap + +from cuda.core._vendored.deprecated.classic import ClassicAdapter +from cuda.core._vendored.deprecated.classic import deprecated as _classic_deprecated + + +class SphinxAdapter(ClassicAdapter): + """ + Sphinx adapter -- *for advanced usage only* + + This adapter overrides :class:`~deprecated.classic.ClassicAdapter` to add + Sphinx directives ("versionadded", "versionchanged", "deprecated") to the + end of the decorated function or class docstring. + """ + + def __init__( + self, + directive, + reason="", + version="", + action=None, + category=DeprecationWarning, + extra_stacklevel=0, + line_length=70, + ): + if not version: + raise ValueError("'version' argument is required in Sphinx directives") + self.directive = directive + self.line_length = line_length + super().__init__( + reason=reason, version=version, action=action, category=category, extra_stacklevel=extra_stacklevel + ) + + def __call__(self, wrapped): + fmt = ".. {directive}:: {version}" if self.version else ".. {directive}::" + div_lines = [fmt.format(directive=self.directive, version=self.version)] + width = self.line_length - 3 if self.line_length > 3 else 2**16 + reason = textwrap.dedent(self.reason).strip() + for paragraph in reason.splitlines(): + if paragraph: + div_lines.extend( + textwrap.fill( + paragraph, + width=width, + initial_indent=" ", + subsequent_indent=" ", + ).splitlines() + ) + else: + div_lines.append("") + + docstring = wrapped.__doc__ or "" + lines = docstring.splitlines(True) or [""] + docstring = textwrap.dedent("".join(lines[1:])) if len(lines) > 1 else "" + docstring = lines[0] + docstring + if docstring: + docstring = re.sub(r"\n+$", "", docstring, flags=re.DOTALL) + "\n\n" + else: + docstring = "\n" + + docstring += "".join(f"{line}\n" for line in div_lines) + + wrapped.__doc__ = docstring + if self.directive in {"versionadded", "versionchanged"}: + return wrapped + return super().__call__(wrapped) + + def get_deprecated_msg(self, wrapped, instance): + msg = super().get_deprecated_msg(wrapped, instance) + msg = re.sub(r"(?: : [a-zA-Z]+ )? : [a-zA-Z]+ : (`[^`]*`)", r"\1", msg, flags=re.X) + return msg + + +def versionadded(reason="", version="", line_length=70): + """ + Decorator that inserts a "versionadded" Sphinx directive into the docstring. + """ + return SphinxAdapter( + "versionadded", + reason=reason, + version=version, + line_length=line_length, + ) + + +def versionchanged(reason="", version="", line_length=70): + """ + Decorator that inserts a "versionchanged" Sphinx directive into the docstring. + """ + return SphinxAdapter( + "versionchanged", + reason=reason, + version=version, + line_length=line_length, + ) + + +def deprecated(reason="", version="", line_length=70, **kwargs): + """ + Decorator that inserts a "deprecated" Sphinx directive into the docstring + and emits a :exc:`DeprecationWarning` when the decorated object is called. + """ + directive = kwargs.pop("directive", "deprecated") + adapter_cls = kwargs.pop("adapter_cls", SphinxAdapter) + kwargs["reason"] = reason + kwargs["version"] = version + kwargs["line_length"] = line_length + return _classic_deprecated(directive=directive, adapter_cls=adapter_cls, **kwargs) diff --git a/cuda_core/cuda/core/system/_device.pyi b/cuda_core/cuda/core/system/_device.pyi index 2d35c7f63bc..d07608505a9 100644 --- a/cuda_core/cuda/core/system/_device.pyi +++ b/cuda_core/cuda/core/system/_device.pyi @@ -6,6 +6,8 @@ from typing import Iterable import cuda.core from cuda.bindings import nvml +from cuda.core._vendored.deprecated.sphinx import (deprecated, versionadded, + versionchanged) from cuda.core.system.typing import (AddressingMode, AffinityScope, ClockId, ClocksEventReasons, ClockType, CoolerControl, CoolerTarget, DeviceArch, @@ -787,11 +789,23 @@ class MigInfo: A list of all MIG devices corresponding to this GPU. """ -class NvlinkInfo: +class _NvlinkInfoMeta(type): + + @property + @deprecated(version='1.1.0', reason='Use Device.get_num_nvlinks instead to get the actual number of Nvlinks available on a specific device.') + def max_links(cls): + """ + The statically-defined maximum number of Nvlinks available. Defined in + upstream NVML as ``NVML_NVLINK_MAX_LINKS``. + + To find the actual number of Nvlinks available on a device, use + :py:attr:`Device.get_num_nvlinks`. + """ + +class _NvlinkInfo: """ Nvlink information for a device. """ - max_links = nvml.NVLINK_MAX_LINKS def __init__(self, device: Device, link: int): ... @@ -824,6 +838,9 @@ class NvlinkInfo: `True` if the Nvlink is active. """ +class NvlinkInfo(_NvlinkInfo, metaclass=_NvlinkInfoMeta): + ... + class PciInfo: """ PCI information about a GPU device. @@ -1719,6 +1736,7 @@ class Device: :obj:`~_device.MemoryInfo` object with memory information. """ + @versionchanged(version='1.1.0', reason='Any link number not supported by this specific device will raise a `ValueError`.') def get_nvlink(self, link: int) -> NvlinkInfo: """ Get :obj:`~NvlinkInfo` about this device. @@ -1726,6 +1744,22 @@ class Device: For devices with NVLink support. """ + @versionadded(version='1.1.0') + def get_nvlink_count(self) -> int: + """ + Get the number of NVLink links on this device. + + For devices with NVLink support. + """ + + @versionadded(version='1.1.0') + def get_nvlinks(self) -> Iterable[NvlinkInfo]: + """ + Get :obj:`~NvlinkInfo` about all NVLink links on this device. + + For devices with NVLink support. + """ + @property def pci_info(self) -> PciInfo: """ diff --git a/cuda_core/cuda/core/system/_device.pyx b/cuda_core/cuda/core/system/_device.pyx index f0126b78a5b..3f843b7ce1b 100644 --- a/cuda_core/cuda/core/system/_device.pyx +++ b/cuda_core/cuda/core/system/_device.pyx @@ -33,6 +33,7 @@ from cuda.core.system.typing import ( ThermalController, ThermalTarget, ) +from cuda.core._vendored.deprecated.sphinx import deprecated, versionadded, versionchanged if TYPE_CHECKING: import cuda.core # no-cython-lint @@ -884,16 +885,39 @@ cdef class Device: # NVLINK # See external class definitions in _nvlink.pxi + @versionchanged( + version="1.1.0", + reason="Any link number not supported by this specific device will raise a `ValueError`." + ) def get_nvlink(self, link: int) -> NvlinkInfo: """ Get :obj:`~NvlinkInfo` about this device. For devices with NVLink support. """ - if link < 0 or link >= NvlinkInfo.max_links: - raise ValueError(f"Link index {link} is out of range [0, {NvlinkInfo.max_links})") + if link < 0 or link >= self.get_nvlink_count(): + raise ValueError(f"Link index {link} is out of range [0, {self.get_nvlink_count()})") return NvlinkInfo(self, link) + @versionadded(version="1.1.0") + def get_nvlink_count(self) -> int: + """ + Get the number of NVLink links on this device. + + For devices with NVLink support. + """ + return self.get_field_values([FieldId.DEV_NVLINK_LINK_COUNT])[0].value + + @versionadded(version="1.1.0") + def get_nvlinks(self) -> Iterable[NvlinkInfo]: + """ + Get :obj:`~NvlinkInfo` about all NVLink links on this device. + + For devices with NVLink support. + """ + for link in range(self.get_nvlink_count()): + yield self.get_nvlink(link) + ########################################################################## # PCI INFO # See external class definitions in _pci_info.pxi diff --git a/cuda_core/cuda/core/system/_nvlink.pxi b/cuda_core/cuda/core/system/_nvlink.pxi index 62ab4e716be..9a2ca8dd3ff 100644 --- a/cuda_core/cuda/core/system/_nvlink.pxi +++ b/cuda_core/cuda/core/system/_nvlink.pxi @@ -18,7 +18,24 @@ if _NVLINK_VERSION_6_0 is not None: _NVLINK_VERSION_MAPPING[_NVLINK_VERSION_6_0] = (6, 0) -cdef class NvlinkInfo: +class _NvlinkInfoMeta(type): + @property + @deprecated( + version="1.1.0", + reason="Use Device.get_num_nvlinks instead to get the actual number of Nvlinks available on a specific device." + ) + def max_links(cls): + """ + The statically-defined maximum number of Nvlinks available. Defined in + upstream NVML as ``NVML_NVLINK_MAX_LINKS``. + + To find the actual number of Nvlinks available on a device, use + :py:attr:`Device.get_num_nvlinks`. + """ + return nvml.NVLINK_MAX_LINKS + + +cdef class _NvlinkInfo: """ Nvlink information for a device. """ @@ -67,4 +84,6 @@ cdef class NvlinkInfo: nvml.device_get_nvlink_state(self._device._handle, self._link) == nvml.EnableState.FEATURE_ENABLED ) - max_links = nvml.NVLINK_MAX_LINKS + +class NvlinkInfo(_NvlinkInfo, metaclass=_NvlinkInfoMeta): + pass diff --git a/cuda_core/pyproject.toml b/cuda_core/pyproject.toml index 94424d8488e..dd47b720884 100644 --- a/cuda_core/pyproject.toml +++ b/cuda_core/pyproject.toml @@ -140,6 +140,10 @@ implicit_reexport = true # Ignore missing imports for now (you can tighten this later) ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "cuda.core._vendored.*" +ignore_errors = true + [[tool.mypy.overrides]] # cpdef functions with Cython-native tuple return types can't carry type args # through stubgen-pyx; suppress the resulting type-arg error for this module. diff --git a/cuda_core/tests/system/test_system_device.py b/cuda_core/tests/system/test_system_device.py index 4aa13840b48..578a4bd945d 100644 --- a/cuda_core/tests/system/test_system_device.py +++ b/cuda_core/tests/system/test_system_device.py @@ -763,27 +763,40 @@ def test_compute_running_processes(): def test_nvlink(): for device in system.Device.get_all_devices(): - max_links = _device.NvlinkInfo.max_links - assert isinstance(max_links, int) - assert max_links > 0 - - for link in range(max_links): - with unsupported_before(device, None): - nvlink_info = device.get_nvlink(link) - assert isinstance(nvlink_info, _device.NvlinkInfo) - - with unsupported_before(device, None): - state = nvlink_info.state - assert isinstance(state, bool) - - if not state: - continue - - with unsupported_before(device, None): - version = nvlink_info.version - assert isinstance(version, tuple) - assert len(version) == 2 - assert all(isinstance(i, int) for i in version) + with unsupported_before(device, None): + for link in range(device.get_nvlink_count()): + with unsupported_before(device, None): + nvlink_info = device.get_nvlink(link) + assert isinstance(nvlink_info, _device.NvlinkInfo) + + with unsupported_before(device, None): + state = nvlink_info.state + assert isinstance(state, bool) + + if not state: + continue + + with unsupported_before(device, None): + version = nvlink_info.version + assert isinstance(version, tuple) + assert len(version) == 2 + assert all(isinstance(i, int) for i in version) + + for nvlink_info in device.get_nvlinks(): + assert isinstance(nvlink_info, _device.NvlinkInfo) + + with unsupported_before(device, None): + state = nvlink_info.state + assert isinstance(state, bool) + + if not state: + continue + + with unsupported_before(device, None): + version = nvlink_info.version + assert isinstance(version, tuple) + assert len(version) == 2 + assert all(isinstance(i, int) for i in version) def test_utilization():