diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 8dd2a319a..827e4d01b 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -329,6 +329,7 @@ deflinear(lax.slice_p) deflinear(lax.reduce_sum_p) deflinear(lax.reduce_window_sum_p) deflinear(lax.fft_p) +deflinear(lax.copy_p) deflinear(dispatch.device_put_p) def _dynamic_slice_jet_rule(primals_in, series_in, **params): diff --git a/tests/jet_test.py b/tests/jet_test.py index 4e437c044..7c2c71e9b 100644 --- a/tests/jet_test.py +++ b/tests/jet_test.py @@ -319,6 +319,8 @@ class JetTest(jtu.JaxTestCase): def test_dynamic_slice(self): self.unary_check(partial(lax.dynamic_slice, start_indices=(1,2), slice_sizes=(1,1))) @jtu.skip_on_devices("tpu") def test_dynamic_update_slice(self): self.unary_check(partial(lax.dynamic_update_slice, start_indices=(1,2), update=np.arange(6.0).reshape(2, 3))) + @jtu.skip_on_devices("tpu") + def test_copy(self): self.unary_check(jnp.array) @jtu.skip_on_devices("tpu")