mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
add jet rule for dynamic_update_slice
This commit is contained in:
parent
2dc804371c
commit
db7eea1f60
@ -335,6 +335,15 @@ deflinear(lax.reduce_window_sum_p)
|
||||
deflinear(lax.fft_p)
|
||||
deflinear(dispatch.device_put_p)
|
||||
|
||||
def _dynamic_update_slice_jet_rule(primals_in, series_in, **params):
|
||||
operand, update, *start_indices = primals_in
|
||||
primal_out = lax.dynamic_update_slice_p.bind(operand, update, *start_indices)
|
||||
series_out = [lax.dynamic_update_slice_p.bind(*terms_in[:2], *start_indices, **params)
|
||||
for terms_in in zip(*series_in)]
|
||||
return primal_out, series_out
|
||||
|
||||
jet_rules[lax.dynamic_update_slice_p] = _dynamic_update_slice_jet_rule
|
||||
|
||||
def _cumulative_jet_rule(primals_in, series_in, *, axis: int, reverse: bool,
|
||||
combine_fn: Callable):
|
||||
# Irrespective of backend, we always use the parallel prefix scan
|
||||
|
@ -312,6 +312,8 @@ class JetTest(jtu.JaxTestCase):
|
||||
def test_cummin(self): self.unary_check(partial(lax.cummin, axis=0))
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_dynamic_slice(self): self.unary_check(partial(lax.dynamic_slice, start_indices=(0,0), slice_sizes=(1,1)))
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_dynamic_update_slice(self): self.unary_check(partial(lax.dynamic_update_slice, start_indices=(1,2), update=jnp.arange(6.0).reshape(2, 3)))
|
||||
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
|
Loading…
x
Reference in New Issue
Block a user