Skip to content

Add Device context manager for temporary device switching#1597

Draft
Andy-Jost wants to merge 1 commit intoNVIDIA:mainfrom
Andy-Jost:device-context-manager
Draft

Add Device context manager for temporary device switching#1597
Andy-Jost wants to merge 1 commit intoNVIDIA:mainfrom
Andy-Jost:device-context-manager

Conversation

@Andy-Jost
Copy link
Contributor

Summary

Closes #1586. Adds __enter__/__exit__ to Device so it can be used as a context manager that temporarily activates a device and restores the previous CUDA context on exit.

from cuda.core import Device

dev0 = Device(0)
dev0.set_current()
# ... do work on device 0 ...

with Device(1) as dev1:
    # device 1 is now current
    buf = dev1.allocate(1024)

# device 0 is automatically restored here

Changes

  • cuda/core/_device.pyx: Added __enter__ and __exit__ methods to Device. On enter, queries the current context via cuCtxGetCurrent and saves it on a per-thread stack (_tls._ctx_stack), then calls set_current(). On exit, restores the saved context via cuCtxSetCurrent. Uses peek-then-pop ordering so the stack is not corrupted if cuCtxSetCurrent raises.
  • tests/test_device.py: Added 12 tests covering basic usage, context restoration, exception safety, same-device nesting, deep nesting, multi-GPU nesting, set_current() inside a with block, device usability after exit, device initialization, and thread safety (3 threads on 3 GPUs).
  • tests/conftest.py: Added teardown to mempool_device_x2 and mempool_device_x3 fixtures to clean up residual contexts between tests.

Design

  • Stateless restoration: Each __enter__ queries the actual CUDA driver state rather than maintaining a Python-side device cache. This ensures correct interoperability with other libraries (PyTorch, CuPy) that use cudaSetDevice/cuCtxSetCurrent.
  • Reentrant: Saved contexts are stored on a per-thread stack (not on the Device singleton), so nested and reentrant usage works correctly.
  • Uses cuCtxGetCurrent/cuCtxSetCurrent: Consistent with set_current() and the runtime API model. Does not use cuCtxPushCurrent/cuCtxPopCurrent.

Test Coverage

All tests pass locally on single-GPU (L40) and multi-GPU (3x RTX PRO 6000 Blackwell) machines. Stress-tested with 20 randomized iterations via pytest-repeat + pytest-randomly with no ordering issues.

Made with Cursor

Closes NVIDIA#1586. Adds __enter__/__exit__ to Device so it can be used as
a context manager that saves the current CUDA context on entry and
restores it on exit. Uses cuCtxGetCurrent/cuCtxSetCurrent (not push/pop)
for interoperability with the runtime API. Saved contexts are stored on
a per-thread stack (_tls._ctx_stack) so nested and reentrant usage works
correctly.

Also adds teardown to mempool_device_x2/x3 fixtures to clean up
residual contexts between tests.

Co-authored-by: Cursor <cursoragent@cursor.com>
@Andy-Jost Andy-Jost added this to the cuda.core v0.6.0 milestone Feb 11, 2026
@Andy-Jost Andy-Jost added feature New feature or request cuda.core Everything related to the cuda.core module labels Feb 11, 2026
@Andy-Jost Andy-Jost self-assigned this Feb 11, 2026
@Andy-Jost Andy-Jost requested a review from leofang February 11, 2026 01:37
@copy-pr-bot
Copy link
Contributor

copy-pr-bot bot commented Feb 11, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@Andy-Jost
Copy link
Contributor Author

/ok to test f02b730

@Andy-Jost Andy-Jost marked this pull request as draft February 11, 2026 01:40
@github-actions
Copy link

Copy link
Contributor

@cpcloud cpcloud left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solid implementation — the design choices (stateless restoration via cuCtxGetCurrent/cuCtxSetCurrent, per-thread stack, peek-then-pop in __exit__) are all correct and well-reasoned. Tests are thorough. Two items to address before merge:

  1. __exit__ can mask the caller's exception — if cuCtxSetCurrent raises during unwinding from a user exception, the CUDA error replaces the original (it's still in __context__, but most users won't look there). See inline comment for a suggested fix.
  2. Missing docs / release notes — this adds a new public API entry point on Device but the PR doesn't update interoperability.rst (the "Current device/context" section should mention the context manager alongside set_current()), getting-started.rst, or 0.7.x-notes.rst. Even for a draft, having these roughed in makes the feature discoverable.

with nogil:
HANDLE_RETURN(cydriver.cuCtxSetCurrent(prev_ctx))
_tls._ctx_stack.pop()
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If cuCtxSetCurrent raises here while an exception is already propagating from the with block, the CUDA error replaces the user's original exception. The original ends up in e.__context__ but is easy to miss.

Consider guarding:

def __exit__(self, exc_type, exc_val, exc_tb):
    cdef uintptr_t prev_ctx_ptr = _tls._ctx_stack.pop()
    cdef cydriver.CUcontext prev_ctx = <cydriver.CUcontext><void*>prev_ctx_ptr
    try:
        with nogil:
            HANDLE_RETURN(cydriver.cuCtxSetCurrent(prev_ctx))
    except Exception:
        if exc_type is None:
            raise  # no active exception, surface the CUDA error
        # else: swallow the restore failure to preserve the original exception;
        # the stack entry is already popped so the next __exit__ won't retry.
    return False

This also simplifies the peek-then-pop dance — just pop eagerly, since a failed cuCtxSetCurrent with a context obtained from cuCtxGetCurrent moments earlier is essentially unrecoverable anyway.

(_graphics.pyx has the same pattern, so this is a pre-existing issue there too, but worth getting right here.)

cdef cydriver.CUcontext prev_ctx
with nogil:
HANDLE_RETURN(cydriver.cuCtxGetCurrent(&prev_ctx))
if not hasattr(_tls, '_ctx_stack'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: the try/except AttributeError pattern (EAFP) is marginally faster after the first call per thread, and is the more common pattern elsewhere in _device.pyx (see Device_ensure_tls_devices):

try:
    _tls._ctx_stack
except AttributeError:
    _tls._ctx_stack = []

Not blocking — just a consistency suggestion.


# ============================================================================
# Device Context Manager Tests
# ============================================================================
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: consider putting _get_current_context() in conftest.py — it's generally useful for any context-related test, and other test files may want it as multi-device / context-switching tests grow.

Copy link
Member

@leofang leofang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocking accidental merge of this PR before a design review/survey happens, as discussed offline and described in the linked issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cuda.core Everything related to the cuda.core module feature New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature Request: Device context manager for temporary device switching

4 participants