mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[PJRT C API] Add stream extension to support DLPack and implement this extension in CUDA plugin.
PiperOrigin-RevId: 626408630
This commit is contained in:
parent
cea36a0438
commit
b2375fa7e9
@ -73,12 +73,11 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=dlpack_dtypes,
|
||||
copy=[False, True, None]
|
||||
copy=[False, True, None],
|
||||
use_stream=[False, True],
|
||||
)
|
||||
@jtu.run_on_devices("gpu")
|
||||
def testJaxRoundTrip(self, shape, dtype, copy):
|
||||
if xb.using_pjrt_c_api():
|
||||
self.skipTest("DLPack support is incomplete in the PJRT C API") # TODO(skyewm)
|
||||
def testJaxRoundTrip(self, shape, dtype, copy, use_stream):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np = rng(shape, dtype)
|
||||
|
||||
@ -91,7 +90,11 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
device = jax.devices("gpu")[0]
|
||||
y = jax.device_put(x, device)
|
||||
dl_device = y.__dlpack_device__()
|
||||
dlpack = jax.dlpack.to_dlpack(y, copy=copy)
|
||||
if use_stream:
|
||||
stream = tuple(y.devices())[0].get_stream_for_external_ready_events()
|
||||
dlpack = jax.dlpack.to_dlpack(y, copy=copy, stream=stream)
|
||||
else:
|
||||
dlpack = jax.dlpack.to_dlpack(y, copy=copy)
|
||||
z = jax.dlpack.from_dlpack(dlpack)
|
||||
|
||||
self.assertEqual(z.devices(), {device})
|
||||
|
Loading…
x
Reference in New Issue
Block a user