mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[JAX] Change jnp.take_along_axis to return invalid (e.g. NaN) values for out-of-bounds indices.
Previously, out-of-bounds indices were clipped into range, but that behavior is error prone. We would rather fail in a more visible way when out-of-bounds indices are used. Future changes will migrate other JAX indexing operations to have the same semantics. PiperOrigin-RevId: 443390170
This commit is contained in:
parent
0fc93a0a99
commit
74346f464b
@ -17,7 +17,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
|
||||
for more information.
|
||||
* {func}`jax.numpy.take_along_axis` now takes an optional `mode` parameter
|
||||
that specifies the behavior of out-of-bounds indexing.
|
||||
that specifies the behavior of out-of-bounds indexing. By default,
|
||||
invalid values (e.g., NaN) will be returned for out-of-bounds indices. In
|
||||
previous versions of JAX, invalid indices were clamped into range. The
|
||||
previous behavior can be restored by passing `mode="clip"`.
|
||||
* {func}`jax.numpy.take_along_axis` now raises a `TypeError` if its indices
|
||||
are not of an integer type, matching the behavior of
|
||||
{func}`numpy.take_along_axis`. Previously non-integer indices were silently
|
||||
|
@ -3431,10 +3431,9 @@ def _normalize_index(index, axis_size):
|
||||
TAKE_ALONG_AXIS_DOC = """
|
||||
Unlike :func:`numpy.take_along_axis`, :func:`jax.numpy.take_along_axis` takes
|
||||
an optional ``mode`` parameter controlling how out-of-bounds indices should be
|
||||
handled. By default, out-of-bounds indices are clamped into range. In a future
|
||||
change, out-of-bounds indices will return invalid (e.g., ``NaN``) values
|
||||
instead. See :attr:`jax.numpy.ndarray.at` for more discussion
|
||||
of out-of-bounds indexing in JAX.
|
||||
handled. By default, out-of-bounds indices yield invalid values (e.g., ``NaN``).
|
||||
See :attr:`jax.numpy.ndarray.at` for futrher discussion of out-of-bounds
|
||||
indexing in JAX.
|
||||
"""
|
||||
|
||||
@_wraps(np.take_along_axis, update_doc=False,
|
||||
@ -3520,9 +3519,8 @@ def take_along_axis(arr, indices, axis: Optional[int],
|
||||
offset_dims=tuple(offset_dims),
|
||||
collapsed_slice_dims=tuple(collapsed_slice_dims),
|
||||
start_index_map=tuple(start_index_map))
|
||||
# TODO(phawkins): change the mode to "fill".
|
||||
return lax.gather(arr, gather_indices, dnums, tuple(slice_sizes),
|
||||
mode="clip" if mode is None else mode)
|
||||
mode="fill" if mode is None else mode)
|
||||
|
||||
### Indexing
|
||||
|
||||
|
@ -4577,15 +4577,16 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
x = jnp.arange(10, dtype=jnp.float32)
|
||||
idx = jnp.array([-11, -10, -9, -5, -1, 0, 1, 5, 9, 10, 11])
|
||||
out = jnp.take_along_axis(x, idx, axis=0)
|
||||
expected_clip = np.array([0, 0, 1, 5, 9, 0, 1, 5, 9, 9, 9], np.float32)
|
||||
np.testing.assert_array_equal(expected_clip, out)
|
||||
out = jnp.take_along_axis(x, idx, axis=0, mode="clip")
|
||||
np.testing.assert_array_equal(expected_clip, out)
|
||||
expected_fill = np.array([jnp.nan, 0, 1, 5, 9, 0, 1, 5, 9, jnp.nan,
|
||||
jnp.nan], np.float32)
|
||||
np.testing.assert_array_equal(expected_fill, out)
|
||||
out = jnp.take_along_axis(x, idx, axis=0, mode="fill")
|
||||
np.testing.assert_array_equal(expected_fill, out)
|
||||
|
||||
expected_clip = np.array([0, 0, 1, 5, 9, 0, 1, 5, 9, 9, 9], np.float32)
|
||||
out = jnp.take_along_axis(x, idx, axis=0, mode="clip")
|
||||
np.testing.assert_array_equal(expected_clip, out)
|
||||
|
||||
def testTakeAlongAxisRequiresIntIndices(self):
|
||||
x = jnp.arange(5)
|
||||
idx = jnp.array([3.], jnp.float32)
|
||||
|
Loading…
x
Reference in New Issue
Block a user