Replace jax.xla.DeviceArray private type with the new public type jax.Array.

PiperOrigin-RevId: 477582562
This commit is contained in:
Yash Katariya 2022-09-28 16:33:43 -07:00 committed by jax authors
parent 33dbf0ea1c
commit 84768d2d49

View File

@ -72,7 +72,7 @@ def pmap_shard_device_array_benchmark():
pmap_fn = pmap(lambda *args: jnp.sum(jnp.array(args)))
shape = (nshards, 4)
args = [jnp.array(np.random.random(shape)) for _ in range(nargs)]
assert all(isinstance(arg, jax.xla.DeviceArray) for arg in args)
assert all(isinstance(arg, jax.Array) for arg in args)
def benchmark_fn():
for _ in range(10):
pmap_fn(*args)