-
Notifications
You must be signed in to change notification settings - Fork 255
Add Device context manager for temporary device switching #1597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
|
|
@@ -1188,6 +1188,49 @@ class Device: | |
| def __reduce__(self): | ||
| return Device, (self.device_id,) | ||
|
|
||
| def __enter__(self): | ||
| """Set this device as current for the duration of the ``with`` block. | ||
|
|
||
| On exit, the previously current device is restored automatically. | ||
| Nested ``with`` blocks are supported and restore correctly at each | ||
| level. | ||
|
|
||
| Returns | ||
| ------- | ||
| Device | ||
| This device instance. | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> from cuda.core import Device | ||
| >>> with Device(0) as dev0: | ||
| ... buf = dev0.allocate(1024) | ||
|
|
||
| See Also | ||
| -------- | ||
| set_current : Non-context-manager entry point. | ||
| """ | ||
| cdef cydriver.CUcontext prev_ctx | ||
| with nogil: | ||
| HANDLE_RETURN(cydriver.cuCtxGetCurrent(&prev_ctx)) | ||
| if not hasattr(_tls, '_ctx_stack'): | ||
| _tls._ctx_stack = [] | ||
| _tls._ctx_stack.append(<uintptr_t><void*>prev_ctx) | ||
| self.set_current() | ||
| return self | ||
|
|
||
| def __exit__(self, exc_type, exc_val, exc_tb): | ||
| """Restore the previously current device upon exiting the ``with`` block. | ||
|
|
||
| Exceptions are not suppressed. | ||
| """ | ||
| cdef uintptr_t prev_ctx_ptr = _tls._ctx_stack[-1] | ||
| cdef cydriver.CUcontext prev_ctx = <cydriver.CUcontext><void*>prev_ctx_ptr | ||
| with nogil: | ||
| HANDLE_RETURN(cydriver.cuCtxSetCurrent(prev_ctx)) | ||
| _tls._ctx_stack.pop() | ||
| return False | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If Consider guarding: def __exit__(self, exc_type, exc_val, exc_tb):
cdef uintptr_t prev_ctx_ptr = _tls._ctx_stack.pop()
cdef cydriver.CUcontext prev_ctx = <cydriver.CUcontext><void*>prev_ctx_ptr
try:
with nogil:
HANDLE_RETURN(cydriver.cuCtxSetCurrent(prev_ctx))
except Exception:
if exc_type is None:
raise # no active exception, surface the CUDA error
# else: swallow the restore failure to preserve the original exception;
# the stack entry is already popped so the next __exit__ won't retry.
return FalseThis also simplifies the peek-then-pop dance — just pop eagerly, since a failed ( |
||
|
|
||
| def set_current(self, ctx: Context = None) -> Context | None: | ||
| """Set device to be used for GPU executions. | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -436,3 +436,179 @@ def test_device_set_membership(init_cuda): | |
| # Same device_id should not add duplicate | ||
| device_set.add(dev0_b) | ||
| assert len(device_set) == 1, "Should not add duplicate device" | ||
|
|
||
|
|
||
| # ============================================================================ | ||
| # Device Context Manager Tests | ||
| # ============================================================================ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: consider putting |
||
|
|
||
|
|
||
| def _get_current_context(): | ||
| """Return the current CUcontext as an int (0 means NULL / no context).""" | ||
| return int(handle_return(driver.cuCtxGetCurrent())) | ||
|
|
||
|
|
||
| def test_context_manager_basic(deinit_cuda): | ||
| """with Device(0) sets the device as current and restores on exit.""" | ||
| assert _get_current_context() == 0, "Should start with no active context" | ||
|
|
||
| with Device(0): | ||
| assert _get_current_context() != 0, "Device should be current inside the with block" | ||
|
|
||
| assert _get_current_context() == 0, "No context should be current after exiting" | ||
|
|
||
|
|
||
| def test_context_manager_restores_previous(deinit_cuda): | ||
| """Context manager restores the previously active context, not NULL.""" | ||
| dev0 = Device(0) | ||
| dev0.set_current() | ||
| ctx_before = _get_current_context() | ||
| assert ctx_before != 0 | ||
|
|
||
| with Device(0): | ||
| pass | ||
|
|
||
| assert _get_current_context() == ctx_before, "Should restore the previous context" | ||
|
|
||
|
|
||
| def test_context_manager_exception_safety(deinit_cuda): | ||
| """Device context is restored even when an exception is raised.""" | ||
| # Start with no active context so restoration is distinguishable | ||
| assert _get_current_context() == 0 | ||
|
|
||
| with pytest.raises(RuntimeError, match="test error"), Device(0): | ||
| assert _get_current_context() != 0, "Device should be active inside the block" | ||
| raise RuntimeError("test error") | ||
|
|
||
| assert _get_current_context() == 0, "Must restore NULL context after exception" | ||
|
|
||
|
|
||
| def test_context_manager_returns_device(deinit_cuda): | ||
| """__enter__ returns the Device instance for use in 'as' clause.""" | ||
| device = Device(0) | ||
| with device as dev: | ||
| assert dev is device | ||
|
|
||
| assert _get_current_context() == 0 | ||
|
|
||
|
|
||
| def test_context_manager_nesting_same_device(deinit_cuda): | ||
| """Nested with-blocks on the same device work correctly.""" | ||
| dev0 = Device(0) | ||
|
|
||
| with dev0: | ||
| ctx_outer = _get_current_context() | ||
| with dev0: | ||
| ctx_inner = _get_current_context() | ||
| assert ctx_inner == ctx_outer, "Same device should yield same context" | ||
| assert _get_current_context() == ctx_outer, "Outer context restored after inner exit" | ||
|
|
||
| assert _get_current_context() == 0 | ||
|
|
||
|
|
||
| def test_context_manager_deep_nesting(deinit_cuda): | ||
| """Deep nesting and reentrancy restore correctly at each level.""" | ||
| dev0 = Device(0) | ||
|
|
||
| with dev0: | ||
| ctx_level1 = _get_current_context() | ||
| with dev0: | ||
| ctx_level2 = _get_current_context() | ||
| with dev0: | ||
| assert _get_current_context() != 0 | ||
| assert _get_current_context() == ctx_level2 | ||
| assert _get_current_context() == ctx_level1 | ||
|
|
||
| assert _get_current_context() == 0 | ||
|
|
||
|
|
||
| def test_context_manager_nesting_different_devices(mempool_device_x2): | ||
| """Nested with-blocks on different devices restore correctly.""" | ||
| dev0, dev1 = mempool_device_x2 | ||
| ctx_dev0 = _get_current_context() | ||
|
|
||
| with dev1: | ||
| ctx_inside = _get_current_context() | ||
| assert ctx_inside != ctx_dev0, "Different device should have different context" | ||
|
|
||
| assert _get_current_context() == ctx_dev0, "Original device context should be restored" | ||
|
|
||
|
|
||
| def test_context_manager_deep_nesting_multi_gpu(mempool_device_x2): | ||
| """Deep nesting across multiple devices restores correctly at each level.""" | ||
| dev0, dev1 = mempool_device_x2 | ||
|
|
||
| with dev0: | ||
| ctx_level0 = _get_current_context() | ||
| with dev1: | ||
| ctx_level1 = _get_current_context() | ||
| assert ctx_level1 != ctx_level0 | ||
| with dev0: | ||
| assert _get_current_context() == ctx_level0, "Same device should yield same primary context" | ||
| with dev1: | ||
| assert _get_current_context() == ctx_level1 | ||
| assert _get_current_context() == ctx_level0 | ||
| assert _get_current_context() == ctx_level1 | ||
| assert _get_current_context() == ctx_level0 | ||
|
|
||
|
|
||
| def test_context_manager_set_current_inside(mempool_device_x2): | ||
| """set_current() inside a with block does not affect restoration on exit.""" | ||
| dev0, dev1 = mempool_device_x2 | ||
| ctx_dev0 = _get_current_context() # dev0 is current from fixture | ||
|
|
||
| with dev0: | ||
| dev1.set_current() # change the active device inside the block | ||
| assert _get_current_context() != ctx_dev0 | ||
|
|
||
| assert _get_current_context() == ctx_dev0, "Must restore the context saved at __enter__" | ||
|
|
||
|
|
||
| def test_context_manager_device_usable_after_exit(deinit_cuda): | ||
| """Device singleton is not corrupted after context manager exit.""" | ||
| device = Device(0) | ||
| with device: | ||
| pass | ||
|
|
||
| assert _get_current_context() == 0 | ||
|
|
||
| # Device should still be usable via set_current | ||
| device.set_current() | ||
| assert _get_current_context() != 0 | ||
| stream = device.create_stream() | ||
| assert stream is not None | ||
|
|
||
|
|
||
| def test_context_manager_initializes_device(deinit_cuda): | ||
| """with Device(N) should initialize the device, making it ready for use.""" | ||
| device = Device(0) | ||
| with device: | ||
| # allocate requires an active context; should not raise | ||
| buf = device.allocate(1024) | ||
| assert buf.handle != 0 | ||
|
|
||
|
|
||
| def test_context_manager_thread_safety(mempool_device_x3): | ||
| """Concurrent threads using context managers on different devices don't interfere.""" | ||
| import concurrent.futures | ||
| import threading | ||
|
|
||
| devices = mempool_device_x3 | ||
| barrier = threading.Barrier(len(devices)) | ||
| errors = [] | ||
|
|
||
| def worker(dev): | ||
| try: | ||
| ctx_before = _get_current_context() | ||
| with dev: | ||
| barrier.wait(timeout=5) | ||
| buf = dev.allocate(1024) | ||
| assert buf.handle != 0 | ||
| assert _get_current_context() == ctx_before | ||
| except Exception as e: | ||
| errors.append(e) | ||
|
|
||
| with concurrent.futures.ThreadPoolExecutor(max_workers=len(devices)) as pool: | ||
| pool.map(worker, devices) | ||
|
|
||
| assert not errors, f"Thread errors: {errors}" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: the
try/except AttributeErrorpattern (EAFP) is marginally faster after the first call per thread, and is the more common pattern elsewhere in_device.pyx(seeDevice_ensure_tls_devices):Not blocking — just a consistency suggestion.