diff --git a/CHANGELOG.md b/CHANGELOG.md index 17d15c740..b0b64ac71 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {func}`jax.tree_util.register_dataclass` now allows metadata fields to be declared inline via {func}`dataclasses.field`. See the function documentation for examples. + * Added {func}`jax.numpy.put_along_axis`. * Bug fixes * Fixed a bug where the GPU implementations of LU and QR decomposition would diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index 3922c92d9..30553a360 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -337,6 +337,7 @@ namespace; they are listed below. promote_types ptp put + put_along_axis quantile r_ rad2deg diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index b90004e19..3ff38f16b 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -68,7 +68,7 @@ from jax._src.typing import ( ) from jax._src.util import ( NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, - ceil_of_ratio, partition_list, safe_zip, subvals,unzip2) + ceil_of_ratio, partition_list, safe_zip, subvals,unzip2, tuple_replace) from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding, PartitionSpec as P) from jax.tree_util import tree_flatten, tree_leaves, tree_map @@ -11433,6 +11433,105 @@ def take_along_axis( mode="fill" if mode is None else mode, fill_value=fill_value) +_indices = indices # argument below named 'indices' shadows the function + + +def _make_along_axis_idx(shape, indices, axis): + return tuple_replace(_indices(shape, sparse=True), axis, indices) + + +@partial(jit, static_argnames=('axis', 'inplace', 'mode')) +def put_along_axis( + arr: ArrayLike, + indices: ArrayLike, + values: ArrayLike, + axis: int | None, + inplace: bool = True, + *, + mode: str | None = None, +) -> Array: + """Put values into the destination array by matching 1d index and data slices. + + JAX implementation of :func:`numpy.put_along_axis`. + + The semantics of :func:`numpy.put_along_axis` are to modify arrays in-place, which + is not possible for JAX's immutable arrays. The JAX version returns a modified + copy of the input, and adds the ``inplace`` parameter which must be set to + `False`` by the user as a reminder of this API difference. + + Args: + arr: array into which values will be put. + indices: array of indices at which to put values. + values: array of values to put into the array. + axis: the axis along which to put values. If not specified, the array will + be flattened before indexing is applied. + inplace: must be set to False to indicate that the input is not modified + in-place, but rather a modified copy is returned. + mode: Out-of-bounds indexing mode. For more discussion of ``mode`` options, + see :attr:`jax.numpy.ndarray.at`. + + Returns: + A copy of ``a`` with specified entries updated. + + See Also: + - :func:`jax.numpy.put`: put elements into an array at given indices. + - :func:`jax.numpy.place`: place elements into an array via boolean mask. + - :func:`jax.numpy.ndarray.at`: array updates using NumPy-style indexing. + - :func:`jax.numpy.take`: extract values from an array at given indices. + - :func:`jax.numpy.take_along_axis`: extract values from an array along an axis. + + Examples: + >>> from jax import numpy as jnp + >>> a = jnp.array([[10, 30, 20], [60, 40, 50]]) + >>> i = jnp.argmax(a, axis=1, keepdims=True) + >>> print(i) + [[1] + [0]] + >>> b = jnp.put_along_axis(a, i, 99, axis=1, inplace=False) + >>> print(b) + [[10 99 20] + [99 40 50]] + """ + if inplace: + raise ValueError( + "jax.numpy.put_along_axis cannot modify arrays in-place, because JAX arrays" + "are immutable. Pass inplace=False to instead return an updated array.") + + util.check_arraylike("put_along_axis", arr, indices, values) + arr = asarray(arr) + indices = asarray(indices) + values = asarray(values) + + original_axis = axis + original_arr_shape = arr.shape + + if axis is None: + arr = arr.ravel() + axis = 0 + + if not arr.ndim == indices.ndim: + raise ValueError( + "put_along_axis arguments 'arr' and 'indices' must have same ndim. Got " + f"{arr.ndim=} and {indices.ndim=}." + ) + + try: + values = broadcast_to(values, indices.shape) + except ValueError: + raise ValueError( + "put_along_axis argument 'values' must be broadcastable to 'indices'. Got " + f"{values.shape=} and {indices.shape=}." + ) + + idx = _make_along_axis_idx(arr.shape, indices, axis) + result = arr.at[idx].set(values, mode=mode) + + if original_axis is None: + result = result.reshape(original_arr_shape) + + return result + + ### Indexing def _is_integer_index(idx: Any) -> bool: diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 78de511d4..e546ebd2a 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -965,6 +965,31 @@ def rand_unique_int(rng, high=None): size=shape, replace=False) return fn +def rand_indices_unique_along_axis(rng): + """Sample an array of given shape containing indices up to dim (exclusive), + such that the indices are unique along the given axis. + Optionally, convert some of the resulting indices to negative indices.""" + def fn(dim, shape, axis, allow_negative=True): + batch_size = math.prod(shape[:axis] + shape[axis:][1:]) + idx = [ + rng.choice(dim, size=shape[axis], replace=False) + for _ in range(batch_size) + ] + idx = np.array(idx).reshape(batch_size, shape[axis]) + idx = idx.reshape(shape[:axis] + shape[axis:][1:] + (shape[axis],)) + idx = np.moveaxis(idx, -1, axis) + + # assert that indices are unique along the given axis + count = partial(np.bincount, minlength=dim) + assert (np.apply_along_axis(count, axis, idx) <= 1).all() + + if allow_negative: + mask = rng.choice([False, True], idx.shape) + idx[mask] -= dim + return idx + + return fn + def rand_bool(rng): def generator(shape, dtype): return _cast_to_shape( diff --git a/jax/_src/util.py b/jax/_src/util.py index fce342c49..8dcc5eaa5 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -453,6 +453,10 @@ def tuple_update(t, idx, val): assert 0 <= idx < len(t), (idx, len(t)) return t[:idx] + (val,) + t[idx+1:] +def tuple_replace(tupl, index, item): + # unlike tuple_update, works with negative indices as well + return tupl[:index] + (item,) + tupl[index:][1:] + class HashableFunction: """Decouples function equality and hash from its identity. diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 9be73e96a..2ab0a0e3d 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -202,6 +202,7 @@ from jax._src.numpy.lax_numpy import ( printoptions as printoptions, promote_types as promote_types, put as put, + put_along_axis as put_along_axis, ravel as ravel, ravel_multi_index as ravel_multi_index, repeat as repeat, diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index d391abd46..339174136 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -742,6 +742,8 @@ def ptp(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ...) -> Array: ... def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike, mode: str | None = ..., *, inplace: builtins.bool = ...) -> Array: ... +def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, + axis: int | None, inplace: bool = True, *, mode: str | None = None) -> Array: ... def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., method: str = ..., keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ... diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 7c2728af4..a1817f528 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -51,7 +51,7 @@ from jax._src import deprecations from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal -from jax._src.util import safe_zip, NumpyComplexWarning +from jax._src.util import safe_zip, NumpyComplexWarning, tuple_replace config.parse_flags_with_absl() @@ -5962,6 +5962,45 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + [ + dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=axis) + for a_shape in nonempty_array_shapes + for axis in list(range(-len(a_shape), len(a_shape))) + for i_shape in [tuple_replace(a_shape, axis, J) for J in range(a_shape[axis] + 1)] + for v_shape in [(), (1,), i_shape] + ] + [ + dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=None) + for a_shape in nonempty_array_shapes + for i_shape in [(J,) for J in range(math.prod(a_shape) + 1)] + for v_shape in [(), (1,), i_shape] + ], + dtype=jtu.dtypes.all, + mode=[None, "promise_in_bounds", "clip"], + ) + def testPutAlongAxis(self, a_shape, i_shape, v_shape, axis, dtype, mode): + a_rng = jtu.rand_default(self.rng()) + if axis is None: + size = math.prod(a_shape) + else: + size = a_shape[axis] + i_rng = jtu.rand_indices_unique_along_axis(self.rng()) + + def args_maker(): + a = a_rng(a_shape, dtype) + i = i_rng(dim=size, shape=i_shape, axis=0 if axis is None else axis) + v = a_rng(v_shape, dtype) + return a, i, v + + def np_fun(a, i, v): + a_copy = a.copy() + np.put_along_axis(a_copy, i, v, axis=axis) + return a_copy + + jnp_fun = partial(jnp.put_along_axis, axis=axis, inplace=False, mode=mode) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + def test_rot90_error(self): with self.assertRaisesRegex( ValueError, @@ -6229,7 +6268,6 @@ class NumpySignaturesTest(jtu.JaxTestCase): 'nditer', 'nested_iters', 'poly1d', - 'put_along_axis', 'putmask', 'real_if_close', 'recarray',