Add a benchmark with many arguments.

PiperOrigin-RevId: 393216026
This commit is contained in:
Jean-Baptiste Lespiau 2021-08-26 15:05:44 -07:00 committed by jax authors
parent 26c9671413
commit 6cb8737c1a

View File

@ -266,6 +266,33 @@ def pmap_simple_8_devices(state):
d.block_until_ready()
@google_benchmark.register
@required_devices(8)
def pmap_simple_dispatch_8_devices_100_args(state):
f = jax.pmap(lambda *args: args[1:] + (args[0] + 1,))
args = []
for i in range(100):
args.append(jnp.array(list(range(i, i+8))))
args = f(*args)
while state:
args = f(*args)
@google_benchmark.register
@required_devices(8)
def pmap_simple_8_devices_100_args(state):
f = jax.pmap(lambda *args: args[1:] + (args[0] + 1,))
args = []
for i in range(100):
args.append(jnp.array(list(range(i, i+8))))
while state:
out = f(*args)
jax.tree_map(lambda x: x.block_until_ready(), out)
def _run_sda_index_bench(state, num_devices):
x = jax.pmap(jnp.sin)(jnp.arange(num_devices))
jax.device_get(x)