Improved docs for jnp.tri, tril and triu

This commit is contained in:
rajasekharporeddy 2024-07-23 20:58:07 +05:30
parent a18872aa13
commit f90d0ee014

View File

@ -4710,17 +4710,114 @@ def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1)
@util.implements(np.tri)
def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None) -> Array:
r"""Return an array with ones on and below the diagonal and zeros elsewhere.
JAX implementation of :func:`numpy.tri`
Args:
N: int. Dimension of the rows of the returned array.
M: optional, int. Dimension of the columns of the returned array. If not
specified, then ``M = N``.
k: optional, int, default=0. Specifies the sub-diagonal on and below which
the array is filled with ones. ``k=0`` refers to main diagonal, ``k<0``
refers to sub-diagonal below the main diagonal and ``k>0`` refers to
sub-diagonal above the main diagonal.
dtype: optional, data type of the returned array. The default type is float.
Returns:
An array of shape ``(N, M)`` containing the lower triangle with elements
below the sub-diagonal specified by ``k`` are set to one and zero elsewhere.
See also:
- :func:`jax.numpy.tril`: Returns a lower triangle of an array.
- :func:`jax.numpy.triu`: Returns an upper triangle of an array.
Examples:
>>> jnp.tri(3)
Array([[1., 0., 0.],
[1., 1., 0.],
[1., 1., 1.]], dtype=float32)
When ``M`` is not equal to ``N``:
>>> jnp.tri(3, 4)
Array([[1., 0., 0., 0.],
[1., 1., 0., 0.],
[1., 1., 1., 0.]], dtype=float32)
when ``k>0``:
>>> jnp.tri(3, k=1)
Array([[1., 1., 0.],
[1., 1., 1.],
[1., 1., 1.]], dtype=float32)
When ``k<0``:
>>> jnp.tri(3, 4, k=-1)
Array([[0., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 1., 0., 0.]], dtype=float32)
"""
dtypes.check_user_dtype_supported(dtype, "tri")
M = M if M is not None else N
dtype = dtype or float32
return lax_internal._tri(dtype, (N, M), k)
@util.implements(np.tril)
@partial(jit, static_argnames=('k',))
def tril(m: ArrayLike, k: int = 0) -> Array:
r"""Return lower triangle of an array.
JAX implementation of :func:`numpy.tril`
Args:
m: input array. Must have ``m.ndim >= 2``.
k: k: optional, int, default=0. Specifies the sub-diagonal above which the
elements of the array are set to zero. ``k=0`` refers to main diagonal,
``k<0`` refers to sub-diagonal below the main diagonal and ``k>0`` refers
to sub-diagonal above the main diagonal.
Returns:
An array with same shape as input containing the upper triangle of the given
array with elements below the sub-diagonal specified by ``k`` are set to zero.
See also:
- :func:`jax.numpy.triu`: Returns an upper triangle of an array.
- :func:`jax.numpy.tri`: Returns an array with ones on and below the diagonal
and zeros elsewhere.
Examples:
>>> x = jnp.array([[1, 2, 3, 4],
... [5, 6, 7, 8],
... [9, 10, 11, 12]])
>>> jnp.tril(x)
Array([[ 1, 0, 0, 0],
[ 5, 6, 0, 0],
[ 9, 10, 11, 0]], dtype=int32)
>>> jnp.tril(x, k=1)
Array([[ 1, 2, 0, 0],
[ 5, 6, 7, 0],
[ 9, 10, 11, 12]], dtype=int32)
>>> jnp.tril(x, k=-1)
Array([[ 0, 0, 0, 0],
[ 5, 0, 0, 0],
[ 9, 10, 0, 0]], dtype=int32)
When ``m.ndim > 2``, ``jnp.tril`` operates batch-wise on the trailing axes.
>>> x1 = jnp.array([[[1, 2],
... [3, 4]],
... [[5, 6],
... [7, 8]]])
>>> jnp.tril(x1)
Array([[[1, 0],
[3, 4]],
<BLANKLINE>
[[5, 0],
[7, 8]]], dtype=int32)
"""
util.check_arraylike("tril", m)
m_shape = shape(m)
if len(m_shape) < 2:
@ -4730,9 +4827,62 @@ def tril(m: ArrayLike, k: int = 0) -> Array:
return lax.select(lax.broadcast(mask, m_shape[:-2]), m, zeros_like(m))
@util.implements(np.triu, update_doc=False)
@partial(jit, static_argnames=('k',))
def triu(m: ArrayLike, k: int = 0) -> Array:
r"""Return upper triangle of an array.
JAX implementation of :func:`numpy.triu`
Args:
m: input array. Must have ``m.ndim >= 2``.
k: k: optional, int, default=0. Specifies the sub-diagonal below which the
elements of the array are set to zero. ``k=0`` refers to main diagonal,
``k<0`` refers to sub-diagonal below the main diagonal and ``k>0`` refers
to sub-diagonal above the main diagonal.
Returns:
An array with same shape as input containing the lower triangle of the given
array with elements above the sub-diagonal specified by ``k`` are set to zero.
See also:
- :func:`jax.numpy.tril`: Returns a lower triangle of an array.
- :func:`jax.numpy.tri`: Returns an array with ones on and below the diagonal
and zeros elsewhere.
Examples:
>>> x = jnp.array([[1, 2, 3],
... [4, 5, 6],
... [7, 8, 9],
... [10, 11, 12]])
>>> jnp.triu(x)
Array([[1, 2, 3],
[0, 5, 6],
[0, 0, 9],
[0, 0, 0]], dtype=int32)
>>> jnp.triu(x, k=1)
Array([[0, 2, 3],
[0, 0, 6],
[0, 0, 0],
[0, 0, 0]], dtype=int32)
>>> jnp.triu(x, k=-1)
Array([[ 1, 2, 3],
[ 4, 5, 6],
[ 0, 8, 9],
[ 0, 0, 12]], dtype=int32)
When ``m.ndim > 2``, ``jnp.triu`` operates batch-wise on the trailing axes.
>>> x1 = jnp.array([[[1, 2],
... [3, 4]],
... [[5, 6],
... [7, 8]]])
>>> jnp.triu(x1)
Array([[[1, 2],
[0, 4]],
<BLANKLINE>
[[5, 6],
[0, 8]]], dtype=int32)
"""
util.check_arraylike("triu", m)
m_shape = shape(m)
if len(m_shape) < 2: