mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Added a pattern-match optimisation for inplace-select.
PiperOrigin-RevId: 497425937
This commit is contained in:
parent
0c51ca274b
commit
eb875cd5dd
@ -836,5 +836,41 @@ def device_put(state):
|
||||
while state:
|
||||
_ = jax.device_put(x).block_until_ready()
|
||||
|
||||
|
||||
def batch_inplace_while(inplace_op, state):
|
||||
|
||||
@jax.jit
|
||||
@jax.vmap
|
||||
def f(init_step, init_xs):
|
||||
|
||||
def cond(carry):
|
||||
step, xs = carry
|
||||
return step < xs.size
|
||||
|
||||
def body(carry):
|
||||
step, xs = carry
|
||||
if inplace_op == 'scatter':
|
||||
xs = xs.at[step].set(1)
|
||||
elif inplace_op == 'dynamic_update_slice':
|
||||
xs = lax.dynamic_update_index_in_dim(xs, 1., step, 0)
|
||||
else:
|
||||
assert False
|
||||
return step + 1, xs
|
||||
|
||||
return lax.while_loop(cond, body, (init_step, init_xs))
|
||||
|
||||
size = 100_000
|
||||
args = jnp.array([0]), jnp.zeros((1, size))
|
||||
f(*args) # compile
|
||||
while state:
|
||||
f(*args)
|
||||
|
||||
|
||||
google_benchmark.register(
|
||||
partial(batch_inplace_while, 'scatter'), name='batch_inplace_while_scatter')
|
||||
google_benchmark.register(
|
||||
partial(batch_inplace_while, 'dynamic_update_slice'),
|
||||
name='batch_inplace_while_dynamic_update_slice')
|
||||
|
||||
if __name__ == "__main__":
|
||||
google_benchmark.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user