mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Add test and jet-primitive for dynamic_slice
This commit is contained in:
parent
44c6c055d3
commit
c07a0f1139
@ -238,6 +238,7 @@ deflinear(lax.reshape_p)
|
||||
deflinear(lax.rev_p)
|
||||
deflinear(lax.transpose_p)
|
||||
deflinear(lax.slice_p)
|
||||
deflinear(lax.dynamic_slice_p)
|
||||
deflinear(lax.reduce_sum_p)
|
||||
deflinear(lax.reduce_window_sum_p)
|
||||
deflinear(lax.fft_p)
|
||||
|
@ -303,6 +303,8 @@ class JetTest(jtu.JaxTestCase):
|
||||
def test_cummax(self): self.unary_check(partial(lax.cummax, axis=0))
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_cummin(self): self.unary_check(partial(lax.cummin, axis=0))
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_dynamic_slice(self): self.unary_check(partial(lax.dynamic_slice, start_indices=(0,0), slice_sizes=(1,1)))
|
||||
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
|
Loading…
x
Reference in New Issue
Block a user