mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[XLA:Python] Fail with an AttributeError if __cuda_array_interface__ is called on a sharded array.
Fixes https://github.com/google/jax/issues/19134 PiperOrigin-RevId: 600570354
This commit is contained in:
parent
46f796b38d
commit
dfda948fbf
@ -61,6 +61,7 @@ jax_test(
|
||||
name = "array_interoperability_test",
|
||||
srcs = ["array_interoperability_test.py"],
|
||||
disable_backends = ["tpu"],
|
||||
tags = ["multiaccelerator"],
|
||||
deps = py_deps("tensorflow_core"),
|
||||
)
|
||||
|
||||
|
@ -19,6 +19,7 @@ from absl.testing import absltest
|
||||
import jax
|
||||
import jax.dlpack
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
@ -217,6 +218,24 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
||||
):
|
||||
_ = x.__cuda_array_interface__
|
||||
|
||||
@jtu.run_on_devices("cuda")
|
||||
@unittest.skipIf(xla_extension_version < 233, "Requires newer jaxlib")
|
||||
def testCudaArrayInterfaceOnShardedArrayFails(self):
|
||||
devices = jax.local_devices()
|
||||
if len(devices) <= 1:
|
||||
raise unittest.SkipTest("Test requires 2 or more devices")
|
||||
mesh = jax.sharding.Mesh(np.array(devices), ("x",))
|
||||
sharding = jax.sharding.NamedSharding(mesh, P("x"))
|
||||
x = jnp.arange(16)
|
||||
x = jax.device_put(x, sharding)
|
||||
self.assertFalse(hasattr(x, "__cuda_array_interface__"))
|
||||
with self.assertRaisesRegex(
|
||||
AttributeError,
|
||||
"__cuda_array_interface__ is only supported for unsharded arrays.",
|
||||
):
|
||||
_ = x.__cuda_array_interface__
|
||||
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=cuda_array_interface_dtypes,
|
||||
|
Loading…
x
Reference in New Issue
Block a user