mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
lax.dynamic_slice: avoid negative index correction for unsigned indices
This commit is contained in:
parent
def6190dc2
commit
0dbda849ef
@ -2559,8 +2559,13 @@ def _dynamic_slice_indices(
|
||||
.format(start_indices.shape)) # type: ignore[union-attr]
|
||||
start_indices = list(start_indices)
|
||||
result: list[ArrayLike] = []
|
||||
# Loop to correct for negative indices.
|
||||
for i, d in zip(start_indices, operand.shape):
|
||||
# We test whether i and d are static to avoid unnecessary staging.
|
||||
# If i is unsigned, then it cannot be negative.
|
||||
if dtypes.issubdtype(_dtype(i), np.unsignedinteger):
|
||||
result.append(i)
|
||||
continue
|
||||
# Test whether i and d are static to avoid unnecessary staging.
|
||||
if isinstance(i, (int, np.integer)) and core.is_constant_dim(d):
|
||||
result.append(lax.convert_element_type(i + d if i < 0 else i, _dtype(i)))
|
||||
continue
|
||||
|
@ -2653,6 +2653,15 @@ class LaxTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(np.full((1, 30), np.float32(42)),
|
||||
f(np.zeros((1, 24), dtype=np.float32)))
|
||||
|
||||
def testDynamicSliceUnsignedNoNormalization(self):
|
||||
# Test that no negative index correction is done for unsigned indices.
|
||||
f = lambda x, i: lax.dynamic_slice(x, [i], [1])
|
||||
x = np.arange(200)
|
||||
i = np.uint32(128)
|
||||
jaxpr = jax.make_jaxpr(f)(x, i)
|
||||
self.assertLen(jaxpr.eqns, 1)
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, lax.dynamic_slice_p)
|
||||
|
||||
def testDynamicSliceU8Index(self):
|
||||
# Regression test for u8 index in dynamic-slice (#6122)
|
||||
# TODO(b/183216273): enable this test for CPU & GPU when possible.
|
||||
|
Loading…
x
Reference in New Issue
Block a user