Add numpy.put_along_axis.

This commit is contained in:
carlosgmartin 2024-11-14 15:23:26 -05:00
parent c4a0369f5c
commit 1f114b1cf7
8 changed files with 174 additions and 3 deletions

View File

@ -48,6 +48,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* {func}`jax.tree_util.register_dataclass` now allows metadata fields to be
declared inline via {func}`dataclasses.field`. See the function documentation
for examples.
* Added {func}`jax.numpy.put_along_axis`.
* Bug fixes
* Fixed a bug where the GPU implementations of LU and QR decomposition would

View File

@ -337,6 +337,7 @@ namespace; they are listed below.
promote_types
ptp
put
put_along_axis
quantile
r_
rad2deg

View File

@ -68,7 +68,7 @@ from jax._src.typing import (
)
from jax._src.util import (
NumpyComplexWarning, canonicalize_axis as _canonicalize_axis,
ceil_of_ratio, partition_list, safe_zip, subvals,unzip2)
ceil_of_ratio, partition_list, safe_zip, subvals,unzip2, tuple_replace)
from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding,
PartitionSpec as P)
from jax.tree_util import tree_flatten, tree_leaves, tree_map
@ -11433,6 +11433,105 @@ def take_along_axis(
mode="fill" if mode is None else mode, fill_value=fill_value)
_indices = indices # argument below named 'indices' shadows the function
def _make_along_axis_idx(shape, indices, axis):
return tuple_replace(_indices(shape, sparse=True), axis, indices)
@partial(jit, static_argnames=('axis', 'inplace', 'mode'))
def put_along_axis(
arr: ArrayLike,
indices: ArrayLike,
values: ArrayLike,
axis: int | None,
inplace: bool = True,
*,
mode: str | None = None,
) -> Array:
"""Put values into the destination array by matching 1d index and data slices.
JAX implementation of :func:`numpy.put_along_axis`.
The semantics of :func:`numpy.put_along_axis` are to modify arrays in-place, which
is not possible for JAX's immutable arrays. The JAX version returns a modified
copy of the input, and adds the ``inplace`` parameter which must be set to
`False`` by the user as a reminder of this API difference.
Args:
arr: array into which values will be put.
indices: array of indices at which to put values.
values: array of values to put into the array.
axis: the axis along which to put values. If not specified, the array will
be flattened before indexing is applied.
inplace: must be set to False to indicate that the input is not modified
in-place, but rather a modified copy is returned.
mode: Out-of-bounds indexing mode. For more discussion of ``mode`` options,
see :attr:`jax.numpy.ndarray.at`.
Returns:
A copy of ``a`` with specified entries updated.
See Also:
- :func:`jax.numpy.put`: put elements into an array at given indices.
- :func:`jax.numpy.place`: place elements into an array via boolean mask.
- :func:`jax.numpy.ndarray.at`: array updates using NumPy-style indexing.
- :func:`jax.numpy.take`: extract values from an array at given indices.
- :func:`jax.numpy.take_along_axis`: extract values from an array along an axis.
Examples:
>>> from jax import numpy as jnp
>>> a = jnp.array([[10, 30, 20], [60, 40, 50]])
>>> i = jnp.argmax(a, axis=1, keepdims=True)
>>> print(i)
[[1]
[0]]
>>> b = jnp.put_along_axis(a, i, 99, axis=1, inplace=False)
>>> print(b)
[[10 99 20]
[99 40 50]]
"""
if inplace:
raise ValueError(
"jax.numpy.put_along_axis cannot modify arrays in-place, because JAX arrays"
"are immutable. Pass inplace=False to instead return an updated array.")
util.check_arraylike("put_along_axis", arr, indices, values)
arr = asarray(arr)
indices = asarray(indices)
values = asarray(values)
original_axis = axis
original_arr_shape = arr.shape
if axis is None:
arr = arr.ravel()
axis = 0
if not arr.ndim == indices.ndim:
raise ValueError(
"put_along_axis arguments 'arr' and 'indices' must have same ndim. Got "
f"{arr.ndim=} and {indices.ndim=}."
)
try:
values = broadcast_to(values, indices.shape)
except ValueError:
raise ValueError(
"put_along_axis argument 'values' must be broadcastable to 'indices'. Got "
f"{values.shape=} and {indices.shape=}."
)
idx = _make_along_axis_idx(arr.shape, indices, axis)
result = arr.at[idx].set(values, mode=mode)
if original_axis is None:
result = result.reshape(original_arr_shape)
return result
### Indexing
def _is_integer_index(idx: Any) -> bool:

