mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add an optional mode= argument to jnp.take_along_axis.
This allows users of jnp.take_along_axis to override the out-of-bounds indexing behavior. Default to "clip", which for the forward computation is identical to the current behavior. In a future change, we will change this to "fill".
This commit is contained in:
parent
7008b32132
commit
a52f07a21b
@ -16,6 +16,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
* `jax.experimental.maps.mesh` has been deleted.
|
* `jax.experimental.maps.mesh` has been deleted.
|
||||||
Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
|
Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
|
||||||
for more information.
|
for more information.
|
||||||
|
* {func}`jax.numpy.take_along_axis` now takes an optional `mode` parameter
|
||||||
|
that specifies the behavior of out-of-bounds indexing.
|
||||||
|
|
||||||
## jaxlib 0.3.8 (Unreleased)
|
## jaxlib 0.3.8 (Unreleased)
|
||||||
* [GitHub
|
* [GitHub
|
||||||
|
@ -3427,12 +3427,24 @@ def _normalize_index(index, axis_size):
|
|||||||
lax.add(index, axis_size_val),
|
lax.add(index, axis_size_val),
|
||||||
index)
|
index)
|
||||||
|
|
||||||
@_wraps(np.take_along_axis, update_doc=False)
|
|
||||||
@partial(jit, static_argnames=('axis',))
|
TAKE_ALONG_AXIS_DOC = """
|
||||||
def take_along_axis(arr, indices, axis: Optional[int]):
|
Unlike :func:`numpy.take_along_axis`, :func:`jax.numpy.take_along_axis` takes
|
||||||
|
an optional ``mode`` parameter controlling how out-of-bounds indices should be
|
||||||
|
handled. By default, out-of-bounds indices are clamped into range. In a future
|
||||||
|
change, out-of-bounds indices will return invalid (e.g., ``NaN``) values
|
||||||
|
instead. See :attr:`jax.numpy.ndarray.at` for more discussion
|
||||||
|
of out-of-bounds indexing in JAX.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@_wraps(np.take_along_axis, update_doc=False,
|
||||||
|
lax_description=TAKE_ALONG_AXIS_DOC)
|
||||||
|
@partial(jit, static_argnames=('axis', 'mode'))
|
||||||
|
def take_along_axis(arr, indices, axis: Optional[int],
|
||||||
|
mode: Optional[Union[str, lax.GatherScatterMode]] = None):
|
||||||
_check_arraylike("take_along_axis", arr, indices)
|
_check_arraylike("take_along_axis", arr, indices)
|
||||||
# index_dtype = dtypes.dtype(indices)
|
# index_dtype = dtypes.dtype(indices)
|
||||||
# TODO(phawkins): reenalbe this check after fixing callers
|
# TODO(phawkins): reenable this check after fixing callers
|
||||||
# if not dtypes.issubdtype(index_dtype, integer):
|
# if not dtypes.issubdtype(index_dtype, integer):
|
||||||
# raise TypeError("take_along_axis indices must be of integer type, got "
|
# raise TypeError("take_along_axis indices must be of integer type, got "
|
||||||
# f"{str(index_dtype)}")
|
# f"{str(index_dtype)}")
|
||||||
@ -3510,7 +3522,8 @@ def take_along_axis(arr, indices, axis: Optional[int]):
|
|||||||
collapsed_slice_dims=tuple(collapsed_slice_dims),
|
collapsed_slice_dims=tuple(collapsed_slice_dims),
|
||||||
start_index_map=tuple(start_index_map))
|
start_index_map=tuple(start_index_map))
|
||||||
# TODO(phawkins): change the mode to "fill".
|
# TODO(phawkins): change the mode to "fill".
|
||||||
return lax.gather(arr, gather_indices, dnums, tuple(slice_sizes))
|
return lax.gather(arr, gather_indices, dnums, tuple(slice_sizes),
|
||||||
|
mode="clip" if mode is None else mode)
|
||||||
|
|
||||||
### Indexing
|
### Indexing
|
||||||
|
|
||||||
|
@ -4573,6 +4573,19 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
q1 = np.take_along_axis( h, g, axis=-1)
|
q1 = np.take_along_axis( h, g, axis=-1)
|
||||||
np.testing.assert_equal(q0, q1)
|
np.testing.assert_equal(q0, q1)
|
||||||
|
|
||||||
|
def testTakeAlongAxisOutOfBounds(self):
|
||||||
|
x = jnp.arange(10, dtype=jnp.float32)
|
||||||
|
idx = jnp.array([-11, -10, -9, -5, -1, 0, 1, 5, 9, 10, 11])
|
||||||
|
out = jnp.take_along_axis(x, idx, axis=0)
|
||||||
|
expected_clip = np.array([0, 0, 1, 5, 9, 0, 1, 5, 9, 9, 9], np.float32)
|
||||||
|
np.testing.assert_array_equal(expected_clip, out)
|
||||||
|
out = jnp.take_along_axis(x, idx, axis=0, mode="clip")
|
||||||
|
np.testing.assert_array_equal(expected_clip, out)
|
||||||
|
expected_fill = np.array([jnp.nan, 0, 1, 5, 9, 0, 1, 5, 9, jnp.nan,
|
||||||
|
jnp.nan], np.float32)
|
||||||
|
out = jnp.take_along_axis(x, idx, axis=0, mode="fill")
|
||||||
|
np.testing.assert_array_equal(expected_fill, out)
|
||||||
|
|
||||||
@parameterized.named_parameters(jtu.cases_from_list(
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
{"testcase_name": "_shape={}_n={}_increasing={}".format(
|
{"testcase_name": "_shape={}_n={}_increasing={}".format(
|
||||||
jtu.format_shape_dtype_string([shape], dtype),
|
jtu.format_shape_dtype_string([shape], dtype),
|
||||||
@ -6162,6 +6175,7 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
|||||||
'broadcast_to': ['arr'],
|
'broadcast_to': ['arr'],
|
||||||
'einsum': ['precision'],
|
'einsum': ['precision'],
|
||||||
'einsum_path': ['subscripts'],
|
'einsum_path': ['subscripts'],
|
||||||
|
'take_along_axis': ['mode'],
|
||||||
}
|
}
|
||||||
|
|
||||||
mismatches = {}
|
mismatches = {}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user