Skip to content

Commit 92aa731

Browse files
Merge pull request #203 from NVIDIA/ksimpson/fix_device_from_ctx
Fix _util.device_from_ctx
2 parents a855762 + 8f357c7 commit 92aa731

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

‎cuda_core/cuda/core/experimental/_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,9 @@ def inner(*args, **kwargs):
116116

117117
def get_device_from_ctx(ctx_handle) -> int:
118118
"""Get device ID from the given ctx."""
119-
prev_ctx = Device().context.handle
120-
if ctx_handle != prev_ctx:
119+
from cuda.core.experimental._device import Device # avoid circular import
120+
prev_ctx = Device().context._handle
121+
if int(ctx_handle) != int(prev_ctx):
121122
switch_context = True
122123
else:
123124
switch_context = False

‎cuda_core/tests/test_stream.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,16 @@ def test_stream_context():
7171
context = stream.context
7272
assert context is not None
7373

74+
def test_stream_from_foreign_stream():
75+
device = Device()
76+
other_stream = device.create_stream(options=StreamOptions())
77+
stream = device.create_stream(obj=other_stream)
78+
assert other_stream.handle == stream.handle
79+
device = stream.device
80+
assert isinstance(device, Device)
81+
context = stream.context
82+
assert context is not None
83+
7484
def test_stream_from_handle():
7585
stream = Stream.from_handle(0)
7686
assert isinstance(stream, Stream)

0 commit comments

Comments
 (0)