Add dynamic slice U8 index test

This commit is contained in:
Jake VanderPlas 2021-06-23 13:29:15 -07:00
parent 92e8ec426c
commit 27fc797a67

View File

@ -2424,6 +2424,17 @@ class LaxTest(jtu.JaxTestCase):
self.assertArraysEqual(np.full((1, 30), np.float32(42)),
f(np.zeros((1, 24), dtype=np.float32)))
def testDynamicSliceU8Index(self):
# Regression test for u8 index in dynamic-slice (#6122)
# TODO(b/183216273): enable this test for CPU & GPU when possible.
if jtu.device_under_test() == "cpu":
raise unittest.SkipTest("DynamicSliceU8Index test is a known failure on CPU.")
if jtu.device_under_test() == "gpu":
raise unittest.SkipTest("DynamicSliceU8Index test is a known failure on GPU.")
x = np.arange(200)
np.testing.assert_equal(
np.array(lax.dynamic_slice(x, np.uint8([128]), (1,))), [128])
class LazyConstantTest(jtu.JaxTestCase):
def _Check(self, make_const, expected):