mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
jnp.linalg: add symmetrize_input argument & docs
This commit is contained in:
parent
74917ce51e
commit
96e63eaee8
@ -72,8 +72,8 @@ def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
|
||||
|
||||
|
||||
@export
|
||||
@partial(jit, static_argnames=['upper'])
|
||||
def cholesky(a: ArrayLike, *, upper: bool = False) -> Array:
|
||||
@partial(jit, static_argnames=['upper', 'symmetrize_input'])
|
||||
def cholesky(a: ArrayLike, *, upper: bool = False, symmetrize_input: bool = True) -> Array:
|
||||
"""Compute the Cholesky decomposition of a matrix.
|
||||
|
||||
JAX implementation of :func:`numpy.linalg.cholesky`.
|
||||
@ -98,6 +98,10 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array:
|
||||
Must have shape ``(..., N, N)``.
|
||||
upper: if True, compute the upper Cholesky decomposition `U`. if False
|
||||
(default), compute the lower Cholesky decomposition `L`.
|
||||
symmetrize_input: if True (default) then input is symmetrized, which leads
|
||||
to better behavior under automatic differentiation. Note that when this
|
||||
is set to True, both the upper and lower triangles of the input will
|
||||
be used in computing the decomposition.
|
||||
|
||||
Returns:
|
||||
array of shape ``(..., N, N)`` representing the Cholesky decomposition
|
||||
@ -135,7 +139,7 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array:
|
||||
"""
|
||||
a = ensure_arraylike("jnp.linalg.cholesky", a)
|
||||
a, = promote_dtypes_inexact(a)
|
||||
L = lax_linalg.cholesky(a)
|
||||
L = lax_linalg.cholesky(a, symmetrize_input=symmetrize_input)
|
||||
return L.mT.conj() if upper else L
|
||||
|
||||
|
||||
@ -821,7 +825,9 @@ def eigh(a: ArrayLike, UPLO: str | None = None,
|
||||
UPLO: specifies whether the calculation is done with the lower triangular
|
||||
part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``).
|
||||
symmetrize_input: if True (default) then input is symmetrized, which leads
|
||||
to better behavior under automatic differentiation.
|
||||
to better behavior under automatic differentiation. Note that when this
|
||||
is set to True, both the upper and lower triangles of the input will
|
||||
be used in computing the decomposition.
|
||||
|
||||
Returns:
|
||||
A namedtuple ``(eigenvalues, eigenvectors)`` where
|
||||
@ -863,8 +869,9 @@ def eigh(a: ArrayLike, UPLO: str | None = None,
|
||||
|
||||
|
||||
@export
|
||||
@partial(jit, static_argnames=('UPLO',))
|
||||
def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
|
||||
@partial(jit, static_argnames=('UPLO', 'symmetrize_input'))
|
||||
def eigvalsh(a: ArrayLike, UPLO: str | None = 'L', *,
|
||||
symmetrize_input: bool = True) -> Array:
|
||||
"""
|
||||
Compute the eigenvalues of a Hermitian matrix.
|
||||
|
||||
@ -875,6 +882,10 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
|
||||
or symmetric (if real) matrix.
|
||||
UPLO: specifies whether the calculation is done with the lower triangular
|
||||
part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``).
|
||||
symmetrize_input: if True (default) then input is symmetrized, which leads
|
||||
to better behavior under automatic differentiation. Note that when this
|
||||
is set to True, both the upper and lower triangles of the input will
|
||||
be used in computing the decomposition.
|
||||
|
||||
Returns:
|
||||
An array of shape ``(..., M)`` containing the eigenvalues, sorted in
|
||||
@ -894,7 +905,7 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
|
||||
"""
|
||||
a = ensure_arraylike("jnp.linalg.eigvalsh", a)
|
||||
a, = promote_dtypes_inexact(a)
|
||||
w, _ = eigh(a, UPLO)
|
||||
w, _ = eigh(a, UPLO, symmetrize_input=symmetrize_input)
|
||||
return w
|
||||
|
||||
|
||||
|
@ -96,7 +96,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
a = rng(factor_shape, dtype)
|
||||
return [np.matmul(a, jnp.conj(T(a)))]
|
||||
|
||||
jnp_fun = partial(jnp.linalg.cholesky, upper=upper)
|
||||
jnp_fun = partial(jnp.linalg.cholesky, upper=upper, symmetrize_input=True)
|
||||
|
||||
def np_fun(x, upper=upper):
|
||||
# Upper argument added in NumPy 2.0.0
|
||||
|
Loading…
x
Reference in New Issue
Block a user