Add a warmup loop to pmap_simple_8_devices_100_args benchmark so as to not measure the compile time.

PiperOrigin-RevId: 401402336
This commit is contained in:
Yash Katariya 2021-10-06 19:51:08 -07:00 committed by jax authors
parent 3c117fd6ed
commit bfbdfa87e7

View File

@ -288,6 +288,9 @@ def pmap_simple_8_devices_100_args(state):
for i in range(100):
args.append(jnp.array(list(range(i, i+8))))
# Warmup loop.
out = f(*args)
while state:
out = f(*args)
jax.tree_map(lambda x: x.block_until_ready(), out)