diff --git a/cuda_pathfinder/cuda/pathfinder/_static_libs/find_bitcode_lib.py b/cuda_pathfinder/cuda/pathfinder/_static_libs/find_bitcode_lib.py index ac038aadfe7..dada58ffd60 100644 --- a/cuda_pathfinder/cuda/pathfinder/_static_libs/find_bitcode_lib.py +++ b/cuda_pathfinder/cuda/pathfinder/_static_libs/find_bitcode_lib.py @@ -3,6 +3,7 @@ import functools import os +import re from dataclasses import dataclass from typing import NoReturn, TypedDict @@ -23,10 +24,12 @@ class LocatedBitcodeLib: abs_path: str filename: str found_via: str + sm_arch: str | None = None class _BitcodeLibInfo(TypedDict): filename: str + arch_specific_filename_template: str | None rel_path: str site_packages_dirs: tuple[str, ...] available_on_windows: bool @@ -35,6 +38,7 @@ class _BitcodeLibInfo(TypedDict): _SUPPORTED_BITCODE_LIBS_INFO: dict[str, _BitcodeLibInfo] = { "device": { "filename": "libdevice.10.bc", + "arch_specific_filename_template": None, "rel_path": os.path.join("nvvm", "libdevice"), "site_packages_dirs": ( "nvidia/cu13/nvvm/libdevice", @@ -44,12 +48,14 @@ class _BitcodeLibInfo(TypedDict): }, "nccl_device": { "filename": "libnccl_device.bc", + "arch_specific_filename_template": None, "rel_path": "lib", "site_packages_dirs": ("nvidia/nccl/lib",), "available_on_windows": False, }, "nvshmem_device": { "filename": "libnvshmem_device.bc", + "arch_specific_filename_template": "libnvshmem_device_{sm_arch}.bc", "rel_path": "lib", "site_packages_dirs": ("nvidia/nvshmem/lib",), "available_on_windows": False, @@ -64,6 +70,23 @@ class _BitcodeLibInfo(TypedDict): ) +def _normalize_sm_arch(sm_arch: str | None) -> str | None: + if sm_arch is None: + return None + + if not isinstance(sm_arch, str): + raise ValueError( + "Invalid sm_arch value. Expected None or an NVIDIA SM architecture like 'sm90', 'sm90a', 'sm_90', or '90'." + ) + + match = re.fullmatch(r"(?:sm_?)?([0-9]{2,3}[a-z]?)", sm_arch) + if match is None: + raise ValueError( + "Invalid sm_arch value. Expected None or an NVIDIA SM architecture like 'sm90', 'sm90a', 'sm_90', or '90'." + ) + return f"sm{match.group(1)}" + + def _no_such_file_in_dir(dir_path: str, filename: str, error_messages: list[str], attachments: list[str]) -> None: error_messages.append(f"No such file: {os.path.join(dir_path, filename)}") if os.path.isdir(dir_path): @@ -75,12 +98,19 @@ def _no_such_file_in_dir(dir_path: str, filename: str, error_messages: list[str] class _FindBitcodeLib: - def __init__(self, name: str) -> None: + def __init__(self, name: str, sm_arch: str | None = None) -> None: if name not in _SUPPORTED_BITCODE_LIBS_INFO: # Updated reference raise ValueError(f"Unknown bitcode library: '{name}'. Supported: {', '.join(SUPPORTED_BITCODE_LIBS)}") self.name: str = name self.config: _BitcodeLibInfo = _SUPPORTED_BITCODE_LIBS_INFO[name] # Updated reference - self.filename: str = self.config["filename"] + self.sm_arch: str | None = _normalize_sm_arch(sm_arch) + if self.sm_arch is None: + self.filename: str = self.config["filename"] + else: + template = self.config["arch_specific_filename_template"] + if template is None: + raise ValueError(f"Bitcode library '{name}' does not support sm_arch lookup") + self.filename = template.format(sm_arch=self.sm_arch) self.rel_path: str = self.config["rel_path"] self.site_packages_dirs: tuple[str, ...] = self.config["site_packages_dirs"] self.error_messages: list[str] = [] @@ -130,14 +160,21 @@ def raise_not_found_error(self) -> NoReturn: raise BitcodeLibNotFoundError(f'Failure finding "{self.filename}": {err}\n{att}') -def locate_bitcode_lib(name: str) -> LocatedBitcodeLib: +def locate_bitcode_lib(name: str, sm_arch: str | None = None) -> LocatedBitcodeLib: """Locate a bitcode library by name. + Args: + name: Supported bitcode library name. + sm_arch: Optional NVIDIA SM architecture for arch-specific bitcode + libraries. Accepted forms include ``"sm90"``, ``"sm90a"``, + ``"sm_90"``, and ``"90"``. + Raises: ValueError: If ``name`` is not a supported bitcode library. + ValueError: If ``sm_arch`` is invalid or unsupported for ``name``. BitcodeLibNotFoundError: If the bitcode library cannot be found. """ - finder = _FindBitcodeLib(name) + finder = _FindBitcodeLib(name, sm_arch=sm_arch) abs_path = finder.try_site_packages() if abs_path is not None: @@ -146,6 +183,7 @@ def locate_bitcode_lib(name: str) -> LocatedBitcodeLib: abs_path=abs_path, filename=finder.filename, found_via="site-packages", + sm_arch=finder.sm_arch, ) abs_path = finder.try_with_conda_prefix() @@ -155,6 +193,7 @@ def locate_bitcode_lib(name: str) -> LocatedBitcodeLib: abs_path=abs_path, filename=finder.filename, found_via="conda", + sm_arch=finder.sm_arch, ) abs_path = finder.try_with_cuda_home() @@ -164,17 +203,25 @@ def locate_bitcode_lib(name: str) -> LocatedBitcodeLib: abs_path=abs_path, filename=finder.filename, found_via="CUDA_PATH", + sm_arch=finder.sm_arch, ) finder.raise_not_found_error() @functools.cache -def find_bitcode_lib(name: str) -> str: +def find_bitcode_lib(name: str, sm_arch: str | None = None) -> str: """Find the absolute path to a bitcode library. + Args: + name: Supported bitcode library name. + sm_arch: Optional NVIDIA SM architecture for arch-specific bitcode + libraries. Accepted forms include ``"sm90"``, ``"sm90a"``, + ``"sm_90"``, and ``"90"``. + Raises: ValueError: If ``name`` is not a supported bitcode library. + ValueError: If ``sm_arch`` is invalid or unsupported for ``name``. BitcodeLibNotFoundError: If the bitcode library cannot be found. """ - return locate_bitcode_lib(name).abs_path + return locate_bitcode_lib(name, sm_arch=sm_arch).abs_path diff --git a/cuda_pathfinder/tests/test_find_bitcode_lib.py b/cuda_pathfinder/tests/test_find_bitcode_lib.py index 659b068f0ff..b596733fd3e 100644 --- a/cuda_pathfinder/tests/test_find_bitcode_lib.py +++ b/cuda_pathfinder/tests/test_find_bitcode_lib.py @@ -23,8 +23,14 @@ def _bitcode_lib_info(libname: str): return find_bitcode_lib_module._SUPPORTED_BITCODE_LIBS_INFO[libname] -def _bitcode_lib_filename(libname: str) -> str: - return _bitcode_lib_info(libname)["filename"] +def _bitcode_lib_filename(libname: str, sm_arch: str | None = None) -> str: + info = _bitcode_lib_info(libname) + normalized_sm_arch = find_bitcode_lib_module._normalize_sm_arch(sm_arch) + if normalized_sm_arch is None: + return info["filename"] + template = info["arch_specific_filename_template"] + assert template is not None + return template.format(sm_arch=normalized_sm_arch) @pytest.fixture @@ -36,9 +42,9 @@ def clear_find_bitcode_lib_cache(): get_cuda_path_or_home.cache_clear() -def _make_bitcode_lib_file(dir_path: Path, libname: str) -> str: +def _make_bitcode_lib_file(dir_path: Path, libname: str, sm_arch: str | None = None) -> str: dir_path.mkdir(parents=True, exist_ok=True) - file_path = dir_path / _bitcode_lib_filename(libname) + file_path = dir_path / _bitcode_lib_filename(libname, sm_arch=sm_arch) file_path.touch() return str(file_path) @@ -65,6 +71,10 @@ def _located_bitcode_lib_asserts(located_bitcode_lib): assert isinstance(located_bitcode_lib.abs_path, str) assert isinstance(located_bitcode_lib.filename, str) assert isinstance(located_bitcode_lib.found_via, str) + assert located_bitcode_lib.sm_arch is None or isinstance( + located_bitcode_lib.sm_arch, + str, + ) assert located_bitcode_lib.found_via in ("site-packages", "conda", "CUDA_PATH") assert os.path.isfile(located_bitcode_lib.abs_path) @@ -89,6 +99,117 @@ def test_locate_bitcode_lib(info_summary_append, libname): assert os.path.basename(lib_path) == expected_filename +@pytest.mark.skipif( + "nvshmem_device" not in SUPPORTED_BITCODE_LIBS, + reason="NVSHMEM bitcode is not supported", +) +@pytest.mark.usefixtures("clear_find_bitcode_lib_cache") +@pytest.mark.parametrize( + ("sm_arch", "expected_normalized_sm_arch"), + ( + ("sm90", "sm90"), + ("sm90a", "sm90a"), + ("sm_90", "sm90"), + ("sm_90a", "sm90a"), + ("90", "sm90"), + ("90a", "sm90a"), + ), +) +def test_locate_bitcode_lib_sm_arch_normalizes( + monkeypatch, + tmp_path, + sm_arch, + expected_normalized_sm_arch, +): + libname = "nvshmem_device" + site_packages_lib_dir = _site_packages_bitcode_lib_dir_under(tmp_path / "site-packages", libname) + expected_path = _make_bitcode_lib_file( + site_packages_lib_dir, + libname, + sm_arch=expected_normalized_sm_arch, + ) + + monkeypatch.setattr( + find_bitcode_lib_module, + "find_sub_dirs_all_sitepackages", + lambda _sub_dir: [str(site_packages_lib_dir)], + ) + monkeypatch.delenv("CONDA_PREFIX", raising=False) + monkeypatch.delenv("CUDA_HOME", raising=False) + monkeypatch.delenv("CUDA_PATH", raising=False) + + located_lib = locate_bitcode_lib(libname, sm_arch=sm_arch) + lib_path = find_bitcode_lib(libname, sm_arch=sm_arch) + + assert located_lib.abs_path == expected_path + assert located_lib.filename == f"libnvshmem_device_{expected_normalized_sm_arch}.bc" + assert located_lib.sm_arch == expected_normalized_sm_arch + assert lib_path == expected_path + + +@pytest.mark.skipif( + "nvshmem_device" not in SUPPORTED_BITCODE_LIBS, + reason="NVSHMEM bitcode is not supported", +) +@pytest.mark.usefixtures("clear_find_bitcode_lib_cache") +def test_locate_bitcode_lib_sm_arch_search_order(monkeypatch, tmp_path): + libname = "nvshmem_device" + sm_arch = "sm90" + site_packages_lib_dir = _site_packages_bitcode_lib_dir_under(tmp_path / "site-packages", libname) + site_packages_path = _make_bitcode_lib_file( + site_packages_lib_dir, + libname, + sm_arch=sm_arch, + ) + + conda_prefix = tmp_path / "conda-prefix" + conda_path = _make_bitcode_lib_file( + _bitcode_lib_dir_under(_conda_anchor(conda_prefix), libname), + libname, + sm_arch=sm_arch, + ) + + cuda_home = tmp_path / "cuda-home" + cuda_home_path = _make_bitcode_lib_file( + _bitcode_lib_dir_under(cuda_home, libname), + libname, + sm_arch=sm_arch, + ) + + site_packages_sub_dirs = tuple( + tuple(rel_dir.split("/")) for rel_dir in _bitcode_lib_info(libname)["site_packages_dirs"] + ) + + def find_expected_sub_dir(sub_dir): + assert sub_dir in site_packages_sub_dirs + if sub_dir == site_packages_sub_dirs[0]: + return [str(site_packages_lib_dir)] + return [] + + monkeypatch.setattr( + find_bitcode_lib_module, + "find_sub_dirs_all_sitepackages", + find_expected_sub_dir, + ) + monkeypatch.setenv("CONDA_PREFIX", str(conda_prefix)) + monkeypatch.setenv("CUDA_HOME", str(cuda_home)) + monkeypatch.delenv("CUDA_PATH", raising=False) + + located_lib = locate_bitcode_lib(libname, sm_arch=sm_arch) + assert located_lib.abs_path == site_packages_path + assert located_lib.found_via == "site-packages" + os.remove(site_packages_path) + + located_lib = locate_bitcode_lib(libname, sm_arch=sm_arch) + assert located_lib.abs_path == conda_path + assert located_lib.found_via == "conda" + os.remove(conda_path) + + located_lib = locate_bitcode_lib(libname, sm_arch=sm_arch) + assert located_lib.abs_path == cuda_home_path + assert located_lib.found_via == "CUDA_PATH" + + @pytest.mark.usefixtures("clear_find_bitcode_lib_cache") @pytest.mark.parametrize("libname", SUPPORTED_BITCODE_LIBS) def test_locate_bitcode_lib_search_order(monkeypatch, tmp_path, libname): @@ -183,3 +304,17 @@ def test_find_bitcode_lib_not_found_error_without_cuda_home(monkeypatch): def test_find_bitcode_lib_invalid_name(): with pytest.raises(ValueError, match="Unknown bitcode library"): find_bitcode_lib_module.locate_bitcode_lib("invalid") + + +@pytest.mark.parametrize( + "sm_arch", + ("", "sm", "sm_90_blah", "gfx90", "90.0", "sm9000", "sm90A", 90), +) +def test_find_bitcode_lib_invalid_sm_arch(sm_arch): + with pytest.raises(ValueError, match="Invalid sm_arch value"): + find_bitcode_lib_module.locate_bitcode_lib("nvshmem_device", sm_arch=sm_arch) + + +def test_find_bitcode_lib_sm_arch_unsupported_for_library(): + with pytest.raises(ValueError, match="does not support sm_arch"): + find_bitcode_lib_module.locate_bitcode_lib("device", sm_arch="sm90")