Adding support for copy_p primitive to jet.

PiperOrigin-RevId: 694296952
This commit is contained in:
James Martens 2024-11-07 17:05:13 -08:00 committed by jax authors
parent 1f1d27de2f
commit 6e8a35f08c
2 changed files with 3 additions and 0 deletions

View File

@ -329,6 +329,7 @@ deflinear(lax.slice_p)
deflinear(lax.reduce_sum_p)
deflinear(lax.reduce_window_sum_p)
deflinear(lax.fft_p)
deflinear(lax.copy_p)
deflinear(dispatch.device_put_p)
def _dynamic_slice_jet_rule(primals_in, series_in, **params):

View File

@ -319,6 +319,8 @@ class JetTest(jtu.JaxTestCase):
def test_dynamic_slice(self): self.unary_check(partial(lax.dynamic_slice, start_indices=(1,2), slice_sizes=(1,1)))
@jtu.skip_on_devices("tpu")
def test_dynamic_update_slice(self): self.unary_check(partial(lax.dynamic_update_slice, start_indices=(1,2), update=np.arange(6.0).reshape(2, 3)))
@jtu.skip_on_devices("tpu")
def test_copy(self): self.unary_check(jnp.array)
@jtu.skip_on_devices("tpu")