From dfda948fbf5c8159fa9e95d3c4e2b82a6855ac28 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 22 Jan 2024 14:24:45 -0800 Subject: [PATCH] [XLA:Python] Fail with an AttributeError if __cuda_array_interface__ is called on a sharded array. Fixes https://github.com/google/jax/issues/19134 PiperOrigin-RevId: 600570354 --- tests/BUILD | 1 + tests/array_interoperability_test.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/tests/BUILD b/tests/BUILD index 77931797e..adcd74fcd 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -61,6 +61,7 @@ jax_test( name = "array_interoperability_test", srcs = ["array_interoperability_test.py"], disable_backends = ["tpu"], + tags = ["multiaccelerator"], deps = py_deps("tensorflow_core"), ) diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index 445195626..6a9ace853 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -19,6 +19,7 @@ from absl.testing import absltest import jax import jax.dlpack import jax.numpy as jnp +from jax.sharding import PartitionSpec as P from jax._src import config from jax._src import test_util as jtu from jax._src import xla_bridge as xb @@ -217,6 +218,24 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase): ): _ = x.__cuda_array_interface__ + @jtu.run_on_devices("cuda") + @unittest.skipIf(xla_extension_version < 233, "Requires newer jaxlib") + def testCudaArrayInterfaceOnShardedArrayFails(self): + devices = jax.local_devices() + if len(devices) <= 1: + raise unittest.SkipTest("Test requires 2 or more devices") + mesh = jax.sharding.Mesh(np.array(devices), ("x",)) + sharding = jax.sharding.NamedSharding(mesh, P("x")) + x = jnp.arange(16) + x = jax.device_put(x, sharding) + self.assertFalse(hasattr(x, "__cuda_array_interface__")) + with self.assertRaisesRegex( + AttributeError, + "__cuda_array_interface__ is only supported for unsharded arrays.", + ): + _ = x.__cuda_array_interface__ + + @jtu.sample_product( shape=all_shapes, dtype=cuda_array_interface_dtypes,