fix benchmark sums (#4329)

This commit is contained in:
Jake Vanderplas 2020-09-18 09:24:00 -07:00 committed by GitHub
parent 2911bcd634
commit 6a89f60683
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -39,7 +39,7 @@ def pmap_shard_sharded_device_array_benchmark():
"""
def get_benchmark_fn(nargs, nshards):
pmap_fn = pmap(lambda *args: jnp.sum(args))
pmap_fn = pmap(lambda *args: jnp.sum(jnp.array(args)))
shape = (nshards, 4)
args = [np.random.random(shape) for _ in range(nargs)]
sharded_args = pmap(lambda x: x)(args)
@ -69,7 +69,7 @@ def pmap_shard_device_array_benchmark():
"""
def get_benchmark_fn(nargs, nshards):
pmap_fn = pmap(lambda *args: jnp.sum(args))
pmap_fn = pmap(lambda *args: jnp.sum(jnp.array(args)))
shape = (nshards, 4)
args = [jnp.array(np.random.random(shape)) for _ in range(nargs)]
assert all(isinstance(arg, jax.xla.DeviceArray) for arg in args)