Merge pull request #13083 from jakevdp:jet-dynamic-update

PiperOrigin-RevId: 485740508
This commit is contained in:
jax authors 2022-11-02 17:44:09 -07:00
commit 5448ea6c10
2 changed files with 11 additions and 0 deletions

View File

@ -335,6 +335,15 @@ deflinear(lax.reduce_window_sum_p)
deflinear(lax.fft_p)
deflinear(dispatch.device_put_p)
def _dynamic_update_slice_jet_rule(primals_in, series_in, **params):
operand, update, *start_indices = primals_in
primal_out = lax.dynamic_update_slice_p.bind(operand, update, *start_indices)
series_out = [lax.dynamic_update_slice_p.bind(*terms_in[:2], *start_indices, **params)
for terms_in in zip(*series_in)]
return primal_out, series_out
jet_rules[lax.dynamic_update_slice_p] = _dynamic_update_slice_jet_rule
def _cumulative_jet_rule(primals_in, series_in, *, axis: int, reverse: bool,
combine_fn: Callable):
# Irrespective of backend, we always use the parallel prefix scan

View File

@ -312,6 +312,8 @@ class JetTest(jtu.JaxTestCase):
def test_cummin(self): self.unary_check(partial(lax.cummin, axis=0))
@jtu.skip_on_devices("tpu")
def test_dynamic_slice(self): self.unary_check(partial(lax.dynamic_slice, start_indices=(0,0), slice_sizes=(1,1)))
@jtu.skip_on_devices("tpu")
def test_dynamic_update_slice(self): self.unary_check(partial(lax.dynamic_update_slice, start_indices=(1,2), update=jnp.arange(6.0).reshape(2, 3)))
@jtu.skip_on_devices("tpu")