Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion cuda_core/cuda/core/_device.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

Expand Down Expand Up @@ -1188,6 +1188,49 @@ class Device:
def __reduce__(self):
return Device, (self.device_id,)

def __enter__(self):
"""Set this device as current for the duration of the ``with`` block.

On exit, the previously current device is restored automatically.
Nested ``with`` blocks are supported and restore correctly at each
level.

Returns
-------
Device
This device instance.

Examples
--------
>>> from cuda.core import Device
>>> with Device(0) as dev0:
... buf = dev0.allocate(1024)

See Also
--------
set_current : Non-context-manager entry point.
"""
cdef cydriver.CUcontext prev_ctx
with nogil:
HANDLE_RETURN(cydriver.cuCtxGetCurrent(&prev_ctx))
if not hasattr(_tls, '_ctx_stack'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: the try/except AttributeError pattern (EAFP) is marginally faster after the first call per thread, and is the more common pattern elsewhere in _device.pyx (see Device_ensure_tls_devices):

try:
    _tls._ctx_stack
except AttributeError:
    _tls._ctx_stack = []

Not blocking — just a consistency suggestion.

_tls._ctx_stack = []
_tls._ctx_stack.append(<uintptr_t><void*>prev_ctx)
self.set_current()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
"""Restore the previously current device upon exiting the ``with`` block.

Exceptions are not suppressed.
"""
cdef uintptr_t prev_ctx_ptr = _tls._ctx_stack[-1]
cdef cydriver.CUcontext prev_ctx = <cydriver.CUcontext><void*>prev_ctx_ptr
with nogil:
HANDLE_RETURN(cydriver.cuCtxSetCurrent(prev_ctx))
_tls._ctx_stack.pop()
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If cuCtxSetCurrent raises here while an exception is already propagating from the with block, the CUDA error replaces the user's original exception. The original ends up in e.__context__ but is easy to miss.

Consider guarding:

def __exit__(self, exc_type, exc_val, exc_tb):
    cdef uintptr_t prev_ctx_ptr = _tls._ctx_stack.pop()
    cdef cydriver.CUcontext prev_ctx = <cydriver.CUcontext><void*>prev_ctx_ptr
    try:
        with nogil:
            HANDLE_RETURN(cydriver.cuCtxSetCurrent(prev_ctx))
    except Exception:
        if exc_type is None:
            raise  # no active exception, surface the CUDA error
        # else: swallow the restore failure to preserve the original exception;
        # the stack entry is already popped so the next __exit__ won't retry.
    return False

This also simplifies the peek-then-pop dance — just pop eagerly, since a failed cuCtxSetCurrent with a context obtained from cuCtxGetCurrent moments earlier is essentially unrecoverable anyway.

(_graphics.pyx has the same pattern, so this is a pre-existing issue there too, but worth getting right here.)


def set_current(self, ctx: Context = None) -> Context | None:
"""Set device to be used for GPU executions.

Expand Down
6 changes: 4 additions & 2 deletions cuda_core/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,15 @@ def _mempool_device_impl(num):
@pytest.fixture
def mempool_device_x2():
"""Fixture that provides two devices if available, otherwise skips test."""
return _mempool_device_impl(2)
yield _mempool_device_impl(2)
_device_unset_current()


@pytest.fixture
def mempool_device_x3():
"""Fixture that provides three devices if available, otherwise skips test."""
return _mempool_device_impl(3)
yield _mempool_device_impl(3)
_device_unset_current()


@pytest.fixture(
Expand Down
176 changes: 176 additions & 0 deletions cuda_core/tests/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,3 +436,179 @@ def test_device_set_membership(init_cuda):
# Same device_id should not add duplicate
device_set.add(dev0_b)
assert len(device_set) == 1, "Should not add duplicate device"


# ============================================================================
# Device Context Manager Tests
# ============================================================================
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: consider putting _get_current_context() in conftest.py — it's generally useful for any context-related test, and other test files may want it as multi-device / context-switching tests grow.



def _get_current_context():
"""Return the current CUcontext as an int (0 means NULL / no context)."""
return int(handle_return(driver.cuCtxGetCurrent()))


def test_context_manager_basic(deinit_cuda):
"""with Device(0) sets the device as current and restores on exit."""
assert _get_current_context() == 0, "Should start with no active context"

with Device(0):
assert _get_current_context() != 0, "Device should be current inside the with block"

assert _get_current_context() == 0, "No context should be current after exiting"


def test_context_manager_restores_previous(deinit_cuda):
"""Context manager restores the previously active context, not NULL."""
dev0 = Device(0)
dev0.set_current()
ctx_before = _get_current_context()
assert ctx_before != 0

with Device(0):
pass

assert _get_current_context() == ctx_before, "Should restore the previous context"


def test_context_manager_exception_safety(deinit_cuda):
"""Device context is restored even when an exception is raised."""
# Start with no active context so restoration is distinguishable
assert _get_current_context() == 0

with pytest.raises(RuntimeError, match="test error"), Device(0):
assert _get_current_context() != 0, "Device should be active inside the block"
raise RuntimeError("test error")

assert _get_current_context() == 0, "Must restore NULL context after exception"


def test_context_manager_returns_device(deinit_cuda):
"""__enter__ returns the Device instance for use in 'as' clause."""
device = Device(0)
with device as dev:
assert dev is device

assert _get_current_context() == 0


def test_context_manager_nesting_same_device(deinit_cuda):
"""Nested with-blocks on the same device work correctly."""
dev0 = Device(0)

with dev0:
ctx_outer = _get_current_context()
with dev0:
ctx_inner = _get_current_context()
assert ctx_inner == ctx_outer, "Same device should yield same context"
assert _get_current_context() == ctx_outer, "Outer context restored after inner exit"

assert _get_current_context() == 0


def test_context_manager_deep_nesting(deinit_cuda):
"""Deep nesting and reentrancy restore correctly at each level."""
dev0 = Device(0)

with dev0:
ctx_level1 = _get_current_context()
with dev0:
ctx_level2 = _get_current_context()
with dev0:
assert _get_current_context() != 0
assert _get_current_context() == ctx_level2
assert _get_current_context() == ctx_level1

assert _get_current_context() == 0


def test_context_manager_nesting_different_devices(mempool_device_x2):
"""Nested with-blocks on different devices restore correctly."""
dev0, dev1 = mempool_device_x2
ctx_dev0 = _get_current_context()

with dev1:
ctx_inside = _get_current_context()
assert ctx_inside != ctx_dev0, "Different device should have different context"

assert _get_current_context() == ctx_dev0, "Original device context should be restored"


def test_context_manager_deep_nesting_multi_gpu(mempool_device_x2):
"""Deep nesting across multiple devices restores correctly at each level."""
dev0, dev1 = mempool_device_x2

with dev0:
ctx_level0 = _get_current_context()
with dev1:
ctx_level1 = _get_current_context()
assert ctx_level1 != ctx_level0
with dev0:
assert _get_current_context() == ctx_level0, "Same device should yield same primary context"
with dev1:
assert _get_current_context() == ctx_level1
assert _get_current_context() == ctx_level0
assert _get_current_context() == ctx_level1
assert _get_current_context() == ctx_level0


def test_context_manager_set_current_inside(mempool_device_x2):
"""set_current() inside a with block does not affect restoration on exit."""
dev0, dev1 = mempool_device_x2
ctx_dev0 = _get_current_context() # dev0 is current from fixture

with dev0:
dev1.set_current() # change the active device inside the block
assert _get_current_context() != ctx_dev0

assert _get_current_context() == ctx_dev0, "Must restore the context saved at __enter__"


def test_context_manager_device_usable_after_exit(deinit_cuda):
"""Device singleton is not corrupted after context manager exit."""
device = Device(0)
with device:
pass

assert _get_current_context() == 0

# Device should still be usable via set_current
device.set_current()
assert _get_current_context() != 0
stream = device.create_stream()
assert stream is not None


def test_context_manager_initializes_device(deinit_cuda):
"""with Device(N) should initialize the device, making it ready for use."""
device = Device(0)
with device:
# allocate requires an active context; should not raise
buf = device.allocate(1024)
assert buf.handle != 0


def test_context_manager_thread_safety(mempool_device_x3):
"""Concurrent threads using context managers on different devices don't interfere."""
import concurrent.futures
import threading

devices = mempool_device_x3
barrier = threading.Barrier(len(devices))
errors = []

def worker(dev):
try:
ctx_before = _get_current_context()
with dev:
barrier.wait(timeout=5)
buf = dev.allocate(1024)
assert buf.handle != 0
assert _get_current_context() == ctx_before
except Exception as e:
errors.append(e)

with concurrent.futures.ThreadPoolExecutor(max_workers=len(devices)) as pool:
pool.map(worker, devices)

assert not errors, f"Thread errors: {errors}"
Loading