diff --git a/CHANGELOG.md b/CHANGELOG.md index 27cfbe85b..ee02a6f5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * `jax.experimental.maps.mesh` has been deleted. Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh for more information. + * {func}`jax.numpy.take_along_axis` now takes an optional `mode` parameter + that specifies the behavior of out-of-bounds indexing. ## jaxlib 0.3.8 (Unreleased) * [GitHub diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 74a65aad1..190ccb4be 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3427,12 +3427,24 @@ def _normalize_index(index, axis_size): lax.add(index, axis_size_val), index) -@_wraps(np.take_along_axis, update_doc=False) -@partial(jit, static_argnames=('axis',)) -def take_along_axis(arr, indices, axis: Optional[int]): + +TAKE_ALONG_AXIS_DOC = """ +Unlike :func:`numpy.take_along_axis`, :func:`jax.numpy.take_along_axis` takes +an optional ``mode`` parameter controlling how out-of-bounds indices should be +handled. By default, out-of-bounds indices are clamped into range. In a future +change, out-of-bounds indices will return invalid (e.g., ``NaN``) values +instead. See :attr:`jax.numpy.ndarray.at` for more discussion +of out-of-bounds indexing in JAX. +""" + +@_wraps(np.take_along_axis, update_doc=False, + lax_description=TAKE_ALONG_AXIS_DOC) +@partial(jit, static_argnames=('axis', 'mode')) +def take_along_axis(arr, indices, axis: Optional[int], + mode: Optional[Union[str, lax.GatherScatterMode]] = None): _check_arraylike("take_along_axis", arr, indices) # index_dtype = dtypes.dtype(indices) - # TODO(phawkins): reenalbe this check after fixing callers + # TODO(phawkins): reenable this check after fixing callers # if not dtypes.issubdtype(index_dtype, integer): # raise TypeError("take_along_axis indices must be of integer type, got " # f"{str(index_dtype)}") @@ -3510,7 +3522,8 @@ def take_along_axis(arr, indices, axis: Optional[int]): collapsed_slice_dims=tuple(collapsed_slice_dims), start_index_map=tuple(start_index_map)) # TODO(phawkins): change the mode to "fill". - return lax.gather(arr, gather_indices, dnums, tuple(slice_sizes)) + return lax.gather(arr, gather_indices, dnums, tuple(slice_sizes), + mode="clip" if mode is None else mode) ### Indexing diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 634beb99d..3d2566367 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -4573,6 +4573,19 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): q1 = np.take_along_axis( h, g, axis=-1) np.testing.assert_equal(q0, q1) + def testTakeAlongAxisOutOfBounds(self): + x = jnp.arange(10, dtype=jnp.float32) + idx = jnp.array([-11, -10, -9, -5, -1, 0, 1, 5, 9, 10, 11]) + out = jnp.take_along_axis(x, idx, axis=0) + expected_clip = np.array([0, 0, 1, 5, 9, 0, 1, 5, 9, 9, 9], np.float32) + np.testing.assert_array_equal(expected_clip, out) + out = jnp.take_along_axis(x, idx, axis=0, mode="clip") + np.testing.assert_array_equal(expected_clip, out) + expected_fill = np.array([jnp.nan, 0, 1, 5, 9, 0, 1, 5, 9, jnp.nan, + jnp.nan], np.float32) + out = jnp.take_along_axis(x, idx, axis=0, mode="fill") + np.testing.assert_array_equal(expected_fill, out) + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_n={}_increasing={}".format( jtu.format_shape_dtype_string([shape], dtype), @@ -6162,6 +6175,7 @@ class NumpySignaturesTest(jtu.JaxTestCase): 'broadcast_to': ['arr'], 'einsum': ['precision'], 'einsum_path': ['subscripts'], + 'take_along_axis': ['mode'], } mismatches = {}