mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Support __cuda_array_interface__ on JAX DeviceArrays. (#2133)
Allows exporting GPU device-resident arrays to other libraries, e.g., CuPy.
This commit is contained in:
parent
4c30c0285c
commit
843e22dd17
@ -28,10 +28,10 @@ http_archive(
|
||||
# and update the sha256 with the result.
|
||||
http_archive(
|
||||
name = "org_tensorflow",
|
||||
sha256 = "46d17ceaae12196c1cb2e99ca0cf040e7cefdc45ae9eede60c3caf3c5aaf5e48",
|
||||
strip_prefix = "tensorflow-93ccefd6d3b8d32f7afcc43568fc7e872e744767",
|
||||
sha256 = "4ce0e08aa014fafa7a0e8fb3531bdc914bd8a49828e1f5c31bb8adfb751ad73d",
|
||||
strip_prefix = "tensorflow-210649dd56d7c4b75e3e8e2a851b61c80ae13dbb",
|
||||
urls = [
|
||||
"https://github.com/tensorflow/tensorflow/archive/93ccefd6d3b8d32f7afcc43568fc7e872e744767.tar.gz",
|
||||
"https://github.com/tensorflow/tensorflow/archive/210649dd56d7c4b75e3e8e2a851b61c80ae13dbb.tar.gz",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -832,6 +832,10 @@ class DeviceArray(DeviceValue):
|
||||
def __array__(self, dtype=None, context=None):
|
||||
return onp.asarray(self._value, dtype=dtype)
|
||||
|
||||
@property
|
||||
def __cuda_array_interface__(self):
|
||||
return _force(self).device_buffer.__cuda_array_interface__
|
||||
|
||||
__str__ = partialmethod(_forward_to_value, str)
|
||||
__bool__ = __nonzero__ = partialmethod(_forward_to_value, bool)
|
||||
def __float__(self): return self._value.__float__()
|
||||
|
@ -31,11 +31,17 @@ try:
|
||||
except ImportError:
|
||||
torch = None
|
||||
|
||||
try:
|
||||
import cupy
|
||||
except ImportError:
|
||||
cupy = None
|
||||
|
||||
scalar_types = [jnp.bool_, jnp.int8, jnp.int16, jnp.int32, jnp.int64,
|
||||
jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64,
|
||||
jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64]
|
||||
torch_types = [jnp.int8, jnp.int16, jnp.int32, jnp.int64,
|
||||
|
||||
standard_dtypes = [jnp.bool_, jnp.int8, jnp.int16, jnp.int32, jnp.int64,
|
||||
jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64,
|
||||
jnp.float16, jnp.float32, jnp.float64]
|
||||
all_dtypes = standard_dtypes + [jnp.bfloat16]
|
||||
torch_dtypes = [jnp.int8, jnp.int16, jnp.int32, jnp.int64,
|
||||
jnp.uint8, jnp.float16, jnp.float32, jnp.float64]
|
||||
|
||||
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (2, 3, 4)]
|
||||
@ -59,7 +65,7 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in all_shapes
|
||||
for dtype in scalar_types))
|
||||
for dtype in all_dtypes))
|
||||
def testJaxRoundTrip(self, shape, dtype):
|
||||
rng = jtu.rand_default()
|
||||
np = rng(shape, dtype)
|
||||
@ -77,7 +83,7 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in all_shapes
|
||||
for dtype in torch_types))
|
||||
for dtype in torch_dtypes))
|
||||
@unittest.skipIf(not torch, "Test requires PyTorch")
|
||||
def testTorchToJax(self, shape, dtype):
|
||||
rng = jtu.rand_default()
|
||||
@ -93,7 +99,7 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in all_shapes
|
||||
for dtype in torch_types))
|
||||
for dtype in torch_dtypes))
|
||||
@unittest.skipIf(not torch or jax.lib.version <= (0, 1, 38),
|
||||
"Test requires PyTorch and jaxlib >= 0.1.39")
|
||||
# TODO(phawkins): the dlpack destructor issues errors in jaxlib 0.1.38.
|
||||
@ -106,5 +112,29 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(np, y.numpy(), check_dtypes=True)
|
||||
|
||||
|
||||
class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
if jtu.device_under_test() != "gpu":
|
||||
self.skipTest("__cuda_array_interface__ is only supported on GPU")
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in all_shapes
|
||||
for dtype in standard_dtypes))
|
||||
@unittest.skipIf(not cupy or jax.lib.version <= (0, 1, 38),
|
||||
"Test requires CuPy and jaxlib >= 0.1.39")
|
||||
def testJaxToCuPy(self, shape, dtype):
|
||||
rng = jtu.rand_default()
|
||||
x = rng(shape, dtype)
|
||||
y = jnp.array(x)
|
||||
z = cupy.asarray(y)
|
||||
self.assertEqual(y.__cuda_array_interface__["data"][0],
|
||||
z.__cuda_array_interface__["data"][0])
|
||||
self.assertAllClose(x, cupy.asnumpy(z), check_dtypes=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
Loading…
x
Reference in New Issue
Block a user