Merge pull request #21442 from vfdev-5:added-trace-alias-to-linalg

PiperOrigin-RevId: 638477013
This commit is contained in:
jax authors 2024-05-29 18:27:20 -07:00
commit ef0b5d7385
5 changed files with 66 additions and 9 deletions

View File

@ -495,6 +495,7 @@ jax.numpy.linalg
tensordot
tensorinv
tensorsolve
trace
vector_norm
vecdot

View File

@ -496,7 +496,7 @@ def _slogdet_qr(a: Array) -> tuple[Array, Array]:
@partial(jit, static_argnames=('method',))
def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult:
"""
Computes the sign and (natural) logarithm of the determinant of an array.
Compute the sign and (natural) logarithm of the determinant of an array.
JAX implementation of :func:`numpy.linalg.slotdet`.
@ -662,7 +662,7 @@ def _det_3x3(a: Array) -> Array:
@jit
def det(a: ArrayLike) -> Array:
"""
Computes the determinant of an array.
Compute the determinant of an array.
JAX implementation of :func:`numpy.linalg.det`.
@ -706,7 +706,7 @@ def _det_jvp(primals, tangents):
def eig(a: ArrayLike) -> tuple[Array, Array]:
"""
Computes the eigenvalues and eigenvectors of a square array.
Compute the eigenvalues and eigenvectors of a square array.
JAX implementation of :func:`numpy.linalg.eig`.
@ -750,7 +750,7 @@ def eig(a: ArrayLike) -> tuple[Array, Array]:
@jit
def eigvals(a: ArrayLike) -> Array:
"""
Computes the eigenvalues of a general matrix.
Compute the eigenvalues of a general matrix.
JAX implementation of :func:`numpy.linalg.eigvals`.
@ -788,7 +788,7 @@ def eigvals(a: ArrayLike) -> Array:
def eigh(a: ArrayLike, UPLO: str | None = None,
symmetrize_input: bool = True) -> EighResult:
"""
Computes the eigenvalues and eigenvectors of a Hermitian matrix.
Compute the eigenvalues and eigenvectors of a Hermitian matrix.
JAX implementation of :func:`numpy.linalg.eigh`.
@ -842,7 +842,7 @@ def eigh(a: ArrayLike, UPLO: str | None = None,
@partial(jit, static_argnames=('UPLO',))
def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
"""
Computes the eigenvalues of a Hermitian matrix.
Compute the eigenvalues of a Hermitian matrix.
JAX implementation of :func:`numpy.linalg.eigvalsh`.
@ -1599,7 +1599,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array:
def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = False,
ord: int | str = 2) -> Array:
"""Computes the vector norm of a vector or batch of vectors.
"""Compute the vector norm of a vector or batch of vectors.
JAX implementation of :func:`numpy.linalg.vector_norm`.
@ -2136,3 +2136,47 @@ def cond(x: ArrayLike, p=None):
r = norm(x, ord=p, axis=(-2, -1)) * norm(inv(x), ord=p, axis=(-2, -1))
# Convert NaNs to infs where original array has no NaNs.
return jnp.where(ufuncs.isnan(r) & ~ufuncs.isnan(x).any(axis=(-2, -1)), jnp.inf, r)
def trace(x: ArrayLike, /, *,
offset: int = 0, dtype: DTypeLike | None = None) -> Array:
"""Compute the trace of a matrix.
JAX implementation of :func:`numpy.linalg.trace`.
Args:
x: array of shape ``(..., M, N)`` and whose innermost two
dimensions form MxN matrices for which to take the trace.
offset: positive or negative offset from the main diagonal
(default: 0).
dtype: data type of the returned array (default: ``None``). If ``None``,
then output dtype will match the dtype of ``x``, promoted to default
precision in the case of integer types.
Returns:
array of batched traces with shape ``x.shape[:-2]``
See also:
- :func:`jax.numpy.trace`: similar API in the ``jax.numpy`` namespace.
Examples:
Trace of a single matrix:
>>> x = jnp.array([[1, 2, 3, 4],
... [5, 6, 7, 8],
... [9, 10, 11, 12]])
>>> jnp.linalg.trace(x)
Array(18, dtype=int32)
>>> jnp.linalg.trace(x, offset=1)
Array(21, dtype=int32)
>>> jnp.linalg.trace(x, offset=-1, dtype="float32")
Array(15., dtype=float32)
Batched traces:
>>> x = jnp.arange(24).reshape(2, 3, 4)
>>> jnp.linalg.trace(x)
Array([15, 51], dtype=int32)
"""
check_arraylike('jnp.linalg.trace', x)
return jnp.trace(x, offset=offset, axis1=-2, axis2=-1, dtype=dtype)

View File

@ -35,8 +35,7 @@ from jax.numpy.linalg import (
vector_norm as vector_norm,
)
# TODO(micky774): Add trace to jax.numpy.linalg
from jax.numpy import trace as trace
from jax.numpy.linalg import trace as trace
from jax.experimental.array_api._linear_algebra_functions import (
matrix_rank as matrix_rank,

View File

@ -44,6 +44,7 @@ from jax._src.numpy.linalg import (
tensordot as tensordot,
tensorinv as tensorinv,
tensorsolve as tensorsolve,
trace as trace,
vector_norm as vector_norm,
vecdot as vecdot,
)

View File

@ -1305,6 +1305,18 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, lax_fun, args_maker)
self._CompileAndCheck(lax_fun, args_maker)
def testTrace(self):
shape, dtype, offset, out_dtype = (3, 4), "float32", 0, None
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
lax_fun = partial(jnp.linalg.trace, offset=offset, dtype=out_dtype)
if jtu.numpy_version() >= (2, 0, 0):
np_fun = partial(np.linalg.trace, offset=offset)
else:
np_fun = partial(np.trace, offset=offset, axis1=-2, axis2=-1, dtype=out_dtype)
self._CheckAgainstNumpy(np_fun, lax_fun, args_maker)
self._CompileAndCheck(lax_fun, args_maker)
class ScipyLinalgTest(jtu.JaxTestCase):