Fixes tril/triu comments (they were flipped)

PiperOrigin-RevId: 712544847
This commit is contained in:
Mark Sandler 2025-01-06 08:54:26 -08:00 committed by jax authors
parent e87a2a5929
commit 6c87bf389f

View File

@ -7667,13 +7667,14 @@ def tril(m: ArrayLike, k: int = 0) -> Array:
to sub-diagonal above the main diagonal.
Returns:
An array with same shape as input containing the upper triangle of the given
array with elements below the sub-diagonal specified by ``k`` are set to zero.
An array with same shape as input containing the lower triangle of the given
array with elements above the sub-diagonal specified by ``k`` are set to
zero.
See also:
- :func:`jax.numpy.triu`: Returns an upper triangle of an array.
- :func:`jax.numpy.tri`: Returns an array with ones on and below the diagonal
and zeros elsewhere.
- :func:`jax.numpy.tri`: Returns an array with ones on and below the
diagonal and zeros elsewhere.
Examples:
>>> x = jnp.array([[1, 2, 3, 4],
@ -7729,13 +7730,14 @@ def triu(m: ArrayLike, k: int = 0) -> Array:
to sub-diagonal above the main diagonal.
Returns:
An array with same shape as input containing the lower triangle of the given
array with elements above the sub-diagonal specified by ``k`` are set to zero.
An array with same shape as input containing the upper triangle of the given
array with elements below the sub-diagonal specified by ``k`` are set to
zero.
See also:
- :func:`jax.numpy.tril`: Returns a lower triangle of an array.
- :func:`jax.numpy.tri`: Returns an array with ones on and below the diagonal
and zeros elsewhere.
- :func:`jax.numpy.tri`: Returns an array with ones on and below the
diagonal and zeros elsewhere.
Examples:
>>> x = jnp.array([[1, 2, 3],