[JAX] Replace uses of jax.interpreters.pxla.ShardedDeviceArray with jax.Array.

PiperOrigin-RevId: 508463147
This commit is contained in:
Peter Hawkins 2023-02-09 13:38:04 -08:00 committed by jax authors
parent 0c14e9ab49
commit 88cc254f2c

View File

@ -43,8 +43,7 @@ def pmap_shard_sharded_device_array_benchmark():
shape = (nshards, 4)
args = [np.random.random(shape) for _ in range(nargs)]
sharded_args = pmap(lambda x: x)(args)
assert all(isinstance(arg, jax.pxla.ShardedDeviceArray)
for arg in sharded_args)
assert all(isinstance(arg, jax.Array) for arg in sharded_args)
def benchmark_fn():
for _ in range(100):
pmap_fn(*sharded_args)