Merge pull request #10150 from jakevdp:normalize-unsigned

PiperOrigin-RevId: 443416771
This commit is contained in:
jax authors 2022-04-21 10:34:35 -07:00
commit bef5e02816
2 changed files with 19 additions and 0 deletions

View File

@ -3417,6 +3417,8 @@ def _take(a, indices, axis: Optional[int] = None, out=None, mode=None):
def _normalize_index(index, axis_size):
"""Normalizes an index value in the range [-N, N) to the range [0, N)."""
if issubdtype(_dtype(index), np.unsignedinteger):
return index
if core.is_constant_dim(axis_size):
axis_size_val = _lax_const(index, axis_size)
else:

View File

@ -27,6 +27,7 @@ from absl.testing import parameterized
import numpy as np
import jax
from jax import lax
from jax import numpy as jnp
from jax import ops
@ -616,6 +617,22 @@ class IndexingTest(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(
{"testcase_name": f"_{dtype}", "dtype": dtype}
for dtype in jtu.dtypes.unsigned + jtu.dtypes.integer)
def testIndicesNormalizationByType(self, dtype):
x = jnp.arange(10)
jaxpr = jax.make_jaxpr(x.__getitem__)(jnp.arange(3, dtype=dtype))
primitives = [eqn.primitive for eqn in jaxpr.eqns]
if np.issubdtype(dtype, np.unsignedinteger):
# Unsigned integers should not require lt, add, and select.
self.assertEqual(primitives, [lax.convert_element_type_p, lax.broadcast_in_dim_p, lax.gather_p])
else:
# May or may not contain convert_element_type.
self.assertIn(len(primitives), [5, 6])
self.assertEqual(primitives[:3], [lax.lt_p, lax.add_p, lax.select_n_p])
self.assertEqual(primitives[-2:], [lax.broadcast_in_dim_p, lax.gather_p])
@parameterized.named_parameters(
{"testcase_name": "{}_inshape={}_indexer={}"
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),