From 92ca76a0395ad32423e681b6d6ce6d84c361852b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 20 Apr 2022 16:04:12 -0700 Subject: [PATCH] Skip normalization of unsigned indices --- jax/_src/numpy/lax_numpy.py | 2 ++ tests/lax_numpy_indexing_test.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index cab728fac..ff09ed7c4 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3417,6 +3417,8 @@ def _take(a, indices, axis: Optional[int] = None, out=None, mode=None): def _normalize_index(index, axis_size): """Normalizes an index value in the range [-N, N) to the range [0, N).""" + if issubdtype(_dtype(index), np.unsignedinteger): + return index if core.is_constant_dim(axis_size): axis_size_val = _lax_const(index, axis_size) else: diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 6d5aa6d39..9b504470f 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -27,6 +27,7 @@ from absl.testing import parameterized import numpy as np import jax +from jax import lax from jax import numpy as jnp from jax import ops @@ -616,6 +617,22 @@ class IndexingTest(jtu.JaxTestCase): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @parameterized.named_parameters( + {"testcase_name": f"_{dtype}", "dtype": dtype} + for dtype in jtu.dtypes.unsigned + jtu.dtypes.integer) + def testIndicesNormalizationByType(self, dtype): + x = jnp.arange(10) + jaxpr = jax.make_jaxpr(x.__getitem__)(jnp.arange(3, dtype=dtype)) + primitives = [eqn.primitive for eqn in jaxpr.eqns] + if np.issubdtype(dtype, np.unsignedinteger): + # Unsigned integers should not require lt, add, and select. + self.assertEqual(primitives, [lax.convert_element_type_p, lax.broadcast_in_dim_p, lax.gather_p]) + else: + # May or may not contain convert_element_type. + self.assertIn(len(primitives), [5, 6]) + self.assertEqual(primitives[:3], [lax.lt_p, lax.add_p, lax.select_n_p]) + self.assertEqual(primitives[-2:], [lax.broadcast_in_dim_p, lax.gather_p]) + @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_indexer={}" .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),