Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions cuda_core/cuda/core/_memory/_virtual_memory_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,8 @@ def modify_allocation(
0,
)

if res != driver.CUresult.CUDA_SUCCESS or new_ptr != (int(buf.handle) + aligned_prev_size):
expected_ptr = int(buf.handle) + aligned_prev_size
if res != driver.CUresult.CUDA_SUCCESS:
# Check for specific errors that are not recoverable with the slow path
if res in (
driver.CUresult.CUDA_ERROR_INVALID_VALUE,
Expand All @@ -274,15 +275,21 @@ def modify_allocation(
driver.CUresult.CUDA_ERROR_NOT_SUPPORTED,
):
raise_if_driver_error(res)
# Fallback: couldn't reserve contiguously, need full remapping
return self._grow_allocation_slow_path(
buf, new_size, prop, aligned_additional_size, total_aligned_size, addr_align
)

if new_ptr != expected_ptr:
(res2,) = driver.cuMemAddressFree(new_ptr, aligned_additional_size)
raise_if_driver_error(res2)
# Fallback: couldn't extend contiguously, need full remapping
return self._grow_allocation_slow_path(
buf, new_size, prop, aligned_additional_size, total_aligned_size, addr_align
)
else:
# Success! We can extend the VA range contiguously
return self._grow_allocation_fast_path(buf, new_size, prop, aligned_additional_size, new_ptr)

# Success! We can extend the VA range contiguously
return self._grow_allocation_fast_path(buf, new_size, prop, aligned_additional_size, new_ptr)

def _grow_allocation_fast_path(
self, buf: Buffer, new_size: int, prop: driver.CUmemAllocationProp, aligned_additional_size: int, new_ptr: int
Expand Down
101 changes: 101 additions & 0 deletions cuda_core/tests/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,107 @@ def __init__(self, size):
assert ("set_access", new_ptr, aligned_additional, 1) in calls


def _make_mock_vmm_resource():
vmm_mr = VirtualMemoryResource.__new__(VirtualMemoryResource)
vmm_mr.device = type("FakeDevice", (), {"device_id": 0})()
vmm_mr.config = VirtualMemoryResourceOptions(handle_type="win32_kmt" if IS_WINDOWS else "posix_fd")
return vmm_mr


def test_vmm_allocator_grow_allocation_does_not_free_failed_adjacent_reservation(monkeypatch):
vmm_mr = _make_mock_vmm_resource()

SUCCESS = driver.CUresult.CUDA_SUCCESS
ERROR = driver.CUresult.CUDA_ERROR_OUT_OF_MEMORY
base_ptr = 0x10_0000
old_size = 2048
new_size = 4096
granularity = 1024
stale_ptr = 0xBAD
calls = []

class FakeBuffer:
handle = base_ptr
size = old_size

def fake_get_allocation_granularity(_, _granularity_flag):
calls.append(("granularity",))
return (SUCCESS, granularity)

def fake_addr_reserve(size, align, hint, flags):
calls.append(("reserve", size, align, hint, flags))
return (ERROR, stale_ptr)

def fake_addr_free(ptr, size):
calls.append(("addr_free", ptr, size))
return (SUCCESS,)

def fake_slow_path(self, buf, result_size, prop, aligned_additional_size, total_aligned_size, addr_align):
calls.append(("slow_path", result_size, aligned_additional_size, total_aligned_size, addr_align))
return buf

monkeypatch.setattr(driver, "cuMemGetAllocationGranularity", fake_get_allocation_granularity)
monkeypatch.setattr(driver, "cuMemAddressReserve", fake_addr_reserve)
monkeypatch.setattr(driver, "cuMemAddressFree", fake_addr_free)
monkeypatch.setattr(VirtualMemoryResource, "_grow_allocation_slow_path", fake_slow_path)

result = vmm_mr.modify_allocation(FakeBuffer(), new_size)

assert isinstance(result, FakeBuffer)
assert calls == [
("granularity",),
("reserve", 2048, granularity, base_ptr + old_size, 0),
("slow_path", new_size, 2048, 4096, granularity),
]


def test_vmm_allocator_grow_allocation_frees_noncontiguous_adjacent_reservation(monkeypatch):
vmm_mr = _make_mock_vmm_resource()

SUCCESS = driver.CUresult.CUDA_SUCCESS
base_ptr = 0x10_0000
old_size = 2048
new_size = 4096
granularity = 1024
noncontiguous_ptr = base_ptr + 4 * granularity
calls = []

class FakeBuffer:
handle = base_ptr
size = old_size

def fake_get_allocation_granularity(_, _granularity_flag):
calls.append(("granularity",))
return (SUCCESS, granularity)

def fake_addr_reserve(size, align, hint, flags):
calls.append(("reserve", size, align, hint, flags))
return (SUCCESS, noncontiguous_ptr)

def fake_addr_free(ptr, size):
calls.append(("addr_free", ptr, size))
return (SUCCESS,)

def fake_slow_path(self, buf, result_size, prop, aligned_additional_size, total_aligned_size, addr_align):
calls.append(("slow_path", result_size, aligned_additional_size, total_aligned_size, addr_align))
return buf

monkeypatch.setattr(driver, "cuMemGetAllocationGranularity", fake_get_allocation_granularity)
monkeypatch.setattr(driver, "cuMemAddressReserve", fake_addr_reserve)
monkeypatch.setattr(driver, "cuMemAddressFree", fake_addr_free)
monkeypatch.setattr(VirtualMemoryResource, "_grow_allocation_slow_path", fake_slow_path)

result = vmm_mr.modify_allocation(FakeBuffer(), new_size)

assert isinstance(result, FakeBuffer)
assert calls == [
("granularity",),
("reserve", 2048, granularity, base_ptr + old_size, 0),
("addr_free", noncontiguous_ptr, 2048),
("slow_path", new_size, 2048, 4096, granularity),
]


def test_vmm_allocator_rdma_unsupported_exception():
"""Test that VirtualMemoryResource throws an exception when RDMA is requested but device doesn't support it.

Expand Down
Loading