mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[JAX] Replace uses of jax.interpreters.pxla.ShardedDeviceArray with jax.Array.
PiperOrigin-RevId: 508463147
This commit is contained in:
parent
0c14e9ab49
commit
88cc254f2c
@ -43,8 +43,7 @@ def pmap_shard_sharded_device_array_benchmark():
|
||||
shape = (nshards, 4)
|
||||
args = [np.random.random(shape) for _ in range(nargs)]
|
||||
sharded_args = pmap(lambda x: x)(args)
|
||||
assert all(isinstance(arg, jax.pxla.ShardedDeviceArray)
|
||||
for arg in sharded_args)
|
||||
assert all(isinstance(arg, jax.Array) for arg in sharded_args)
|
||||
def benchmark_fn():
|
||||
for _ in range(100):
|
||||
pmap_fn(*sharded_args)
|
||||
|
Loading…
x
Reference in New Issue
Block a user