Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 53 additions & 6 deletions cuda_pathfinder/cuda/pathfinder/_static_libs/find_bitcode_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import functools
import os
import re
from dataclasses import dataclass
from typing import NoReturn, TypedDict

Expand All @@ -23,10 +24,12 @@ class LocatedBitcodeLib:
abs_path: str
filename: str
found_via: str
sm_arch: str | None = None


class _BitcodeLibInfo(TypedDict):
filename: str
arch_specific_filename_template: str | None
rel_path: str
site_packages_dirs: tuple[str, ...]
available_on_windows: bool
Expand All @@ -35,6 +38,7 @@ class _BitcodeLibInfo(TypedDict):
_SUPPORTED_BITCODE_LIBS_INFO: dict[str, _BitcodeLibInfo] = {
"device": {
"filename": "libdevice.10.bc",
"arch_specific_filename_template": None,
"rel_path": os.path.join("nvvm", "libdevice"),
"site_packages_dirs": (
"nvidia/cu13/nvvm/libdevice",
Expand All @@ -44,12 +48,14 @@ class _BitcodeLibInfo(TypedDict):
},
"nccl_device": {
"filename": "libnccl_device.bc",
"arch_specific_filename_template": None,
"rel_path": "lib",
"site_packages_dirs": ("nvidia/nccl/lib",),
"available_on_windows": False,
},
"nvshmem_device": {
"filename": "libnvshmem_device.bc",
"arch_specific_filename_template": "libnvshmem_device_{sm_arch}.bc",
"rel_path": "lib",
"site_packages_dirs": ("nvidia/nvshmem/lib",),
"available_on_windows": False,
Expand All @@ -64,6 +70,23 @@ class _BitcodeLibInfo(TypedDict):
)


def _normalize_sm_arch(sm_arch: str | None) -> str | None:
if sm_arch is None:
return None

if not isinstance(sm_arch, str):
raise ValueError(
"Invalid sm_arch value. Expected None or an NVIDIA SM architecture like 'sm90', 'sm90a', 'sm_90', or '90'."
)

match = re.fullmatch(r"(?:sm_?)?([0-9]{2,3}[a-z]?)", sm_arch)
if match is None:
raise ValueError(
"Invalid sm_arch value. Expected None or an NVIDIA SM architecture like 'sm90', 'sm90a', 'sm_90', or '90'."
)
return f"sm{match.group(1)}"


def _no_such_file_in_dir(dir_path: str, filename: str, error_messages: list[str], attachments: list[str]) -> None:
error_messages.append(f"No such file: {os.path.join(dir_path, filename)}")
if os.path.isdir(dir_path):
Expand All @@ -75,12 +98,19 @@ def _no_such_file_in_dir(dir_path: str, filename: str, error_messages: list[str]


class _FindBitcodeLib:
def __init__(self, name: str) -> None:
def __init__(self, name: str, sm_arch: str | None = None) -> None:
if name not in _SUPPORTED_BITCODE_LIBS_INFO: # Updated reference
raise ValueError(f"Unknown bitcode library: '{name}'. Supported: {', '.join(SUPPORTED_BITCODE_LIBS)}")
self.name: str = name
self.config: _BitcodeLibInfo = _SUPPORTED_BITCODE_LIBS_INFO[name] # Updated reference
self.filename: str = self.config["filename"]
self.sm_arch: str | None = _normalize_sm_arch(sm_arch)
if self.sm_arch is None:
self.filename: str = self.config["filename"]
else:
template = self.config["arch_specific_filename_template"]
if template is None:
raise ValueError(f"Bitcode library '{name}' does not support sm_arch lookup")
self.filename = template.format(sm_arch=self.sm_arch)
self.rel_path: str = self.config["rel_path"]
self.site_packages_dirs: tuple[str, ...] = self.config["site_packages_dirs"]
self.error_messages: list[str] = []
Expand Down Expand Up @@ -130,14 +160,21 @@ def raise_not_found_error(self) -> NoReturn:
raise BitcodeLibNotFoundError(f'Failure finding "{self.filename}": {err}\n{att}')


def locate_bitcode_lib(name: str) -> LocatedBitcodeLib:
def locate_bitcode_lib(name: str, sm_arch: str | None = None) -> LocatedBitcodeLib:
"""Locate a bitcode library by name.

Args:
name: Supported bitcode library name.
sm_arch: Optional NVIDIA SM architecture for arch-specific bitcode
libraries. Accepted forms include ``"sm90"``, ``"sm90a"``,
``"sm_90"``, and ``"90"``.

Raises:
ValueError: If ``name`` is not a supported bitcode library.
ValueError: If ``sm_arch`` is invalid or unsupported for ``name``.
BitcodeLibNotFoundError: If the bitcode library cannot be found.
"""
finder = _FindBitcodeLib(name)
finder = _FindBitcodeLib(name, sm_arch=sm_arch)

abs_path = finder.try_site_packages()
if abs_path is not None:
Expand All @@ -146,6 +183,7 @@ def locate_bitcode_lib(name: str) -> LocatedBitcodeLib:
abs_path=abs_path,
filename=finder.filename,
found_via="site-packages",
sm_arch=finder.sm_arch,
)

abs_path = finder.try_with_conda_prefix()
Expand All @@ -155,6 +193,7 @@ def locate_bitcode_lib(name: str) -> LocatedBitcodeLib:
abs_path=abs_path,
filename=finder.filename,
found_via="conda",
sm_arch=finder.sm_arch,
)

abs_path = finder.try_with_cuda_home()
Expand All @@ -164,17 +203,25 @@ def locate_bitcode_lib(name: str) -> LocatedBitcodeLib:
abs_path=abs_path,
filename=finder.filename,
found_via="CUDA_PATH",
sm_arch=finder.sm_arch,
)

