rocm_jax/jax/_src/scipy/linalg.py
tttc3 b1b56ea0b0 Enable pivoted QR on GPU via MAGMA.
Originally noted in #20282, this commit provides a GPU compatible
implementation of `geqp3` via MAGMA.
2025-02-12 16:12:42 +00:00

2184 lines
76 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from functools import partial
import numpy as np
import textwrap
from typing import overload, Any, Literal
import jax
import jax.numpy as jnp
from jax import jit, vmap, jvp
from jax import lax
from jax._src import dtypes
from jax._src.lax import linalg as lax_linalg
from jax._src.lax import qdwh
from jax._src.numpy.util import (
check_arraylike, promote_dtypes, promote_dtypes_inexact,
promote_dtypes_complex)
from jax._src.typing import Array, ArrayLike
_no_chkfinite_doc = textwrap.dedent("""
Does not support the Scipy argument ``check_finite=True``,
because compiled JAX code cannot perform checks of array values at runtime.
""")
_no_overwrite_and_chkfinite_doc = _no_chkfinite_doc + "\nDoes not support the Scipy argument ``overwrite_*=True``."
@partial(jit, static_argnames=('lower',))
def _cholesky(a: ArrayLike, lower: bool) -> Array:
a, = promote_dtypes_inexact(jnp.asarray(a))
l = lax_linalg.cholesky(a if lower else jnp.conj(a.mT), symmetrize_input=False)
return l if lower else jnp.conj(l.mT)
def cholesky(a: ArrayLike, lower: bool = False, overwrite_a: bool = False,
check_finite: bool = True) -> Array:
"""Compute the Cholesky decomposition of a matrix.
JAX implementation of :func:`scipy.linalg.cholesky`.
The Cholesky decomposition of a matrix `A` is:
.. math::
A = U^HU = LL^H
where `U` is an upper-triangular matrix and `L` is a lower-triangular matrix.
Args:
a: input array, representing a (batched) positive-definite hermitian matrix.
Must have shape ``(..., N, N)``.
lower: if True, compute the lower Cholesky decomposition `L`. if False
(default), compute the upper Cholesky decomposition `U`.
overwrite_a: unused by JAX
check_finite: unused by JAX
Returns:
array of shape ``(..., N, N)`` representing the cholesky decomposition
of the input.
See Also:
- :func:`jax.numpy.linalg.cholesky`: NumPy-stype Cholesky API
- :func:`jax.lax.linalg.cholesky`: XLA-style Cholesky API
- :func:`jax.scipy.linalg.cho_factor`
- :func:`jax.scipy.linalg.cho_solve`
Examples:
A small real Hermitian positive-definite matrix:
>>> x = jnp.array([[2., 1.],
... [1., 2.]])
Upper Cholesky factorization:
>>> jax.scipy.linalg.cholesky(x)
Array([[1.4142135 , 0.70710677],
[0. , 1.2247449 ]], dtype=float32)
Lower Cholesky factorization:
>>> jax.scipy.linalg.cholesky(x, lower=True)
Array([[1.4142135 , 0. ],
[0.70710677, 1.2247449 ]], dtype=float32)
Reconstructing ``x`` from its factorization:
>>> L = jax.scipy.linalg.cholesky(x, lower=True)
>>> jnp.allclose(x, L @ L.T)
Array(True, dtype=bool)
"""
del overwrite_a, check_finite # Unused
return _cholesky(a, lower)
def cho_factor(a: ArrayLike, lower: bool = False, overwrite_a: bool = False,
check_finite: bool = True) -> tuple[Array, bool]:
"""Factorization for Cholesky-based linear solves
JAX implementation of :func:`scipy.linalg.cho_factor`. This function returns
a result suitable for use with :func:`jax.scipy.linalg.cho_solve`. For direct
Cholesky decompositions, prefer :func:`jax.scipy.linalg.cholesky`.
Args:
a: input array, representing a (batched) positive-definite hermitian matrix.
Must have shape ``(..., N, N)``.
lower: if True, compute the lower triangular Cholesky decomposition (default: False).
overwrite_a: unused by JAX
check_finite: unused by JAX
Returns:
``(c, lower)``: ``c`` is an array of shape ``(..., N, N)`` representing the lower or
upper cholesky decomposition of the input; ``lower`` is a boolean specifying whether
this is the lower or upper decomposition.
See Also:
- :func:`jax.scipy.linalg.cholesky`
- :func:`jax.scipy.linalg.cho_solve`
Examples:
A small real Hermitian positive-definite matrix:
>>> x = jnp.array([[2., 1.],
... [1., 2.]])
Compute the cholesky factorization via :func:`~jax.scipy.linalg.cho_factor`,
and use it to solve a linear equation via :func:`~jax.scipy.linalg.cho_solve`.
>>> b = jnp.array([3., 4.])
>>> cfac = jax.scipy.linalg.cho_factor(x)
>>> y = jax.scipy.linalg.cho_solve(cfac, b)
>>> y
Array([0.6666666, 1.6666666], dtype=float32)
Check that the result is consistent:
>>> jnp.allclose(x @ y, b)
Array(True, dtype=bool)
"""
del overwrite_a, check_finite # Unused
return (cholesky(a, lower=lower), lower)
@partial(jit, static_argnames=('lower',))
def _cho_solve(c: ArrayLike, b: ArrayLike, lower: bool) -> Array:
c, b = promote_dtypes_inexact(jnp.asarray(c), jnp.asarray(b))
lax_linalg._check_solve_shapes(c, b)
b = lax_linalg.triangular_solve(c, b, left_side=True, lower=lower,
transpose_a=not lower, conjugate_a=not lower)
b = lax_linalg.triangular_solve(c, b, left_side=True, lower=lower,
transpose_a=lower, conjugate_a=lower)
return b
def cho_solve(c_and_lower: tuple[ArrayLike, bool], b: ArrayLike,
overwrite_b: bool = False, check_finite: bool = True) -> Array:
"""Solve a linear system using a Cholesky factorization
JAX implementation of :func:`scipy.linalg.cho_solve`. Uses the output
of :func:`jax.scipy.linalg.cho_factor`.
Args:
c_and_lower: ``(c, lower)``, where ``c`` is an array of shape ``(..., N, N)``
representing the lower or upper cholesky decomposition of the matrix, and
``lower`` is a boolean specifying whether this is the lower or upper decomposition.
b: right-hand-side of linear system. Must have shape ``(..., N)``
overwrite_a: unused by JAX
check_finite: unused by JAX
Returns:
Array of shape ``(..., N)`` representing the solution of the linear system.
See Also:
- :func:`jax.scipy.linalg.cholesky`
- :func:`jax.scipy.linalg.cho_factor`
Examples:
A small real Hermitian positive-definite matrix:
>>> x = jnp.array([[2., 1.],
... [1., 2.]])
Compute the cholesky factorization via :func:`~jax.scipy.linalg.cho_factor`,
and use it to solve a linear equation via :func:`~jax.scipy.linalg.cho_solve`.
>>> b = jnp.array([3., 4.])
>>> cfac = jax.scipy.linalg.cho_factor(x)
>>> y = jax.scipy.linalg.cho_solve(cfac, b)
>>> y
Array([0.6666666, 1.6666666], dtype=float32)
Check that the result is consistent:
>>> jnp.allclose(x @ y, b)
Array(True, dtype=bool)
"""
del overwrite_b, check_finite # Unused
c, lower = c_and_lower
return _cho_solve(c, b, lower)
@overload
def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: Literal[True]) -> tuple[Array, Array, Array]: ...
@overload
def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: Literal[False]) -> Array: ...
@overload
def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Array | tuple[Array, Array, Array]: ...
@partial(jit, static_argnames=('full_matrices', 'compute_uv'))
def _svd(a: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Array | tuple[Array, Array, Array]:
a, = promote_dtypes_inexact(jnp.asarray(a))
return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
@overload
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: Literal[True] = True,
overwrite_a: bool = False, check_finite: bool = True,
lapack_driver: str = 'gesdd') -> tuple[Array, Array, Array]: ...
@overload
def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[False],
overwrite_a: bool = False, check_finite: bool = True,
lapack_driver: str = 'gesdd') -> Array: ...
@overload
def svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[False],
overwrite_a: bool = False, check_finite: bool = True,
lapack_driver: str = 'gesdd') -> Array: ...
@overload
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
overwrite_a: bool = False, check_finite: bool = True,
lapack_driver: str = 'gesdd') -> Array | tuple[Array, Array, Array]: ...
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
overwrite_a: bool = False, check_finite: bool = True,
lapack_driver: str = 'gesdd') -> Array | tuple[Array, Array, Array]:
r"""Compute the singular value decomposition.
JAX implementation of :func:`scipy.linalg.svd`.
The SVD of a matrix `A` is given by
.. math::
A = U\Sigma V^H
- :math:`U` contains the left singular vectors and satisfies :math:`U^HU=I`
- :math:`V` contains the right singular vectors and satisfies :math:`V^HV=I`
- :math:`\Sigma` is a diagonal matrix of singular values.
Args:
a: input array, of shape ``(..., N, M)``
full_matrices: if True (default) compute the full matrices; i.e. ``u`` and ``vh`` have
shape ``(..., N, N)`` and ``(..., M, M)``. If False, then the shapes are
``(..., N, K)`` and ``(..., K, M)`` with ``K = min(N, M)``.
compute_uv: if True (default), return the full SVD ``(u, s, vh)``. If False then return
only the singular values ``s``.
overwrite_a: unused by JAX
check_finite: unused by JAX
lapack_driver: unused by JAX
Returns:
A tuple of arrays ``(u, s, vh)`` if ``compute_uv`` is True, otherwise the array ``s``.
- ``u``: left singular vectors of shape ``(..., N, N)`` if ``full_matrices`` is True
or ``(..., N, K)`` otherwise.
- ``s``: singular values of shape ``(..., K)``
- ``vh``: conjugate-transposed right singular vectors of shape ``(..., M, M)``
if ``full_matrices`` is True or ``(..., K, M)`` otherwise.
where ``K = min(N, M)``.
See also:
- :func:`jax.numpy.linalg.svd`: NumPy-style SVD API
- :func:`jax.lax.linalg.svd`: XLA-style SVD API
Examples:
Consider the SVD of a small real-valued array:
>>> x = jnp.array([[1., 2., 3.],
... [6., 5., 4.]])
>>> u, s, vt = jax.scipy.linalg.svd(x, full_matrices=False)
>>> s # doctest: +SKIP
Array([9.361919 , 1.8315067], dtype=float32)
The singular vectors are in the columns of ``u`` and ``v = vt.T``. These vectors are
orthonormal, which can be demonstrated by comparing the matrix product with the
identity matrix:
>>> jnp.allclose(u.T @ u, jnp.eye(2), atol=1E-5)
Array(True, dtype=bool)
>>> v = vt.T
>>> jnp.allclose(v.T @ v, jnp.eye(2), atol=1E-5)
Array(True, dtype=bool)
Given the SVD, ``x`` can be reconstructed via matrix multiplication:
>>> x_reconstructed = u @ jnp.diag(s) @ vt
>>> jnp.allclose(x_reconstructed, x)
Array(True, dtype=bool)
"""
del overwrite_a, check_finite, lapack_driver # unused
return _svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
def det(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Array:
"""Compute the determinant of a matrix
JAX implementation of :func:`scipy.linalg.det`.
Args:
a: input array, of shape ``(..., N, N)``
overwrite_a: unused by JAX
check_finite: unused by JAX
Returns
Determinant of shape ``a.shape[:-2]``
See Also:
:func:`jax.numpy.linalg.det`: NumPy-style determinant API
Examples:
Determinant of a small 2D array:
>>> x = jnp.array([[1., 2.],
... [3., 4.]])
>>> jax.scipy.linalg.det(x)
Array(-2., dtype=float32)
Batch-wise determinant of multiple 2D arrays:
>>> x = jnp.array([[[1., 2.],
... [3., 4.]],
... [[8., 5.],
... [7., 9.]]])
>>> jax.scipy.linalg.det(x)
Array([-2., 37.], dtype=float32)
"""
del overwrite_a, check_finite # unused
return jnp.linalg.det(a)
@overload
def _eigh(a: ArrayLike, b: ArrayLike | None, lower: bool, eigvals_only: Literal[True],
eigvals: None, type: int) -> Array: ...
@overload
def _eigh(a: ArrayLike, b: ArrayLike | None, lower: bool, eigvals_only: Literal[False],
eigvals: None, type: int) -> tuple[Array, Array]: ...
@overload
def _eigh(a: ArrayLike, b: ArrayLike | None, lower: bool, eigvals_only: bool,
eigvals: None, type: int) -> Array | tuple[Array, Array]: ...
@partial(jit, static_argnames=('lower', 'eigvals_only', 'eigvals', 'type'))
def _eigh(a: ArrayLike, b: ArrayLike | None, lower: bool, eigvals_only: bool,
eigvals: None, type: int) -> Array | tuple[Array, Array]:
if b is not None:
raise NotImplementedError("Only the b=None case of eigh is implemented")
if type != 1:
raise NotImplementedError("Only the type=1 case of eigh is implemented.")
if eigvals is not None:
raise NotImplementedError(
"Only the eigvals=None case of eigh is implemented.")
a, = promote_dtypes_inexact(jnp.asarray(a))
v, w = lax_linalg.eigh(a, lower=lower)
if eigvals_only:
return w
else:
return w, v
@overload
def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True,
eigvals_only: Literal[False] = False, overwrite_a: bool = False,
overwrite_b: bool = False, turbo: bool = True, eigvals: None = None,
type: int = 1, check_finite: bool = True) -> tuple[Array, Array]: ...
@overload
def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, *,
eigvals_only: Literal[True], overwrite_a: bool = False,
overwrite_b: bool = False, turbo: bool = True, eigvals: None = None,
type: int = 1, check_finite: bool = True) -> Array: ...
@overload
def eigh(a: ArrayLike, b: ArrayLike | None, lower: bool,
eigvals_only: Literal[True], overwrite_a: bool = False,
overwrite_b: bool = False, turbo: bool = True, eigvals: None = None,
type: int = 1, check_finite: bool = True) -> Array: ...
@overload
def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True,
eigvals_only: bool = False, overwrite_a: bool = False,
overwrite_b: bool = False, turbo: bool = True, eigvals: None = None,
type: int = 1, check_finite: bool = True) -> Array | tuple[Array, Array]: ...
def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True,
eigvals_only: bool = False, overwrite_a: bool = False,
overwrite_b: bool = False, turbo: bool = True, eigvals: None = None,
type: int = 1, check_finite: bool = True) -> Array | tuple[Array, Array]:
"""Compute eigenvalues and eigenvectors for a Hermitian matrix
JAX implementation of :func:`scipy.linalg.eigh`.
Args:
a: Hermitian input array of shape ``(..., N, N)``
b: optional Hermitian input of shape ``(..., N, N)``. If specified, compute
the generalized eigenvalue problem.
lower: if True (default) access only the lower portion of the input matrix.
Otherwise access only the upper portion.
eigvals_only: If True, compute only the eigenvalues. If False (default) compute
both eigenvalues and eigenvectors.
type: if ``b`` is specified, ``type`` gives the type of generalized eigenvalue
problem to be computed. Denoting ``(λ, v)`` as an eigenvalue, eigenvector pair:
- ``type = 1`` solves ``a @ v = λ * b @ v`` (default)
- ``type = 2`` solves ``a @ b @ v = λ * v``
- ``type = 3`` solves ``b @ a @ v = λ * v``
eigvals: a ``(low, high)`` tuple specifying which eigenvalues to compute.
overwrite_a: unused by JAX.
overwrite_b: unused by JAX.
turbo: unused by JAX.
check_finite: unused by JAX.
Returns:
A tuple of arrays ``(eigvals, eigvecs)`` if ``eigvals_only`` is False, otherwise
an array ``eigvals``.
- ``eigvals``: array of shape ``(..., N)`` containing the eigenvalues.
- ``eigvecs``: array of shape ``(..., N, N)`` containing the eigenvectors.
See also:
- :func:`jax.numpy.linalg.eigh`: NumPy-style eigh API.
- :func:`jax.lax.linalg.eigh`: XLA-style eigh API.
- :func:`jax.numpy.linalg.eig`: non-hermitian eigenvalue problem.
- :func:`jax.scipy.linalg.eigh_tridiagonal`: tri-diagonal eigenvalue problem.
Examples:
Compute the standard eigenvalue decomposition of a simple 2x2 matrix:
>>> a = jnp.array([[2., 1.],
... [1., 2.]])
>>> eigvals, eigvecs = jax.scipy.linalg.eigh(a)
>>> eigvals
Array([1., 3.], dtype=float32)
>>> eigvecs
Array([[-0.70710677, 0.70710677],
[ 0.70710677, 0.70710677]], dtype=float32)
Eigenvectors are orthonormal:
>>> jnp.allclose(eigvecs.T @ eigvecs, jnp.eye(2), atol=1E-5)
Array(True, dtype=bool)
Solution satisfies the eigenvalue problem:
>>> jnp.allclose(a @ eigvecs, eigvecs @ jnp.diag(eigvals))
Array(True, dtype=bool)
"""
del overwrite_a, overwrite_b, turbo, check_finite # unused
return _eigh(a, b, lower, eigvals_only, eigvals, type)
@partial(jit, static_argnames=('output',))
def _schur(a: Array, output: str) -> tuple[Array, Array]:
if output == "complex":
a = a.astype(dtypes.to_complex_dtype(a.dtype))
return lax_linalg.schur(a)
def schur(a: ArrayLike, output: str = 'real') -> tuple[Array, Array]:
"""Compute the Schur decomposition
JAX implementation of :func:`scipy.linalg.schur`.
The Schur form `T` of a matrix `A` satisfies:
.. math::
A = Z T Z^H
where `Z` is unitary, and `T` is upper-triangular for the complex-valued Schur
decomposition (i.e. ``output="complex"``) and is quasi-upper-triangular for the
real-valued Schur decomposition (i.e. ``output="real"``). In the quasi-triangular
case, the diagonal may include 2x2 blocks associated with complex-valued
eigenvalue pairs of `A`.
Args:
a: input array of shape ``(..., N, N)``
output: Specify whether to compute the ``"real"`` (default) or ``"complex"``
Schur decomposition.
Returns:
A tuple of arrays ``(T, Z)``
- ``T`` is a shape ``(..., N, N)`` array containing the upper-triangular
Schur form of the input.
- ``Z`` is a shape ``(..., N, N)`` array containing the unitary Schur
transformation matrix.
See also:
- :func:`jax.scipy.linalg.rsf2csf`: convert real Schur form to complex Schur form.
- :func:`jax.lax.linalg.schur`: XLA-style API for Schur decomposition.
Examples:
A Schur decomposition of a 3x3 matrix:
>>> a = jnp.array([[1., 2., 3.],
... [1., 4., 2.],
... [3., 2., 1.]])
>>> T, Z = jax.scipy.linalg.schur(a)
The Schur form ``T`` is quasi-upper-triangular in general, but is truly
upper-triangular in this case because the input matrix is symmetric:
>>> T # doctest: +SKIP
Array([[-2.0000005 , 0.5066295 , -0.43360388],
[ 0. , 1.5505103 , 0.74519426],
[ 0. , 0. , 6.449491 ]], dtype=float32)
The transformation matrix ``Z`` is unitary:
>>> jnp.allclose(Z.T @ Z, jnp.eye(3), atol=1E-5)
Array(True, dtype=bool)
The input can be reconstructed from the outputs:
>>> jnp.allclose(Z @ T @ Z.T, a)
Array(True, dtype=bool)
"""
if output not in ('real', 'complex'):
raise ValueError(
f"Expected 'output' to be either 'real' or 'complex', got {output=}.")
return _schur(a, output)
def inv(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Array:
"""Return the inverse of a square matrix
JAX implementation of :func:`scipy.linalg.inv`.
Args:
a: array of shape ``(..., N, N)`` specifying square array(s) to be inverted.
overwrite_a: unused in JAX
check_finite: unused in JAX
Returns:
Array of shape ``(..., N, N)`` containing the inverse of the input.
Notes:
In most cases, explicitly computing the inverse of a matrix is ill-advised. For
example, to compute ``x = inv(A) @ b``, it is more performant and numerically
precise to use a direct solve, such as :func:`jax.scipy.linalg.solve`.
See Also:
- :func:`jax.numpy.linalg.inv`: NumPy-style API for matrix inverse
- :func:`jax.scipy.linalg.solve`: direct linear solver
Examples:
Compute the inverse of a 3x3 matrix
>>> a = jnp.array([[1., 2., 3.],
... [2., 4., 2.],
... [3., 2., 1.]])
>>> a_inv = jax.scipy.linalg.inv(a)
>>> a_inv # doctest: +SKIP
Array([[ 0. , -0.25 , 0.5 ],
[-0.25 , 0.5 , -0.25000003],
[ 0.5 , -0.25 , 0. ]], dtype=float32)
Check that multiplying with the inverse gives the identity:
>>> jnp.allclose(a @ a_inv, jnp.eye(3), atol=1E-5)
Array(True, dtype=bool)
Multiply the inverse by a vector ``b``, to find a solution to ``a @ x = b``
>>> b = jnp.array([1., 4., 2.])
>>> a_inv @ b
Array([ 0. , 1.25, -0.5 ], dtype=float32)
Note, however, that explicitly computing the inverse in such a case can lead
to poor performance and loss of precision as the size of the problem grows.
Instead, you should use a direct solver like :func:`jax.scipy.linalg.solve`:
>>> jax.scipy.linalg.solve(a, b)
Array([ 0. , 1.25, -0.5 ], dtype=float32)
"""
del overwrite_a, check_finite # unused
return jnp.linalg.inv(a)
@partial(jit, static_argnames=('overwrite_a', 'check_finite'))
def lu_factor(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array]:
"""Factorization for LU-based linear solves
JAX implementation of :func:`scipy.linalg.lu_factor`.
This function returns a result suitable for use with :func:`jax.scipy.linalg.lu_solve`.
For direct LU decompositions, prefer :func:`jax.scipy.linalg.lu`.
Args:
a: input array of shape ``(..., M, N)``.
overwrite_a: unused by JAX
check_finite: unused by JAX
Returns:
A tuple ``(lu, piv)``
- ``lu`` is an array of shape ``(..., M, N)``, containing ``L`` in its
lower triangle and ``U`` in its upper.
- ``piv`` is an array of shape ``(..., K)`` with ``K = min(M, N)``,
which encodes the pivots.
See Also:
- :func:`jax.scipy.linalg.lu`
- :func:`jax.scipy.linalg.lu_solve`
Examples:
Solving a small linear system via LU factorization:
>>> a = jnp.array([[2., 1.],
... [1., 2.]])
Compute the lu factorization via :func:`~jax.scipy.linalg.lu_factor`,
and use it to solve a linear equation via :func:`~jax.scipy.linalg.lu_solve`.
>>> b = jnp.array([3., 4.])
>>> lufac = jax.scipy.linalg.lu_factor(a)
>>> y = jax.scipy.linalg.lu_solve(lufac, b)
>>> y
Array([0.6666666, 1.6666667], dtype=float32)
Check that the result is consistent:
>>> jnp.allclose(a @ y, b)
Array(True, dtype=bool)
"""
del overwrite_a, check_finite # unused
a, = promote_dtypes_inexact(jnp.asarray(a))
lu, pivots, _ = lax_linalg.lu(a)
return lu, pivots
@partial(jit, static_argnames=('trans', 'overwrite_b', 'check_finite'))
def lu_solve(lu_and_piv: tuple[Array, ArrayLike], b: ArrayLike, trans: int = 0,
overwrite_b: bool = False, check_finite: bool = True) -> Array:
"""Solve a linear system using an LU factorization
JAX implementation of :func:`scipy.linalg.lu_solve`. Uses the output
of :func:`jax.scipy.linalg.lu_factor`.
Args:
lu_and_piv: ``(lu, piv)``, output of :func:`~jax.scipy.linalg.lu_factor`.
``lu`` is an array of shape ``(..., M, N)``, containing ``L`` in its lower
triangle and ``U`` in its upper. ``piv`` is an array of shape ``(..., K)``,
with ``K = min(M, N)``, which encodes the pivots.
b: right-hand-side of linear system. Must have shape ``(..., M)``
trans: type of system to solve. Options are:
- ``0``: :math:`A x = b`
- ``1``: :math:`A^Tx = b`
- ``2``: :math:`A^Hx = b`
overwrite_b: unused by JAX
check_finite: unused by JAX
Returns:
Array of shape ``(..., N)`` representing the solution of the linear system.
See Also:
- :func:`jax.scipy.linalg.lu`
- :func:`jax.scipy.linalg.lu_factor`
Examples:
Solving a small linear system via LU factorization:
>>> a = jnp.array([[2., 1.],
... [1., 2.]])
Compute the lu factorization via :func:`~jax.scipy.linalg.lu_factor`,
and use it to solve a linear equation via :func:`~jax.scipy.linalg.lu_solve`.
>>> b = jnp.array([3., 4.])
>>> lufac = jax.scipy.linalg.lu_factor(a)
>>> y = jax.scipy.linalg.lu_solve(lufac, b)
>>> y
Array([0.6666666, 1.6666667], dtype=float32)
Check that the result is consistent:
>>> jnp.allclose(a @ y, b)
Array(True, dtype=bool)
"""
del overwrite_b, check_finite # unused
lu, pivots = lu_and_piv
m, _ = lu.shape[-2:]
perm = lax_linalg.lu_pivots_to_permutation(pivots, m)
return lax_linalg.lu_solve(lu, perm, b, trans)
@overload
def _lu(a: ArrayLike, permute_l: Literal[True]) -> tuple[Array, Array]: ...
@overload
def _lu(a: ArrayLike, permute_l: Literal[False]) -> tuple[Array, Array, Array]: ...
@overload
def _lu(a: ArrayLike, permute_l: bool) -> tuple[Array, Array] | tuple[Array, Array, Array]: ...
@partial(jit, static_argnums=(1,))
def _lu(a: ArrayLike, permute_l: bool) -> tuple[Array, Array] | tuple[Array, Array, Array]:
a, = promote_dtypes_inexact(jnp.asarray(a))
lu, _, permutation = lax_linalg.lu(a)
dtype = lax.dtype(a)
m, n = jnp.shape(a)
p = jnp.real(jnp.array(permutation[None, :] == jnp.arange(m, dtype=permutation.dtype)[:, None], dtype=dtype))
k = min(m, n)
l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype)
u = jnp.triu(lu)[:k, :]
if permute_l:
return jnp.matmul(p, l, precision=lax.Precision.HIGHEST), u
else:
return p, l, u
@overload
def lu(a: ArrayLike, permute_l: Literal[False] = False, overwrite_a: bool = False,
check_finite: bool = True) -> tuple[Array, Array, Array]: ...
@overload
def lu(a: ArrayLike, permute_l: Literal[True], overwrite_a: bool = False,
check_finite: bool = True) -> tuple[Array, Array]: ...
@overload
def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False,
check_finite: bool = True) -> tuple[Array, Array] | tuple[Array, Array, Array]: ...
@partial(jit, static_argnames=('permute_l', 'overwrite_a', 'check_finite'))
def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False,
check_finite: bool = True) -> tuple[Array, Array] | tuple[Array, Array, Array]:
"""Compute the LU decomposition
JAX implementation of :func:`scipy.linalg.lu`.
The LU decomposition of a matrix `A` is:
.. math::
A = P L U
where `P` is a permutation matrix, `L` is lower-triangular and `U` is upper-triangular.
Args:
a: array of shape ``(..., M, N)`` to decompose.
permute_l: if True, then permute ``L`` and return ``(P @ L, U)`` (default: False)
overwrite_a: not used by JAX
check_finite: not used by JAX
Returns:
A tuple of arrays ``(P @ L, U)`` if ``permute_l`` is True, else ``(P, L, U)``:
- ``P`` is a permutation matrix of shape ``(..., M, M)``
- ``L`` is a lower-triangular matrix of shape ``(... M, K)``
- ``U`` is an upper-triangular matrix of shape ``(..., K, N)``
with ``K = min(M, N)``
See also:
- :func:`jax.numpy.linalg.lu`: NumPy-style API for LU decomposition.
- :func:`jax.lax.linalg.lu`: XLA-style API for LU decomposition.
- :func:`jax.scipy.linalg.lu_solve`: LU-based linear solver.
Examples:
An LU decomposition of a 3x3 matrix:
>>> a = jnp.array([[1., 2., 3.],
... [5., 4., 2.],
... [3., 2., 1.]])
>>> P, L, U = jax.scipy.linalg.lu(a)
``P`` is a permutation matrix: i.e. each row and column has a single ``1``:
>>> P
Array([[0., 1., 0.],
[1., 0., 0.],
[0., 0., 1.]], dtype=float32)
``L`` and ``U`` are lower-triangular and upper-triangular matrices:
>>> with jnp.printoptions(precision=3):
... print(L)
... print(U)
[[ 1. 0. 0. ]
[ 0.2 1. 0. ]
[ 0.6 -0.333 1. ]]
[[5. 4. 2. ]
[0. 1.2 2.6 ]
[0. 0. 0.667]]
The original matrix can be reconstructed by multiplying the three together:
>>> a_reconstructed = P @ L @ U
>>> jnp.allclose(a, a_reconstructed)
Array(True, dtype=bool)
"""
del overwrite_a, check_finite # unused
return _lu(a, permute_l)
@overload
def _qr(a: ArrayLike, mode: Literal["r"], pivoting: Literal[False]
) -> tuple[Array]: ...
@overload
def _qr(a: ArrayLike, mode: Literal["r"], pivoting: Literal[True]
) -> tuple[Array, Array]: ...
@overload
def _qr(a: ArrayLike, mode: Literal["full", "economic"], pivoting: Literal[False]
) -> tuple[Array, Array]: ...
@overload
def _qr(a: ArrayLike, mode: Literal["full", "economic"], pivoting: Literal[True]
) -> tuple[Array, Array, Array]: ...
@overload
def _qr(a: ArrayLike, mode: str, pivoting: Literal[False]
) -> tuple[Array] | tuple[Array, Array]: ...
@overload
def _qr(a: ArrayLike, mode: str, pivoting: Literal[True]
) -> tuple[Array, Array] | tuple[Array, Array, Array]: ...
@overload
def _qr(a: ArrayLike, mode: str, pivoting: bool
) -> tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]: ...
@partial(jit, static_argnames=('mode', 'pivoting'))
def _qr(a: ArrayLike, mode: str, pivoting: bool
) -> tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]:
if mode in ("full", "r"):
full_matrices = True
elif mode == "economic":
full_matrices = False
else:
raise ValueError(f"Unsupported QR decomposition mode '{mode}'")
a, = promote_dtypes_inexact(jnp.asarray(a))
q, r, *p = lax_linalg.qr(a, pivoting=pivoting, full_matrices=full_matrices)
if mode == "r":
if pivoting:
return r, p[0]
return (r,)
if pivoting:
return q, r, p[0]
return q, r
@overload
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *,
mode: Literal["full", "economic"], pivoting: Literal[False] = False,
check_finite: bool = True) -> tuple[Array, Array]: ...
@overload
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *,
mode: Literal["full", "economic"], pivoting: Literal[True] = True,
check_finite: bool = True) -> tuple[Array, Array, Array]: ...
@overload
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *,
mode: Literal["full", "economic"], pivoting: bool = False,
check_finite: bool = True
) -> tuple[Array, Array] | tuple[Array, Array, Array]: ...
@overload
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *,
mode: Literal["r"], pivoting: Literal[False] = False, check_finite: bool = True
) -> tuple[Array]: ...
@overload
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *,
mode: Literal["r"], pivoting: Literal[True] = True, check_finite: bool = True
) -> tuple[Array, Array]: ...
@overload
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *,
mode: Literal["r"], pivoting: bool = False, check_finite: bool = True
) -> tuple[Array] | tuple[Array, Array]: ...
@overload
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full",
pivoting: bool = False, check_finite: bool = True
) -> tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]: ...
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full",
pivoting: bool = False, check_finite: bool = True
) -> tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]:
"""Compute the QR decomposition of an array
JAX implementation of :func:`scipy.linalg.qr`.
The QR decomposition of a matrix `A` is given by
.. math::
A = QR
Where `Q` is a unitary matrix (i.e. :math:`Q^HQ=I`) and `R` is an upper-triangular
matrix.
Args:
a: array of shape (..., M, N)
mode: Computational mode. Supported values are:
- ``"full"`` (default): return `Q` of shape ``(M, M)`` and `R` of shape ``(M, N)``.
- ``"r"``: return only `R`
- ``"economic"``: return `Q` of shape ``(M, K)`` and `R` of shape ``(K, N)``,
where K = min(M, N).
pivoting: Allows the QR decomposition to be rank-revealing. If ``True``, compute
the column-pivoted decomposition ``A[:, P] = Q @ R``, where ``P`` is chosen such
that the diagonal of ``R`` is non-increasing.
overwrite_a: unused in JAX
lwork: unused in JAX
check_finite: unused in JAX
Returns:
A tuple ``(Q, R)`` or ``(Q, R, P)``, if ``mode`` is not ``"r"`` and ``pivoting`` is
respectively ``False`` or ``True``, otherwise an array ``R`` or tuple ``(R, P)`` if
mode is ``"r"``, and ``pivoting`` is respectively ``False`` or ``True``, where:
- ``Q`` is an orthogonal matrix of shape ``(..., M, M)`` (if ``mode`` is ``"full"``)
or ``(..., M, K)`` (if ``mode`` is ``"economic"``),
- ``R`` is an upper-triangular matrix of shape ``(..., M, N)`` (if ``mode`` is
``"r"`` or ``"full"``) or ``(..., K, N)`` (if ``mode`` is ``"economic"``),
- ``P`` is an index vector of shape ``(..., N)``.
with ``K = min(M, N)``.
Notes:
- At present, pivoting is only implemented on the CPU and GPU backends. For further
details about the GPU implementation, see the documentation for
:func:`jax.lax.linalg.qr`.
See also:
- :func:`jax.numpy.linalg.qr`: NumPy-style QR decomposition API
- :func:`jax.lax.linalg.qr`: XLA-style QR decomposition API
Examples:
Compute the QR decomposition of a matrix:
>>> a = jnp.array([[1., 2., 3., 4.],
... [5., 4., 2., 1.],
... [6., 3., 1., 5.]])
>>> Q, R = jax.scipy.linalg.qr(a)
>>> Q # doctest: +SKIP
Array([[-0.12700021, -0.7581426 , -0.6396022 ],
[-0.63500065, -0.43322435, 0.63960224],
[-0.7620008 , 0.48737738, -0.42640156]], dtype=float32)
>>> R # doctest: +SKIP
Array([[-7.8740077, -5.080005 , -2.4130025, -4.953006 ],
[ 0. , -1.7870499, -2.6534991, -1.028908 ],
[ 0. , 0. , -1.0660033, -4.050814 ]], dtype=float32)
Check that ``Q`` is orthonormal:
>>> jnp.allclose(Q.T @ Q, jnp.eye(3), atol=1E-5)
Array(True, dtype=bool)
Reconstruct the input:
>>> jnp.allclose(Q @ R, a)
Array(True, dtype=bool)
"""
del overwrite_a, lwork, check_finite # unused
return _qr(a, mode, pivoting)
@partial(jit, static_argnames=('assume_a', 'lower'))
def _solve(a: ArrayLike, b: ArrayLike, assume_a: str, lower: bool) -> Array:
if assume_a != 'pos':
return jnp.linalg.solve(a, b)
a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
lax_linalg._check_solve_shapes(a, b)
# With custom_linear_solve, we can reuse the same factorization when
# computing sensitivities. This is considerably faster.
factors = cho_factor(lax.stop_gradient(a), lower=lower)
custom_solve = partial(
lax.custom_linear_solve,
lambda x: lax_linalg._broadcasted_matvec(a, x),
solve=lambda _, x: cho_solve(factors, x),
symmetric=True)
if a.ndim == b.ndim + 1:
# b.shape == [..., m]
return custom_solve(b)
else:
# b.shape == [..., m, k]
return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
def solve(a: ArrayLike, b: ArrayLike, lower: bool = False,
overwrite_a: bool = False, overwrite_b: bool = False, debug: bool = False,
check_finite: bool = True, assume_a: str = 'gen') -> Array:
"""Solve a linear system of equations.
JAX implementation of :func:`scipy.linalg.solve`.
This solves a (batched) linear system of equations ``a @ x = b`` for ``x``
given ``a`` and ``b``.
If ``a`` is singular, this will return ``nan`` or ``inf`` values.
Args:
a: array of shape ``(..., N, N)``.
b: array of shape ``(..., N)`` or ``(..., N, M)``
lower: Referenced only if ``assume_a != 'gen'``. If True, only use the lower
triangle of the input, If False (default), only use the upper triangle.
assume_a: specify what properties of ``a`` can be assumed. Options are:
- ``"gen"``: generic matrix (default)
- ``"sym"``: symmetric matrix
- ``"her"``: hermitian matrix
- ``"pos"``: positive-definite matrix
overwrite_a: unused by JAX
overwrite_b: unused by JAX
debug: unused by JAX
check_finite: unused by JAX
Returns:
An array of the same shape as ``b`` containing the solution to the linear
system if ``a`` is non-singular.
If ``a`` is singular, the result contains ``nan`` or ``inf`` values.
See also:
- :func:`jax.scipy.linalg.lu_solve`: Solve via LU factorization.
- :func:`jax.scipy.linalg.cho_solve`: Solve via Cholesky factorization.
- :func:`jax.scipy.linalg.solve_triangular`: Solve a triangular system.
- :func:`jax.numpy.linalg.solve`: NumPy-style API for solving linear systems.
- :func:`jax.lax.custom_linear_solve`: matrix-free linear solver.
Examples:
A simple 3x3 linear system:
>>> A = jnp.array([[1., 2., 3.],
... [2., 4., 2.],
... [3., 2., 1.]])
>>> b = jnp.array([14., 16., 10.])
>>> x = jax.scipy.linalg.solve(A, b)
>>> x
Array([1., 2., 3.], dtype=float32)
Confirming that the result solves the system:
>>> jnp.allclose(A @ x, b)
Array(True, dtype=bool)
"""
del overwrite_a, overwrite_b, debug, check_finite #unused
valid_assume_a = ['gen', 'sym', 'her', 'pos']
if assume_a not in valid_assume_a:
raise ValueError(f"Expected assume_a to be one of {valid_assume_a}; got {assume_a!r}")
return _solve(a, b, assume_a, lower)
@partial(jit, static_argnames=('trans', 'lower', 'unit_diagonal'))
def _solve_triangular(a: ArrayLike, b: ArrayLike, trans: int | str,
lower: bool, unit_diagonal: bool) -> Array:
if trans == 0 or trans == "N":
transpose_a, conjugate_a = False, False
elif trans == 1 or trans == "T":
transpose_a, conjugate_a = True, False
elif trans == 2 or trans == "C":
transpose_a, conjugate_a = True, True
else:
raise ValueError(f"Invalid 'trans' value {trans}")
a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
# lax_linalg.triangular_solve only supports matrix 'b's at the moment.
b_is_vector = jnp.ndim(a) == jnp.ndim(b) + 1
if b_is_vector:
b = b[..., None]
out = lax_linalg.triangular_solve(a, b, left_side=True, lower=lower,
transpose_a=transpose_a,
conjugate_a=conjugate_a,
unit_diagonal=unit_diagonal)
if b_is_vector:
return out[..., 0]
else:
return out
def solve_triangular(a: ArrayLike, b: ArrayLike, trans: int | str = 0, lower: bool = False,
unit_diagonal: bool = False, overwrite_b: bool = False,
debug: Any = None, check_finite: bool = True) -> Array:
"""Solve a triangular linear system of equations
JAX implementation of :func:`scipy.linalg.solve_triangular`.
This solves a (batched) linear system of equations ``a @ x = b`` for ``x``
given a triangular matrix ``a`` and a vector or matrix ``b``.
Args:
a: array of shape ``(..., N, N)``. Only part of the array will be accessed,
depending on the ``lower`` and ``unit_diagonal`` arguments.
b: array of shape ``(..., N)`` or ``(..., N, M)``
lower: If True, only use the lower triangle of the input, If False (default),
only use the upper triangle.
unit_diagonal: If True, ignore diagonal elements of ``a`` and assume they are
``1`` (default: False).
trans: specify what properties of ``a`` can be assumed. Options are:
- ``0`` or ``'N'``: solve :math:`Ax=b`
- ``1`` or ``'T'``: solve :math:`A^Tx=b`
- ``2`` or ``'C'``: solve :math:`A^Hx=b`
overwrite_b: unused by JAX
debug: unused by JAX
check_finite: unused by JAX
Returns:
An array of the same shape as ``b`` containing the solution to the linear system.
See also:
:func:`jax.scipy.linalg.solve`: Solve a general linear system.
Examples:
A simple 3x3 triangular linear system:
>>> A = jnp.array([[1., 2., 3.],
... [0., 3., 2.],
... [0., 0., 5.]])
>>> b = jnp.array([10., 8., 5.])
>>> x = jax.scipy.linalg.solve_triangular(A, b)
>>> x
Array([3., 2., 1.], dtype=float32)
Confirming that the result solves the system:
>>> jnp.allclose(A @ x, b)
Array(True, dtype=bool)
Computing the transposed problem:
>>> x = jax.scipy.linalg.solve_triangular(A, b, trans='T')
>>> x
Array([10. , -4. , -3.4], dtype=float32)
Confirming that the result solves the system:
>>> jnp.allclose(A.T @ x, b)
Array(True, dtype=bool)
"""
del overwrite_b, debug, check_finite # unused
return _solve_triangular(a, b, trans, lower, unit_diagonal)
@partial(jit, static_argnames=('upper_triangular', 'max_squarings'))
def expm(A: ArrayLike, *, upper_triangular: bool = False, max_squarings: int = 16) -> Array:
"""Compute the matrix exponential
JAX implementation of :func:`scipy.linalg.expm`.
Args:
A: array of shape ``(..., N, N)``
upper_triangular: if True, then assume that ``A`` is upper-triangular. Default=False.
max_squarings: The number of squarings in the scaling-and-squaring approximation method
(default: 16).
Returns:
An array of shape ``(..., N, N)`` containing the matrix exponent of ``A``.
Notes:
This uses the scaling-and-squaring approximation method, with computational complexity
controlled by the optional ``max_squarings`` argument. Theoretically, the number of
required squarings is ``max(0, ceil(log2(norm(A))) - c)`` where ``norm(A)`` is the L1
norm and ``c=2.42`` for float64/complex128, or ``c=1.97`` for float32/complex64.
See Also:
:func:`jax.scipy.linalg.expm_frechet`
Examples:
``expm`` is the matrix exponential, and has similar properties to the more
familiar scalar exponential. For scalars ``a`` and ``b``, :math:`e^{a + b}
= e^a e^b`. However, for matrices, this property only holds when ``A`` and
``B`` commute (``AB = BA``). In this case, ``expm(A+B) = expm(A) @ expm(B)``
>>> A = jnp.array([[2, 0],
... [0, 1]])
>>> B = jnp.array([[3, 0],
... [0, 4]])
>>> jnp.allclose(jax.scipy.linalg.expm(A+B),
... jax.scipy.linalg.expm(A) @ jax.scipy.linalg.expm(B),
... rtol=0.0001)
Array(True, dtype=bool)
If a matrix ``X`` is invertible, then
``expm(X @ A @ inv(X)) = X @ expm(A) @ inv(X)``
>>> X = jnp.array([[3, 1],
... [2, 5]])
>>> X_inv = jax.scipy.linalg.inv(X)
>>> jnp.allclose(jax.scipy.linalg.expm(X @ A @ X_inv),
... X @ jax.scipy.linalg.expm(A) @ X_inv)
Array(True, dtype=bool)
"""
A, = promote_dtypes_inexact(A)
if A.ndim < 2 or A.shape[-1] != A.shape[-2]:
raise ValueError(f"Expected A to be a (batched) square matrix, got {A.shape=}.")
if A.ndim > 2:
return jnp.vectorize(
partial(expm, upper_triangular=upper_triangular, max_squarings=max_squarings),
signature="(n,n)->(n,n)")(A)
P, Q, n_squarings = _calc_P_Q(jnp.asarray(A))
def _nan(args):
A, *_ = args
return jnp.full_like(A, jnp.nan)
def _compute(args):
A, P, Q = args
R = _solve_P_Q(P, Q, upper_triangular)
R = _squaring(R, n_squarings, max_squarings)
return R
R = lax.cond(n_squarings > max_squarings, _nan, _compute, (A, P, Q))
return R
@jit
def _calc_P_Q(A: Array) -> tuple[Array, Array, Array]:
if A.ndim != 2 or A.shape[0] != A.shape[1]:
raise ValueError('expected A to be a square matrix')
A_L1 = jnp.linalg.norm(A,1)
n_squarings: Array
U: Array
V: Array
if A.dtype == 'float64' or A.dtype == 'complex128':
maxnorm = 5.371920351148152
n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
A = A / 2 ** n_squarings.astype(A.dtype)
conds = jnp.array([1.495585217958292e-002, 2.539398330063230e-001,
9.504178996162932e-001, 2.097847961257068e+000],
dtype=A_L1.dtype)
idx = jnp.digitize(A_L1, conds)
U, V = lax.switch(idx, [_pade3, _pade5, _pade7, _pade9, _pade13], A)
elif A.dtype == 'float32' or A.dtype == 'complex64':
maxnorm = 3.925724783138660
n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
A = A / 2 ** n_squarings.astype(A.dtype)
conds = jnp.array([4.258730016922831e-001, 1.880152677804762e+000],
dtype=A_L1.dtype)
idx = jnp.digitize(A_L1, conds)
U, V = lax.switch(idx, [_pade3, _pade5, _pade7], A)
else:
raise TypeError(f"A.dtype={A.dtype} is not supported.")
P = U + V # p_m(A) : numerator
Q = -U + V # q_m(A) : denominator
return P, Q, n_squarings
def _solve_P_Q(P: ArrayLike, Q: ArrayLike, upper_triangular: bool = False) -> Array:
if upper_triangular:
return solve_triangular(Q, P)
else:
return jnp.linalg.solve(Q, P)
def _precise_dot(A: ArrayLike, B: ArrayLike) -> Array:
return jnp.dot(A, B, precision=lax.Precision.HIGHEST)
@partial(jit, static_argnums=2)
def _squaring(R: Array, n_squarings: Array, max_squarings: int) -> Array:
# squaring step to undo scaling
def _squaring_precise(x):
return _precise_dot(x, x)
def _identity(x):
return x
def _scan_f(c, i):
return lax.cond(i < n_squarings, _squaring_precise, _identity, c), None
res, _ = lax.scan(_scan_f, R, jnp.arange(max_squarings, dtype=n_squarings.dtype))
return res
def _pade3(A: Array) -> tuple[Array, Array]:
b = (120., 60., 12., 1.)
M, N = A.shape
ident = jnp.eye(M, N, dtype=A.dtype)
A2 = _precise_dot(A, A)
U = _precise_dot(A, (b[3]*A2 + b[1]*ident))
V: Array = b[2]*A2 + b[0]*ident
return U, V
def _pade5(A: Array) -> tuple[Array, Array]:
b = (30240., 15120., 3360., 420., 30., 1.)
M, N = A.shape
ident = jnp.eye(M, N, dtype=A.dtype)
A2 = _precise_dot(A, A)
A4 = _precise_dot(A2, A2)
U = _precise_dot(A, b[5]*A4 + b[3]*A2 + b[1]*ident)
V: Array = b[4]*A4 + b[2]*A2 + b[0]*ident
return U, V
def _pade7(A: Array) -> tuple[Array, Array]:
b = (17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.)
M, N = A.shape
ident = jnp.eye(M, N, dtype=A.dtype)
A2 = _precise_dot(A, A)
A4 = _precise_dot(A2, A2)
A6 = _precise_dot(A4, A2)
U = _precise_dot(A, b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
V = b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V
def _pade9(A: Array) -> tuple[Array, Array]:
b = (17643225600., 8821612800., 2075673600., 302702400., 30270240.,
2162160., 110880., 3960., 90., 1.)
M, N = A.shape
ident = jnp.eye(M, N, dtype=A.dtype)
A2 = _precise_dot(A, A)
A4 = _precise_dot(A2, A2)
A6 = _precise_dot(A4, A2)
A8 = _precise_dot(A6, A2)
U = _precise_dot(A, b[9]*A8 + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
V = b[8]*A8 + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V
def _pade13(A: Array) -> tuple[Array, Array]:
b = (64764752532480000., 32382376266240000., 7771770303897600.,
1187353796428800., 129060195264000., 10559470521600., 670442572800.,
33522128640., 1323241920., 40840800., 960960., 16380., 182., 1.)
M, N = A.shape
ident = jnp.eye(M, N, dtype=A.dtype)
A2 = _precise_dot(A, A)
A4 = _precise_dot(A2, A2)
A6 = _precise_dot(A4, A2)
U = _precise_dot(A, _precise_dot(A6, b[13]*A6 + b[11]*A4 + b[9]*A2) + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
V = _precise_dot(A6, b[12]*A6 + b[10]*A4 + b[8]*A2) + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V
@overload
def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None,
compute_expm: Literal[True] = True) -> tuple[Array, Array]: ...
@overload
def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None,
compute_expm: Literal[False]) -> Array: ...
@overload
def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None,
compute_expm: bool = True) -> Array | tuple[Array, Array]: ...
@partial(jit, static_argnames=('method', 'compute_expm'))
def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None,
compute_expm: bool = True) -> Array | tuple[Array, Array]:
"""Compute the Frechet derivative of the matrix exponential.
JAX implementation of :func:`scipy.linalg.expm_frechet`
Args:
A: array of shape ``(..., N, N)``
E: array of shape ``(..., N, N)``; specifies the direction of the derivative.
compute_expm: if True (default) then compute and return ``expm(A)``.
method: ignored by JAX
Returns:
A tuple ``(expm_A, expm_frechet_AE)`` if ``compute_expm`` is True, else
the array ``expm_frechet_AE``. Both returned arrays have shape ``(..., N, N)``.
See also:
:func:`jax.scipy.linalg.expm`
Examples:
We can use this API to compute the matrix exponential of ``A``, as well as its
derivative in the direction ``E``:
>>> key1, key2 = jax.random.split(jax.random.key(3372))
>>> A = jax.random.normal(key1, (3, 3))
>>> E = jax.random.normal(key2, (3, 3))
>>> expmA, expm_frechet_AE = jax.scipy.linalg.expm_frechet(A, E)
This can be equivalently computed using JAX's automatic differentiation methods;
here we'll compute the derivative of :func:`~jax.scipy.linalg.expm` in the
direction of ``E`` using :func:`jax.jvp`, and find the same results:
>>> expmA2, expm_frechet_AE2 = jax.jvp(jax.scipy.linalg.expm, (A,), (E,))
>>> jnp.allclose(expmA, expmA2)
Array(True, dtype=bool)
>>> jnp.allclose(expm_frechet_AE, expm_frechet_AE2)
Array(True, dtype=bool)
"""
del method # unused
A_arr = jnp.asarray(A)
E_arr = jnp.asarray(E)
if A_arr.ndim < 2 or A_arr.shape[-2] != A_arr.shape[1]:
raise ValueError(f'expected A to be a (batched) square matrix, got A.shape={A_arr.shape}')
if E_arr.ndim < 2 or E_arr.shape[-2] != E_arr.shape[-1]:
raise ValueError(f'expected E to be a (batched) square matrix, got E.shape={E_arr.shape}')
if A_arr.shape != E_arr.shape:
raise ValueError('expected A and E to be the same shape, got '
f'A.shape={A_arr.shape} E.shape={E_arr.shape}')
bound_fun = partial(expm, upper_triangular=False, max_squarings=16)
expm_A, expm_frechet_AE = jvp(bound_fun, (A_arr,), (E_arr,))
if compute_expm:
return expm_A, expm_frechet_AE
else:
return expm_frechet_AE
@jit
def block_diag(*arrs: ArrayLike) -> Array:
"""Create a block diagonal matrix from input arrays.
JAX implementation of :func:`scipy.linalg.block_diag`.
Args:
*arrs: arrays of at most two dimensions
Returns:
2D block-diagonal array constructed by placing the input arrays
along the diagonal.
Examples:
>>> A = jnp.ones((1, 1))
>>> B = jnp.ones((2, 2))
>>> C = jnp.ones((3, 3))
>>> jax.scipy.linalg.block_diag(A, B, C)
Array([[1., 0., 0., 0., 0., 0.],
[0., 1., 1., 0., 0., 0.],
[0., 1., 1., 0., 0., 0.],
[0., 0., 0., 1., 1., 1.],
[0., 0., 0., 1., 1., 1.],
[0., 0., 0., 1., 1., 1.]], dtype=float32)
"""
if len(arrs) == 0:
arrs = (jnp.zeros((1, 0)),)
arrs = tuple(promote_dtypes(*arrs))
bad_shapes = [i for i, a in enumerate(arrs) if jnp.ndim(a) > 2]
if bad_shapes:
raise ValueError("Arguments to jax.scipy.linalg.block_diag must have at "
"most 2 dimensions, got {} at argument {}."
.format(arrs[bad_shapes[0]], bad_shapes[0]))
converted_arrs = [jnp.atleast_2d(a) for a in arrs]
acc = converted_arrs[0]
dtype = lax.dtype(acc)
for a in converted_arrs[1:]:
_, c = a.shape
a = lax.pad(a, dtype.type(0), ((0, 0, 0), (acc.shape[-1], 0, 0)))
acc = lax.pad(acc, dtype.type(0), ((0, 0, 0), (0, c, 0)))
acc = lax.concatenate([acc, a], dimension=0)
return acc
@partial(jit, static_argnames=("eigvals_only", "select", "select_range"))
def eigh_tridiagonal(d: ArrayLike, e: ArrayLike, *, eigvals_only: bool = False,
select: str = 'a', select_range: tuple[float, float] | None = None,
tol: float | None = None) -> Array:
"""Solve the eigenvalue problem for a symmetric real tridiagonal matrix
JAX implementation of :func:`scipy.linalg.eigh_tridiagonal`.
Args:
d: real-valued array of shape ``(N,)`` specifying the diagonal elements.
e: real-valued array of shape ``(N - 1,)`` specifying the off-diagonal elements.
eigvals_only: If True, return only the eigenvalues (default: False). Computation
of eigenvectors is not yet implemented, so ``eigvals_only`` must be set to True.
select: specify which eigenvalues to calculate. Supported values are:
- ``'a'``: all eigenvalues
- ``'i'``: eigenvalues with indices ``select_range[0] <= i <= select_range[1]``
JAX does not currently implement ``select = 'v'``.
select_range: range of values used when ``select='i'``.
tol: absolute tolerance to use when solving for the eigenvalues.
Returns:
An array of eigenvalues with shape ``(N,)``.
See also:
:func:`jax.scipy.linalg.eigh`: general Hermitian eigenvalue solver
Examples:
>>> d = jnp.array([1., 2., 3., 4.])
>>> e = jnp.array([1., 1., 1.])
>>> eigvals = jax.scipy.linalg.eigh_tridiagonal(d, e, eigvals_only=True)
>>> eigvals
Array([0.2547188, 1.8227171, 3.1772828, 4.745281 ], dtype=float32)
For comparison, we can construct the full matrix and compute the same result
using :func:`~jax.scipy.linalg.eigh`:
>>> A = jnp.diag(d) + jnp.diag(e, 1) + jnp.diag(e, -1)
>>> A
Array([[1., 1., 0., 0.],
[1., 2., 1., 0.],
[0., 1., 3., 1.],
[0., 0., 1., 4.]], dtype=float32)
>>> eigvals_full = jax.scipy.linalg.eigh(A, eigvals_only=True)
>>> jnp.allclose(eigvals, eigvals_full)
Array(True, dtype=bool)
"""
if not eigvals_only:
raise NotImplementedError("Calculation of eigenvectors is not implemented")
def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x):
"""Implements the Sturm sequence recurrence."""
n = alpha.shape[0]
zeros = jnp.zeros(x.shape, dtype=jnp.int32)
ones = jnp.ones(x.shape, dtype=jnp.int32)
# The first step in the Sturm sequence recurrence
# requires special care if x is equal to alpha[0].
def sturm_step0():
q = alpha[0] - x
count = jnp.where(q < 0, ones, zeros)
q = jnp.where(alpha[0] == x, alpha0_perturbation, q)
return q, count
# Subsequent steps all take this form:
def sturm_step(i, q, count):
q = alpha[i] - beta_sq[i - 1] / q - x
count = jnp.where(q <= pivmin, count + 1, count)
q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q)
return q, count
# The first step initializes q and count.
q, count = sturm_step0()
# Peel off ((n-1) % blocksize) steps from the main loop, so we can run
# the bulk of the iterations unrolled by a factor of blocksize.
blocksize = 16
i = 1
peel = (n - 1) % blocksize
unroll_cnt = peel
def unrolled_steps(args):
start, q, count = args
for j in range(unroll_cnt):
q, count = sturm_step(start + j, q, count)
return start + unroll_cnt, q, count
i, q, count = unrolled_steps((i, q, count))
# Run the remaining steps of the Sturm sequence using a partially
# unrolled while loop.
unroll_cnt = blocksize
def cond(iqc):
i, q, count = iqc
return jnp.less(i, n)
_, _, count = lax.while_loop(cond, unrolled_steps, (i, q, count))
return count
alpha = jnp.asarray(d)
beta = jnp.asarray(e)
supported_dtypes = (jnp.float32, jnp.float64, jnp.complex64, jnp.complex128)
if alpha.dtype != beta.dtype:
raise TypeError("diagonal and off-diagonal values must have same dtype, "
f"got {alpha.dtype} and {beta.dtype}")
if alpha.dtype not in supported_dtypes or beta.dtype not in supported_dtypes:
raise TypeError("Only float32 and float64 inputs are supported as inputs "
"to jax.scipy.linalg.eigh_tridiagonal, got "
f"{alpha.dtype} and {beta.dtype}")
n = alpha.shape[0]
if n <= 1:
return jnp.real(alpha)
if jnp.issubdtype(alpha.dtype, np.complexfloating):
alpha = jnp.real(alpha)
beta_sq = jnp.real(beta * jnp.conj(beta))
beta_abs = jnp.sqrt(beta_sq)
else:
beta_abs = jnp.abs(beta)
beta_sq = jnp.square(beta)
# Estimate the largest and smallest eigenvalues of T using the Gershgorin
# circle theorem.
off_diag_abs_row_sum = jnp.concatenate(
[beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta_abs[-1:]], axis=0)
lambda_est_max = jnp.amax(alpha + off_diag_abs_row_sum)
lambda_est_min = jnp.amin(alpha - off_diag_abs_row_sum)
# Upper bound on 2-norm of T.
t_norm = jnp.maximum(jnp.abs(lambda_est_min), jnp.abs(lambda_est_max))
# Compute the smallest allowed pivot in the Sturm sequence to avoid
# overflow.
finfo = np.finfo(alpha.dtype)
one = np.ones([], dtype=alpha.dtype)
safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny)
pivmin = safemin * jnp.maximum(1, jnp.amax(beta_sq))
alpha0_perturbation = jnp.square(finfo.eps * beta_abs[0])
abs_tol = finfo.eps * t_norm
if tol is not None:
abs_tol = jnp.maximum(tol, abs_tol)
# In the worst case, when the absolute tolerance is eps*lambda_est_max and
# lambda_est_max = -lambda_est_min, we have to take as many bisection steps
# as there are bits in the mantissa plus 1.
# The proof is left as an exercise to the reader.
max_it = finfo.nmant + 1
# Determine the indices of the desired eigenvalues, based on select and
# select_range.
if select == 'a':
target_counts = jnp.arange(n, dtype=jnp.int32)
elif select == 'i':
if select_range is None:
raise ValueError("for select='i', select_range must be specified.")
if select_range[0] > select_range[1]:
raise ValueError('Got empty index range in select_range.')
target_counts = jnp.arange(select_range[0], select_range[1] + 1, dtype=jnp.int32)
elif select == 'v':
# TODO(phawkins): requires dynamic shape support.
raise NotImplementedError("eigh_tridiagonal(..., select='v') is not "
"implemented")
else:
raise ValueError("'select must have a value in {'a', 'i', 'v'}.")
# Run binary search for all desired eigenvalues in parallel, starting from
# the interval lightly wider than the estimated
# [lambda_est_min, lambda_est_max].
fudge = 2.1 # We widen starting interval the Gershgorin interval a bit.
norm_slack = jnp.array(n, alpha.dtype) * fudge * finfo.eps * t_norm
lower = lambda_est_min - norm_slack - 2 * fudge * pivmin
upper = lambda_est_max + norm_slack + fudge * pivmin
# Pre-broadcast the scalars used in the Sturm sequence for improved
# performance.
target_shape = jnp.shape(target_counts)
lower = jnp.broadcast_to(lower, shape=target_shape)
upper = jnp.broadcast_to(upper, shape=target_shape)
mid = 0.5 * (upper + lower)
pivmin = jnp.broadcast_to(pivmin, target_shape)
alpha0_perturbation = jnp.broadcast_to(alpha0_perturbation, target_shape)
# Start parallel binary searches.
def cond(args):
i, lower, _, upper = args
return jnp.logical_and(
jnp.less(i, max_it),
jnp.less(abs_tol, jnp.amax(upper - lower)))
def body(args):
i, lower, mid, upper = args
counts = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, mid)
lower = jnp.where(counts <= target_counts, mid, lower)
upper = jnp.where(counts > target_counts, mid, upper)
mid = 0.5 * (lower + upper)
return i + 1, lower, mid, upper
_, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper))
return mid
@partial(jit, static_argnames=('side', 'method'))
@jax.default_matmul_precision("float32")
def polar(a: ArrayLike, side: str = 'right', *, method: str = 'qdwh', eps: float | None = None,
max_iterations: int | None = None) -> tuple[Array, Array]:
r"""Computes the polar decomposition.
Given the :math:`m \times n` matrix :math:`a`, returns the factors of the polar
decomposition :math:`u` (also :math:`m \times n`) and :math:`p` such that
:math:`a = up` (if side is ``"right"``; :math:`p` is :math:`n \times n`) or
:math:`a = pu` (if side is ``"left"``; :math:`p` is :math:`m \times m`),
where :math:`p` is positive semidefinite. If :math:`a` is nonsingular,
:math:`p` is positive definite and the
decomposition is unique. :math:`u` has orthonormal columns unless
:math:`n > m`, in which case it has orthonormal rows.
Writing the SVD of :math:`a` as
:math:`a = u_\mathit{svd} \cdot s_\mathit{svd} \cdot v^h_\mathit{svd}`, we
have :math:`u = u_\mathit{svd} \cdot v^h_\mathit{svd}`. Thus the unitary
factor :math:`u` can be constructed as the application of the sign function to
the singular values of :math:`a`; or, if :math:`a` is Hermitian, the
eigenvalues.
Several methods exist to compute the polar decomposition. Currently two
are supported:
* ``method="svd"``:
Computes the SVD of :math:`a` and then forms
:math:`u = u_\mathit{svd} \cdot v^h_\mathit{svd}`.
* ``method="qdwh"``:
Applies the `QDWH`_ (QR-based Dynamically Weighted Halley) algorithm.
Args:
a: The :math:`m \times n` input matrix.
side: Determines whether a right or left polar decomposition is computed.
If ``side`` is ``"right"`` then :math:`a = up`. If ``side`` is ``"left"``
then :math:`a = pu`. The default is ``"right"``.
method: Determines the algorithm used, as described above.
precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
eps: The final result will satisfy
:math:`\left|x_k - x_{k-1}\right| < \left|x_k\right| (4\epsilon)^{\frac{1}{3}}`,
where :math:`x_k` are the QDWH iterates. Ignored if ``method`` is not
``"qdwh"``.
max_iterations: Iterations will terminate after this many steps even if the
above is unsatisfied. Ignored if ``method`` is not ``"qdwh"``.
Returns:
A ``(unitary, posdef)`` tuple, where ``unitary`` is the unitary factor
(:math:`m \times n`), and ``posdef`` is the positive-semidefinite factor.
``posdef`` is either :math:`n \times n` or :math:`m \times m` depending on
whether ``side`` is ``"right"`` or ``"left"``, respectively.
Examples:
Polar decomposition of a 3x3 matrix:
>>> a = jnp.array([[1., 2., 3.],
... [5., 4., 2.],
... [3., 2., 1.]])
>>> U, P = jax.scipy.linalg.polar(a)
U is a Unitary Matrix:
>>> jnp.round(U.T @ U) # doctest: +SKIP
Array([[ 1., -0., -0.],
[-0., 1., 0.],
[-0., 0., 1.]], dtype=float32)
P is positive-semidefinite Matrix:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(P)
[[4.79 3.25 1.23]
[3.25 3.06 2.01]
[1.23 2.01 2.91]]
The original matrix can be reconstructed by multiplying the U and P:
>>> a_reconstructed = U @ P
>>> jnp.allclose(a, a_reconstructed)
Array(True, dtype=bool)
.. _QDWH: https://epubs.siam.org/doi/abs/10.1137/090774999
"""
arr = jnp.asarray(a)
if arr.ndim != 2:
raise ValueError("The input `a` must be a 2-D array.")
if side not in ["right", "left"]:
raise ValueError("The argument `side` must be either 'right' or 'left'.")
m, n = arr.shape
if method == "qdwh":
# TODO(phawkins): return info also if the user opts in?
if m >= n and side == "right":
unitary, posdef, _, _ = qdwh.qdwh(arr, is_hermitian=False, eps=eps)
elif m < n and side == "left":
arr = arr.T.conj()
unitary, posdef, _, _ = qdwh.qdwh(arr, is_hermitian=False, eps=eps)
posdef = posdef.T.conj()
unitary = unitary.T.conj()
else:
raise NotImplementedError("method='qdwh' only supports mxn matrices "
"where m < n where side='right' and m >= n "
f"side='left', got {arr.shape} with {side=}")
elif method == "svd":
u_svd, s_svd, vh_svd = lax_linalg.svd(arr, full_matrices=False)
s_svd = s_svd.astype(u_svd.dtype)
unitary = u_svd @ vh_svd
if side == "right":
# a = u * p
posdef = (vh_svd.T.conj() * s_svd[None, :]) @ vh_svd
else:
# a = p * u
posdef = (u_svd * s_svd[None, :]) @ (u_svd.T.conj())
else:
raise ValueError(f"Unknown polar decomposition method {method}.")
return unitary, posdef
@jit
def _sqrtm_triu(T: Array) -> Array:
"""
Implements Björck, Å., & Hammarling, S. (1983).
"A Schur method for the square root of a matrix". Linear algebra and
its applications", 52, 127-140.
"""
diag = jnp.sqrt(jnp.diag(T))
n = diag.size
U = jnp.diag(diag)
def i_loop(l, data):
j, U = data
i = j - 1 - l
s = lax.fori_loop(i + 1, j, lambda k, val: val + U[i, k] * U[k, j], 0.0)
value = jnp.where(T[i, j] == s, 0.0,
(T[i, j] - s) / (diag[i] + diag[j]))
return j, U.at[i, j].set(value)
def j_loop(j, U):
_, U = lax.fori_loop(0, j, i_loop, (j, U))
return U
U = lax.fori_loop(0, n, j_loop, U)
return U
@jit
def _sqrtm(A: ArrayLike) -> Array:
T, Z = schur(A, output='complex')
sqrt_T = _sqrtm_triu(T)
return jnp.matmul(jnp.matmul(Z, sqrt_T, precision=lax.Precision.HIGHEST),
jnp.conj(Z.T), precision=lax.Precision.HIGHEST)
def sqrtm(A: ArrayLike, blocksize: int = 1) -> Array:
"""Compute the matrix square root
JAX implementation of :func:`scipy.linalg.sqrtm`.
Args:
A: array of shape ``(N, N)``
blocksize: Not supported in JAX; JAX always uses ``blocksize=1``.
Returns:
An array of shape ``(N, N)`` containing the matrix square root of ``A``
See Also:
:func:`jax.scipy.linalg.expm`
Examples:
>>> a = jnp.array([[1., 2., 3.],
... [2., 4., 2.],
... [3., 2., 1.]])
>>> sqrt_a = jax.scipy.linalg.sqrtm(a)
>>> with jnp.printoptions(precision=2, suppress=True):
... print(sqrt_a)
[[0.92+0.71j 0.54+0.j 0.92-0.71j]
[0.54+0.j 1.85+0.j 0.54-0.j ]
[0.92-0.71j 0.54-0.j 0.92+0.71j]]
By definition, matrix multiplication of the matrix square root with itself should
equal the input:
>>> jnp.allclose(a, sqrt_a @ sqrt_a)
Array(True, dtype=bool)
Notes:
This function implements the complex Schur method described in [1]_. It does not use
recursive blocking to speed up computations as a Sylvester Equation solver is not
yet available in JAX.
References:
.. [1] Björck, Å., & Hammarling, S. (1983). "A Schur method for the square root of a matrix".
Linear algebra and its applications, 52, 127-140.
"""
if blocksize > 1:
raise NotImplementedError("Blocked version is not implemented yet.")
return _sqrtm(A)
@partial(jit, static_argnames=('check_finite',))
def rsf2csf(T: ArrayLike, Z: ArrayLike, check_finite: bool = True) -> tuple[Array, Array]:
"""Convert real Schur form to complex Schur form.
JAX implementation of :func:`scipy.linalg.rsf2csf`.
Args:
T: array of shape ``(..., N, N)`` containing the real Schur form of the input.
Z: array of shape ``(..., N, N)`` containing the corresponding Schur transformation
matrix.
check_finite: unused by JAX
Returns:
A tuple of arrays ``(T, Z)`` of the same shape as the inputs, containing the
Complex Schur form and the associated Schur transformation matrix.
See Also:
:func:`jax.scipy.linalg.schur`: Schur decomposition
Examples:
>>> A = jnp.array([[0., 3., 3.],
... [0., 1., 2.],
... [2., 0., 1.]])
>>> Tr, Zr = jax.scipy.linalg.schur(A)
>>> Tc, Zc = jax.scipy.linalg.rsf2csf(Tr, Zr)
Both the real and complex form can be used to reconstruct the input matrix
to float32 precision:
>>> jnp.allclose(Zr @ Tr @ Zr.T, A, atol=1E-5)
Array(True, dtype=bool)
>>> jnp.allclose(Zc @ Tc @ Zc.conj().T, A, atol=1E-5)
Array(True, dtype=bool)
The real-valued Schur form is only quasi-upper-triangular, as we can see in this case:
>>> with jax.numpy.printoptions(precision=2, suppress=True):
... print(Tr)
[[ 3.76 -2.17 1.38]
[ 0. -0.88 -0.35]
[ 0. 2.37 -0.88]]
By contrast, the complex form is truly upper-triangular:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(Tc)
[[ 3.76+0.j 1.29-0.78j 2.02-0.5j ]
[ 0. +0.j -0.88+0.91j -2.02+0.j ]
[ 0. +0.j 0. +0.j -0.88-0.91j]]
"""
del check_finite # unused
T_arr = jnp.asarray(T)
Z_arr = jnp.asarray(Z)
if T_arr.ndim != 2 or T_arr.shape[0] != T_arr.shape[1]:
raise ValueError("Input 'T' must be square.")
if Z_arr.ndim != 2 or Z_arr.shape[0] != Z_arr.shape[1]:
raise ValueError("Input 'Z' must be square.")
if T_arr.shape[0] != Z_arr.shape[0]:
raise ValueError(f"Input array shapes must match: Z: {Z_arr.shape} vs. T: {T_arr.shape}")
T_arr, Z_arr = promote_dtypes_complex(T_arr, Z_arr)
eps = jnp.finfo(T_arr.dtype).eps
N = T_arr.shape[0]
if N == 1:
return T_arr, Z_arr
def _update_T_Z(m, T, Z):
mu = jnp.linalg.eigvals(lax.dynamic_slice(T, (m-1, m-1), (2, 2))) - T[m, m]
r = jnp.linalg.norm(jnp.array([mu[0], T[m, m-1]])).astype(T.dtype)
c = mu[0] / r
s = T[m, m-1] / r
G = jnp.array([[c.conj(), s], [-s, c]], dtype=T.dtype)
# T[m-1:m+1, m-1:] = G @ T[m-1:m+1, m-1:]
T_rows = lax.dynamic_slice_in_dim(T, m-1, 2, axis=0)
col_mask = jnp.arange(N) >= m-1
G_dot_T_zeroed_cols = G @ jnp.where(col_mask, T_rows, 0)
T_rows_new = jnp.where(~col_mask, T_rows, G_dot_T_zeroed_cols)
T = lax.dynamic_update_slice_in_dim(T, T_rows_new, m-1, axis=0)
# T[:m+1, m-1:m+1] = T[:m+1, m-1:m+1] @ G.conj().T
T_cols = lax.dynamic_slice_in_dim(T, m-1, 2, axis=1)
row_mask = jnp.arange(N)[:, jnp.newaxis] < m+1
T_zeroed_rows_dot_GH = jnp.where(row_mask, T_cols, 0) @ G.conj().T
T_cols_new = jnp.where(~row_mask, T_cols, T_zeroed_rows_dot_GH)
T = lax.dynamic_update_slice_in_dim(T, T_cols_new, m-1, axis=1)
# Z[:, m-1:m+1] = Z[:, m-1:m+1] @ G.conj().T
Z_cols = lax.dynamic_slice_in_dim(Z, m-1, 2, axis=1)
Z = lax.dynamic_update_slice_in_dim(Z, Z_cols @ G.conj().T, m-1, axis=1)
return T, Z
def _rsf2scf_iter(i, TZ):
m = N-i
T, Z = TZ
T, Z = lax.cond(
jnp.abs(T[m, m-1]) > eps*(jnp.abs(T[m-1, m-1]) + jnp.abs(T[m, m])),
_update_T_Z,
lambda m, T, Z: (T, Z),
m, T, Z)
T = T.at[m, m-1].set(0.0)
return T, Z
return lax.fori_loop(1, N, _rsf2scf_iter, (T_arr, Z_arr))
@overload
def hessenberg(a: ArrayLike, *, calc_q: Literal[False], overwrite_a: bool = False,
check_finite: bool = True) -> Array: ...
@overload
def hessenberg(a: ArrayLike, *, calc_q: Literal[True], overwrite_a: bool = False,
check_finite: bool = True) -> tuple[Array, Array]: ...
@partial(jit, static_argnames=('calc_q', 'check_finite', 'overwrite_a'))
def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False,
check_finite: bool = True) -> Array | tuple[Array, Array]:
"""Compute the Hessenberg form of the matrix
JAX implementation of :func:`scipy.linalg.hessenberg`.
The Hessenberg form `H` of a matrix `A` satisfies:
.. math::
A = Q H Q^H
where `Q` is unitary and `H` is zero below the first subdiagonal.
Args:
a : array of shape ``(..., N, N)``
calc_q: if True, calculate the ``Q`` matrix (default: False)
overwrite_a: unused by JAX
check_finite: unused by JAX
Returns:
A tuple of arrays ``(H, Q)`` if ``calc_q`` is True, else an array ``H``
- ``H`` has shape ``(..., N, N)`` and is the Hessenberg form of ``a``
- ``Q`` has shape ``(..., N, N)`` and is the associated unitary matrix
Examples:
Computing the Hessenberg form of a 4x4 matrix
>>> a = jnp.array([[1., 2., 3., 4.],
... [1., 4., 2., 3.],
... [3., 2., 1., 4.],
... [2., 3., 2., 2.]])
>>> H, Q = jax.scipy.linalg.hessenberg(a, calc_q=True)
>>> with jnp.printoptions(suppress=True, precision=3):
... print(H)
[[ 1. -5.078 1.167 1.361]
[-3.742 5.786 -3.613 -1.825]
[ 0. -2.992 2.493 -0.577]
[ 0. 0. -0.043 -1.279]]
Notice the zeros in the subdiagonal positions. The original matrix
can be reconstructed using the ``Q`` vectors:
>>> a_reconstructed = Q @ H @ Q.conj().T
>>> jnp.allclose(a_reconstructed, a)
Array(True, dtype=bool)
"""
del overwrite_a, check_finite # unused
n = jnp.shape(a)[-1]
if n == 0:
if calc_q:
return jnp.zeros_like(a), jnp.zeros_like(a)
else:
return jnp.zeros_like(a)
a_out, taus = lax_linalg.hessenberg(a)
h = jnp.triu(a_out, -1)
if calc_q:
q = lax_linalg.householder_product(a_out[..., 1:, :-1], taus)
batch_dims = a_out.shape[:-2]
q = jnp.block([[jnp.ones(batch_dims + (1, 1), dtype=a_out.dtype),
jnp.zeros(batch_dims + (1, n - 1), dtype=a_out.dtype)],
[jnp.zeros(batch_dims + (n - 1, 1), dtype=a_out.dtype), q]])
return h, q
else:
return h
def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array:
r"""Construct a Toeplitz matrix.
JAX implementation of :func:`scipy.linalg.toeplitz`.
A Toeplitz matrix has equal diagonals: :math:`A_{ij} = k_{i - j}`
for :math:`0 \le i < n` and :math:`0 \le j < n`. This function
specifies the diagonals via the first column ``c`` and the first row
``r``, such that for row `i` and column `j`:
.. math::
A_{ij} = \begin{cases}
c_{i - j} & i \ge j \\
r_{j - i} & i < j
\end{cases}
Notice this implies that :math:`r_0` is ignored.
Args:
c: array of shape ``(..., N)`` specifying the first column.
r: (optional) array of shape ``(..., M)`` specifying the first row. Leading
dimensions must be broadcast-compatible with those of ``c``. If not specified,
``r`` defaults to ``conj(c)``.
Returns:
A Toeplitz matrix of shape ``(... N, M)``.
Examples:
Specifying ``c`` only:
>>> c = jnp.array([1, 2, 3])
>>> jax.scipy.linalg.toeplitz(c)
Array([[1, 2, 3],
[2, 1, 2],
[3, 2, 1]], dtype=int32)
Specifying ``c`` and ``r``:
>>> r = jnp.array([-1, -2, -3])
>>> jax.scipy.linalg.toeplitz(c, r) # Note r[0] is ignored
Array([[ 1, -2, -3],
[ 2, 1, -2],
[ 3, 2, 1]], dtype=int32)
If specifying only complex-valued ``c``, ``r`` defaults to ``c.conj()``,
resulting in a Hermitian matrix if ``c[0].imag == 0``:
>>> c = jnp.array([1, 2+1j, 1+2j])
>>> M = jax.scipy.linalg.toeplitz(c)
>>> M
Array([[1.+0.j, 2.-1.j, 1.-2.j],
[2.+1.j, 1.+0.j, 2.-1.j],
[1.+2.j, 2.+1.j, 1.+0.j]], dtype=complex64)
>>> print("M is Hermitian:", jnp.all(M == M.conj().T))
M is Hermitian: True
For N-dimensional ``c`` and/or ``r``, the result is a batch of Toeplitz matrices:
>>> c = jnp.array([[1, 2, 3], [4, 5, 6]])
>>> jax.scipy.linalg.toeplitz(c)
Array([[[1, 2, 3],
[2, 1, 2],
[3, 2, 1]],
<BLANKLINE>
[[4, 5, 6],
[5, 4, 5],
[6, 5, 4]]], dtype=int32)
"""
if r is None:
check_arraylike("toeplitz", c)
r = jnp.conjugate(jnp.asarray(c))
else:
check_arraylike("toeplitz", c, r)
return _toeplitz(jnp.atleast_1d(jnp.asarray(c)), jnp.atleast_1d(jnp.asarray(r)))
@partial(jnp.vectorize, signature="(m),(n)->(m,n)")
def _toeplitz(c: Array, r: Array) -> Array:
ncols, = c.shape
nrows, = r.shape
if ncols == 0 or nrows == 0:
return jnp.empty((ncols, nrows), dtype=jnp.promote_types(c.dtype, r.dtype))
nelems = ncols + nrows - 1
elems = jnp.concatenate((c[::-1], r[1:]))
patches = lax.conv_general_dilated_patches(
elems.reshape((1, nelems, 1)),
(nrows,), (1,), 'VALID', dimension_numbers=('NTC', 'IOT', 'NTC'),
precision=lax.Precision.HIGHEST)[0]
return jnp.flip(patches, axis=0)
@partial(jit, static_argnames=("n",))
def hilbert(n: int) -> Array:
r"""Create a Hilbert matrix of order n.
JAX implementation of :func:`scipy.linalg.hilbert`.
The Hilbert matrix is defined by:
.. math::
H_{ij} = \frac{1}{i + j + 1}
for :math:`1 \le i \le n` and :math:`1 \le j \le n`.
Args:
n: the size of the matrix to create.
Returns:
A Hilbert matrix of shape ``(n, n)``
Examples:
>>> jax.scipy.linalg.hilbert(2)
Array([[1. , 0.5 ],
[0.5 , 0.33333334]], dtype=float32)
>>> jax.scipy.linalg.hilbert(3)
Array([[1. , 0.5 , 0.33333334],
[0.5 , 0.33333334, 0.25 ],
[0.33333334, 0.25 , 0.2 ]], dtype=float32)
"""
a = lax.broadcasted_iota(jnp.float64, (n, 1), 0)
return 1/(a + a.T + 1)