diff --git a/CHANGELOG.md b/CHANGELOG.md index 4fc2600ea..ee6dad30e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,7 +17,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh for more information. * {func}`jax.numpy.take_along_axis` now takes an optional `mode` parameter - that specifies the behavior of out-of-bounds indexing. + that specifies the behavior of out-of-bounds indexing. By default, + invalid values (e.g., NaN) will be returned for out-of-bounds indices. In + previous versions of JAX, invalid indices were clamped into range. The + previous behavior can be restored by passing `mode="clip"`. * {func}`jax.numpy.take_along_axis` now raises a `TypeError` if its indices are not of an integer type, matching the behavior of {func}`numpy.take_along_axis`. Previously non-integer indices were silently diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index cab728fac..9d1ef2be6 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3431,10 +3431,9 @@ def _normalize_index(index, axis_size): TAKE_ALONG_AXIS_DOC = """ Unlike :func:`numpy.take_along_axis`, :func:`jax.numpy.take_along_axis` takes an optional ``mode`` parameter controlling how out-of-bounds indices should be -handled. By default, out-of-bounds indices are clamped into range. In a future -change, out-of-bounds indices will return invalid (e.g., ``NaN``) values -instead. See :attr:`jax.numpy.ndarray.at` for more discussion -of out-of-bounds indexing in JAX. +handled. By default, out-of-bounds indices yield invalid values (e.g., ``NaN``). +See :attr:`jax.numpy.ndarray.at` for futrher discussion of out-of-bounds +indexing in JAX. """ @_wraps(np.take_along_axis, update_doc=False, @@ -3520,9 +3519,8 @@ def take_along_axis(arr, indices, axis: Optional[int], offset_dims=tuple(offset_dims), collapsed_slice_dims=tuple(collapsed_slice_dims), start_index_map=tuple(start_index_map)) - # TODO(phawkins): change the mode to "fill". return lax.gather(arr, gather_indices, dnums, tuple(slice_sizes), - mode="clip" if mode is None else mode) + mode="fill" if mode is None else mode) ### Indexing diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index a635f07d7..0a3ab07d2 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -4577,15 +4577,16 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): x = jnp.arange(10, dtype=jnp.float32) idx = jnp.array([-11, -10, -9, -5, -1, 0, 1, 5, 9, 10, 11]) out = jnp.take_along_axis(x, idx, axis=0) - expected_clip = np.array([0, 0, 1, 5, 9, 0, 1, 5, 9, 9, 9], np.float32) - np.testing.assert_array_equal(expected_clip, out) - out = jnp.take_along_axis(x, idx, axis=0, mode="clip") - np.testing.assert_array_equal(expected_clip, out) expected_fill = np.array([jnp.nan, 0, 1, 5, 9, 0, 1, 5, 9, jnp.nan, jnp.nan], np.float32) + np.testing.assert_array_equal(expected_fill, out) out = jnp.take_along_axis(x, idx, axis=0, mode="fill") np.testing.assert_array_equal(expected_fill, out) + expected_clip = np.array([0, 0, 1, 5, 9, 0, 1, 5, 9, 9, 9], np.float32) + out = jnp.take_along_axis(x, idx, axis=0, mode="clip") + np.testing.assert_array_equal(expected_clip, out) + def testTakeAlongAxisRequiresIntIndices(self): x = jnp.arange(5) idx = jnp.array([3.], jnp.float32)