From 88cc254f2c608d42e835c9c3b5a0c24f7e47bd9c Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 9 Feb 2023 13:38:04 -0800 Subject: [PATCH] [JAX] Replace uses of jax.interpreters.pxla.ShardedDeviceArray with jax.Array. PiperOrigin-RevId: 508463147 --- benchmarks/pmap_benchmark.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/benchmarks/pmap_benchmark.py b/benchmarks/pmap_benchmark.py index c1d16e7d0..e11319bd0 100644 --- a/benchmarks/pmap_benchmark.py +++ b/benchmarks/pmap_benchmark.py @@ -43,8 +43,7 @@ def pmap_shard_sharded_device_array_benchmark(): shape = (nshards, 4) args = [np.random.random(shape) for _ in range(nargs)] sharded_args = pmap(lambda x: x)(args) - assert all(isinstance(arg, jax.pxla.ShardedDeviceArray) - for arg in sharded_args) + assert all(isinstance(arg, jax.Array) for arg in sharded_args) def benchmark_fn(): for _ in range(100): pmap_fn(*sharded_args)