mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add a warmup loop to pmap_simple_8_devices_100_args benchmark so as to not measure the compile time.
PiperOrigin-RevId: 401402336
This commit is contained in:
parent
3c117fd6ed
commit
bfbdfa87e7
@ -288,6 +288,9 @@ def pmap_simple_8_devices_100_args(state):
|
||||
for i in range(100):
|
||||
args.append(jnp.array(list(range(i, i+8))))
|
||||
|
||||
# Warmup loop.
|
||||
out = f(*args)
|
||||
|
||||
while state:
|
||||
out = f(*args)
|
||||
jax.tree_map(lambda x: x.block_until_ready(), out)
|
||||
|
Loading…
x
Reference in New Issue
Block a user