lax.dynamic_slice: avoid negative index correction for unsigned indices

This commit is contained in:
Jake VanderPlas 2023-07-25 13:09:09 -07:00
parent def6190dc2
commit 0dbda849ef
2 changed files with 15 additions and 1 deletions

View File

@ -2559,8 +2559,13 @@ def _dynamic_slice_indices(
.format(start_indices.shape)) # type: ignore[union-attr]
start_indices = list(start_indices)
result: list[ArrayLike] = []
# Loop to correct for negative indices.
for i, d in zip(start_indices, operand.shape):
# We test whether i and d are static to avoid unnecessary staging.
# If i is unsigned, then it cannot be negative.
if dtypes.issubdtype(_dtype(i), np.unsignedinteger):
result.append(i)
continue
# Test whether i and d are static to avoid unnecessary staging.
if isinstance(i, (int, np.integer)) and core.is_constant_dim(d):
result.append(lax.convert_element_type(i + d if i < 0 else i, _dtype(i)))
continue

View File

@ -2653,6 +2653,15 @@ class LaxTest(jtu.JaxTestCase):
self.assertArraysEqual(np.full((1, 30), np.float32(42)),
f(np.zeros((1, 24), dtype=np.float32)))
def testDynamicSliceUnsignedNoNormalization(self):
# Test that no negative index correction is done for unsigned indices.
f = lambda x, i: lax.dynamic_slice(x, [i], [1])
x = np.arange(200)
i = np.uint32(128)
jaxpr = jax.make_jaxpr(f)(x, i)
self.assertLen(jaxpr.eqns, 1)
self.assertEqual(jaxpr.eqns[0].primitive, lax.dynamic_slice_p)
def testDynamicSliceU8Index(self):
# Regression test for u8 index in dynamic-slice (#6122)
# TODO(b/183216273): enable this test for CPU & GPU when possible.