diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 259c47948..f23261252 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -7667,13 +7667,14 @@ def tril(m: ArrayLike, k: int = 0) -> Array: to sub-diagonal above the main diagonal. Returns: - An array with same shape as input containing the upper triangle of the given - array with elements below the sub-diagonal specified by ``k`` are set to zero. + An array with same shape as input containing the lower triangle of the given + array with elements above the sub-diagonal specified by ``k`` are set to + zero. See also: - :func:`jax.numpy.triu`: Returns an upper triangle of an array. - - :func:`jax.numpy.tri`: Returns an array with ones on and below the diagonal - and zeros elsewhere. + - :func:`jax.numpy.tri`: Returns an array with ones on and below the + diagonal and zeros elsewhere. Examples: >>> x = jnp.array([[1, 2, 3, 4], @@ -7729,13 +7730,14 @@ def triu(m: ArrayLike, k: int = 0) -> Array: to sub-diagonal above the main diagonal. Returns: - An array with same shape as input containing the lower triangle of the given - array with elements above the sub-diagonal specified by ``k`` are set to zero. + An array with same shape as input containing the upper triangle of the given + array with elements below the sub-diagonal specified by ``k`` are set to + zero. See also: - :func:`jax.numpy.tril`: Returns a lower triangle of an array. - - :func:`jax.numpy.tri`: Returns an array with ones on and below the diagonal - and zeros elsewhere. + - :func:`jax.numpy.tri`: Returns an array with ones on and below the + diagonal and zeros elsewhere. Examples: >>> x = jnp.array([[1, 2, 3],