[XLA:Python] Fix __cuda_array_interface__.

Adds a test for __cuda_array_interface__ that does not depend on cupy.

Fixes https://github.com/google/jax/issues/16440

PiperOrigin-RevId: 541965361
This commit is contained in:
Peter Hawkins 2023-06-20 10:08:38 -07:00 committed by jax authors
parent afcd1a7c8e
commit 0ec03dbdce
2 changed files with 28 additions and 3 deletions

View File

@ -37,6 +37,10 @@ Remember to align the itemized text with the first line of an item within a list
## jaxlib 0.4.13
* Bug fixes
* `__cuda_array_interface__` was broken in previous jaxlib versions and is now
fixed ({jax-issue}`16440`).
## jax 0.4.12 (June 8, 2023)
* Changes

View File

@ -21,6 +21,7 @@ from jax import config
import jax.dlpack
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
import numpy as np
@ -48,6 +49,8 @@ numpy_dtypes = sorted(
[dt for dt in jax.dlpack.SUPPORTED_DTYPES if dt != jnp.bfloat16],
key=lambda x: x.__name__)
cuda_array_interface_dtypes = [dt for dt in dlpack_dtypes if dt != jnp.bfloat16]
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (2, 3, 4)]
empty_array_shapes = []
empty_array_shapes += [(0,), (0, 4), (3, 0),]
@ -162,6 +165,7 @@ class DLPackTest(jtu.JaxTestCase):
self.assertAllClose(x_np, x_jax)
@unittest.skipIf(xla_extension_version < 163, "Test requires jaxlib 0.4.13")
class CudaArrayInterfaceTest(jtu.JaxTestCase):
def setUp(self):
@ -171,12 +175,29 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
@jtu.sample_product(
shape=all_shapes,
dtype=dlpack_dtypes,
dtype=cuda_array_interface_dtypes,
)
def testCudaArrayInterfaceWorks(self, shape, dtype):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
y = jnp.array(x)
a = y.__cuda_array_interface__
self.assertEqual(shape, a["shape"])
self.assertEqual(x.__array_interface__["typestr"], a["typestr"])
def testCudaArrayInterfaceBfloat16Fails(self):
rng = jtu.rand_default(self.rng())
x = rng((2, 2), jnp.bfloat16)
y = jnp.array(x)
with self.assertRaisesRegex(RuntimeError, ".*not supported for bfloat16.*"):
_ = y.__cuda_array_interface__
@jtu.sample_product(
shape=all_shapes,
dtype=cuda_array_interface_dtypes,
)
@unittest.skipIf(not cupy, "Test requires CuPy")
def testJaxToCuPy(self, shape, dtype):
if dtype == jnp.bfloat16:
raise unittest.SkipTest("cupy does not support bfloat16")
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
y = jnp.array(x)