Better doc for jnp.trace

This commit is contained in:
rajasekharporeddy 2024-09-26 09:39:18 +05:30
parent 911acf1bbf
commit 8ffeb2388a

View File

@ -6652,10 +6652,53 @@ def triu(m: ArrayLike, k: int = 0) -> Array:
return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m)
@util.implements(np.trace, skip_params=['out'])
@partial(jit, static_argnames=('axis1', 'axis2', 'dtype'))
def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int = 1,
dtype: DTypeLike | None = None, out: None = None) -> Array:
"""Calculate sum of the diagonal of input along the given axes.
JAX implementation of :func:`numpy.trace`.
Args:
a: input array. Must have ``a.ndim >= 2``.
offset: optional, int, default=0. Diagonal offset from the main diagonal.
Can be positive or negative.
axis1: optional, default=0. The first axis along which to take the sum of
diagonal. Must be a static integer value.
axis2: optional, default=1. The second axis along which to take the sum of
diagonal. Must be a static integer value.
dtype: optional. The dtype of the output array. Should be provided as static
argument in JIT compilation.
out: Not used by JAX.
Returns:
An array of dimension x.ndim-2 containing the sum of the diagonal elements
along axes (axis1, axis2)
See also:
- :func:`jax.numpy.diag`: Returns the specified diagonal or constructs a diagonal
array
- :func:`jax.numpy.diagonal`: Returns the specified diagonal of an array.
- :func:`jax.numpy.diagflat`: Returns a 2-D array with the flattened input array
laid out on the diagonal.
Examples:
>>> x = jnp.arange(1, 9).reshape(2, 2, 2)
>>> x
Array([[[1, 2],
[3, 4]],
<BLANKLINE>
[[5, 6],
[7, 8]]], dtype=int32)
>>> jnp.trace(x)
Array([ 8, 10], dtype=int32)
>>> jnp.trace(x, offset=1)
Array([3, 4], dtype=int32)
>>> jnp.trace(x, axis1=1, axis2=2)
Array([ 5, 13], dtype=int32)
>>> jnp.trace(x, offset=1, axis1=1, axis2=2)
Array([2, 6], dtype=int32)
"""
util.check_arraylike("trace", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.trace is not supported.")