diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index 768cb4b1f..c668b4cee 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -836,5 +836,41 @@ def device_put(state): while state: _ = jax.device_put(x).block_until_ready() + +def batch_inplace_while(inplace_op, state): + + @jax.jit + @jax.vmap + def f(init_step, init_xs): + + def cond(carry): + step, xs = carry + return step < xs.size + + def body(carry): + step, xs = carry + if inplace_op == 'scatter': + xs = xs.at[step].set(1) + elif inplace_op == 'dynamic_update_slice': + xs = lax.dynamic_update_index_in_dim(xs, 1., step, 0) + else: + assert False + return step + 1, xs + + return lax.while_loop(cond, body, (init_step, init_xs)) + + size = 100_000 + args = jnp.array([0]), jnp.zeros((1, size)) + f(*args) # compile + while state: + f(*args) + + +google_benchmark.register( + partial(batch_inplace_while, 'scatter'), name='batch_inplace_while_scatter') +google_benchmark.register( + partial(batch_inplace_while, 'dynamic_update_slice'), + name='batch_inplace_while_dynamic_update_slice') + if __name__ == "__main__": google_benchmark.main()