mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Replace jax.xla.DeviceArray
private type with the new public type jax.Array
.
PiperOrigin-RevId: 477582562
This commit is contained in:
parent
33dbf0ea1c
commit
84768d2d49
@ -72,7 +72,7 @@ def pmap_shard_device_array_benchmark():
|
||||
pmap_fn = pmap(lambda *args: jnp.sum(jnp.array(args)))
|
||||
shape = (nshards, 4)
|
||||
args = [jnp.array(np.random.random(shape)) for _ in range(nargs)]
|
||||
assert all(isinstance(arg, jax.xla.DeviceArray) for arg in args)
|
||||
assert all(isinstance(arg, jax.Array) for arg in args)
|
||||
def benchmark_fn():
|
||||
for _ in range(10):
|
||||
pmap_fn(*args)
|
||||
|
Loading…
x
Reference in New Issue
Block a user