[JAX] Change jnp.take_along_axis to return invalid (e.g. NaN) values for out-of-bounds indices.

Previously, out-of-bounds indices were clipped into range, but that behavior is error prone. We would rather fail in a more visible way when out-of-bounds indices are used. Future changes will migrate other JAX indexing operations to have the same semantics.

PiperOrigin-RevId: 443390170
This commit is contained in:
Peter Hawkins 2022-04-21 08:51:35 -07:00 committed by jax authors
parent 0fc93a0a99
commit 74346f464b
3 changed files with 13 additions and 11 deletions

View File

@ -17,7 +17,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
for more information.
* {func}`jax.numpy.take_along_axis` now takes an optional `mode` parameter
that specifies the behavior of out-of-bounds indexing.
that specifies the behavior of out-of-bounds indexing. By default,
invalid values (e.g., NaN) will be returned for out-of-bounds indices. In
previous versions of JAX, invalid indices were clamped into range. The
previous behavior can be restored by passing `mode="clip"`.
* {func}`jax.numpy.take_along_axis` now raises a `TypeError` if its indices
are not of an integer type, matching the behavior of
{func}`numpy.take_along_axis`. Previously non-integer indices were silently

View File

@ -3431,10 +3431,9 @@ def _normalize_index(index, axis_size):
TAKE_ALONG_AXIS_DOC = """
Unlike :func:`numpy.take_along_axis`, :func:`jax.numpy.take_along_axis` takes
an optional ``mode`` parameter controlling how out-of-bounds indices should be
handled. By default, out-of-bounds indices are clamped into range. In a future
change, out-of-bounds indices will return invalid (e.g., ``NaN``) values
instead. See :attr:`jax.numpy.ndarray.at` for more discussion
of out-of-bounds indexing in JAX.
handled. By default, out-of-bounds indices yield invalid values (e.g., ``NaN``).
See :attr:`jax.numpy.ndarray.at` for futrher discussion of out-of-bounds
indexing in JAX.
"""
@_wraps(np.take_along_axis, update_doc=False,
@ -3520,9 +3519,8 @@ def take_along_axis(arr, indices, axis: Optional[int],
offset_dims=tuple(offset_dims),
collapsed_slice_dims=tuple(collapsed_slice_dims),
start_index_map=tuple(start_index_map))
# TODO(phawkins): change the mode to "fill".
return lax.gather(arr, gather_indices, dnums, tuple(slice_sizes),
mode="clip" if mode is None else mode)
mode="fill" if mode is None else mode)
### Indexing

View File

@ -4577,15 +4577,16 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
x = jnp.arange(10, dtype=jnp.float32)
idx = jnp.array([-11, -10, -9, -5, -1, 0, 1, 5, 9, 10, 11])
out = jnp.take_along_axis(x, idx, axis=0)
expected_clip = np.array([0, 0, 1, 5, 9, 0, 1, 5, 9, 9, 9], np.float32)
np.testing.assert_array_equal(expected_clip, out)
out = jnp.take_along_axis(x, idx, axis=0, mode="clip")
np.testing.assert_array_equal(expected_clip, out)
expected_fill = np.array([jnp.nan, 0, 1, 5, 9, 0, 1, 5, 9, jnp.nan,
jnp.nan], np.float32)
np.testing.assert_array_equal(expected_fill, out)
out = jnp.take_along_axis(x, idx, axis=0, mode="fill")
np.testing.assert_array_equal(expected_fill, out)
expected_clip = np.array([0, 0, 1, 5, 9, 0, 1, 5, 9, 9, 9], np.float32)
out = jnp.take_along_axis(x, idx, axis=0, mode="clip")
np.testing.assert_array_equal(expected_clip, out)
def testTakeAlongAxisRequiresIntIndices(self):
x = jnp.arange(5)
idx = jnp.array([3.], jnp.float32)