mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add numpy.put_along_axis.
This commit is contained in:
parent
c4a0369f5c
commit
1f114b1cf7
@ -48,6 +48,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* {func}`jax.tree_util.register_dataclass` now allows metadata fields to be
|
||||
declared inline via {func}`dataclasses.field`. See the function documentation
|
||||
for examples.
|
||||
* Added {func}`jax.numpy.put_along_axis`.
|
||||
|
||||
* Bug fixes
|
||||
* Fixed a bug where the GPU implementations of LU and QR decomposition would
|
||||
|
@ -337,6 +337,7 @@ namespace; they are listed below.
|
||||
promote_types
|
||||
ptp
|
||||
put
|
||||
put_along_axis
|
||||
quantile
|
||||
r_
|
||||
rad2deg
|
||||
|
@ -68,7 +68,7 @@ from jax._src.typing import (
|
||||
)
|
||||
from jax._src.util import (
|
||||
NumpyComplexWarning, canonicalize_axis as _canonicalize_axis,
|
||||
ceil_of_ratio, partition_list, safe_zip, subvals,unzip2)
|
||||
ceil_of_ratio, partition_list, safe_zip, subvals,unzip2, tuple_replace)
|
||||
from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding,
|
||||
PartitionSpec as P)
|
||||
from jax.tree_util import tree_flatten, tree_leaves, tree_map
|
||||
@ -11433,6 +11433,105 @@ def take_along_axis(
|
||||
mode="fill" if mode is None else mode, fill_value=fill_value)
|
||||
|
||||
|
||||
_indices = indices # argument below named 'indices' shadows the function
|
||||
|
||||
|
||||
def _make_along_axis_idx(shape, indices, axis):
|
||||
return tuple_replace(_indices(shape, sparse=True), axis, indices)
|
||||
|
||||
|
||||
@partial(jit, static_argnames=('axis', 'inplace', 'mode'))
|
||||
def put_along_axis(
|
||||
arr: ArrayLike,
|
||||
indices: ArrayLike,
|
||||
values: ArrayLike,
|
||||
axis: int | None,
|
||||
inplace: bool = True,
|
||||
*,
|
||||
mode: str | None = None,
|
||||
) -> Array:
|
||||
"""Put values into the destination array by matching 1d index and data slices.
|
||||
|
||||
JAX implementation of :func:`numpy.put_along_axis`.
|
||||
|
||||
The semantics of :func:`numpy.put_along_axis` are to modify arrays in-place, which
|
||||
is not possible for JAX's immutable arrays. The JAX version returns a modified
|
||||
copy of the input, and adds the ``inplace`` parameter which must be set to
|
||||
`False`` by the user as a reminder of this API difference.
|
||||
|
||||
Args:
|
||||
arr: array into which values will be put.
|
||||
indices: array of indices at which to put values.
|
||||
values: array of values to put into the array.
|
||||
axis: the axis along which to put values. If not specified, the array will
|
||||
be flattened before indexing is applied.
|
||||
inplace: must be set to False to indicate that the input is not modified
|
||||
in-place, but rather a modified copy is returned.
|
||||
mode: Out-of-bounds indexing mode. For more discussion of ``mode`` options,
|
||||
see :attr:`jax.numpy.ndarray.at`.
|
||||
|
||||
Returns:
|
||||
A copy of ``a`` with specified entries updated.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.numpy.put`: put elements into an array at given indices.
|
||||
- :func:`jax.numpy.place`: place elements into an array via boolean mask.
|
||||
- :func:`jax.numpy.ndarray.at`: array updates using NumPy-style indexing.
|
||||
- :func:`jax.numpy.take`: extract values from an array at given indices.
|
||||
- :func:`jax.numpy.take_along_axis`: extract values from an array along an axis.
|
||||
|
||||
Examples:
|
||||
>>> from jax import numpy as jnp
|
||||
>>> a = jnp.array([[10, 30, 20], [60, 40, 50]])
|
||||
>>> i = jnp.argmax(a, axis=1, keepdims=True)
|
||||
>>> print(i)
|
||||
[[1]
|
||||
[0]]
|
||||
>>> b = jnp.put_along_axis(a, i, 99, axis=1, inplace=False)
|
||||
>>> print(b)
|
||||
[[10 99 20]
|
||||
[99 40 50]]
|
||||
"""
|
||||
if inplace:
|
||||
raise ValueError(
|
||||
"jax.numpy.put_along_axis cannot modify arrays in-place, because JAX arrays"
|
||||
"are immutable. Pass inplace=False to instead return an updated array.")
|
||||
|
||||
util.check_arraylike("put_along_axis", arr, indices, values)
|
||||
arr = asarray(arr)
|
||||
indices = asarray(indices)
|
||||
values = asarray(values)
|
||||
|
||||
original_axis = axis
|
||||
original_arr_shape = arr.shape
|
||||
|
||||
if axis is None:
|
||||
arr = arr.ravel()
|
||||
axis = 0
|
||||
|
||||
if not arr.ndim == indices.ndim:
|
||||
raise ValueError(
|
||||
"put_along_axis arguments 'arr' and 'indices' must have same ndim. Got "
|
||||
f"{arr.ndim=} and {indices.ndim=}."
|
||||
)
|
||||
|
||||
try:
|
||||
values = broadcast_to(values, indices.shape)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"put_along_axis argument 'values' must be broadcastable to 'indices'. Got "
|
||||
f"{values.shape=} and {indices.shape=}."
|
||||
)
|
||||
|
||||
idx = _make_along_axis_idx(arr.shape, indices, axis)
|
||||
result = arr.at[idx].set(values, mode=mode)
|
||||
|
||||
if original_axis is None:
|
||||
result = result.reshape(original_arr_shape)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
### Indexing
|
||||
|
||||
def _is_integer_index(idx: Any) -> bool:
|
||||
|
@ -965,6 +965,31 @@ def rand_unique_int(rng, high=None):
|
||||
size=shape, replace=False)
|
||||
return fn
|
||||
|
||||
def rand_indices_unique_along_axis(rng):
|
||||
"""Sample an array of given shape containing indices up to dim (exclusive),
|
||||
such that the indices are unique along the given axis.
|
||||
Optionally, convert some of the resulting indices to negative indices."""
|
||||
def fn(dim, shape, axis, allow_negative=True):
|
||||
batch_size = math.prod(shape[:axis] + shape[axis:][1:])
|
||||
idx = [
|
||||
rng.choice(dim, size=shape[axis], replace=False)
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
idx = np.array(idx).reshape(batch_size, shape[axis])
|
||||
idx = idx.reshape(shape[:axis] + shape[axis:][1:] + (shape[axis],))
|
||||
idx = np.moveaxis(idx, -1, axis)
|
||||
|
||||
# assert that indices are unique along the given axis
|
||||
count = partial(np.bincount, minlength=dim)
|
||||
assert (np.apply_along_axis(count, axis, idx) <= 1).all()
|
||||
|
||||
if allow_negative:
|
||||
mask = rng.choice([False, True], idx.shape)
|
||||
idx[mask] -= dim
|
||||
return idx
|
||||
|
||||
return fn
|
||||
|
||||
def rand_bool(rng):
|
||||
def generator(shape, dtype):
|
||||
return _cast_to_shape(
|
||||
|
@ -453,6 +453,10 @@ def tuple_update(t, idx, val):
|
||||
assert 0 <= idx < len(t), (idx, len(t))
|
||||
return t[:idx] + (val,) + t[idx+1:]
|
||||
|
||||
def tuple_replace(tupl, index, item):
|
||||
# unlike tuple_update, works with negative indices as well
|
||||
return tupl[:index] + (item,) + tupl[index:][1:]
|
||||
|
||||
class HashableFunction:
|
||||
"""Decouples function equality and hash from its identity.
|
||||
|
||||
|
@ -202,6 +202,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
printoptions as printoptions,
|
||||
promote_types as promote_types,
|
||||
put as put,
|
||||
put_along_axis as put_along_axis,
|
||||
ravel as ravel,
|
||||
ravel_multi_index as ravel_multi_index,
|
||||
repeat as repeat,
|
||||
|
@ -742,6 +742,8 @@ def ptp(a: ArrayLike, axis: _Axis = ..., out: None = ...,
|
||||
keepdims: builtins.bool = ...) -> Array: ...
|
||||
def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike,
|
||||
mode: str | None = ..., *, inplace: builtins.bool = ...) -> Array: ...
|
||||
def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike,
|
||||
axis: int | None, inplace: bool = True, *, mode: str | None = None) -> Array: ...
|
||||
def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ...,
|
||||
out: None = ..., overwrite_input: builtins.bool = ..., method: str = ...,
|
||||
keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ...
|
||||
|
@ -51,7 +51,7 @@ from jax._src import deprecations
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.util import safe_zip, NumpyComplexWarning
|
||||
from jax._src.util import safe_zip, NumpyComplexWarning, tuple_replace
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -5962,6 +5962,45 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
[
|
||||
dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=axis)
|
||||
for a_shape in nonempty_array_shapes
|
||||
for axis in list(range(-len(a_shape), len(a_shape)))
|
||||
for i_shape in [tuple_replace(a_shape, axis, J) for J in range(a_shape[axis] + 1)]
|
||||
for v_shape in [(), (1,), i_shape]
|
||||
] + [
|
||||
dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=None)
|
||||
for a_shape in nonempty_array_shapes
|
||||
for i_shape in [(J,) for J in range(math.prod(a_shape) + 1)]
|
||||
for v_shape in [(), (1,), i_shape]
|
||||
],
|
||||
dtype=jtu.dtypes.all,
|
||||
mode=[None, "promise_in_bounds", "clip"],
|
||||
)
|
||||
def testPutAlongAxis(self, a_shape, i_shape, v_shape, axis, dtype, mode):
|
||||
a_rng = jtu.rand_default(self.rng())
|
||||
if axis is None:
|
||||
size = math.prod(a_shape)
|
||||
else:
|
||||
size = a_shape[axis]
|
||||
i_rng = jtu.rand_indices_unique_along_axis(self.rng())
|
||||
|
||||
def args_maker():
|
||||
a = a_rng(a_shape, dtype)
|
||||
i = i_rng(dim=size, shape=i_shape, axis=0 if axis is None else axis)
|
||||
v = a_rng(v_shape, dtype)
|
||||
return a, i, v
|
||||
|
||||
def np_fun(a, i, v):
|
||||
a_copy = a.copy()
|
||||
np.put_along_axis(a_copy, i, v, axis=axis)
|
||||
return a_copy
|
||||
|
||||
jnp_fun = partial(jnp.put_along_axis, axis=axis, inplace=False, mode=mode)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
def test_rot90_error(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
@ -6229,7 +6268,6 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
'nditer',
|
||||
'nested_iters',
|
||||
'poly1d',
|
||||
'put_along_axis',
|
||||
'putmask',
|
||||
'real_if_close',
|
||||
'recarray',
|
||||
|
Loading…
x
Reference in New Issue
Block a user