mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Better doc for jnp.trace
This commit is contained in:
parent
911acf1bbf
commit
8ffeb2388a
@ -6652,10 +6652,53 @@ def triu(m: ArrayLike, k: int = 0) -> Array:
|
||||
return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m)
|
||||
|
||||
|
||||
@util.implements(np.trace, skip_params=['out'])
|
||||
@partial(jit, static_argnames=('axis1', 'axis2', 'dtype'))
|
||||
def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int = 1,
|
||||
dtype: DTypeLike | None = None, out: None = None) -> Array:
|
||||
"""Calculate sum of the diagonal of input along the given axes.
|
||||
|
||||
JAX implementation of :func:`numpy.trace`.
|
||||
|
||||
Args:
|
||||
a: input array. Must have ``a.ndim >= 2``.
|
||||
offset: optional, int, default=0. Diagonal offset from the main diagonal.
|
||||
Can be positive or negative.
|
||||
axis1: optional, default=0. The first axis along which to take the sum of
|
||||
diagonal. Must be a static integer value.
|
||||
axis2: optional, default=1. The second axis along which to take the sum of
|
||||
diagonal. Must be a static integer value.
|
||||
dtype: optional. The dtype of the output array. Should be provided as static
|
||||
argument in JIT compilation.
|
||||
out: Not used by JAX.
|
||||
|
||||
Returns:
|
||||
An array of dimension x.ndim-2 containing the sum of the diagonal elements
|
||||
along axes (axis1, axis2)
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.diag`: Returns the specified diagonal or constructs a diagonal
|
||||
array
|
||||
- :func:`jax.numpy.diagonal`: Returns the specified diagonal of an array.
|
||||
- :func:`jax.numpy.diagflat`: Returns a 2-D array with the flattened input array
|
||||
laid out on the diagonal.
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.arange(1, 9).reshape(2, 2, 2)
|
||||
>>> x
|
||||
Array([[[1, 2],
|
||||
[3, 4]],
|
||||
<BLANKLINE>
|
||||
[[5, 6],
|
||||
[7, 8]]], dtype=int32)
|
||||
>>> jnp.trace(x)
|
||||
Array([ 8, 10], dtype=int32)
|
||||
>>> jnp.trace(x, offset=1)
|
||||
Array([3, 4], dtype=int32)
|
||||
>>> jnp.trace(x, axis1=1, axis2=2)
|
||||
Array([ 5, 13], dtype=int32)
|
||||
>>> jnp.trace(x, offset=1, axis1=1, axis2=2)
|
||||
Array([2, 6], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("trace", a)
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.trace is not supported.")
|
||||
|
Loading…
x
Reference in New Issue
Block a user