Support __cuda_array_interface__ on JAX DeviceArrays. (#2133)

Allows exporting GPU device-resident arrays to other libraries, e.g., CuPy.
This commit is contained in:
Peter Hawkins 2020-01-31 10:09:40 -05:00 committed by GitHub
parent 4c30c0285c
commit 843e22dd17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 10 deletions

View File

@ -28,10 +28,10 @@ http_archive(
# and update the sha256 with the result.
http_archive(
name = "org_tensorflow",
sha256 = "46d17ceaae12196c1cb2e99ca0cf040e7cefdc45ae9eede60c3caf3c5aaf5e48",
strip_prefix = "tensorflow-93ccefd6d3b8d32f7afcc43568fc7e872e744767",
sha256 = "4ce0e08aa014fafa7a0e8fb3531bdc914bd8a49828e1f5c31bb8adfb751ad73d",
strip_prefix = "tensorflow-210649dd56d7c4b75e3e8e2a851b61c80ae13dbb",
urls = [
"https://github.com/tensorflow/tensorflow/archive/93ccefd6d3b8d32f7afcc43568fc7e872e744767.tar.gz",
"https://github.com/tensorflow/tensorflow/archive/210649dd56d7c4b75e3e8e2a851b61c80ae13dbb.tar.gz",
],
)

View File

@ -832,6 +832,10 @@ class DeviceArray(DeviceValue):
def __array__(self, dtype=None, context=None):
return onp.asarray(self._value, dtype=dtype)
@property
def __cuda_array_interface__(self):
return _force(self).device_buffer.__cuda_array_interface__
__str__ = partialmethod(_forward_to_value, str)
__bool__ = __nonzero__ = partialmethod(_forward_to_value, bool)
def __float__(self): return self._value.__float__()

View File

@ -31,11 +31,17 @@ try:
except ImportError:
torch = None
try:
import cupy
except ImportError:
cupy = None
scalar_types = [jnp.bool_, jnp.int8, jnp.int16, jnp.int32, jnp.int64,
jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64,
jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64]
torch_types = [jnp.int8, jnp.int16, jnp.int32, jnp.int64,
standard_dtypes = [jnp.bool_, jnp.int8, jnp.int16, jnp.int32, jnp.int64,
jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64,
jnp.float16, jnp.float32, jnp.float64]
all_dtypes = standard_dtypes + [jnp.bfloat16]
torch_dtypes = [jnp.int8, jnp.int16, jnp.int32, jnp.int64,
jnp.uint8, jnp.float16, jnp.float32, jnp.float64]
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (2, 3, 4)]
@ -59,7 +65,7 @@ class DLPackTest(jtu.JaxTestCase):
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in all_shapes
for dtype in scalar_types))
for dtype in all_dtypes))
def testJaxRoundTrip(self, shape, dtype):
rng = jtu.rand_default()
np = rng(shape, dtype)
@ -77,7 +83,7 @@ class DLPackTest(jtu.JaxTestCase):
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in all_shapes
for dtype in torch_types))
for dtype in torch_dtypes))
@unittest.skipIf(not torch, "Test requires PyTorch")
def testTorchToJax(self, shape, dtype):
rng = jtu.rand_default()
@ -93,7 +99,7 @@ class DLPackTest(jtu.JaxTestCase):
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in all_shapes
for dtype in torch_types))
for dtype in torch_dtypes))
@unittest.skipIf(not torch or jax.lib.version <= (0, 1, 38),
"Test requires PyTorch and jaxlib >= 0.1.39")
# TODO(phawkins): the dlpack destructor issues errors in jaxlib 0.1.38.
@ -106,5 +112,29 @@ class DLPackTest(jtu.JaxTestCase):
self.assertAllClose(np, y.numpy(), check_dtypes=True)
class CudaArrayInterfaceTest(jtu.JaxTestCase):
def setUp(self):
if jtu.device_under_test() != "gpu":
self.skipTest("__cuda_array_interface__ is only supported on GPU")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in all_shapes
for dtype in standard_dtypes))
@unittest.skipIf(not cupy or jax.lib.version <= (0, 1, 38),
"Test requires CuPy and jaxlib >= 0.1.39")
def testJaxToCuPy(self, shape, dtype):
rng = jtu.rand_default()
x = rng(shape, dtype)
y = jnp.array(x)
z = cupy.asarray(y)
self.assertEqual(y.__cuda_array_interface__["data"][0],
z.__cuda_array_interface__["data"][0])
self.assertAllClose(x, cupy.asnumpy(z), check_dtypes=True)
if __name__ == "__main__":
absltest.main()