Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion cuda_bindings/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

import functools
import inspect
import pathlib
import sys
from contextlib import contextmanager
from importlib.metadata import PackageNotFoundError, distribution

import pytest
Expand All @@ -25,6 +28,74 @@
sys.path.insert(0, test_helpers_root)


def pytest_configure(config):
# When using `parallel-threads` set up mini-plugin to ensure each thread has a CUDA context
parallel_threads = getattr(config.option, "parallel_threads", 0)
if parallel_threads == "auto" or int(parallel_threads) > 1:
config.pluginmanager.register(_CudaBindingsParallelPlugin(), name="_cuda_bindings_parallel_plugin")


@contextmanager
def _thread_context():
# Context setting up `device` and `ctx` for individual threads on
# pytest-run-parallel
err, device = cuda.cuDeviceGet(0)
assert err == cuda.CUresult.CUDA_SUCCESS
err, ctx = cuda.cuCtxCreate(None, 0, device)
assert err == cuda.CUresult.CUDA_SUCCESS
try:
yield device, ctx
finally:
(err,) = cuda.cuCtxDestroy(ctx)
assert err == cuda.CUresult.CUDA_SUCCESS


def _wrap_worker_cuda_test(func):
if getattr(func, "_cuda_bindings_worker_cuda_wrapped", False):
return func

sig = inspect.signature(func)
wants_device = "device" in sig.parameters
wants_ctx = "ctx" in sig.parameters

@functools.wraps(func)
def wrapper(*args, **kwargs):
with _thread_context() as (device, ctx):
# device is None when reusing an existing context (defensive path);
# keep whatever the fixture provided in kwargs as-is.
if wants_device and device is not None:
kwargs["device"] = device
if wants_ctx:
kwargs["ctx"] = ctx
return func(*args, **kwargs)

wrapper._cuda_bindings_worker_cuda_wrapped = True
return wrapper


def _item_needs_thread_ctx(item):
fixturenames = getattr(item, "fixturenames", ())
# The 'device' fixture is the main fixture to set up a CUDA context.
# 'driver' is specific to the cufile tests and used there instead.
return "device" in fixturenames or "driver" in fixturenames


class _CudaBindingsParallelPlugin:
"""A mini pytest plugin used only for pytest-run-parallel testing.
pytest-run-parallel spawns new threads for each test and we need to
initialize and pass the correct CUDA context for each these.

This plugin looks for context specific fixtures and replaces them
new context specific fixtures may have to be added.
"""

@pytest.hookimpl()
def pytest_collection_modifyitems(self, config, items):
for item in items:
if _item_needs_thread_ctx(item):
item.obj = _wrap_worker_cuda_test(item.obj)


@pytest.fixture(scope="module")
def cuda_driver():
(err,) = cuda.cuInit(0)
Expand Down
93 changes: 73 additions & 20 deletions cuda_core/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import functools
import multiprocessing
import os
import pathlib
Expand Down Expand Up @@ -91,6 +92,70 @@ def xfail_if_mempool_oom(err_or_exc, api_name=None, device=0):
sys.path.insert(0, test_helpers_root)


def pytest_configure(config):
# When using `parallel-threads` set up mini-plugin to ensure each thread has a CUDA context
parallel_threads = getattr(config.option, "parallel_threads", 0)
if parallel_threads == "auto" or int(parallel_threads) > 1:
config.pluginmanager.register(_CudaCoreParallelPlugin(), name="_cuda_core_parallel_plugin")


@contextmanager
def _init_cuda_context():
# TODO: rename this to e.g. init_context
device = Device(0)
device.set_current()

# Set option to avoid spin-waiting on synchronization.
if int(os.environ.get("CUDA_CORE_TEST_BLOCKING_SYNC", 0)) != 0:
handle_return(
driver.cuDevicePrimaryCtxSetFlags(device.device_id, driver.CUctx_flags.CU_CTX_SCHED_BLOCKING_SYNC)
)

try:
yield device
finally:
_ = _device_unset_current()


def _wrap_worker_cuda_test(func):
if getattr(func, "_cuda_core_worker_cuda_wrapped", False):
return func

@functools.wraps(func)
def wrapper(*args, **kwargs):
with _init_cuda_context() as device:
if "init_cuda" in kwargs:
kwargs["init_cuda"] = device
if "mempool_device_x2" in kwargs:
kwargs["mempool_device_x2"] = _mempool_device_impl(2)
if "mempool_device_x3" in kwargs:
kwargs["mempool_device_x3"] = _mempool_device_impl(3)
return func(*args, **kwargs)

wrapper._cuda_core_worker_cuda_wrapped = True
return wrapper


def _item_uses_init_cuda(item):
return "init_cuda" in getattr(item, "fixturenames", ())


class _CudaCoreParallelPlugin:
"""A mini pytest plugin used only for pytest-run-parallel testing.
pytest-run-parallel spawns new threads for each test and we need to
initialize and pass the correct CUDA context for each these.

This plugin looks for context specific fixtures and replaces them
new context specific fixtures may have to be added.
"""

@pytest.hookimpl(tryfirst=True)
def pytest_collection_modifyitems(self, config, items):
for item in items:
if _item_uses_init_cuda(item):
item.obj = _wrap_worker_cuda_test(item.obj)


def skip_if_pinned_memory_unsupported(device):
try:
if not device.properties.host_memory_pools_supported:
Expand Down Expand Up @@ -194,18 +259,8 @@ def session_setup():

@pytest.fixture
def init_cuda():
# TODO: rename this to e.g. init_context
device = Device(0)
device.set_current()

# Set option to avoid spin-waiting on synchronization.
if int(os.environ.get("CUDA_CORE_TEST_BLOCKING_SYNC", 0)) != 0:
handle_return(
driver.cuDevicePrimaryCtxSetFlags(device.device_id, driver.CUctx_flags.CU_CTX_SCHED_BLOCKING_SYNC)
)

yield device
_ = _device_unset_current()
with _init_cuda_context() as device:
yield device


def _device_unset_current() -> bool:
Expand Down Expand Up @@ -247,7 +302,7 @@ def pop_all_contexts():


@pytest.fixture
def ipc_device():
def ipc_device(init_cuda):
"""Obtains a device suitable for IPC-enabled mempool tests, or skips.

The fixture also tracks every ``multiprocessing.Process`` spawned during
Expand All @@ -257,8 +312,7 @@ def ipc_device():
"""
from helpers.child_processes import track_child_processes

device = Device(0)
device.set_current()
device = init_cuda

if not device.properties.memory_pools_supported:
pytest.skip("Device does not support mempool operations")
Expand Down Expand Up @@ -293,10 +347,9 @@ def ipc_memory_resource(request, ipc_device):


@pytest.fixture
def mempool_device():
def mempool_device(init_cuda):
"""Obtains a device suitable for mempool tests, or skips."""
device = Device(0)
device.set_current()
device = init_cuda

if not device.properties.memory_pools_supported:
pytest.skip("Device does not support mempool operations")
Expand All @@ -323,13 +376,13 @@ def _mempool_device_impl(num):


@pytest.fixture
def mempool_device_x2():
def mempool_device_x2(init_cuda):
"""Fixture that provides two devices if available, otherwise skips test."""
return _mempool_device_impl(2)


@pytest.fixture
def mempool_device_x3():
def mempool_device_x3(init_cuda):
"""Fixture that provides three devices if available, otherwise skips test."""
return _mempool_device_impl(3)

Expand Down