mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Merge pull request #10150 from jakevdp:normalize-unsigned
PiperOrigin-RevId: 443416771
This commit is contained in:
commit
bef5e02816
@ -3417,6 +3417,8 @@ def _take(a, indices, axis: Optional[int] = None, out=None, mode=None):
|
||||
|
||||
def _normalize_index(index, axis_size):
|
||||
"""Normalizes an index value in the range [-N, N) to the range [0, N)."""
|
||||
if issubdtype(_dtype(index), np.unsignedinteger):
|
||||
return index
|
||||
if core.is_constant_dim(axis_size):
|
||||
axis_size_val = _lax_const(index, axis_size)
|
||||
else:
|
||||
|
@ -27,6 +27,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import ops
|
||||
|
||||
@ -616,6 +617,22 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{dtype}", "dtype": dtype}
|
||||
for dtype in jtu.dtypes.unsigned + jtu.dtypes.integer)
|
||||
def testIndicesNormalizationByType(self, dtype):
|
||||
x = jnp.arange(10)
|
||||
jaxpr = jax.make_jaxpr(x.__getitem__)(jnp.arange(3, dtype=dtype))
|
||||
primitives = [eqn.primitive for eqn in jaxpr.eqns]
|
||||
if np.issubdtype(dtype, np.unsignedinteger):
|
||||
# Unsigned integers should not require lt, add, and select.
|
||||
self.assertEqual(primitives, [lax.convert_element_type_p, lax.broadcast_in_dim_p, lax.gather_p])
|
||||
else:
|
||||
# May or may not contain convert_element_type.
|
||||
self.assertIn(len(primitives), [5, 6])
|
||||
self.assertEqual(primitives[:3], [lax.lt_p, lax.add_p, lax.select_n_p])
|
||||
self.assertEqual(primitives[-2:], [lax.broadcast_in_dim_p, lax.gather_p])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "{}_inshape={}_indexer={}"
|
||||
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
|
||||
|
Loading…
x
Reference in New Issue
Block a user