mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add numpy.put_along_axis.
This commit is contained in:
parent
c4a0369f5c
commit
1f114b1cf7
@ -48,6 +48,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
|||||||
* {func}`jax.tree_util.register_dataclass` now allows metadata fields to be
|
* {func}`jax.tree_util.register_dataclass` now allows metadata fields to be
|
||||||
declared inline via {func}`dataclasses.field`. See the function documentation
|
declared inline via {func}`dataclasses.field`. See the function documentation
|
||||||
for examples.
|
for examples.
|
||||||
|
* Added {func}`jax.numpy.put_along_axis`.
|
||||||
|
|
||||||
* Bug fixes
|
* Bug fixes
|
||||||
* Fixed a bug where the GPU implementations of LU and QR decomposition would
|
* Fixed a bug where the GPU implementations of LU and QR decomposition would
|
||||||
|
@ -337,6 +337,7 @@ namespace; they are listed below.
|
|||||||
promote_types
|
promote_types
|
||||||
ptp
|
ptp
|
||||||
put
|
put
|
||||||
|
put_along_axis
|
||||||
quantile
|
quantile
|
||||||
r_
|
r_
|
||||||
rad2deg
|
rad2deg
|
||||||
|
@ -68,7 +68,7 @@ from jax._src.typing import (
|
|||||||
)
|
)
|
||||||
from jax._src.util import (
|
from jax._src.util import (
|
||||||
NumpyComplexWarning, canonicalize_axis as _canonicalize_axis,
|
NumpyComplexWarning, canonicalize_axis as _canonicalize_axis,
|
||||||
ceil_of_ratio, partition_list, safe_zip, subvals,unzip2)
|
ceil_of_ratio, partition_list, safe_zip, subvals,unzip2, tuple_replace)
|
||||||
from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding,
|
from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding,
|
||||||
PartitionSpec as P)
|
PartitionSpec as P)
|
||||||
from jax.tree_util import tree_flatten, tree_leaves, tree_map
|
from jax.tree_util import tree_flatten, tree_leaves, tree_map
|
||||||
@ -11433,6 +11433,105 @@ def take_along_axis(
|
|||||||
mode="fill" if mode is None else mode, fill_value=fill_value)
|
mode="fill" if mode is None else mode, fill_value=fill_value)
|
||||||
|
|
||||||
|
|
||||||
|
_indices = indices # argument below named 'indices' shadows the function
|
||||||
|
|
||||||
|
|
||||||
|
def _make_along_axis_idx(shape, indices, axis):
|
||||||
|
return tuple_replace(_indices(shape, sparse=True), axis, indices)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(jit, static_argnames=('axis', 'inplace', 'mode'))
|
||||||
|
def put_along_axis(
|
||||||
|
arr: ArrayLike,
|
||||||
|
indices: ArrayLike,
|
||||||
|
values: ArrayLike,
|
||||||
|
axis: int | None,
|
||||||
|
inplace: bool = True,
|
||||||
|
*,
|
||||||
|
mode: str | None = None,
|
||||||
|
) -> Array:
|
||||||
|
"""Put values into the destination array by matching 1d index and data slices.
|
||||||
|
|
||||||
|
JAX implementation of :func:`numpy.put_along_axis`.
|
||||||
|
|
||||||
|
The semantics of :func:`numpy.put_along_axis` are to modify arrays in-place, which
|
||||||
|
is not possible for JAX's immutable arrays. The JAX version returns a modified
|
||||||
|
copy of the input, and adds the ``inplace`` parameter which must be set to
|
||||||
|
`False`` by the user as a reminder of this API difference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arr: array into which values will be put.
|
||||||
|
indices: array of indices at which to put values.
|
||||||
|
values: array of values to put into the array.
|
||||||
|
axis: the axis along which to put values. If not specified, the array will
|
||||||
|
be flattened before indexing is applied.
|
||||||
|
inplace: must be set to False to indicate that the input is not modified
|
||||||
|
in-place, but rather a modified copy is returned.
|
||||||
|
mode: Out-of-bounds indexing mode. For more discussion of ``mode`` options,
|
||||||
|
see :attr:`jax.numpy.ndarray.at`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A copy of ``a`` with specified entries updated.
|
||||||
|
|
||||||
|
See Also:
|
||||||
|
- :func:`jax.numpy.put`: put elements into an array at given indices.
|
||||||
|
- :func:`jax.numpy.place`: place elements into an array via boolean mask.
|
||||||
|
- :func:`jax.numpy.ndarray.at`: array updates using NumPy-style indexing.
|
||||||
|
- :func:`jax.numpy.take`: extract values from an array at given indices.
|
||||||
|
- :func:`jax.numpy.take_along_axis`: extract values from an array along an axis.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from jax import numpy as jnp
|
||||||
|
>>> a = jnp.array([[10, 30, 20], [60, 40, 50]])
|
||||||
|
>>> i = jnp.argmax(a, axis=1, keepdims=True)
|
||||||
|
>>> print(i)
|
||||||
|
[[1]
|
||||||
|
[0]]
|
||||||
|
>>> b = jnp.put_along_axis(a, i, 99, axis=1, inplace=False)
|
||||||
|
>>> print(b)
|
||||||
|
[[10 99 20]
|
||||||
|
[99 40 50]]
|
||||||
|
"""
|
||||||
|
if inplace:
|
||||||
|
raise ValueError(
|
||||||
|
"jax.numpy.put_along_axis cannot modify arrays in-place, because JAX arrays"
|
||||||
|
"are immutable. Pass inplace=False to instead return an updated array.")
|
||||||
|
|
||||||
|
util.check_arraylike("put_along_axis", arr, indices, values)
|
||||||
|
arr = asarray(arr)
|
||||||
|
indices = asarray(indices)
|
||||||
|
values = asarray(values)
|
||||||
|
|
||||||
|
original_axis = axis
|
||||||
|
original_arr_shape = arr.shape
|
||||||
|
|
||||||
|
if axis is None:
|
||||||
|
arr = arr.ravel()
|
||||||
|
axis = 0
|
||||||
|
|
||||||
|
if not arr.ndim == indices.ndim:
|
||||||
|
raise ValueError(
|
||||||
|
"put_along_axis arguments 'arr' and 'indices' must have same ndim. Got "
|
||||||
|
f"{arr.ndim=} and {indices.ndim=}."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
values = broadcast_to(values, indices.shape)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(
|
||||||
|
"put_along_axis argument 'values' must be broadcastable to 'indices'. Got "
|
||||||
|
f"{values.shape=} and {indices.shape=}."
|
||||||
|
)
|
||||||
|
|
||||||
|
idx = _make_along_axis_idx(arr.shape, indices, axis)
|
||||||
|
result = arr.at[idx].set(values, mode=mode)
|
||||||
|
|
||||||
|
if original_axis is None:
|
||||||
|
result = result.reshape(original_arr_shape)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
### Indexing
|
### Indexing
|
||||||
|
|
||||||
def _is_integer_index(idx: Any) -> bool:
|
def _is_integer_index(idx: Any) -> bool:
|
||||||
|
@ -965,6 +965,31 @@ def rand_unique_int(rng, high=None):
|
|||||||
size=shape, replace=False)
|
size=shape, replace=False)
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
def rand_indices_unique_along_axis(rng):
|
||||||
|
"""Sample an array of given shape containing indices up to dim (exclusive),
|
||||||
|
such that the indices are unique along the given axis.
|
||||||
|
Optionally, convert some of the resulting indices to negative indices."""
|
||||||
|
def fn(dim, shape, axis, allow_negative=True):
|
||||||
|
batch_size = math.prod(shape[:axis] + shape[axis:][1:])
|
||||||
|
idx = [
|
||||||
|
rng.choice(dim, size=shape[axis], replace=False)
|
||||||
|
for _ in range(batch_size)
|
||||||
|
]
|
||||||
|
idx = np.array(idx).reshape(batch_size, shape[axis])
|
||||||
|
idx = idx.reshape(shape[:axis] + shape[axis:][1:] + (shape[axis],))
|
||||||
|
idx = np.moveaxis(idx, -1, axis)
|
||||||
|
|
||||||
|
# assert that indices are unique along the given axis
|
||||||
|
count = partial(np.bincount, minlength=dim)
|
||||||
|
assert (np.apply_along_axis(count, axis, idx) <= 1).all()
|
||||||
|
|
||||||
|
if allow_negative:
|
||||||
|
mask = rng.choice([False, True], idx.shape)
|
||||||
|
idx[mask] -= dim
|
||||||
|
return idx
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
def rand_bool(rng):
|
def rand_bool(rng):
|
||||||
def generator(shape, dtype):
|
def generator(shape, dtype):
|
||||||
return _cast_to_shape(
|
return _cast_to_shape(
|
||||||
|
@ -453,6 +453,10 @@ def tuple_update(t, idx, val):
|
|||||||
assert 0 <= idx < len(t), (idx, len(t))
|
assert 0 <= idx < len(t), (idx, len(t))
|
||||||
return t[:idx] + (val,) + t[idx+1:]
|
return t[:idx] + (val,) + t[idx+1:]
|
||||||
|
|
||||||
|
def tuple_replace(tupl, index, item):
|
||||||
|
# unlike tuple_update, works with negative indices as well
|
||||||
|
return tupl[:index] + (item,) + tupl[index:][1:]
|
||||||
|
|
||||||
class HashableFunction:
|
class HashableFunction:
|
||||||
"""Decouples function equality and hash from its identity.
|
"""Decouples function equality and hash from its identity.
|
||||||
|
|
||||||
|
@ -202,6 +202,7 @@ from jax._src.numpy.lax_numpy import (
|
|||||||
printoptions as printoptions,
|
printoptions as printoptions,
|
||||||
promote_types as promote_types,
|
promote_types as promote_types,
|
||||||
put as put,
|
put as put,
|
||||||
|
put_along_axis as put_along_axis,
|
||||||
ravel as ravel,
|
ravel as ravel,
|
||||||
ravel_multi_index as ravel_multi_index,
|
ravel_multi_index as ravel_multi_index,
|
||||||
repeat as repeat,
|
repeat as repeat,
|
||||||
|
@ -742,6 +742,8 @@ def ptp(a: ArrayLike, axis: _Axis = ..., out: None = ...,
|
|||||||
keepdims: builtins.bool = ...) -> Array: ...
|
keepdims: builtins.bool = ...) -> Array: ...
|
||||||
def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike,
|
def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike,
|
||||||
mode: str | None = ..., *, inplace: builtins.bool = ...) -> Array: ...
|
mode: str | None = ..., *, inplace: builtins.bool = ...) -> Array: ...
|
||||||
|
def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike,
|
||||||
|
axis: int | None, inplace: bool = True, *, mode: str | None = None) -> Array: ...
|
||||||
def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ...,
|
def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ...,
|
||||||
out: None = ..., overwrite_input: builtins.bool = ..., method: str = ...,
|
out: None = ..., overwrite_input: builtins.bool = ..., method: str = ...,
|
||||||
keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ...
|
keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ...
|
||||||
|
@ -51,7 +51,7 @@ from jax._src import deprecations
|
|||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
from jax._src.lax import lax as lax_internal
|
from jax._src.lax import lax as lax_internal
|
||||||
from jax._src.util import safe_zip, NumpyComplexWarning
|
from jax._src.util import safe_zip, NumpyComplexWarning, tuple_replace
|
||||||
|
|
||||||
config.parse_flags_with_absl()
|
config.parse_flags_with_absl()
|
||||||
|
|
||||||
@ -5962,6 +5962,45 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||||
self._CompileAndCheck(jnp_fun, args_maker)
|
self._CompileAndCheck(jnp_fun, args_maker)
|
||||||
|
|
||||||
|
@jtu.sample_product(
|
||||||
|
[
|
||||||
|
dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=axis)
|
||||||
|
for a_shape in nonempty_array_shapes
|
||||||
|
for axis in list(range(-len(a_shape), len(a_shape)))
|
||||||
|
for i_shape in [tuple_replace(a_shape, axis, J) for J in range(a_shape[axis] + 1)]
|
||||||
|
for v_shape in [(), (1,), i_shape]
|
||||||
|
] + [
|
||||||
|
dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=None)
|
||||||
|
for a_shape in nonempty_array_shapes
|
||||||
|
for i_shape in [(J,) for J in range(math.prod(a_shape) + 1)]
|
||||||
|
for v_shape in [(), (1,), i_shape]
|
||||||
|
],
|
||||||
|
dtype=jtu.dtypes.all,
|
||||||
|
mode=[None, "promise_in_bounds", "clip"],
|
||||||
|
)
|
||||||
|
def testPutAlongAxis(self, a_shape, i_shape, v_shape, axis, dtype, mode):
|
||||||
|
a_rng = jtu.rand_default(self.rng())
|
||||||
|
if axis is None:
|
||||||
|
size = math.prod(a_shape)
|
||||||
|
else:
|
||||||
|
size = a_shape[axis]
|
||||||
|
i_rng = jtu.rand_indices_unique_along_axis(self.rng())
|
||||||
|
|
||||||
|
def args_maker():
|
||||||
|
a = a_rng(a_shape, dtype)
|
||||||
|
i = i_rng(dim=size, shape=i_shape, axis=0 if axis is None else axis)
|
||||||
|
v = a_rng(v_shape, dtype)
|
||||||
|
return a, i, v
|
||||||
|
|
||||||
|
def np_fun(a, i, v):
|
||||||
|
a_copy = a.copy()
|
||||||
|
np.put_along_axis(a_copy, i, v, axis=axis)
|
||||||
|
return a_copy
|
||||||
|
|
||||||
|
jnp_fun = partial(jnp.put_along_axis, axis=axis, inplace=False, mode=mode)
|
||||||
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||||
|
self._CompileAndCheck(jnp_fun, args_maker)
|
||||||
|
|
||||||
def test_rot90_error(self):
|
def test_rot90_error(self):
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError,
|
ValueError,
|
||||||
@ -6229,7 +6268,6 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
|||||||
'nditer',
|
'nditer',
|
||||||
'nested_iters',
|
'nested_iters',
|
||||||
'poly1d',
|
'poly1d',
|
||||||
'put_along_axis',
|
|
||||||
'putmask',
|
'putmask',
|
||||||
'real_if_close',
|
'real_if_close',
|
||||||
'recarray',
|
'recarray',
|
||||||
|
Loading…
x
Reference in New Issue
Block a user