From eb875cd5dd3a47cb8c99a324f37507a0ddfc622e Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 23 Dec 2022 16:05:18 -0800 Subject: [PATCH] Added a pattern-match optimisation for inplace-select. PiperOrigin-RevId: 497425937 --- benchmarks/api_benchmark.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index 768cb4b1f..c668b4cee 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -836,5 +836,41 @@ def device_put(state): while state: _ = jax.device_put(x).block_until_ready() + +def batch_inplace_while(inplace_op, state): + + @jax.jit + @jax.vmap + def f(init_step, init_xs): + + def cond(carry): + step, xs = carry + return step < xs.size + + def body(carry): + step, xs = carry + if inplace_op == 'scatter': + xs = xs.at[step].set(1) + elif inplace_op == 'dynamic_update_slice': + xs = lax.dynamic_update_index_in_dim(xs, 1., step, 0) + else: + assert False + return step + 1, xs + + return lax.while_loop(cond, body, (init_step, init_xs)) + + size = 100_000 + args = jnp.array([0]), jnp.zeros((1, size)) + f(*args) # compile + while state: + f(*args) + + +google_benchmark.register( + partial(batch_inplace_while, 'scatter'), name='batch_inplace_while_scatter') +google_benchmark.register( + partial(batch_inplace_while, 'dynamic_update_slice'), + name='batch_inplace_while_dynamic_update_slice') + if __name__ == "__main__": google_benchmark.main()