diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 886413ddc..236fccca5 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -335,6 +335,15 @@ deflinear(lax.reduce_window_sum_p) deflinear(lax.fft_p) deflinear(dispatch.device_put_p) +def _dynamic_update_slice_jet_rule(primals_in, series_in, **params): + operand, update, *start_indices = primals_in + primal_out = lax.dynamic_update_slice_p.bind(operand, update, *start_indices) + series_out = [lax.dynamic_update_slice_p.bind(*terms_in[:2], *start_indices, **params) + for terms_in in zip(*series_in)] + return primal_out, series_out + +jet_rules[lax.dynamic_update_slice_p] = _dynamic_update_slice_jet_rule + def _cumulative_jet_rule(primals_in, series_in, *, axis: int, reverse: bool, combine_fn: Callable): # Irrespective of backend, we always use the parallel prefix scan diff --git a/tests/jet_test.py b/tests/jet_test.py index 7b277dec4..b3c2df339 100644 --- a/tests/jet_test.py +++ b/tests/jet_test.py @@ -312,6 +312,8 @@ class JetTest(jtu.JaxTestCase): def test_cummin(self): self.unary_check(partial(lax.cummin, axis=0)) @jtu.skip_on_devices("tpu") def test_dynamic_slice(self): self.unary_check(partial(lax.dynamic_slice, start_indices=(0,0), slice_sizes=(1,1))) + @jtu.skip_on_devices("tpu") + def test_dynamic_update_slice(self): self.unary_check(partial(lax.dynamic_update_slice, start_indices=(1,2), update=jnp.arange(6.0).reshape(2, 3))) @jtu.skip_on_devices("tpu")