mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #21442 from vfdev-5:added-trace-alias-to-linalg
PiperOrigin-RevId: 638477013
This commit is contained in:
commit
ef0b5d7385
@ -495,6 +495,7 @@ jax.numpy.linalg
|
||||
tensordot
|
||||
tensorinv
|
||||
tensorsolve
|
||||
trace
|
||||
vector_norm
|
||||
vecdot
|
||||
|
||||
|
@ -496,7 +496,7 @@ def _slogdet_qr(a: Array) -> tuple[Array, Array]:
|
||||
@partial(jit, static_argnames=('method',))
|
||||
def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult:
|
||||
"""
|
||||
Computes the sign and (natural) logarithm of the determinant of an array.
|
||||
Compute the sign and (natural) logarithm of the determinant of an array.
|
||||
|
||||
JAX implementation of :func:`numpy.linalg.slotdet`.
|
||||
|
||||
@ -662,7 +662,7 @@ def _det_3x3(a: Array) -> Array:
|
||||
@jit
|
||||
def det(a: ArrayLike) -> Array:
|
||||
"""
|
||||
Computes the determinant of an array.
|
||||
Compute the determinant of an array.
|
||||
|
||||
JAX implementation of :func:`numpy.linalg.det`.
|
||||
|
||||
@ -706,7 +706,7 @@ def _det_jvp(primals, tangents):
|
||||
|
||||
def eig(a: ArrayLike) -> tuple[Array, Array]:
|
||||
"""
|
||||
Computes the eigenvalues and eigenvectors of a square array.
|
||||
Compute the eigenvalues and eigenvectors of a square array.
|
||||
|
||||
JAX implementation of :func:`numpy.linalg.eig`.
|
||||
|
||||
@ -750,7 +750,7 @@ def eig(a: ArrayLike) -> tuple[Array, Array]:
|
||||
@jit
|
||||
def eigvals(a: ArrayLike) -> Array:
|
||||
"""
|
||||
Computes the eigenvalues of a general matrix.
|
||||
Compute the eigenvalues of a general matrix.
|
||||
|
||||
JAX implementation of :func:`numpy.linalg.eigvals`.
|
||||
|
||||
@ -788,7 +788,7 @@ def eigvals(a: ArrayLike) -> Array:
|
||||
def eigh(a: ArrayLike, UPLO: str | None = None,
|
||||
symmetrize_input: bool = True) -> EighResult:
|
||||
"""
|
||||
Computes the eigenvalues and eigenvectors of a Hermitian matrix.
|
||||
Compute the eigenvalues and eigenvectors of a Hermitian matrix.
|
||||
|
||||
JAX implementation of :func:`numpy.linalg.eigh`.
|
||||
|
||||
@ -842,7 +842,7 @@ def eigh(a: ArrayLike, UPLO: str | None = None,
|
||||
@partial(jit, static_argnames=('UPLO',))
|
||||
def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
|
||||
"""
|
||||
Computes the eigenvalues of a Hermitian matrix.
|
||||
Compute the eigenvalues of a Hermitian matrix.
|
||||
|
||||
JAX implementation of :func:`numpy.linalg.eigvalsh`.
|
||||
|
||||
@ -1599,7 +1599,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array:
|
||||
|
||||
def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = False,
|
||||
ord: int | str = 2) -> Array:
|
||||
"""Computes the vector norm of a vector or batch of vectors.
|
||||
"""Compute the vector norm of a vector or batch of vectors.
|
||||
|
||||
JAX implementation of :func:`numpy.linalg.vector_norm`.
|
||||
|
||||
@ -2136,3 +2136,47 @@ def cond(x: ArrayLike, p=None):
|
||||
r = norm(x, ord=p, axis=(-2, -1)) * norm(inv(x), ord=p, axis=(-2, -1))
|
||||
# Convert NaNs to infs where original array has no NaNs.
|
||||
return jnp.where(ufuncs.isnan(r) & ~ufuncs.isnan(x).any(axis=(-2, -1)), jnp.inf, r)
|
||||
|
||||
|
||||
def trace(x: ArrayLike, /, *,
|
||||
offset: int = 0, dtype: DTypeLike | None = None) -> Array:
|
||||
"""Compute the trace of a matrix.
|
||||
|
||||
JAX implementation of :func:`numpy.linalg.trace`.
|
||||
|
||||
Args:
|
||||
x: array of shape ``(..., M, N)`` and whose innermost two
|
||||
dimensions form MxN matrices for which to take the trace.
|
||||
offset: positive or negative offset from the main diagonal
|
||||
(default: 0).
|
||||
dtype: data type of the returned array (default: ``None``). If ``None``,
|
||||
then output dtype will match the dtype of ``x``, promoted to default
|
||||
precision in the case of integer types.
|
||||
|
||||
Returns:
|
||||
array of batched traces with shape ``x.shape[:-2]``
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.trace`: similar API in the ``jax.numpy`` namespace.
|
||||
|
||||
Examples:
|
||||
Trace of a single matrix:
|
||||
|
||||
>>> x = jnp.array([[1, 2, 3, 4],
|
||||
... [5, 6, 7, 8],
|
||||
... [9, 10, 11, 12]])
|
||||
>>> jnp.linalg.trace(x)
|
||||
Array(18, dtype=int32)
|
||||
>>> jnp.linalg.trace(x, offset=1)
|
||||
Array(21, dtype=int32)
|
||||
>>> jnp.linalg.trace(x, offset=-1, dtype="float32")
|
||||
Array(15., dtype=float32)
|
||||
|
||||
Batched traces:
|
||||
|
||||
>>> x = jnp.arange(24).reshape(2, 3, 4)
|
||||
>>> jnp.linalg.trace(x)
|
||||
Array([15, 51], dtype=int32)
|
||||
"""
|
||||
check_arraylike('jnp.linalg.trace', x)
|
||||
return jnp.trace(x, offset=offset, axis1=-2, axis2=-1, dtype=dtype)
|
||||
|
@ -35,8 +35,7 @@ from jax.numpy.linalg import (
|
||||
vector_norm as vector_norm,
|
||||
)
|
||||
|
||||
# TODO(micky774): Add trace to jax.numpy.linalg
|
||||
from jax.numpy import trace as trace
|
||||
from jax.numpy.linalg import trace as trace
|
||||
|
||||
from jax.experimental.array_api._linear_algebra_functions import (
|
||||
matrix_rank as matrix_rank,
|
||||
|
@ -44,6 +44,7 @@ from jax._src.numpy.linalg import (
|
||||
tensordot as tensordot,
|
||||
tensorinv as tensorinv,
|
||||
tensorsolve as tensorsolve,
|
||||
trace as trace,
|
||||
vector_norm as vector_norm,
|
||||
vecdot as vecdot,
|
||||
)
|
||||
|
@ -1305,6 +1305,18 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, lax_fun, args_maker)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
def testTrace(self):
|
||||
shape, dtype, offset, out_dtype = (3, 4), "float32", 0, None
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
lax_fun = partial(jnp.linalg.trace, offset=offset, dtype=out_dtype)
|
||||
if jtu.numpy_version() >= (2, 0, 0):
|
||||
np_fun = partial(np.linalg.trace, offset=offset)
|
||||
else:
|
||||
np_fun = partial(np.trace, offset=offset, axis1=-2, axis2=-1, dtype=out_dtype)
|
||||
self._CheckAgainstNumpy(np_fun, lax_fun, args_maker)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
|
||||
class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user