mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[XLA:Python] Raise an AttributeError if __cuda_array_interface__ is called on various invalid buffers, rather than a RuntimeError.
This makes hasattr(x, "__cuda_array_interface__") fail gracefully. In passing, also move the implementation into py_array.cc, and use an allowlist of supported types rather than a denylist. Fixes https://github.com/google/jax/issues/19134 PiperOrigin-RevId: 595788328
This commit is contained in:
parent
ebc7af95df
commit
3ff4eb410d
@ -202,15 +202,22 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
|
||||
class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not jtu.test_device_matches(["cuda"]):
|
||||
self.skipTest("__cuda_array_interface__ is only supported on GPU")
|
||||
@jtu.skip_on_devices("cuda")
|
||||
@unittest.skipIf(xla_extension_version < 228, "Requires newer jaxlib")
|
||||
def testCudaArrayInterfaceOnNonCudaFails(self):
|
||||
x = jnp.arange(5)
|
||||
self.assertFalse(hasattr(x, "__cuda_array_interface__"))
|
||||
with self.assertRaisesRegex(
|
||||
AttributeError,
|
||||
"__cuda_array_interface__ is only defined for NVidia GPU buffers.",
|
||||
):
|
||||
_ = x.__cuda_array_interface__
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=cuda_array_interface_dtypes,
|
||||
)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def testCudaArrayInterfaceWorks(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = rng(shape, dtype)
|
||||
@ -220,11 +227,13 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
||||
self.assertEqual(shape, a["shape"])
|
||||
self.assertEqual(z.__array_interface__["typestr"], a["typestr"])
|
||||
|
||||
@jtu.run_on_devices("cuda")
|
||||
@unittest.skipIf(xla_extension_version < 228, "Requires newer jaxlib")
|
||||
def testCudaArrayInterfaceBfloat16Fails(self):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = rng((2, 2), jnp.bfloat16)
|
||||
y = jnp.array(x)
|
||||
with self.assertRaisesRegex(RuntimeError, ".*not supported for bfloat16.*"):
|
||||
with self.assertRaisesRegex(AttributeError, ".*not supported for BF16.*"):
|
||||
_ = y.__cuda_array_interface__
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -232,6 +241,7 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
||||
dtype=cuda_array_interface_dtypes,
|
||||
)
|
||||
@unittest.skipIf(not cupy, "Test requires CuPy")
|
||||
@jtu.run_on_devices("cuda")
|
||||
def testJaxToCuPy(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = rng(shape, dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user