mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Improved docs for jnp.tri, tril and triu
This commit is contained in:
parent
a18872aa13
commit
f90d0ee014
@ -4710,17 +4710,114 @@ def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
|
||||
return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1)
|
||||
|
||||
|
||||
@util.implements(np.tri)
|
||||
def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None) -> Array:
|
||||
r"""Return an array with ones on and below the diagonal and zeros elsewhere.
|
||||
|
||||
JAX implementation of :func:`numpy.tri`
|
||||
|
||||
Args:
|
||||
N: int. Dimension of the rows of the returned array.
|
||||
M: optional, int. Dimension of the columns of the returned array. If not
|
||||
specified, then ``M = N``.
|
||||
k: optional, int, default=0. Specifies the sub-diagonal on and below which
|
||||
the array is filled with ones. ``k=0`` refers to main diagonal, ``k<0``
|
||||
refers to sub-diagonal below the main diagonal and ``k>0`` refers to
|
||||
sub-diagonal above the main diagonal.
|
||||
dtype: optional, data type of the returned array. The default type is float.
|
||||
|
||||
Returns:
|
||||
An array of shape ``(N, M)`` containing the lower triangle with elements
|
||||
below the sub-diagonal specified by ``k`` are set to one and zero elsewhere.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.tril`: Returns a lower triangle of an array.
|
||||
- :func:`jax.numpy.triu`: Returns an upper triangle of an array.
|
||||
|
||||
Examples:
|
||||
>>> jnp.tri(3)
|
||||
Array([[1., 0., 0.],
|
||||
[1., 1., 0.],
|
||||
[1., 1., 1.]], dtype=float32)
|
||||
|
||||
When ``M`` is not equal to ``N``:
|
||||
|
||||
>>> jnp.tri(3, 4)
|
||||
Array([[1., 0., 0., 0.],
|
||||
[1., 1., 0., 0.],
|
||||
[1., 1., 1., 0.]], dtype=float32)
|
||||
|
||||
when ``k>0``:
|
||||
|
||||
>>> jnp.tri(3, k=1)
|
||||
Array([[1., 1., 0.],
|
||||
[1., 1., 1.],
|
||||
[1., 1., 1.]], dtype=float32)
|
||||
|
||||
When ``k<0``:
|
||||
|
||||
>>> jnp.tri(3, 4, k=-1)
|
||||
Array([[0., 0., 0., 0.],
|
||||
[1., 0., 0., 0.],
|
||||
[1., 1., 0., 0.]], dtype=float32)
|
||||
"""
|
||||
dtypes.check_user_dtype_supported(dtype, "tri")
|
||||
M = M if M is not None else N
|
||||
dtype = dtype or float32
|
||||
return lax_internal._tri(dtype, (N, M), k)
|
||||
|
||||
|
||||
@util.implements(np.tril)
|
||||
@partial(jit, static_argnames=('k',))
|
||||
def tril(m: ArrayLike, k: int = 0) -> Array:
|
||||
r"""Return lower triangle of an array.
|
||||
|
||||
JAX implementation of :func:`numpy.tril`
|
||||
|
||||
Args:
|
||||
m: input array. Must have ``m.ndim >= 2``.
|
||||
k: k: optional, int, default=0. Specifies the sub-diagonal above which the
|
||||
elements of the array are set to zero. ``k=0`` refers to main diagonal,
|
||||
``k<0`` refers to sub-diagonal below the main diagonal and ``k>0`` refers
|
||||
to sub-diagonal above the main diagonal.
|
||||
|
||||
Returns:
|
||||
An array with same shape as input containing the upper triangle of the given
|
||||
array with elements below the sub-diagonal specified by ``k`` are set to zero.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.triu`: Returns an upper triangle of an array.
|
||||
- :func:`jax.numpy.tri`: Returns an array with ones on and below the diagonal
|
||||
and zeros elsewhere.
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.array([[1, 2, 3, 4],
|
||||
... [5, 6, 7, 8],
|
||||
... [9, 10, 11, 12]])
|
||||
>>> jnp.tril(x)
|
||||
Array([[ 1, 0, 0, 0],
|
||||
[ 5, 6, 0, 0],
|
||||
[ 9, 10, 11, 0]], dtype=int32)
|
||||
>>> jnp.tril(x, k=1)
|
||||
Array([[ 1, 2, 0, 0],
|
||||
[ 5, 6, 7, 0],
|
||||
[ 9, 10, 11, 12]], dtype=int32)
|
||||
>>> jnp.tril(x, k=-1)
|
||||
Array([[ 0, 0, 0, 0],
|
||||
[ 5, 0, 0, 0],
|
||||
[ 9, 10, 0, 0]], dtype=int32)
|
||||
|
||||
When ``m.ndim > 2``, ``jnp.tril`` operates batch-wise on the trailing axes.
|
||||
|
||||
>>> x1 = jnp.array([[[1, 2],
|
||||
... [3, 4]],
|
||||
... [[5, 6],
|
||||
... [7, 8]]])
|
||||
>>> jnp.tril(x1)
|
||||
Array([[[1, 0],
|
||||
[3, 4]],
|
||||
<BLANKLINE>
|
||||
[[5, 0],
|
||||
[7, 8]]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("tril", m)
|
||||
m_shape = shape(m)
|
||||
if len(m_shape) < 2:
|
||||
@ -4730,9 +4827,62 @@ def tril(m: ArrayLike, k: int = 0) -> Array:
|
||||
return lax.select(lax.broadcast(mask, m_shape[:-2]), m, zeros_like(m))
|
||||
|
||||
|
||||
@util.implements(np.triu, update_doc=False)
|
||||
@partial(jit, static_argnames=('k',))
|
||||
def triu(m: ArrayLike, k: int = 0) -> Array:
|
||||
r"""Return upper triangle of an array.
|
||||
|
||||
JAX implementation of :func:`numpy.triu`
|
||||
|
||||
Args:
|
||||
m: input array. Must have ``m.ndim >= 2``.
|
||||
k: k: optional, int, default=0. Specifies the sub-diagonal below which the
|
||||
elements of the array are set to zero. ``k=0`` refers to main diagonal,
|
||||
``k<0`` refers to sub-diagonal below the main diagonal and ``k>0`` refers
|
||||
to sub-diagonal above the main diagonal.
|
||||
|
||||
Returns:
|
||||
An array with same shape as input containing the lower triangle of the given
|
||||
array with elements above the sub-diagonal specified by ``k`` are set to zero.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.tril`: Returns a lower triangle of an array.
|
||||
- :func:`jax.numpy.tri`: Returns an array with ones on and below the diagonal
|
||||
and zeros elsewhere.
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.array([[1, 2, 3],
|
||||
... [4, 5, 6],
|
||||
... [7, 8, 9],
|
||||
... [10, 11, 12]])
|
||||
>>> jnp.triu(x)
|
||||
Array([[1, 2, 3],
|
||||
[0, 5, 6],
|
||||
[0, 0, 9],
|
||||
[0, 0, 0]], dtype=int32)
|
||||
>>> jnp.triu(x, k=1)
|
||||
Array([[0, 2, 3],
|
||||
[0, 0, 6],
|
||||
[0, 0, 0],
|
||||
[0, 0, 0]], dtype=int32)
|
||||
>>> jnp.triu(x, k=-1)
|
||||
Array([[ 1, 2, 3],
|
||||
[ 4, 5, 6],
|
||||
[ 0, 8, 9],
|
||||
[ 0, 0, 12]], dtype=int32)
|
||||
|
||||
When ``m.ndim > 2``, ``jnp.triu`` operates batch-wise on the trailing axes.
|
||||
|
||||
>>> x1 = jnp.array([[[1, 2],
|
||||
... [3, 4]],
|
||||
... [[5, 6],
|
||||
... [7, 8]]])
|
||||
>>> jnp.triu(x1)
|
||||
Array([[[1, 2],
|
||||
[0, 4]],
|
||||
<BLANKLINE>
|
||||
[[5, 6],
|
||||
[0, 8]]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("triu", m)
|
||||
m_shape = shape(m)
|
||||
if len(m_shape) < 2:
|
||||
|
Loading…
x
Reference in New Issue
Block a user