mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Adding support for copy_p primitive to jet.
PiperOrigin-RevId: 694296952
This commit is contained in:
parent
1f1d27de2f
commit
6e8a35f08c
@ -329,6 +329,7 @@ deflinear(lax.slice_p)
|
||||
deflinear(lax.reduce_sum_p)
|
||||
deflinear(lax.reduce_window_sum_p)
|
||||
deflinear(lax.fft_p)
|
||||
deflinear(lax.copy_p)
|
||||
deflinear(dispatch.device_put_p)
|
||||
|
||||
def _dynamic_slice_jet_rule(primals_in, series_in, **params):
|
||||
|
@ -319,6 +319,8 @@ class JetTest(jtu.JaxTestCase):
|
||||
def test_dynamic_slice(self): self.unary_check(partial(lax.dynamic_slice, start_indices=(1,2), slice_sizes=(1,1)))
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_dynamic_update_slice(self): self.unary_check(partial(lax.dynamic_update_slice, start_indices=(1,2), update=np.arange(6.0).reshape(2, 3)))
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_copy(self): self.unary_check(jnp.array)
|
||||
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
|
Loading…
x
Reference in New Issue
Block a user