View File

@ -965,6 +965,31 @@ def rand_unique_int(rng, high=None):
size=shape, replace=False)
return fn
def rand_indices_unique_along_axis(rng):
"""Sample an array of given shape containing indices up to dim (exclusive),
such that the indices are unique along the given axis.
Optionally, convert some of the resulting indices to negative indices."""
def fn(dim, shape, axis, allow_negative=True):
batch_size = math.prod(shape[:axis] + shape[axis:][1:])
idx = [
rng.choice(dim, size=shape[axis], replace=False)
for _ in range(batch_size)
]
idx = np.array(idx).reshape(batch_size, shape[axis])
idx = idx.reshape(shape[:axis] + shape[axis:][1:] + (shape[axis],))
idx = np.moveaxis(idx, -1, axis)
# assert that indices are unique along the given axis
count = partial(np.bincount, minlength=dim)
assert (np.apply_along_axis(count, axis, idx) <= 1).all()
if allow_negative:
mask = rng.choice([False, True], idx.shape)
idx[mask] -= dim
return idx
return fn
def rand_bool(rng):
def generator(shape, dtype):
return _cast_to_shape(

View File

@ -453,6 +453,10 @@ def tuple_update(t, idx, val):
assert 0 <= idx < len(t), (idx, len(t))
return t[:idx] + (val,) + t[idx+1:]
def tuple_replace(tupl, index, item):
# unlike tuple_update, works with negative indices as well
return tupl[:index] + (item,) + tupl[index:][1:]
class HashableFunction:
"""Decouples function equality and hash from its identity.

View File

@ -202,6 +202,7 @@ from jax._src.numpy.lax_numpy import (
printoptions as printoptions,
promote_types as promote_types,
put as put,
put_along_axis as put_along_axis,
ravel as ravel,
ravel_multi_index as ravel_multi_index,
repeat as repeat,

View File

@ -742,6 +742,8 @@ def ptp(a: ArrayLike, axis: _Axis = ..., out: None = ...,
keepdims: builtins.bool = ...) -> Array: ...
def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike,
mode: str | None = ..., *, inplace: builtins.bool = ...) -> Array: ...
def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike,
axis: int | None, inplace: bool = True, *, mode: str | None = None) -> Array: ...
def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ...,
out: None = ..., overwrite_input: builtins.bool = ..., method: str = ...,
keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ...

View File

@ -51,7 +51,7 @@ from jax._src import deprecations
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal
from jax._src.util import safe_zip, NumpyComplexWarning
from jax._src.util import safe_zip, NumpyComplexWarning, tuple_replace
config.parse_flags_with_absl()
@ -5962,6 +5962,45 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[
dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=axis)
for a_shape in nonempty_array_shapes
for axis in list(range(-len(a_shape), len(a_shape)))
for i_shape in [tuple_replace(a_shape, axis, J) for J in range(a_shape[axis] + 1)]
for v_shape in [(), (1,), i_shape]
] + [
dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=None)
for a_shape in nonempty_array_shapes
for i_shape in [(J,) for J in range(math.prod(a_shape) + 1)]
for v_shape in [(), (1,), i_shape]
],
dtype=jtu.dtypes.all,
mode=[None, "promise_in_bounds", "clip"],
)
def testPutAlongAxis(self, a_shape, i_shape, v_shape, axis, dtype, mode):
a_rng = jtu.rand_default(self.rng())
if axis is None:
size = math.prod(a_shape)
else:
size = a_shape[axis]
i_rng = jtu.rand_indices_unique_along_axis(self.rng())
def args_maker():
a = a_rng(a_shape, dtype)
i = i_rng(dim=size, shape=i_shape, axis=0 if axis is None else axis)
v = a_rng(v_shape, dtype)
return a, i, v
def np_fun(a, i, v):
a_copy = a.copy()
np.put_along_axis(a_copy, i, v, axis=axis)
return a_copy
jnp_fun = partial(jnp.put_along_axis, axis=axis, inplace=False, mode=mode)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def test_rot90_error(self):
with self.assertRaisesRegex(
ValueError,
@ -6229,7 +6268,6 @@ class NumpySignaturesTest(jtu.JaxTestCase):
'nditer',
'nested_iters',
'poly1d',
'put_along_axis',
'putmask',
'real_if_close',
'recarray',