[XLA:Python] Fail with an AttributeError if __cuda_array_interface__ is called on a sharded array.

Fixes https://github.com/google/jax/issues/19134

PiperOrigin-RevId: 600570354
This commit is contained in:
Peter Hawkins 2024-01-22 14:24:45 -08:00 committed by jax authors
parent 46f796b38d
commit dfda948fbf
2 changed files with 20 additions and 0 deletions

View File

@ -61,6 +61,7 @@ jax_test(
name = "array_interoperability_test",
srcs = ["array_interoperability_test.py"],
disable_backends = ["tpu"],
tags = ["multiaccelerator"],
deps = py_deps("tensorflow_core"),
)

View File

@ -19,6 +19,7 @@ from absl.testing import absltest
import jax
import jax.dlpack
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax._src import config
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
@ -217,6 +218,24 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
):
_ = x.__cuda_array_interface__
@jtu.run_on_devices("cuda")
@unittest.skipIf(xla_extension_version < 233, "Requires newer jaxlib")
def testCudaArrayInterfaceOnShardedArrayFails(self):
devices = jax.local_devices()
if len(devices) <= 1:
raise unittest.SkipTest("Test requires 2 or more devices")
mesh = jax.sharding.Mesh(np.array(devices), ("x",))
sharding = jax.sharding.NamedSharding(mesh, P("x"))
x = jnp.arange(16)
x = jax.device_put(x, sharding)
self.assertFalse(hasattr(x, "__cuda_array_interface__"))
with self.assertRaisesRegex(
AttributeError,
"__cuda_array_interface__ is only supported for unsharded arrays.",
):
_ = x.__cuda_array_interface__
@jtu.sample_product(
shape=all_shapes,
dtype=cuda_array_interface_dtypes,