mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
diag_indices_docstring_added
see_also_diagonal_added
This commit is contained in:
parent
1ac2085417
commit
a11ddfd4bc
@ -5236,8 +5236,31 @@ def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *, inplace:
|
||||
return a.at[idx].set(val if val.ndim == 0 else _tile_to_size(val.ravel(), n))
|
||||
|
||||
|
||||
@util.implements(np.diag_indices)
|
||||
def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]:
|
||||
"""Return indices for accessing the main diagonal of a multidimensional array.
|
||||
|
||||
JAX implementation of :func:`numpy.diag_indices`.
|
||||
|
||||
Args:
|
||||
n: int. The size of each dimension of the square array.
|
||||
ndim: optional, int, default=2. The number of dimensions of the array.
|
||||
|
||||
Returns:
|
||||
A tuple of arrays, each of length `n`, containing the indices to access
|
||||
the main diagonal.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.diag_indices_from`
|
||||
- :func:`jax.numpy.diagonal`
|
||||
|
||||
Examples:
|
||||
>>> jnp.diag_indices(3)
|
||||
(Array([0, 1, 2], dtype=int32), Array([0, 1, 2], dtype=int32))
|
||||
>>> jnp.diag_indices(4, ndim=3)
|
||||
(Array([0, 1, 2, 3], dtype=int32),
|
||||
Array([0, 1, 2, 3], dtype=int32),
|
||||
Array([0, 1, 2, 3], dtype=int32))
|
||||
"""
|
||||
n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diag_indices()")
|
||||
ndim = core.concrete_or_error(operator.index, ndim, "'ndim' argument of jnp.diag_indices()")
|
||||
if n < 0:
|
||||
@ -5248,8 +5271,36 @@ def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]:
|
||||
.format(ndim))
|
||||
return (lax.iota(int_, n),) * ndim
|
||||
|
||||
@util.implements(np.diag_indices_from)
|
||||
def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]:
|
||||
"""Return indices for accessing the main diagonal of a given array.
|
||||
|
||||
JAX implementation of :func:`numpy.diag_indices_from`.
|
||||
|
||||
Args:
|
||||
arr: Input array. Must be at least 2-dimensional and have equal length along
|
||||
all dimensions.
|
||||
|
||||
Returns:
|
||||
A tuple of arrays containing the indices to access the main diagonal of
|
||||
the input array.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.diag_indices`
|
||||
- :func:`jax.numpy.diagonal`
|
||||
|
||||
Examples:
|
||||
>>> arr = jnp.array([[1, 2, 3],
|
||||
... [4, 5, 6],
|
||||
... [7, 8, 9]])
|
||||
>>> jnp.diag_indices_from(arr)
|
||||
(Array([0, 1, 2], dtype=int32), Array([0, 1, 2], dtype=int32))
|
||||
>>> arr = jnp.array([[[1, 2], [3, 4]],
|
||||
... [[5, 6], [7, 8]]])
|
||||
>>> jnp.diag_indices_from(arr)
|
||||
(Array([0, 1], dtype=int32),
|
||||
Array([0, 1], dtype=int32),
|
||||
Array([0, 1], dtype=int32))
|
||||
"""
|
||||
util.check_arraylike("diag_indices_from", arr)
|
||||
nd = ndim(arr)
|
||||
if not ndim(arr) >= 2:
|
||||
|
Loading…
x
Reference in New Issue
Block a user