[XLA:Python] Raise an AttributeError if __cuda_array_interface__ is called on various invalid buffers, rather than a RuntimeError.

This makes hasattr(x, "__cuda_array_interface__") fail gracefully.

In passing, also move the implementation into py_array.cc, and use an allowlist of supported types rather than a denylist.

Fixes https://github.com/google/jax/issues/19134

PiperOrigin-RevId: 595788328
This commit is contained in:
Peter Hawkins 2024-01-04 13:26:20 -08:00 committed by jax authors
parent ebc7af95df
commit 3ff4eb410d

View File

@ -202,15 +202,22 @@ class DLPackTest(jtu.JaxTestCase):
class CudaArrayInterfaceTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.test_device_matches(["cuda"]):
self.skipTest("__cuda_array_interface__ is only supported on GPU")
@jtu.skip_on_devices("cuda")
@unittest.skipIf(xla_extension_version < 228, "Requires newer jaxlib")
def testCudaArrayInterfaceOnNonCudaFails(self):
x = jnp.arange(5)
self.assertFalse(hasattr(x, "__cuda_array_interface__"))
with self.assertRaisesRegex(
AttributeError,
"__cuda_array_interface__ is only defined for NVidia GPU buffers.",
):
_ = x.__cuda_array_interface__
@jtu.sample_product(
shape=all_shapes,
dtype=cuda_array_interface_dtypes,
)
@jtu.run_on_devices("cuda")
def testCudaArrayInterfaceWorks(self, shape, dtype):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
@ -220,11 +227,13 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
self.assertEqual(shape, a["shape"])
self.assertEqual(z.__array_interface__["typestr"], a["typestr"])
@jtu.run_on_devices("cuda")
@unittest.skipIf(xla_extension_version < 228, "Requires newer jaxlib")
def testCudaArrayInterfaceBfloat16Fails(self):
rng = jtu.rand_default(self.rng())
x = rng((2, 2), jnp.bfloat16)
y = jnp.array(x)
with self.assertRaisesRegex(RuntimeError, ".*not supported for bfloat16.*"):
with self.assertRaisesRegex(AttributeError, ".*not supported for BF16.*"):
_ = y.__cuda_array_interface__
@jtu.sample_product(
@ -232,6 +241,7 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
dtype=cuda_array_interface_dtypes,
)
@unittest.skipIf(not cupy, "Test requires CuPy")
@jtu.run_on_devices("cuda")
def testJaxToCuPy(self, shape, dtype):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)