Add test and jet-primitive for dynamic_slice

This commit is contained in:
Nicholas Krämer 2022-02-08 13:28:41 +01:00
parent 44c6c055d3
commit c07a0f1139
2 changed files with 3 additions and 0 deletions

View File

@ -238,6 +238,7 @@ deflinear(lax.reshape_p)
deflinear(lax.rev_p)
deflinear(lax.transpose_p)
deflinear(lax.slice_p)
deflinear(lax.dynamic_slice_p)
deflinear(lax.reduce_sum_p)
deflinear(lax.reduce_window_sum_p)
deflinear(lax.fft_p)

View File

@ -303,6 +303,8 @@ class JetTest(jtu.JaxTestCase):
def test_cummax(self): self.unary_check(partial(lax.cummax, axis=0))
@jtu.skip_on_devices("tpu")
def test_cummin(self): self.unary_check(partial(lax.cummin, axis=0))
@jtu.skip_on_devices("tpu")
def test_dynamic_slice(self): self.unary_check(partial(lax.dynamic_slice, start_indices=(0,0), slice_sizes=(1,1)))
@jtu.skip_on_devices("tpu")