[PJRT C API] Add stream extension to support DLPack and implement this extension in CUDA plugin.

PiperOrigin-RevId: 626408630
This commit is contained in:
Jieying Luo 2024-04-19 10:41:20 -07:00 committed by jax authors
parent cea36a0438
commit b2375fa7e9

View File

@ -73,12 +73,11 @@ class DLPackTest(jtu.JaxTestCase):
@jtu.sample_product(
shape=all_shapes,
dtype=dlpack_dtypes,
copy=[False, True, None]
copy=[False, True, None],
use_stream=[False, True],
)
@jtu.run_on_devices("gpu")
def testJaxRoundTrip(self, shape, dtype, copy):
if xb.using_pjrt_c_api():
self.skipTest("DLPack support is incomplete in the PJRT C API") # TODO(skyewm)
def testJaxRoundTrip(self, shape, dtype, copy, use_stream):
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)
@ -91,7 +90,11 @@ class DLPackTest(jtu.JaxTestCase):
device = jax.devices("gpu")[0]
y = jax.device_put(x, device)
dl_device = y.__dlpack_device__()
dlpack = jax.dlpack.to_dlpack(y, copy=copy)
if use_stream:
stream = tuple(y.devices())[0].get_stream_for_external_ready_events()
dlpack = jax.dlpack.to_dlpack(y, copy=copy, stream=stream)
else:
dlpack = jax.dlpack.to_dlpack(y, copy=copy)
z = jax.dlpack.from_dlpack(dlpack)
self.assertEqual(z.devices(), {device})