finder.raise_not_found_error()


@functools.cache
def find_bitcode_lib(name: str) -> str:
def find_bitcode_lib(name: str, sm_arch: str | None = None) -> str:
"""Find the absolute path to a bitcode library.

Args:
name: Supported bitcode library name.
sm_arch: Optional NVIDIA SM architecture for arch-specific bitcode
libraries. Accepted forms include ``"sm90"``, ``"sm90a"``,
``"sm_90"``, and ``"90"``.

Raises:
ValueError: If ``name`` is not a supported bitcode library.
ValueError: If ``sm_arch`` is invalid or unsupported for ``name``.
BitcodeLibNotFoundError: If the bitcode library cannot be found.
"""
return locate_bitcode_lib(name).abs_path
return locate_bitcode_lib(name, sm_arch=sm_arch).abs_path
143 changes: 139 additions & 4 deletions cuda_pathfinder/tests/test_find_bitcode_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,14 @@ def _bitcode_lib_info(libname: str):
return find_bitcode_lib_module._SUPPORTED_BITCODE_LIBS_INFO[libname]


def _bitcode_lib_filename(libname: str) -> str:
return _bitcode_lib_info(libname)["filename"]
def _bitcode_lib_filename(libname: str, sm_arch: str | None = None) -> str:
info = _bitcode_lib_info(libname)
normalized_sm_arch = find_bitcode_lib_module._normalize_sm_arch(sm_arch)
if normalized_sm_arch is None:
return info["filename"]
template = info["arch_specific_filename_template"]
assert template is not None
return template.format(sm_arch=normalized_sm_arch)


@pytest.fixture
Expand All @@ -36,9 +42,9 @@ def clear_find_bitcode_lib_cache():
get_cuda_path_or_home.cache_clear()


def _make_bitcode_lib_file(dir_path: Path, libname: str) -> str:
def _make_bitcode_lib_file(dir_path: Path, libname: str, sm_arch: str | None = None) -> str:
dir_path.mkdir(parents=True, exist_ok=True)
file_path = dir_path / _bitcode_lib_filename(libname)
file_path = dir_path / _bitcode_lib_filename(libname, sm_arch=sm_arch)
file_path.touch()
return str(file_path)

Expand All @@ -65,6 +71,10 @@ def _located_bitcode_lib_asserts(located_bitcode_lib):
assert isinstance(located_bitcode_lib.abs_path, str)
assert isinstance(located_bitcode_lib.filename, str)
assert isinstance(located_bitcode_lib.found_via, str)
assert located_bitcode_lib.sm_arch is None or isinstance(
located_bitcode_lib.sm_arch,
str,
)
assert located_bitcode_lib.found_via in ("site-packages", "conda", "CUDA_PATH")
assert os.path.isfile(located_bitcode_lib.abs_path)

Expand All @@ -89,6 +99,117 @@ def test_locate_bitcode_lib(info_summary_append, libname):
assert os.path.basename(lib_path) == expected_filename


@pytest.mark.skipif(
"nvshmem_device" not in SUPPORTED_BITCODE_LIBS,
reason="NVSHMEM bitcode is not supported",
)
@pytest.mark.usefixtures("clear_find_bitcode_lib_cache")
@pytest.mark.parametrize(
("sm_arch", "expected_normalized_sm_arch"),
(
("sm90", "sm90"),
("sm90a", "sm90a"),
("sm_90", "sm90"),
("sm_90a", "sm90a"),
("90", "sm90"),
("90a", "sm90a"),
),
)
def test_locate_bitcode_lib_sm_arch_normalizes(
monkeypatch,
tmp_path,
sm_arch,
expected_normalized_sm_arch,
):
libname = "nvshmem_device"
site_packages_lib_dir = _site_packages_bitcode_lib_dir_under(tmp_path / "site-packages", libname)
expected_path = _make_bitcode_lib_file(
site_packages_lib_dir,
libname,
sm_arch=expected_normalized_sm_arch,
)

