mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
fix benchmark sums (#4329)
This commit is contained in:
parent
2911bcd634
commit
6a89f60683
@ -39,7 +39,7 @@ def pmap_shard_sharded_device_array_benchmark():
|
||||
"""
|
||||
|
||||
def get_benchmark_fn(nargs, nshards):
|
||||
pmap_fn = pmap(lambda *args: jnp.sum(args))
|
||||
pmap_fn = pmap(lambda *args: jnp.sum(jnp.array(args)))
|
||||
shape = (nshards, 4)
|
||||
args = [np.random.random(shape) for _ in range(nargs)]
|
||||
sharded_args = pmap(lambda x: x)(args)
|
||||
@ -69,7 +69,7 @@ def pmap_shard_device_array_benchmark():
|
||||
"""
|
||||
|
||||
def get_benchmark_fn(nargs, nshards):
|
||||
pmap_fn = pmap(lambda *args: jnp.sum(args))
|
||||
pmap_fn = pmap(lambda *args: jnp.sum(jnp.array(args)))
|
||||
shape = (nshards, 4)
|
||||
args = [jnp.array(np.random.random(shape)) for _ in range(nargs)]
|
||||
assert all(isinstance(arg, jax.xla.DeviceArray) for arg in args)
|
||||
|
Loading…
x
Reference in New Issue
Block a user