Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 54 additions & 8 deletions cuda_core/cuda/core/_memory/_virtual_memory_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,17 @@ def _grow_allocation_fast_path(
)
res, new_handle = driver.cuMemCreate(aligned_additional_size, prop, 0)
raise_if_driver_error(res)
# Register undo for creation
trans.append(lambda h=new_handle: raise_if_driver_error(driver.cuMemRelease(h)[0]))
new_handle_released = False

def _release_new_handle() -> None:
nonlocal new_handle_released
if not new_handle_released:
raise_if_driver_error(driver.cuMemRelease(new_handle)[0])
new_handle_released = True

# Register undo for creation. Callback is conditional to avoid
# double-release after an explicit successful release.
trans.append(_release_new_handle)

# Map the new physical memory to the extended VA range
(res,) = driver.cuMemMap(new_ptr, aligned_additional_size, 0, new_handle, 0)
Expand All @@ -339,6 +348,9 @@ def _grow_allocation_fast_path(
(res,) = driver.cuMemSetAccess(new_ptr, aligned_additional_size, descs, len(descs))
raise_if_driver_error(res)

# Release handle ownership now that mapping is stable.
_release_new_handle()

# All succeeded, cancel undo actions
trans.commit()

Expand Down Expand Up @@ -389,8 +401,17 @@ def _grow_allocation_slow_path(
# Get the old allocation handle for remapping
result, old_handle = driver.cuMemRetainAllocationHandle(buf.handle)
raise_if_driver_error(result)
# Register undo for old_handle
trans.append(lambda h=old_handle: raise_if_driver_error(driver.cuMemRelease(h)[0]))
old_handle_released = False

def _release_old_handle() -> None:
nonlocal old_handle_released
if not old_handle_released:
raise_if_driver_error(driver.cuMemRelease(old_handle)[0])
old_handle_released = True

# Register undo for old handle. Callback is conditional to avoid
# double-release after explicit success.
trans.append(_release_old_handle)

# Unmap the old VA range (aligned previous size)
aligned_prev_size = total_aligned_size - aligned_additional_size
Expand Down Expand Up @@ -419,8 +440,17 @@ def _remap_old() -> None:
res, new_handle = driver.cuMemCreate(aligned_additional_size, prop, 0)
raise_if_driver_error(res)

# Register undo for new physical memory
trans.append(lambda h=new_handle: raise_if_driver_error(driver.cuMemRelease(h)[0]))
new_handle_released = False

def _release_new_handle() -> None:
nonlocal new_handle_released
if not new_handle_released:
raise_if_driver_error(driver.cuMemRelease(new_handle)[0])
new_handle_released = True

# Register undo for new physical memory. Callback is conditional to
# avoid double-release after explicit success.
trans.append(_release_new_handle)

# Map the new physical memory to the extended portion (aligned offset)
(res,) = driver.cuMemMap(int(new_ptr) + aligned_prev_size, aligned_additional_size, 0, new_handle, 0)
Expand All @@ -439,6 +469,10 @@ def _remap_old() -> None:
(res,) = driver.cuMemSetAccess(new_ptr, total_aligned_size, descs, len(descs))
raise_if_driver_error(res)

# Release handles once all operations that need them have completed.
_release_new_handle()
_release_old_handle()

# All succeeded, cancel undo actions
trans.commit()

Expand Down Expand Up @@ -542,8 +576,17 @@ def allocate(self, size: int, *, stream: Stream | GraphBuilder | None = None) ->
# ---- Create physical memory ----
res, handle = driver.cuMemCreate(aligned_size, prop, 0)
raise_if_driver_error(res)
# Register undo for physical memory
trans.append(lambda h=handle: raise_if_driver_error(driver.cuMemRelease(h)[0]))
handle_released = False

def _release_handle() -> None:
nonlocal handle_released
if not handle_released:
raise_if_driver_error(driver.cuMemRelease(handle)[0])
handle_released = True

# Register undo for physical memory. Callback is conditional to
# avoid double-release after explicit success.
trans.append(_release_handle)

# ---- Reserve VA space ----
# Potentially, use a separate size for the VA reservation from the physical allocation size
Expand All @@ -563,6 +606,9 @@ def allocate(self, size: int, *, stream: Stream | GraphBuilder | None = None) ->
(res,) = driver.cuMemSetAccess(ptr, aligned_size, descs, len(descs))
raise_if_driver_error(res)

# Release handle ownership once map+access setup succeeded.
_release_handle()

trans.commit()

# Done — return a Buffer that tracks this VA range
Expand Down
Loading
Loading