From 6e8a35f08ca9f76d19534aaf9eb6550f143f0e63 Mon Sep 17 00:00:00 2001 From: James Martens Date: Thu, 7 Nov 2024 17:05:13 -0800 Subject: [PATCH] Adding support for copy_p primitive to jet. PiperOrigin-RevId: 694296952 --- jax/experimental/jet.py | 1 + tests/jet_test.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 8dd2a319a..827e4d01b 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -329,6 +329,7 @@ deflinear(lax.slice_p) deflinear(lax.reduce_sum_p) deflinear(lax.reduce_window_sum_p) deflinear(lax.fft_p) +deflinear(lax.copy_p) deflinear(dispatch.device_put_p) def _dynamic_slice_jet_rule(primals_in, series_in, **params): diff --git a/tests/jet_test.py b/tests/jet_test.py index 4e437c044..7c2c71e9b 100644 --- a/tests/jet_test.py +++ b/tests/jet_test.py @@ -319,6 +319,8 @@ class JetTest(jtu.JaxTestCase): def test_dynamic_slice(self): self.unary_check(partial(lax.dynamic_slice, start_indices=(1,2), slice_sizes=(1,1))) @jtu.skip_on_devices("tpu") def test_dynamic_update_slice(self): self.unary_check(partial(lax.dynamic_update_slice, start_indices=(1,2), update=np.arange(6.0).reshape(2, 3))) + @jtu.skip_on_devices("tpu") + def test_copy(self): self.unary_check(jnp.array) @jtu.skip_on_devices("tpu")