mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[XLA:Python] Fail with an AttributeError if __cuda_array_interface__ is called on a sharded array.
Fixes https://github.com/google/jax/issues/19134 PiperOrigin-RevId: 600570354
This commit is contained in:
parent
46f796b38d
commit
dfda948fbf
@ -61,6 +61,7 @@ jax_test(
|
|||||||
name = "array_interoperability_test",
|
name = "array_interoperability_test",
|
||||||
srcs = ["array_interoperability_test.py"],
|
srcs = ["array_interoperability_test.py"],
|
||||||
disable_backends = ["tpu"],
|
disable_backends = ["tpu"],
|
||||||
|
tags = ["multiaccelerator"],
|
||||||
deps = py_deps("tensorflow_core"),
|
deps = py_deps("tensorflow_core"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ from absl.testing import absltest
|
|||||||
import jax
|
import jax
|
||||||
import jax.dlpack
|
import jax.dlpack
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
from jax.sharding import PartitionSpec as P
|
||||||
from jax._src import config
|
from jax._src import config
|
||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
from jax._src import xla_bridge as xb
|
from jax._src import xla_bridge as xb
|
||||||
@ -217,6 +218,24 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
|||||||
):
|
):
|
||||||
_ = x.__cuda_array_interface__
|
_ = x.__cuda_array_interface__
|
||||||
|
|
||||||
|
@jtu.run_on_devices("cuda")
|
||||||
|
@unittest.skipIf(xla_extension_version < 233, "Requires newer jaxlib")
|
||||||
|
def testCudaArrayInterfaceOnShardedArrayFails(self):
|
||||||
|
devices = jax.local_devices()
|
||||||
|
if len(devices) <= 1:
|
||||||
|
raise unittest.SkipTest("Test requires 2 or more devices")
|
||||||
|
mesh = jax.sharding.Mesh(np.array(devices), ("x",))
|
||||||
|
sharding = jax.sharding.NamedSharding(mesh, P("x"))
|
||||||
|
x = jnp.arange(16)
|
||||||
|
x = jax.device_put(x, sharding)
|
||||||
|
self.assertFalse(hasattr(x, "__cuda_array_interface__"))
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
AttributeError,
|
||||||
|
"__cuda_array_interface__ is only supported for unsharded arrays.",
|
||||||
|
):
|
||||||
|
_ = x.__cuda_array_interface__
|
||||||
|
|
||||||
|
|
||||||
@jtu.sample_product(
|
@jtu.sample_product(
|
||||||
shape=all_shapes,
|
shape=all_shapes,
|
||||||
dtype=cuda_array_interface_dtypes,
|
dtype=cuda_array_interface_dtypes,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user