diff --git a/cuda_bindings/tests/conftest.py b/cuda_bindings/tests/conftest.py index f30500c134..d0267ae51e 100644 --- a/cuda_bindings/tests/conftest.py +++ b/cuda_bindings/tests/conftest.py @@ -1,8 +1,11 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE +import functools +import inspect import pathlib import sys +from contextlib import contextmanager from importlib.metadata import PackageNotFoundError, distribution import pytest @@ -25,6 +28,74 @@ sys.path.insert(0, test_helpers_root) +def pytest_configure(config): + # When using `parallel-threads` set up mini-plugin to ensure each thread has a CUDA context + parallel_threads = getattr(config.option, "parallel_threads", 0) + if parallel_threads == "auto" or int(parallel_threads) > 1: + config.pluginmanager.register(_CudaBindingsParallelPlugin(), name="_cuda_bindings_parallel_plugin") + + +@contextmanager +def _thread_context(): + # Context setting up `device` and `ctx` for individual threads on + # pytest-run-parallel + err, device = cuda.cuDeviceGet(0) + assert err == cuda.CUresult.CUDA_SUCCESS + err, ctx = cuda.cuCtxCreate(None, 0, device) + assert err == cuda.CUresult.CUDA_SUCCESS + try: + yield device, ctx + finally: + (err,) = cuda.cuCtxDestroy(ctx) + assert err == cuda.CUresult.CUDA_SUCCESS + + +def _wrap_worker_cuda_test(func): + if getattr(func, "_cuda_bindings_worker_cuda_wrapped", False): + return func + + sig = inspect.signature(func) + wants_device = "device" in sig.parameters + wants_ctx = "ctx" in sig.parameters + + @functools.wraps(func) + def wrapper(*args, **kwargs): + with _thread_context() as (device, ctx): + # device is None when reusing an existing context (defensive path); + # keep whatever the fixture provided in kwargs as-is. + if wants_device and device is not None: + kwargs["device"] = device + if wants_ctx: + kwargs["ctx"] = ctx + return func(*args, **kwargs) + + wrapper._cuda_bindings_worker_cuda_wrapped = True + return wrapper + + +def _item_needs_thread_ctx(item): + fixturenames = getattr(item, "fixturenames", ()) + # The 'device' fixture is the main fixture to set up a CUDA context. + # 'driver' is specific to the cufile tests and used there instead. + return "device" in fixturenames or "driver" in fixturenames + + +class _CudaBindingsParallelPlugin: + """A mini pytest plugin used only for pytest-run-parallel testing. + pytest-run-parallel spawns new threads for each test and we need to + initialize and pass the correct CUDA context for each these. + + This plugin looks for context specific fixtures and replaces them + new context specific fixtures may have to be added. + """ + + @pytest.hookimpl() + def pytest_collection_modifyitems(self, config, items): + for item in items: + if _item_needs_thread_ctx(item): + item.obj = _wrap_worker_cuda_test(item.obj) + + @pytest.fixture(scope="module") def cuda_driver(): (err,) = cuda.cuInit(0) diff --git a/cuda_core/tests/conftest.py b/cuda_core/tests/conftest.py index d7a81d8890..6c42ff44f2 100644 --- a/cuda_core/tests/conftest.py +++ b/cuda_core/tests/conftest.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import functools import multiprocessing import os import pathlib @@ -91,6 +92,70 @@ def xfail_if_mempool_oom(err_or_exc, api_name=None, device=0): sys.path.insert(0, test_helpers_root) +def pytest_configure(config): + # When using `parallel-threads` set up mini-plugin to ensure each thread has a CUDA context + parallel_threads = getattr(config.option, "parallel_threads", 0) + if parallel_threads == "auto" or int(parallel_threads) > 1: + config.pluginmanager.register(_CudaCoreParallelPlugin(), name="_cuda_core_parallel_plugin") + + +@contextmanager +def _init_cuda_context(): + # TODO: rename this to e.g. init_context + device = Device(0) + device.set_current() + + # Set option to avoid spin-waiting on synchronization. + if int(os.environ.get("CUDA_CORE_TEST_BLOCKING_SYNC", 0)) != 0: + handle_return( + driver.cuDevicePrimaryCtxSetFlags(device.device_id, driver.CUctx_flags.CU_CTX_SCHED_BLOCKING_SYNC) + ) + + try: + yield device + finally: + _ = _device_unset_current() + + +def _wrap_worker_cuda_test(func): + if getattr(func, "_cuda_core_worker_cuda_wrapped", False): + return func + + @functools.wraps(func) + def wrapper(*args, **kwargs): + with _init_cuda_context() as device: + if "init_cuda" in kwargs: + kwargs["init_cuda"] = device + if "mempool_device_x2" in kwargs: + kwargs["mempool_device_x2"] = _mempool_device_impl(2) + if "mempool_device_x3" in kwargs: + kwargs["mempool_device_x3"] = _mempool_device_impl(3) + return func(*args, **kwargs) + + wrapper._cuda_core_worker_cuda_wrapped = True + return wrapper + + +def _item_uses_init_cuda(item): + return "init_cuda" in getattr(item, "fixturenames", ()) + + +class _CudaCoreParallelPlugin: + """A mini pytest plugin used only for pytest-run-parallel testing. + pytest-run-parallel spawns new threads for each test and we need to + initialize and pass the correct CUDA context for each these. + + This plugin looks for context specific fixtures and replaces them + new context specific fixtures may have to be added. + """ + + @pytest.hookimpl(tryfirst=True) + def pytest_collection_modifyitems(self, config, items): + for item in items: + if _item_uses_init_cuda(item): + item.obj = _wrap_worker_cuda_test(item.obj) + + def skip_if_pinned_memory_unsupported(device): try: if not device.properties.host_memory_pools_supported: @@ -194,18 +259,8 @@ def session_setup(): @pytest.fixture def init_cuda(): - # TODO: rename this to e.g. init_context - device = Device(0) - device.set_current() - - # Set option to avoid spin-waiting on synchronization. - if int(os.environ.get("CUDA_CORE_TEST_BLOCKING_SYNC", 0)) != 0: - handle_return( - driver.cuDevicePrimaryCtxSetFlags(device.device_id, driver.CUctx_flags.CU_CTX_SCHED_BLOCKING_SYNC) - ) - - yield device - _ = _device_unset_current() + with _init_cuda_context() as device: + yield device def _device_unset_current() -> bool: @@ -247,7 +302,7 @@ def pop_all_contexts(): @pytest.fixture -def ipc_device(): +def ipc_device(init_cuda): """Obtains a device suitable for IPC-enabled mempool tests, or skips. The fixture also tracks every ``multiprocessing.Process`` spawned during @@ -257,8 +312,7 @@ def ipc_device(): """ from helpers.child_processes import track_child_processes - device = Device(0) - device.set_current() + device = init_cuda if not device.properties.memory_pools_supported: pytest.skip("Device does not support mempool operations") @@ -293,10 +347,9 @@ def ipc_memory_resource(request, ipc_device): @pytest.fixture -def mempool_device(): +def mempool_device(init_cuda): """Obtains a device suitable for mempool tests, or skips.""" - device = Device(0) - device.set_current() + device = init_cuda if not device.properties.memory_pools_supported: pytest.skip("Device does not support mempool operations") @@ -323,13 +376,13 @@ def _mempool_device_impl(num): @pytest.fixture -def mempool_device_x2(): +def mempool_device_x2(init_cuda): """Fixture that provides two devices if available, otherwise skips test.""" return _mempool_device_impl(2) @pytest.fixture -def mempool_device_x3(): +def mempool_device_x3(init_cuda): """Fixture that provides three devices if available, otherwise skips test.""" return _mempool_device_impl(3)