diff --git a/benchmarks/pmap_benchmark.py b/benchmarks/pmap_benchmark.py index 5a1dfc583..608e8ca25 100644 --- a/benchmarks/pmap_benchmark.py +++ b/benchmarks/pmap_benchmark.py @@ -39,7 +39,7 @@ def pmap_shard_sharded_device_array_benchmark(): """ def get_benchmark_fn(nargs, nshards): - pmap_fn = pmap(lambda *args: jnp.sum(args)) + pmap_fn = pmap(lambda *args: jnp.sum(jnp.array(args))) shape = (nshards, 4) args = [np.random.random(shape) for _ in range(nargs)] sharded_args = pmap(lambda x: x)(args) @@ -69,7 +69,7 @@ def pmap_shard_device_array_benchmark(): """ def get_benchmark_fn(nargs, nshards): - pmap_fn = pmap(lambda *args: jnp.sum(args)) + pmap_fn = pmap(lambda *args: jnp.sum(jnp.array(args))) shape = (nshards, 4) args = [jnp.array(np.random.random(shape)) for _ in range(nargs)] assert all(isinstance(arg, jax.xla.DeviceArray) for arg in args)