Added a pattern-match optimisation for inplace-select.

PiperOrigin-RevId: 497425937
This commit is contained in:
jax authors 2022-12-23 16:05:18 -08:00
parent 0c51ca274b
commit eb875cd5dd

View File

@ -836,5 +836,41 @@ def device_put(state):
while state:
_ = jax.device_put(x).block_until_ready()
def batch_inplace_while(inplace_op, state):
@jax.jit
@jax.vmap
def f(init_step, init_xs):
def cond(carry):
step, xs = carry
return step < xs.size
def body(carry):
step, xs = carry
if inplace_op == 'scatter':
xs = xs.at[step].set(1)
elif inplace_op == 'dynamic_update_slice':
xs = lax.dynamic_update_index_in_dim(xs, 1., step, 0)
else:
assert False
return step + 1, xs
return lax.while_loop(cond, body, (init_step, init_xs))
size = 100_000
args = jnp.array([0]), jnp.zeros((1, size))
f(*args) # compile
while state:
f(*args)
google_benchmark.register(
partial(batch_inplace_while, 'scatter'), name='batch_inplace_while_scatter')
google_benchmark.register(
partial(batch_inplace_while, 'dynamic_update_slice'),
name='batch_inplace_while_dynamic_update_slice')
if __name__ == "__main__":
google_benchmark.main()