diff --git a/CHANGELOG.md b/CHANGELOG.md index 06ed6c59e..97c1f014c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,10 @@ Remember to align the itemized text with the first line of an item within a list ## jaxlib 0.4.13 +* Bug fixes + * `__cuda_array_interface__` was broken in previous jaxlib versions and is now + fixed ({jax-issue}`16440`). + ## jax 0.4.12 (June 8, 2023) * Changes diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index 3225b57f4..70d7d2adf 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -21,6 +21,7 @@ from jax import config import jax.dlpack import jax.numpy as jnp from jax._src import test_util as jtu +from jax._src.lib import xla_extension_version import numpy as np @@ -48,6 +49,8 @@ numpy_dtypes = sorted( [dt for dt in jax.dlpack.SUPPORTED_DTYPES if dt != jnp.bfloat16], key=lambda x: x.__name__) +cuda_array_interface_dtypes = [dt for dt in dlpack_dtypes if dt != jnp.bfloat16] + nonempty_nonscalar_array_shapes = [(4,), (3, 4), (2, 3, 4)] empty_array_shapes = [] empty_array_shapes += [(0,), (0, 4), (3, 0),] @@ -162,6 +165,7 @@ class DLPackTest(jtu.JaxTestCase): self.assertAllClose(x_np, x_jax) +@unittest.skipIf(xla_extension_version < 163, "Test requires jaxlib 0.4.13") class CudaArrayInterfaceTest(jtu.JaxTestCase): def setUp(self): @@ -171,12 +175,29 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase): @jtu.sample_product( shape=all_shapes, - dtype=dlpack_dtypes, + dtype=cuda_array_interface_dtypes, + ) + def testCudaArrayInterfaceWorks(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + x = rng(shape, dtype) + y = jnp.array(x) + a = y.__cuda_array_interface__ + self.assertEqual(shape, a["shape"]) + self.assertEqual(x.__array_interface__["typestr"], a["typestr"]) + + def testCudaArrayInterfaceBfloat16Fails(self): + rng = jtu.rand_default(self.rng()) + x = rng((2, 2), jnp.bfloat16) + y = jnp.array(x) + with self.assertRaisesRegex(RuntimeError, ".*not supported for bfloat16.*"): + _ = y.__cuda_array_interface__ + + @jtu.sample_product( + shape=all_shapes, + dtype=cuda_array_interface_dtypes, ) @unittest.skipIf(not cupy, "Test requires CuPy") def testJaxToCuPy(self, shape, dtype): - if dtype == jnp.bfloat16: - raise unittest.SkipTest("cupy does not support bfloat16") rng = jtu.rand_default(self.rng()) x = rng(shape, dtype) y = jnp.array(x)