mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add dynamic slice U8 index test
This commit is contained in:
parent
92e8ec426c
commit
27fc797a67
@ -2424,6 +2424,17 @@ class LaxTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(np.full((1, 30), np.float32(42)),
|
||||
f(np.zeros((1, 24), dtype=np.float32)))
|
||||
|
||||
def testDynamicSliceU8Index(self):
|
||||
# Regression test for u8 index in dynamic-slice (#6122)
|
||||
# TODO(b/183216273): enable this test for CPU & GPU when possible.
|
||||
if jtu.device_under_test() == "cpu":
|
||||
raise unittest.SkipTest("DynamicSliceU8Index test is a known failure on CPU.")
|
||||
if jtu.device_under_test() == "gpu":
|
||||
raise unittest.SkipTest("DynamicSliceU8Index test is a known failure on GPU.")
|
||||
x = np.arange(200)
|
||||
np.testing.assert_equal(
|
||||
np.array(lax.dynamic_slice(x, np.uint8([128]), (1,))), [128])
|
||||
|
||||
|
||||
class LazyConstantTest(jtu.JaxTestCase):
|
||||
def _Check(self, make_const, expected):
|
||||
|
Loading…
x
Reference in New Issue
Block a user