mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Better docs for jnp.apply_along_axis & apply_over_axes
This commit is contained in:
parent
1f0b5728a4
commit
0c307fe706
@ -8309,10 +8309,72 @@ def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike,
|
||||
return out
|
||||
|
||||
|
||||
@util.implements(np.apply_along_axis)
|
||||
def apply_along_axis(
|
||||
func1d: Callable, axis: int, arr: ArrayLike, *args, **kwargs
|
||||
) -> Array:
|
||||
"""Apply a function to 1D array slices along an axis.
|
||||
|
||||
JAX implementation of :func:`numpy.apply_along_axis`. While NumPy implements
|
||||
this iteratively, JAX implements this via :func:`jax.vmap`, and so ``func1d``
|
||||
must be compatible with ``vmap``.
|
||||
|
||||
Args:
|
||||
func1d: a callable function with signature ``func1d(arr, /, *args, **kwargs)``
|
||||
where ``*args`` and ``**kwargs`` are the additional positional and keyword
|
||||
arguments passed to :func:`apply_along_axis`.
|
||||
axis: integer axis along which to apply the function.
|
||||
arr: the array over which to apply the function.
|
||||
args, kwargs: additional positional and keyword arguments are passed through
|
||||
to ``func1d``.
|
||||
|
||||
Returns:
|
||||
The result of ``func1d`` applied along the specified axis.
|
||||
|
||||
See also:
|
||||
- :func:`jax.vmap`: a more direct way to create a vectorized version of a function.
|
||||
- :func:`jax.numpy.apply_over_axes`: repeatedly apply a function over multiple axes.
|
||||
- :func:`jax.numpy.vectorize`: create a vectorized version of a function.
|
||||
|
||||
Examples:
|
||||
A simple example in two dimensions, where the function is applied either row-wise
|
||||
or column-wise:
|
||||
|
||||
>>> x = jnp.array([[1, 2, 3],
|
||||
... [4, 5, 6]])
|
||||
>>> def func1d(x):
|
||||
... return jnp.sum(x ** 2)
|
||||
>>> jnp.apply_along_axis(func1d, 0, x)
|
||||
Array([17, 29, 45], dtype=int32)
|
||||
>>> jnp.apply_along_axis(func1d, 1, x)
|
||||
Array([14, 77], dtype=int32)
|
||||
|
||||
For 2D inputs, this can be equivalently expressed using :func:`jax.vmap`,
|
||||
though note that `vmap` specifies the mapped axis rather than the applied axis:
|
||||
|
||||
>>> jax.vmap(func1d, in_axes=1)(x) # same as applying along axis 0
|
||||
Array([17, 29, 45], dtype=int32)
|
||||
>>> jax.vmap(func1d, in_axes=0)(x) # same as applying along axis 1
|
||||
Array([14, 77], dtype=int32)
|
||||
|
||||
For 3D inputs, :func:`apply_along_axis` is equivalent to mapping over two
|
||||
dimensions:
|
||||
|
||||
>>> x_3d = jnp.arange(24).reshape(2, 3, 4)
|
||||
>>> jnp.apply_along_axis(func1d, 2, x_3d)
|
||||
Array([[ 14, 126, 366],
|
||||
[ 734, 1230, 1854]], dtype=int32)
|
||||
>>> jax.vmap(jax.vmap(func1d))(x_3d)
|
||||
Array([[ 14, 126, 366],
|
||||
[ 734, 1230, 1854]], dtype=int32)
|
||||
|
||||
The applied function may also take arbitrary positional or keyword arguments,
|
||||
which should be passed directly as additional arguments to :func:`apply_along_axis`:
|
||||
|
||||
>>> def func1d(x, exponent):
|
||||
... return jnp.sum(x ** exponent)
|
||||
>>> jnp.apply_along_axis(func1d, 0, x, exponent=3)
|
||||
Array([ 65, 133, 243], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("apply_along_axis", arr)
|
||||
num_dims = ndim(arr)
|
||||
axis = _canonicalize_axis(axis, num_dims)
|
||||
@ -8324,9 +8386,49 @@ def apply_along_axis(
|
||||
return func(arr)
|
||||
|
||||
|
||||
@util.implements(np.apply_over_axes)
|
||||
def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike,
|
||||
axes: Sequence[int]) -> Array:
|
||||
"""Apply a function repeatedly over specified axes.
|
||||
|
||||
JAX implementation of :func:`numpy.apply_over_axes`.
|
||||
|
||||
Args:
|
||||
func: the function to apply, with signature ``func(Array, int) -> Array``, and
|
||||
where ``y = func(x, axis)`` must satisfy ``y.ndim in [x.ndim, x.ndim - 1]``.
|
||||
a: N-dimensional array over which to apply the function.
|
||||
axes: the sequence of axes over which to apply the function.
|
||||
|
||||
Returns:
|
||||
An N-dimensional array containing the result of the repeated function application.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.apply_along_axis`: apply a 1D function along a single axis.
|
||||
|
||||
Examples:
|
||||
This function is designed to have similar semantics to typical associative
|
||||
:mod:`jax.numpy` reductions over one or more axes with ``keepdims=True``.
|
||||
For example:
|
||||
|
||||
>>> x = jnp.array([[1, 2, 3],
|
||||
... [4, 5, 6]])
|
||||
|
||||
>>> jnp.apply_over_axes(jnp.sum, x, [0])
|
||||
Array([[5, 7, 9]], dtype=int32)
|
||||
>>> jnp.sum(x, [0], keepdims=True)
|
||||
Array([[5, 7, 9]], dtype=int32)
|
||||
|
||||
>>> jnp.apply_over_axes(jnp.min, x, [1])
|
||||
Array([[1],
|
||||
[4]], dtype=int32)
|
||||
>>> jnp.min(x, [1], keepdims=True)
|
||||
Array([[1],
|
||||
[4]], dtype=int32)
|
||||
|
||||
>>> jnp.apply_over_axes(jnp.prod, x, [0, 1])
|
||||
Array([[720]], dtype=int32)
|
||||
>>> jnp.prod(x, [0, 1], keepdims=True)
|
||||
Array([[720]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("apply_over_axes", a)
|
||||
a_arr = asarray(a)
|
||||
for axis in axes:
|
||||
|
Loading…
x
Reference in New Issue
Block a user