diag_indices_docstring_added

see_also_diagonal_added
This commit is contained in:
selamw1 2024-07-31 15:36:38 -07:00
parent 1ac2085417
commit a11ddfd4bc

View File

@ -5236,8 +5236,31 @@ def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *, inplace:
return a.at[idx].set(val if val.ndim == 0 else _tile_to_size(val.ravel(), n))
@util.implements(np.diag_indices)
def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]:
"""Return indices for accessing the main diagonal of a multidimensional array.
JAX implementation of :func:`numpy.diag_indices`.
Args:
n: int. The size of each dimension of the square array.
ndim: optional, int, default=2. The number of dimensions of the array.
Returns:
A tuple of arrays, each of length `n`, containing the indices to access
the main diagonal.
See also:
- :func:`jax.numpy.diag_indices_from`
- :func:`jax.numpy.diagonal`
Examples:
>>> jnp.diag_indices(3)
(Array([0, 1, 2], dtype=int32), Array([0, 1, 2], dtype=int32))
>>> jnp.diag_indices(4, ndim=3)
(Array([0, 1, 2, 3], dtype=int32),
Array([0, 1, 2, 3], dtype=int32),
Array([0, 1, 2, 3], dtype=int32))
"""
n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diag_indices()")
ndim = core.concrete_or_error(operator.index, ndim, "'ndim' argument of jnp.diag_indices()")
if n < 0:
@ -5248,8 +5271,36 @@ def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]:
.format(ndim))
return (lax.iota(int_, n),) * ndim
@util.implements(np.diag_indices_from)
def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]:
"""Return indices for accessing the main diagonal of a given array.
JAX implementation of :func:`numpy.diag_indices_from`.
Args:
arr: Input array. Must be at least 2-dimensional and have equal length along
all dimensions.
Returns:
A tuple of arrays containing the indices to access the main diagonal of
the input array.
See also:
- :func:`jax.numpy.diag_indices`
- :func:`jax.numpy.diagonal`
Examples:
>>> arr = jnp.array([[1, 2, 3],
... [4, 5, 6],
... [7, 8, 9]])
>>> jnp.diag_indices_from(arr)
(Array([0, 1, 2], dtype=int32), Array([0, 1, 2], dtype=int32))
>>> arr = jnp.array([[[1, 2], [3, 4]],
... [[5, 6], [7, 8]]])
>>> jnp.diag_indices_from(arr)
(Array([0, 1], dtype=int32),
Array([0, 1], dtype=int32),
Array([0, 1], dtype=int32))
"""
util.check_arraylike("diag_indices_from", arr)
nd = ndim(arr)
if not ndim(arr) >= 2: