jnp.linalg: add symmetrize_input argument & docs

This commit is contained in:
Jake VanderPlas 2025-04-07 14:46:38 -07:00
parent 74917ce51e
commit 96e63eaee8
2 changed files with 19 additions and 8 deletions

View File

@ -72,8 +72,8 @@ def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
@export
@partial(jit, static_argnames=['upper'])
def cholesky(a: ArrayLike, *, upper: bool = False) -> Array:
@partial(jit, static_argnames=['upper', 'symmetrize_input'])
def cholesky(a: ArrayLike, *, upper: bool = False, symmetrize_input: bool = True) -> Array:
"""Compute the Cholesky decomposition of a matrix.
JAX implementation of :func:`numpy.linalg.cholesky`.
@ -98,6 +98,10 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array:
Must have shape ``(..., N, N)``.
upper: if True, compute the upper Cholesky decomposition `U`. if False
(default), compute the lower Cholesky decomposition `L`.
symmetrize_input: if True (default) then input is symmetrized, which leads
to better behavior under automatic differentiation. Note that when this
is set to True, both the upper and lower triangles of the input will
be used in computing the decomposition.
Returns:
array of shape ``(..., N, N)`` representing the Cholesky decomposition
@ -135,7 +139,7 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array:
"""
a = ensure_arraylike("jnp.linalg.cholesky", a)
a, = promote_dtypes_inexact(a)
L = lax_linalg.cholesky(a)
L = lax_linalg.cholesky(a, symmetrize_input=symmetrize_input)
return L.mT.conj() if upper else L
@ -821,7 +825,9 @@ def eigh(a: ArrayLike, UPLO: str | None = None,
UPLO: specifies whether the calculation is done with the lower triangular
part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``).
symmetrize_input: if True (default) then input is symmetrized, which leads
to better behavior under automatic differentiation.
to better behavior under automatic differentiation. Note that when this
is set to True, both the upper and lower triangles of the input will
be used in computing the decomposition.
Returns:
A namedtuple ``(eigenvalues, eigenvectors)`` where
@ -863,8 +869,9 @@ def eigh(a: ArrayLike, UPLO: str | None = None,
@export
@partial(jit, static_argnames=('UPLO',))
def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
@partial(jit, static_argnames=('UPLO', 'symmetrize_input'))
def eigvalsh(a: ArrayLike, UPLO: str | None = 'L', *,
symmetrize_input: bool = True) -> Array:
"""
Compute the eigenvalues of a Hermitian matrix.
@ -875,6 +882,10 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
or symmetric (if real) matrix.
UPLO: specifies whether the calculation is done with the lower triangular
part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``).
symmetrize_input: if True (default) then input is symmetrized, which leads
to better behavior under automatic differentiation. Note that when this
is set to True, both the upper and lower triangles of the input will
be used in computing the decomposition.
Returns:
An array of shape ``(..., M)`` containing the eigenvalues, sorted in
@ -894,7 +905,7 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
"""
a = ensure_arraylike("jnp.linalg.eigvalsh", a)
a, = promote_dtypes_inexact(a)
w, _ = eigh(a, UPLO)
w, _ = eigh(a, UPLO, symmetrize_input=symmetrize_input)
return w

View File

@ -96,7 +96,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
a = rng(factor_shape, dtype)
return [np.matmul(a, jnp.conj(T(a)))]
jnp_fun = partial(jnp.linalg.cholesky, upper=upper)
jnp_fun = partial(jnp.linalg.cholesky, upper=upper, symmetrize_input=True)
def np_fun(x, upper=upper):
# Upper argument added in NumPy 2.0.0