monkeypatch.setattr(
find_bitcode_lib_module,
"find_sub_dirs_all_sitepackages",
lambda _sub_dir: [str(site_packages_lib_dir)],
)
monkeypatch.delenv("CONDA_PREFIX", raising=False)
monkeypatch.delenv("CUDA_HOME", raising=False)
monkeypatch.delenv("CUDA_PATH", raising=False)

located_lib = locate_bitcode_lib(libname, sm_arch=sm_arch)
lib_path = find_bitcode_lib(libname, sm_arch=sm_arch)

assert located_lib.abs_path == expected_path
assert located_lib.filename == f"libnvshmem_device_{expected_normalized_sm_arch}.bc"
assert located_lib.sm_arch == expected_normalized_sm_arch
assert lib_path == expected_path


@pytest.mark.skipif(
"nvshmem_device" not in SUPPORTED_BITCODE_LIBS,
reason="NVSHMEM bitcode is not supported",
)
@pytest.mark.usefixtures("clear_find_bitcode_lib_cache")
def test_locate_bitcode_lib_sm_arch_search_order(monkeypatch, tmp_path):
libname = "nvshmem_device"
sm_arch = "sm90"
site_packages_lib_dir = _site_packages_bitcode_lib_dir_under(tmp_path / "site-packages", libname)
site_packages_path = _make_bitcode_lib_file(
site_packages_lib_dir,
libname,
sm_arch=sm_arch,
)

conda_prefix = tmp_path / "conda-prefix"
conda_path = _make_bitcode_lib_file(
_bitcode_lib_dir_under(_conda_anchor(conda_prefix), libname),
libname,
sm_arch=sm_arch,
)

cuda_home = tmp_path / "cuda-home"
cuda_home_path = _make_bitcode_lib_file(
_bitcode_lib_dir_under(cuda_home, libname),
libname,
sm_arch=sm_arch,
)

site_packages_sub_dirs = tuple(
tuple(rel_dir.split("/")) for rel_dir in _bitcode_lib_info(libname)["site_packages_dirs"]
)

def find_expected_sub_dir(sub_dir):
assert sub_dir in site_packages_sub_dirs
if sub_dir == site_packages_sub_dirs[0]:
return [str(site_packages_lib_dir)]
return []

monkeypatch.setattr(
find_bitcode_lib_module,
"find_sub_dirs_all_sitepackages",
find_expected_sub_dir,
)
monkeypatch.setenv("CONDA_PREFIX", str(conda_prefix))
monkeypatch.setenv("CUDA_HOME", str(cuda_home))
monkeypatch.delenv("CUDA_PATH", raising=False)

located_lib = locate_bitcode_lib(libname, sm_arch=sm_arch)
assert located_lib.abs_path == site_packages_path
assert located_lib.found_via == "site-packages"
os.remove(site_packages_path)

located_lib = locate_bitcode_lib(libname, sm_arch=sm_arch)
assert located_lib.abs_path == conda_path
assert located_lib.found_via == "conda"
os.remove(conda_path)

located_lib = locate_bitcode_lib(libname, sm_arch=sm_arch)
assert located_lib.abs_path == cuda_home_path
assert located_lib.found_via == "CUDA_PATH"


@pytest.mark.usefixtures("clear_find_bitcode_lib_cache")
@pytest.mark.parametrize("libname", SUPPORTED_BITCODE_LIBS)
def test_locate_bitcode_lib_search_order(monkeypatch, tmp_path, libname):
Expand Down Expand Up @@ -183,3 +304,17 @@ def test_find_bitcode_lib_not_found_error_without_cuda_home(monkeypatch):
def test_find_bitcode_lib_invalid_name():
with pytest.raises(ValueError, match="Unknown bitcode library"):
find_bitcode_lib_module.locate_bitcode_lib("invalid")


@pytest.mark.parametrize(
"sm_arch",
("", "sm", "sm_90_blah", "gfx90", "90.0", "sm9000", "sm90A", 90),
)
def test_find_bitcode_lib_invalid_sm_arch(sm_arch):
with pytest.raises(ValueError, match="Invalid sm_arch value"):
find_bitcode_lib_module.locate_bitcode_lib("nvshmem_device", sm_arch=sm_arch)


def test_find_bitcode_lib_sm_arch_unsupported_for_library():
with pytest.raises(ValueError, match="does not support sm_arch"):
find_bitcode_lib_module.locate_bitcode_lib("device", sm_arch="sm90")