# Copyright 2018 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from functools import partial import numpy as np import textwrap from typing import overload, Any, Literal import jax import jax.numpy as jnp from jax import jit, vmap, jvp from jax import lax from jax._src import dtypes from jax._src.lax import linalg as lax_linalg from jax._src.lax import qdwh from jax._src.numpy.util import ( check_arraylike, promote_dtypes, promote_dtypes_inexact, promote_dtypes_complex) from jax._src.typing import Array, ArrayLike _no_chkfinite_doc = textwrap.dedent(""" Does not support the Scipy argument ``check_finite=True``, because compiled JAX code cannot perform checks of array values at runtime. """) _no_overwrite_and_chkfinite_doc = _no_chkfinite_doc + "\nDoes not support the Scipy argument ``overwrite_*=True``." @partial(jit, static_argnames=('lower',)) def _cholesky(a: ArrayLike, lower: bool) -> Array: a, = promote_dtypes_inexact(jnp.asarray(a)) l = lax_linalg.cholesky(a if lower else jnp.conj(a.mT), symmetrize_input=False) return l if lower else jnp.conj(l.mT) def cholesky(a: ArrayLike, lower: bool = False, overwrite_a: bool = False, check_finite: bool = True) -> Array: """Compute the Cholesky decomposition of a matrix. JAX implementation of :func:`scipy.linalg.cholesky`. The Cholesky decomposition of a matrix `A` is: .. math:: A = U^HU = LL^H where `U` is an upper-triangular matrix and `L` is a lower-triangular matrix. Args: a: input array, representing a (batched) positive-definite hermitian matrix. Must have shape ``(..., N, N)``. lower: if True, compute the lower Cholesky decomposition `L`. if False (default), compute the upper Cholesky decomposition `U`. overwrite_a: unused by JAX check_finite: unused by JAX Returns: array of shape ``(..., N, N)`` representing the cholesky decomposition of the input. See Also: - :func:`jax.numpy.linalg.cholesky`: NumPy-stype Cholesky API - :func:`jax.lax.linalg.cholesky`: XLA-style Cholesky API - :func:`jax.scipy.linalg.cho_factor` - :func:`jax.scipy.linalg.cho_solve` Examples: A small real Hermitian positive-definite matrix: >>> x = jnp.array([[2., 1.], ... [1., 2.]]) Upper Cholesky factorization: >>> jax.scipy.linalg.cholesky(x) Array([[1.4142135 , 0.70710677], [0. , 1.2247449 ]], dtype=float32) Lower Cholesky factorization: >>> jax.scipy.linalg.cholesky(x, lower=True) Array([[1.4142135 , 0. ], [0.70710677, 1.2247449 ]], dtype=float32) Reconstructing ``x`` from its factorization: >>> L = jax.scipy.linalg.cholesky(x, lower=True) >>> jnp.allclose(x, L @ L.T) Array(True, dtype=bool) """ del overwrite_a, check_finite # Unused return _cholesky(a, lower) def cho_factor(a: ArrayLike, lower: bool = False, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, bool]: """Factorization for Cholesky-based linear solves JAX implementation of :func:`scipy.linalg.cho_factor`. This function returns a result suitable for use with :func:`jax.scipy.linalg.cho_solve`. For direct Cholesky decompositions, prefer :func:`jax.scipy.linalg.cholesky`. Args: a: input array, representing a (batched) positive-definite hermitian matrix. Must have shape ``(..., N, N)``. lower: if True, compute the lower triangular Cholesky decomposition (default: False). overwrite_a: unused by JAX check_finite: unused by JAX Returns: ``(c, lower)``: ``c`` is an array of shape ``(..., N, N)`` representing the lower or upper cholesky decomposition of the input; ``lower`` is a boolean specifying whether this is the lower or upper decomposition. See Also: - :func:`jax.scipy.linalg.cholesky` - :func:`jax.scipy.linalg.cho_solve` Examples: A small real Hermitian positive-definite matrix: >>> x = jnp.array([[2., 1.], ... [1., 2.]]) Compute the cholesky factorization via :func:`~jax.scipy.linalg.cho_factor`, and use it to solve a linear equation via :func:`~jax.scipy.linalg.cho_solve`. >>> b = jnp.array([3., 4.]) >>> cfac = jax.scipy.linalg.cho_factor(x) >>> y = jax.scipy.linalg.cho_solve(cfac, b) >>> y Array([0.6666666, 1.6666666], dtype=float32) Check that the result is consistent: >>> jnp.allclose(x @ y, b) Array(True, dtype=bool) """ del overwrite_a, check_finite # Unused return (cholesky(a, lower=lower), lower) @partial(jit, static_argnames=('lower',)) def _cho_solve(c: ArrayLike, b: ArrayLike, lower: bool) -> Array: c, b = promote_dtypes_inexact(jnp.asarray(c), jnp.asarray(b)) lax_linalg._check_solve_shapes(c, b) b = lax_linalg.triangular_solve(c, b, left_side=True, lower=lower, transpose_a=not lower, conjugate_a=not lower) b = lax_linalg.triangular_solve(c, b, left_side=True, lower=lower, transpose_a=lower, conjugate_a=lower) return b def cho_solve(c_and_lower: tuple[ArrayLike, bool], b: ArrayLike, overwrite_b: bool = False, check_finite: bool = True) -> Array: """Solve a linear system using a Cholesky factorization JAX implementation of :func:`scipy.linalg.cho_solve`. Uses the output of :func:`jax.scipy.linalg.cho_factor`. Args: c_and_lower: ``(c, lower)``, where ``c`` is an array of shape ``(..., N, N)`` representing the lower or upper cholesky decomposition of the matrix, and ``lower`` is a boolean specifying whether this is the lower or upper decomposition. b: right-hand-side of linear system. Must have shape ``(..., N)`` overwrite_a: unused by JAX check_finite: unused by JAX Returns: Array of shape ``(..., N)`` representing the solution of the linear system. See Also: - :func:`jax.scipy.linalg.cholesky` - :func:`jax.scipy.linalg.cho_factor` Examples: A small real Hermitian positive-definite matrix: >>> x = jnp.array([[2., 1.], ... [1., 2.]]) Compute the cholesky factorization via :func:`~jax.scipy.linalg.cho_factor`, and use it to solve a linear equation via :func:`~jax.scipy.linalg.cho_solve`. >>> b = jnp.array([3., 4.]) >>> cfac = jax.scipy.linalg.cho_factor(x) >>> y = jax.scipy.linalg.cho_solve(cfac, b) >>> y Array([0.6666666, 1.6666666], dtype=float32) Check that the result is consistent: >>> jnp.allclose(x @ y, b) Array(True, dtype=bool) """ del overwrite_b, check_finite # Unused c, lower = c_and_lower return _cho_solve(c, b, lower) @overload def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: Literal[True]) -> tuple[Array, Array, Array]: ... @overload def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: Literal[False]) -> Array: ... @overload def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Array | tuple[Array, Array, Array]: ... @partial(jit, static_argnames=('full_matrices', 'compute_uv')) def _svd(a: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Array | tuple[Array, Array, Array]: a, = promote_dtypes_inexact(jnp.asarray(a)) return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv) @overload def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: Literal[True] = True, overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') -> tuple[Array, Array, Array]: ... @overload def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[False], overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') -> Array: ... @overload def svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[False], overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') -> Array: ... @overload def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') -> Array | tuple[Array, Array, Array]: ... def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') -> Array | tuple[Array, Array, Array]: r"""Compute the singular value decomposition. JAX implementation of :func:`scipy.linalg.svd`. The SVD of a matrix `A` is given by .. math:: A = U\Sigma V^H - :math:`U` contains the left singular vectors and satisfies :math:`U^HU=I` - :math:`V` contains the right singular vectors and satisfies :math:`V^HV=I` - :math:`\Sigma` is a diagonal matrix of singular values. Args: a: input array, of shape ``(..., N, M)`` full_matrices: if True (default) compute the full matrices; i.e. ``u`` and ``vh`` have shape ``(..., N, N)`` and ``(..., M, M)``. If False, then the shapes are ``(..., N, K)`` and ``(..., K, M)`` with ``K = min(N, M)``. compute_uv: if True (default), return the full SVD ``(u, s, vh)``. If False then return only the singular values ``s``. overwrite_a: unused by JAX check_finite: unused by JAX lapack_driver: unused by JAX Returns: A tuple of arrays ``(u, s, vh)`` if ``compute_uv`` is True, otherwise the array ``s``. - ``u``: left singular vectors of shape ``(..., N, N)`` if ``full_matrices`` is True or ``(..., N, K)`` otherwise. - ``s``: singular values of shape ``(..., K)`` - ``vh``: conjugate-transposed right singular vectors of shape ``(..., M, M)`` if ``full_matrices`` is True or ``(..., K, M)`` otherwise. where ``K = min(N, M)``. See also: - :func:`jax.numpy.linalg.svd`: NumPy-style SVD API - :func:`jax.lax.linalg.svd`: XLA-style SVD API Examples: Consider the SVD of a small real-valued array: >>> x = jnp.array([[1., 2., 3.], ... [6., 5., 4.]]) >>> u, s, vt = jax.scipy.linalg.svd(x, full_matrices=False) >>> s # doctest: +SKIP Array([9.361919 , 1.8315067], dtype=float32) The singular vectors are in the columns of ``u`` and ``v = vt.T``. These vectors are orthonormal, which can be demonstrated by comparing the matrix product with the identity matrix: >>> jnp.allclose(u.T @ u, jnp.eye(2), atol=1E-5) Array(True, dtype=bool) >>> v = vt.T >>> jnp.allclose(v.T @ v, jnp.eye(2), atol=1E-5) Array(True, dtype=bool) Given the SVD, ``x`` can be reconstructed via matrix multiplication: >>> x_reconstructed = u @ jnp.diag(s) @ vt >>> jnp.allclose(x_reconstructed, x) Array(True, dtype=bool) """ del overwrite_a, check_finite, lapack_driver # unused return _svd(a, full_matrices=full_matrices, compute_uv=compute_uv) def det(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Array: """Compute the determinant of a matrix JAX implementation of :func:`scipy.linalg.det`. Args: a: input array, of shape ``(..., N, N)`` overwrite_a: unused by JAX check_finite: unused by JAX Returns Determinant of shape ``a.shape[:-2]`` See Also: :func:`jax.numpy.linalg.det`: NumPy-style determinant API Examples: Determinant of a small 2D array: >>> x = jnp.array([[1., 2.], ... [3., 4.]]) >>> jax.scipy.linalg.det(x) Array(-2., dtype=float32) Batch-wise determinant of multiple 2D arrays: >>> x = jnp.array([[[1., 2.], ... [3., 4.]], ... [[8., 5.], ... [7., 9.]]]) >>> jax.scipy.linalg.det(x) Array([-2., 37.], dtype=float32) """ del overwrite_a, check_finite # unused return jnp.linalg.det(a) @overload def _eigh(a: ArrayLike, b: ArrayLike | None, lower: bool, eigvals_only: Literal[True], eigvals: None, type: int) -> Array: ... @overload def _eigh(a: ArrayLike, b: ArrayLike | None, lower: bool, eigvals_only: Literal[False], eigvals: None, type: int) -> tuple[Array, Array]: ... @overload def _eigh(a: ArrayLike, b: ArrayLike | None, lower: bool, eigvals_only: bool, eigvals: None, type: int) -> Array | tuple[Array, Array]: ... @partial(jit, static_argnames=('lower', 'eigvals_only', 'eigvals', 'type')) def _eigh(a: ArrayLike, b: ArrayLike | None, lower: bool, eigvals_only: bool, eigvals: None, type: int) -> Array | tuple[Array, Array]: if b is not None: raise NotImplementedError("Only the b=None case of eigh is implemented") if type != 1: raise NotImplementedError("Only the type=1 case of eigh is implemented.") if eigvals is not None: raise NotImplementedError( "Only the eigvals=None case of eigh is implemented.") a, = promote_dtypes_inexact(jnp.asarray(a)) v, w = lax_linalg.eigh(a, lower=lower) if eigvals_only: return w else: return w, v @overload def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, eigvals_only: Literal[False] = False, overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) -> tuple[Array, Array]: ... @overload def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, *, eigvals_only: Literal[True], overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) -> Array: ... @overload def eigh(a: ArrayLike, b: ArrayLike | None, lower: bool, eigvals_only: Literal[True], overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) -> Array: ... @overload def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, eigvals_only: bool = False, overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) -> Array | tuple[Array, Array]: ... def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, eigvals_only: bool = False, overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) -> Array | tuple[Array, Array]: """Compute eigenvalues and eigenvectors for a Hermitian matrix JAX implementation of :func:`scipy.linalg.eigh`. Args: a: Hermitian input array of shape ``(..., N, N)`` b: optional Hermitian input of shape ``(..., N, N)``. If specified, compute the generalized eigenvalue problem. lower: if True (default) access only the lower portion of the input matrix. Otherwise access only the upper portion. eigvals_only: If True, compute only the eigenvalues. If False (default) compute both eigenvalues and eigenvectors. type: if ``b`` is specified, ``type`` gives the type of generalized eigenvalue problem to be computed. Denoting ``(λ, v)`` as an eigenvalue, eigenvector pair: - ``type = 1`` solves ``a @ v = λ * b @ v`` (default) - ``type = 2`` solves ``a @ b @ v = λ * v`` - ``type = 3`` solves ``b @ a @ v = λ * v`` eigvals: a ``(low, high)`` tuple specifying which eigenvalues to compute. overwrite_a: unused by JAX. overwrite_b: unused by JAX. turbo: unused by JAX. check_finite: unused by JAX. Returns: A tuple of arrays ``(eigvals, eigvecs)`` if ``eigvals_only`` is False, otherwise an array ``eigvals``. - ``eigvals``: array of shape ``(..., N)`` containing the eigenvalues. - ``eigvecs``: array of shape ``(..., N, N)`` containing the eigenvectors. See also: - :func:`jax.numpy.linalg.eigh`: NumPy-style eigh API. - :func:`jax.lax.linalg.eigh`: XLA-style eigh API. - :func:`jax.numpy.linalg.eig`: non-hermitian eigenvalue problem. - :func:`jax.scipy.linalg.eigh_tridiagonal`: tri-diagonal eigenvalue problem. Examples: Compute the standard eigenvalue decomposition of a simple 2x2 matrix: >>> a = jnp.array([[2., 1.], ... [1., 2.]]) >>> eigvals, eigvecs = jax.scipy.linalg.eigh(a) >>> eigvals Array([1., 3.], dtype=float32) >>> eigvecs Array([[-0.70710677, 0.70710677], [ 0.70710677, 0.70710677]], dtype=float32) Eigenvectors are orthonormal: >>> jnp.allclose(eigvecs.T @ eigvecs, jnp.eye(2), atol=1E-5) Array(True, dtype=bool) Solution satisfies the eigenvalue problem: >>> jnp.allclose(a @ eigvecs, eigvecs @ jnp.diag(eigvals)) Array(True, dtype=bool) """ del overwrite_a, overwrite_b, turbo, check_finite # unused return _eigh(a, b, lower, eigvals_only, eigvals, type) @partial(jit, static_argnames=('output',)) def _schur(a: Array, output: str) -> tuple[Array, Array]: if output == "complex": a = a.astype(dtypes.to_complex_dtype(a.dtype)) return lax_linalg.schur(a) def schur(a: ArrayLike, output: str = 'real') -> tuple[Array, Array]: """Compute the Schur decomposition JAX implementation of :func:`scipy.linalg.schur`. The Schur form `T` of a matrix `A` satisfies: .. math:: A = Z T Z^H where `Z` is unitary, and `T` is upper-triangular for the complex-valued Schur decomposition (i.e. ``output="complex"``) and is quasi-upper-triangular for the real-valued Schur decomposition (i.e. ``output="real"``). In the quasi-triangular case, the diagonal may include 2x2 blocks associated with complex-valued eigenvalue pairs of `A`. Args: a: input array of shape ``(..., N, N)`` output: Specify whether to compute the ``"real"`` (default) or ``"complex"`` Schur decomposition. Returns: A tuple of arrays ``(T, Z)`` - ``T`` is a shape ``(..., N, N)`` array containing the upper-triangular Schur form of the input. - ``Z`` is a shape ``(..., N, N)`` array containing the unitary Schur transformation matrix. See also: - :func:`jax.scipy.linalg.rsf2csf`: convert real Schur form to complex Schur form. - :func:`jax.lax.linalg.schur`: XLA-style API for Schur decomposition. Examples: A Schur decomposition of a 3x3 matrix: >>> a = jnp.array([[1., 2., 3.], ... [1., 4., 2.], ... [3., 2., 1.]]) >>> T, Z = jax.scipy.linalg.schur(a) The Schur form ``T`` is quasi-upper-triangular in general, but is truly upper-triangular in this case because the input matrix is symmetric: >>> T # doctest: +SKIP Array([[-2.0000005 , 0.5066295 , -0.43360388], [ 0. , 1.5505103 , 0.74519426], [ 0. , 0. , 6.449491 ]], dtype=float32) The transformation matrix ``Z`` is unitary: >>> jnp.allclose(Z.T @ Z, jnp.eye(3), atol=1E-5) Array(True, dtype=bool) The input can be reconstructed from the outputs: >>> jnp.allclose(Z @ T @ Z.T, a) Array(True, dtype=bool) """ if output not in ('real', 'complex'): raise ValueError( f"Expected 'output' to be either 'real' or 'complex', got {output=}.") return _schur(a, output) def inv(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Array: """Return the inverse of a square matrix JAX implementation of :func:`scipy.linalg.inv`. Args: a: array of shape ``(..., N, N)`` specifying square array(s) to be inverted. overwrite_a: unused in JAX check_finite: unused in JAX Returns: Array of shape ``(..., N, N)`` containing the inverse of the input. Notes: In most cases, explicitly computing the inverse of a matrix is ill-advised. For example, to compute ``x = inv(A) @ b``, it is more performant and numerically precise to use a direct solve, such as :func:`jax.scipy.linalg.solve`. See Also: - :func:`jax.numpy.linalg.inv`: NumPy-style API for matrix inverse - :func:`jax.scipy.linalg.solve`: direct linear solver Examples: Compute the inverse of a 3x3 matrix >>> a = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) >>> a_inv = jax.scipy.linalg.inv(a) >>> a_inv # doctest: +SKIP Array([[ 0. , -0.25 , 0.5 ], [-0.25 , 0.5 , -0.25000003], [ 0.5 , -0.25 , 0. ]], dtype=float32) Check that multiplying with the inverse gives the identity: >>> jnp.allclose(a @ a_inv, jnp.eye(3), atol=1E-5) Array(True, dtype=bool) Multiply the inverse by a vector ``b``, to find a solution to ``a @ x = b`` >>> b = jnp.array([1., 4., 2.]) >>> a_inv @ b Array([ 0. , 1.25, -0.5 ], dtype=float32) Note, however, that explicitly computing the inverse in such a case can lead to poor performance and loss of precision as the size of the problem grows. Instead, you should use a direct solver like :func:`jax.scipy.linalg.solve`: >>> jax.scipy.linalg.solve(a, b) Array([ 0. , 1.25, -0.5 ], dtype=float32) """ del overwrite_a, check_finite # unused return jnp.linalg.inv(a) @partial(jit, static_argnames=('overwrite_a', 'check_finite')) def lu_factor(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array]: """Factorization for LU-based linear solves JAX implementation of :func:`scipy.linalg.lu_factor`. This function returns a result suitable for use with :func:`jax.scipy.linalg.lu_solve`. For direct LU decompositions, prefer :func:`jax.scipy.linalg.lu`. Args: a: input array of shape ``(..., M, N)``. overwrite_a: unused by JAX check_finite: unused by JAX Returns: A tuple ``(lu, piv)`` - ``lu`` is an array of shape ``(..., M, N)``, containing ``L`` in its lower triangle and ``U`` in its upper. - ``piv`` is an array of shape ``(..., K)`` with ``K = min(M, N)``, which encodes the pivots. See Also: - :func:`jax.scipy.linalg.lu` - :func:`jax.scipy.linalg.lu_solve` Examples: Solving a small linear system via LU factorization: >>> a = jnp.array([[2., 1.], ... [1., 2.]]) Compute the lu factorization via :func:`~jax.scipy.linalg.lu_factor`, and use it to solve a linear equation via :func:`~jax.scipy.linalg.lu_solve`. >>> b = jnp.array([3., 4.]) >>> lufac = jax.scipy.linalg.lu_factor(a) >>> y = jax.scipy.linalg.lu_solve(lufac, b) >>> y Array([0.6666666, 1.6666667], dtype=float32) Check that the result is consistent: >>> jnp.allclose(a @ y, b) Array(True, dtype=bool) """ del overwrite_a, check_finite # unused a, = promote_dtypes_inexact(jnp.asarray(a)) lu, pivots, _ = lax_linalg.lu(a) return lu, pivots @partial(jit, static_argnames=('trans', 'overwrite_b', 'check_finite')) def lu_solve(lu_and_piv: tuple[Array, ArrayLike], b: ArrayLike, trans: int = 0, overwrite_b: bool = False, check_finite: bool = True) -> Array: """Solve a linear system using an LU factorization JAX implementation of :func:`scipy.linalg.lu_solve`. Uses the output of :func:`jax.scipy.linalg.lu_factor`. Args: lu_and_piv: ``(lu, piv)``, output of :func:`~jax.scipy.linalg.lu_factor`. ``lu`` is an array of shape ``(..., M, N)``, containing ``L`` in its lower triangle and ``U`` in its upper. ``piv`` is an array of shape ``(..., K)``, with ``K = min(M, N)``, which encodes the pivots. b: right-hand-side of linear system. Must have shape ``(..., M)`` trans: type of system to solve. Options are: - ``0``: :math:`A x = b` - ``1``: :math:`A^Tx = b` - ``2``: :math:`A^Hx = b` overwrite_b: unused by JAX check_finite: unused by JAX Returns: Array of shape ``(..., N)`` representing the solution of the linear system. See Also: - :func:`jax.scipy.linalg.lu` - :func:`jax.scipy.linalg.lu_factor` Examples: Solving a small linear system via LU factorization: >>> a = jnp.array([[2., 1.], ... [1., 2.]]) Compute the lu factorization via :func:`~jax.scipy.linalg.lu_factor`, and use it to solve a linear equation via :func:`~jax.scipy.linalg.lu_solve`. >>> b = jnp.array([3., 4.]) >>> lufac = jax.scipy.linalg.lu_factor(a) >>> y = jax.scipy.linalg.lu_solve(lufac, b) >>> y Array([0.6666666, 1.6666667], dtype=float32) Check that the result is consistent: >>> jnp.allclose(a @ y, b) Array(True, dtype=bool) """ del overwrite_b, check_finite # unused lu, pivots = lu_and_piv m, _ = lu.shape[-2:] perm = lax_linalg.lu_pivots_to_permutation(pivots, m) return lax_linalg.lu_solve(lu, perm, b, trans) @overload def _lu(a: ArrayLike, permute_l: Literal[True]) -> tuple[Array, Array]: ... @overload def _lu(a: ArrayLike, permute_l: Literal[False]) -> tuple[Array, Array, Array]: ... @overload def _lu(a: ArrayLike, permute_l: bool) -> tuple[Array, Array] | tuple[Array, Array, Array]: ... @partial(jit, static_argnums=(1,)) def _lu(a: ArrayLike, permute_l: bool) -> tuple[Array, Array] | tuple[Array, Array, Array]: a, = promote_dtypes_inexact(jnp.asarray(a)) lu, _, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) m, n = jnp.shape(a) p = jnp.real(jnp.array(permutation[None, :] == jnp.arange(m, dtype=permutation.dtype)[:, None], dtype=dtype)) k = min(m, n) l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype) u = jnp.triu(lu)[:k, :] if permute_l: return jnp.matmul(p, l, precision=lax.Precision.HIGHEST), u else: return p, l, u @overload def lu(a: ArrayLike, permute_l: Literal[False] = False, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array, Array]: ... @overload def lu(a: ArrayLike, permute_l: Literal[True], overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array]: ... @overload def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array] | tuple[Array, Array, Array]: ... @partial(jit, static_argnames=('permute_l', 'overwrite_a', 'check_finite')) def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array] | tuple[Array, Array, Array]: """Compute the LU decomposition JAX implementation of :func:`scipy.linalg.lu`. The LU decomposition of a matrix `A` is: .. math:: A = P L U where `P` is a permutation matrix, `L` is lower-triangular and `U` is upper-triangular. Args: a: array of shape ``(..., M, N)`` to decompose. permute_l: if True, then permute ``L`` and return ``(P @ L, U)`` (default: False) overwrite_a: not used by JAX check_finite: not used by JAX Returns: A tuple of arrays ``(P @ L, U)`` if ``permute_l`` is True, else ``(P, L, U)``: - ``P`` is a permutation matrix of shape ``(..., M, M)`` - ``L`` is a lower-triangular matrix of shape ``(... M, K)`` - ``U`` is an upper-triangular matrix of shape ``(..., K, N)`` with ``K = min(M, N)`` See also: - :func:`jax.numpy.linalg.lu`: NumPy-style API for LU decomposition. - :func:`jax.lax.linalg.lu`: XLA-style API for LU decomposition. - :func:`jax.scipy.linalg.lu_solve`: LU-based linear solver. Examples: An LU decomposition of a 3x3 matrix: >>> a = jnp.array([[1., 2., 3.], ... [5., 4., 2.], ... [3., 2., 1.]]) >>> P, L, U = jax.scipy.linalg.lu(a) ``P`` is a permutation matrix: i.e. each row and column has a single ``1``: >>> P Array([[0., 1., 0.], [1., 0., 0.], [0., 0., 1.]], dtype=float32) ``L`` and ``U`` are lower-triangular and upper-triangular matrices: >>> with jnp.printoptions(precision=3): ... print(L) ... print(U) [[ 1. 0. 0. ] [ 0.2 1. 0. ] [ 0.6 -0.333 1. ]] [[5. 4. 2. ] [0. 1.2 2.6 ] [0. 0. 0.667]] The original matrix can be reconstructed by multiplying the three together: >>> a_reconstructed = P @ L @ U >>> jnp.allclose(a, a_reconstructed) Array(True, dtype=bool) """ del overwrite_a, check_finite # unused return _lu(a, permute_l) @overload def _qr(a: ArrayLike, mode: Literal["r"], pivoting: Literal[False] ) -> tuple[Array]: ... @overload def _qr(a: ArrayLike, mode: Literal["r"], pivoting: Literal[True] ) -> tuple[Array, Array]: ... @overload def _qr(a: ArrayLike, mode: Literal["full", "economic"], pivoting: Literal[False] ) -> tuple[Array, Array]: ... @overload def _qr(a: ArrayLike, mode: Literal["full", "economic"], pivoting: Literal[True] ) -> tuple[Array, Array, Array]: ... @overload def _qr(a: ArrayLike, mode: str, pivoting: Literal[False] ) -> tuple[Array] | tuple[Array, Array]: ... @overload def _qr(a: ArrayLike, mode: str, pivoting: Literal[True] ) -> tuple[Array, Array] | tuple[Array, Array, Array]: ... @overload def _qr(a: ArrayLike, mode: str, pivoting: bool ) -> tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]: ... @partial(jit, static_argnames=('mode', 'pivoting')) def _qr(a: ArrayLike, mode: str, pivoting: bool ) -> tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]: if mode in ("full", "r"): full_matrices = True elif mode == "economic": full_matrices = False else: raise ValueError(f"Unsupported QR decomposition mode '{mode}'") a, = promote_dtypes_inexact(jnp.asarray(a)) q, r, *p = lax_linalg.qr(a, pivoting=pivoting, full_matrices=full_matrices) if mode == "r": if pivoting: return r, p[0] return (r,) if pivoting: return q, r, p[0] return q, r @overload def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal["full", "economic"], pivoting: Literal[False] = False, check_finite: bool = True) -> tuple[Array, Array]: ... @overload def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal["full", "economic"], pivoting: Literal[True] = True, check_finite: bool = True) -> tuple[Array, Array, Array]: ... @overload def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal["full", "economic"], pivoting: bool = False, check_finite: bool = True ) -> tuple[Array, Array] | tuple[Array, Array, Array]: ... @overload def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal["r"], pivoting: Literal[False] = False, check_finite: bool = True ) -> tuple[Array]: ... @overload def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal["r"], pivoting: Literal[True] = True, check_finite: bool = True ) -> tuple[Array, Array]: ... @overload def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal["r"], pivoting: bool = False, check_finite: bool = True ) -> tuple[Array] | tuple[Array, Array]: ... @overload def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full", pivoting: bool = False, check_finite: bool = True ) -> tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]: ... def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full", pivoting: bool = False, check_finite: bool = True ) -> tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]: """Compute the QR decomposition of an array JAX implementation of :func:`scipy.linalg.qr`. The QR decomposition of a matrix `A` is given by .. math:: A = QR Where `Q` is a unitary matrix (i.e. :math:`Q^HQ=I`) and `R` is an upper-triangular matrix. Args: a: array of shape (..., M, N) mode: Computational mode. Supported values are: - ``"full"`` (default): return `Q` of shape ``(M, M)`` and `R` of shape ``(M, N)``. - ``"r"``: return only `R` - ``"economic"``: return `Q` of shape ``(M, K)`` and `R` of shape ``(K, N)``, where K = min(M, N). pivoting: Allows the QR decomposition to be rank-revealing. If ``True``, compute the column-pivoted decomposition ``A[:, P] = Q @ R``, where ``P`` is chosen such that the diagonal of ``R`` is non-increasing. overwrite_a: unused in JAX lwork: unused in JAX check_finite: unused in JAX Returns: A tuple ``(Q, R)`` or ``(Q, R, P)``, if ``mode`` is not ``"r"`` and ``pivoting`` is respectively ``False`` or ``True``, otherwise an array ``R`` or tuple ``(R, P)`` if mode is ``"r"``, and ``pivoting`` is respectively ``False`` or ``True``, where: - ``Q`` is an orthogonal matrix of shape ``(..., M, M)`` (if ``mode`` is ``"full"``) or ``(..., M, K)`` (if ``mode`` is ``"economic"``), - ``R`` is an upper-triangular matrix of shape ``(..., M, N)`` (if ``mode`` is ``"r"`` or ``"full"``) or ``(..., K, N)`` (if ``mode`` is ``"economic"``), - ``P`` is an index vector of shape ``(..., N)``. with ``K = min(M, N)``. Notes: - At present, pivoting is only implemented on the CPU and GPU backends. For further details about the GPU implementation, see the documentation for :func:`jax.lax.linalg.qr`. See also: - :func:`jax.numpy.linalg.qr`: NumPy-style QR decomposition API - :func:`jax.lax.linalg.qr`: XLA-style QR decomposition API Examples: Compute the QR decomposition of a matrix: >>> a = jnp.array([[1., 2., 3., 4.], ... [5., 4., 2., 1.], ... [6., 3., 1., 5.]]) >>> Q, R = jax.scipy.linalg.qr(a) >>> Q # doctest: +SKIP Array([[-0.12700021, -0.7581426 , -0.6396022 ], [-0.63500065, -0.43322435, 0.63960224], [-0.7620008 , 0.48737738, -0.42640156]], dtype=float32) >>> R # doctest: +SKIP Array([[-7.8740077, -5.080005 , -2.4130025, -4.953006 ], [ 0. , -1.7870499, -2.6534991, -1.028908 ], [ 0. , 0. , -1.0660033, -4.050814 ]], dtype=float32) Check that ``Q`` is orthonormal: >>> jnp.allclose(Q.T @ Q, jnp.eye(3), atol=1E-5) Array(True, dtype=bool) Reconstruct the input: >>> jnp.allclose(Q @ R, a) Array(True, dtype=bool) """ del overwrite_a, lwork, check_finite # unused return _qr(a, mode, pivoting) @partial(jit, static_argnames=('assume_a', 'lower')) def _solve(a: ArrayLike, b: ArrayLike, assume_a: str, lower: bool) -> Array: if assume_a != 'pos': return jnp.linalg.solve(a, b) a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b)) lax_linalg._check_solve_shapes(a, b) # With custom_linear_solve, we can reuse the same factorization when # computing sensitivities. This is considerably faster. factors = cho_factor(lax.stop_gradient(a), lower=lower) custom_solve = partial( lax.custom_linear_solve, lambda x: lax_linalg._broadcasted_matvec(a, x), solve=lambda _, x: cho_solve(factors, x), symmetric=True) if a.ndim == b.ndim + 1: # b.shape == [..., m] return custom_solve(b) else: # b.shape == [..., m, k] return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b) def solve(a: ArrayLike, b: ArrayLike, lower: bool = False, overwrite_a: bool = False, overwrite_b: bool = False, debug: bool = False, check_finite: bool = True, assume_a: str = 'gen') -> Array: """Solve a linear system of equations. JAX implementation of :func:`scipy.linalg.solve`. This solves a (batched) linear system of equations ``a @ x = b`` for ``x`` given ``a`` and ``b``. If ``a`` is singular, this will return ``nan`` or ``inf`` values. Args: a: array of shape ``(..., N, N)``. b: array of shape ``(..., N)`` or ``(..., N, M)`` lower: Referenced only if ``assume_a != 'gen'``. If True, only use the lower triangle of the input, If False (default), only use the upper triangle. assume_a: specify what properties of ``a`` can be assumed. Options are: - ``"gen"``: generic matrix (default) - ``"sym"``: symmetric matrix - ``"her"``: hermitian matrix - ``"pos"``: positive-definite matrix overwrite_a: unused by JAX overwrite_b: unused by JAX debug: unused by JAX check_finite: unused by JAX Returns: An array of the same shape as ``b`` containing the solution to the linear system if ``a`` is non-singular. If ``a`` is singular, the result contains ``nan`` or ``inf`` values. See also: - :func:`jax.scipy.linalg.lu_solve`: Solve via LU factorization. - :func:`jax.scipy.linalg.cho_solve`: Solve via Cholesky factorization. - :func:`jax.scipy.linalg.solve_triangular`: Solve a triangular system. - :func:`jax.numpy.linalg.solve`: NumPy-style API for solving linear systems. - :func:`jax.lax.custom_linear_solve`: matrix-free linear solver. Examples: A simple 3x3 linear system: >>> A = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) >>> b = jnp.array([14., 16., 10.]) >>> x = jax.scipy.linalg.solve(A, b) >>> x Array([1., 2., 3.], dtype=float32) Confirming that the result solves the system: >>> jnp.allclose(A @ x, b) Array(True, dtype=bool) """ del overwrite_a, overwrite_b, debug, check_finite #unused valid_assume_a = ['gen', 'sym', 'her', 'pos'] if assume_a not in valid_assume_a: raise ValueError(f"Expected assume_a to be one of {valid_assume_a}; got {assume_a!r}") return _solve(a, b, assume_a, lower) @partial(jit, static_argnames=('trans', 'lower', 'unit_diagonal')) def _solve_triangular(a: ArrayLike, b: ArrayLike, trans: int | str, lower: bool, unit_diagonal: bool) -> Array: if trans == 0 or trans == "N": transpose_a, conjugate_a = False, False elif trans == 1 or trans == "T": transpose_a, conjugate_a = True, False elif trans == 2 or trans == "C": transpose_a, conjugate_a = True, True else: raise ValueError(f"Invalid 'trans' value {trans}") a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b)) # lax_linalg.triangular_solve only supports matrix 'b's at the moment. b_is_vector = jnp.ndim(a) == jnp.ndim(b) + 1 if b_is_vector: b = b[..., None] out = lax_linalg.triangular_solve(a, b, left_side=True, lower=lower, transpose_a=transpose_a, conjugate_a=conjugate_a, unit_diagonal=unit_diagonal) if b_is_vector: return out[..., 0] else: return out def solve_triangular(a: ArrayLike, b: ArrayLike, trans: int | str = 0, lower: bool = False, unit_diagonal: bool = False, overwrite_b: bool = False, debug: Any = None, check_finite: bool = True) -> Array: """Solve a triangular linear system of equations JAX implementation of :func:`scipy.linalg.solve_triangular`. This solves a (batched) linear system of equations ``a @ x = b`` for ``x`` given a triangular matrix ``a`` and a vector or matrix ``b``. Args: a: array of shape ``(..., N, N)``. Only part of the array will be accessed, depending on the ``lower`` and ``unit_diagonal`` arguments. b: array of shape ``(..., N)`` or ``(..., N, M)`` lower: If True, only use the lower triangle of the input, If False (default), only use the upper triangle. unit_diagonal: If True, ignore diagonal elements of ``a`` and assume they are ``1`` (default: False). trans: specify what properties of ``a`` can be assumed. Options are: - ``0`` or ``'N'``: solve :math:`Ax=b` - ``1`` or ``'T'``: solve :math:`A^Tx=b` - ``2`` or ``'C'``: solve :math:`A^Hx=b` overwrite_b: unused by JAX debug: unused by JAX check_finite: unused by JAX Returns: An array of the same shape as ``b`` containing the solution to the linear system. See also: :func:`jax.scipy.linalg.solve`: Solve a general linear system. Examples: A simple 3x3 triangular linear system: >>> A = jnp.array([[1., 2., 3.], ... [0., 3., 2.], ... [0., 0., 5.]]) >>> b = jnp.array([10., 8., 5.]) >>> x = jax.scipy.linalg.solve_triangular(A, b) >>> x Array([3., 2., 1.], dtype=float32) Confirming that the result solves the system: >>> jnp.allclose(A @ x, b) Array(True, dtype=bool) Computing the transposed problem: >>> x = jax.scipy.linalg.solve_triangular(A, b, trans='T') >>> x Array([10. , -4. , -3.4], dtype=float32) Confirming that the result solves the system: >>> jnp.allclose(A.T @ x, b) Array(True, dtype=bool) """ del overwrite_b, debug, check_finite # unused return _solve_triangular(a, b, trans, lower, unit_diagonal) @partial(jit, static_argnames=('upper_triangular', 'max_squarings')) def expm(A: ArrayLike, *, upper_triangular: bool = False, max_squarings: int = 16) -> Array: """Compute the matrix exponential JAX implementation of :func:`scipy.linalg.expm`. Args: A: array of shape ``(..., N, N)`` upper_triangular: if True, then assume that ``A`` is upper-triangular. Default=False. max_squarings: The number of squarings in the scaling-and-squaring approximation method (default: 16). Returns: An array of shape ``(..., N, N)`` containing the matrix exponent of ``A``. Notes: This uses the scaling-and-squaring approximation method, with computational complexity controlled by the optional ``max_squarings`` argument. Theoretically, the number of required squarings is ``max(0, ceil(log2(norm(A))) - c)`` where ``norm(A)`` is the L1 norm and ``c=2.42`` for float64/complex128, or ``c=1.97`` for float32/complex64. See Also: :func:`jax.scipy.linalg.expm_frechet` Examples: ``expm`` is the matrix exponential, and has similar properties to the more familiar scalar exponential. For scalars ``a`` and ``b``, :math:`e^{a + b} = e^a e^b`. However, for matrices, this property only holds when ``A`` and ``B`` commute (``AB = BA``). In this case, ``expm(A+B) = expm(A) @ expm(B)`` >>> A = jnp.array([[2, 0], ... [0, 1]]) >>> B = jnp.array([[3, 0], ... [0, 4]]) >>> jnp.allclose(jax.scipy.linalg.expm(A+B), ... jax.scipy.linalg.expm(A) @ jax.scipy.linalg.expm(B), ... rtol=0.0001) Array(True, dtype=bool) If a matrix ``X`` is invertible, then ``expm(X @ A @ inv(X)) = X @ expm(A) @ inv(X)`` >>> X = jnp.array([[3, 1], ... [2, 5]]) >>> X_inv = jax.scipy.linalg.inv(X) >>> jnp.allclose(jax.scipy.linalg.expm(X @ A @ X_inv), ... X @ jax.scipy.linalg.expm(A) @ X_inv) Array(True, dtype=bool) """ A, = promote_dtypes_inexact(A) if A.ndim < 2 or A.shape[-1] != A.shape[-2]: raise ValueError(f"Expected A to be a (batched) square matrix, got {A.shape=}.") if A.ndim > 2: return jnp.vectorize( partial(expm, upper_triangular=upper_triangular, max_squarings=max_squarings), signature="(n,n)->(n,n)")(A) P, Q, n_squarings = _calc_P_Q(jnp.asarray(A)) def _nan(args): A, *_ = args return jnp.full_like(A, jnp.nan) def _compute(args): A, P, Q = args R = _solve_P_Q(P, Q, upper_triangular) R = _squaring(R, n_squarings, max_squarings) return R R = lax.cond(n_squarings > max_squarings, _nan, _compute, (A, P, Q)) return R @jit def _calc_P_Q(A: Array) -> tuple[Array, Array, Array]: if A.ndim != 2 or A.shape[0] != A.shape[1]: raise ValueError('expected A to be a square matrix') A_L1 = jnp.linalg.norm(A,1) n_squarings: Array U: Array V: Array if A.dtype == 'float64' or A.dtype == 'complex128': maxnorm = 5.371920351148152 n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm))) A = A / 2 ** n_squarings.astype(A.dtype) conds = jnp.array([1.495585217958292e-002, 2.539398330063230e-001, 9.504178996162932e-001, 2.097847961257068e+000], dtype=A_L1.dtype) idx = jnp.digitize(A_L1, conds) U, V = lax.switch(idx, [_pade3, _pade5, _pade7, _pade9, _pade13], A) elif A.dtype == 'float32' or A.dtype == 'complex64': maxnorm = 3.925724783138660 n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm))) A = A / 2 ** n_squarings.astype(A.dtype) conds = jnp.array([4.258730016922831e-001, 1.880152677804762e+000], dtype=A_L1.dtype) idx = jnp.digitize(A_L1, conds) U, V = lax.switch(idx, [_pade3, _pade5, _pade7], A) else: raise TypeError(f"A.dtype={A.dtype} is not supported.") P = U + V # p_m(A) : numerator Q = -U + V # q_m(A) : denominator return P, Q, n_squarings def _solve_P_Q(P: ArrayLike, Q: ArrayLike, upper_triangular: bool = False) -> Array: if upper_triangular: return solve_triangular(Q, P) else: return jnp.linalg.solve(Q, P) def _precise_dot(A: ArrayLike, B: ArrayLike) -> Array: return jnp.dot(A, B, precision=lax.Precision.HIGHEST) @partial(jit, static_argnums=2) def _squaring(R: Array, n_squarings: Array, max_squarings: int) -> Array: # squaring step to undo scaling def _squaring_precise(x): return _precise_dot(x, x) def _identity(x): return x def _scan_f(c, i): return lax.cond(i < n_squarings, _squaring_precise, _identity, c), None res, _ = lax.scan(_scan_f, R, jnp.arange(max_squarings, dtype=n_squarings.dtype)) return res def _pade3(A: Array) -> tuple[Array, Array]: b = (120., 60., 12., 1.) M, N = A.shape ident = jnp.eye(M, N, dtype=A.dtype) A2 = _precise_dot(A, A) U = _precise_dot(A, (b[3]*A2 + b[1]*ident)) V: Array = b[2]*A2 + b[0]*ident return U, V def _pade5(A: Array) -> tuple[Array, Array]: b = (30240., 15120., 3360., 420., 30., 1.) M, N = A.shape ident = jnp.eye(M, N, dtype=A.dtype) A2 = _precise_dot(A, A) A4 = _precise_dot(A2, A2) U = _precise_dot(A, b[5]*A4 + b[3]*A2 + b[1]*ident) V: Array = b[4]*A4 + b[2]*A2 + b[0]*ident return U, V def _pade7(A: Array) -> tuple[Array, Array]: b = (17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.) M, N = A.shape ident = jnp.eye(M, N, dtype=A.dtype) A2 = _precise_dot(A, A) A4 = _precise_dot(A2, A2) A6 = _precise_dot(A4, A2) U = _precise_dot(A, b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident) V = b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident return U,V def _pade9(A: Array) -> tuple[Array, Array]: b = (17643225600., 8821612800., 2075673600., 302702400., 30270240., 2162160., 110880., 3960., 90., 1.) M, N = A.shape ident = jnp.eye(M, N, dtype=A.dtype) A2 = _precise_dot(A, A) A4 = _precise_dot(A2, A2) A6 = _precise_dot(A4, A2) A8 = _precise_dot(A6, A2) U = _precise_dot(A, b[9]*A8 + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident) V = b[8]*A8 + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident return U,V def _pade13(A: Array) -> tuple[Array, Array]: b = (64764752532480000., 32382376266240000., 7771770303897600., 1187353796428800., 129060195264000., 10559470521600., 670442572800., 33522128640., 1323241920., 40840800., 960960., 16380., 182., 1.) M, N = A.shape ident = jnp.eye(M, N, dtype=A.dtype) A2 = _precise_dot(A, A) A4 = _precise_dot(A2, A2) A6 = _precise_dot(A4, A2) U = _precise_dot(A, _precise_dot(A6, b[13]*A6 + b[11]*A4 + b[9]*A2) + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident) V = _precise_dot(A6, b[12]*A6 + b[10]*A4 + b[8]*A2) + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident return U,V @overload def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: Literal[True] = True) -> tuple[Array, Array]: ... @overload def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: Literal[False]) -> Array: ... @overload def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: bool = True) -> Array | tuple[Array, Array]: ... @partial(jit, static_argnames=('method', 'compute_expm')) def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: bool = True) -> Array | tuple[Array, Array]: """Compute the Frechet derivative of the matrix exponential. JAX implementation of :func:`scipy.linalg.expm_frechet` Args: A: array of shape ``(..., N, N)`` E: array of shape ``(..., N, N)``; specifies the direction of the derivative. compute_expm: if True (default) then compute and return ``expm(A)``. method: ignored by JAX Returns: A tuple ``(expm_A, expm_frechet_AE)`` if ``compute_expm`` is True, else the array ``expm_frechet_AE``. Both returned arrays have shape ``(..., N, N)``. See also: :func:`jax.scipy.linalg.expm` Examples: We can use this API to compute the matrix exponential of ``A``, as well as its derivative in the direction ``E``: >>> key1, key2 = jax.random.split(jax.random.key(3372)) >>> A = jax.random.normal(key1, (3, 3)) >>> E = jax.random.normal(key2, (3, 3)) >>> expmA, expm_frechet_AE = jax.scipy.linalg.expm_frechet(A, E) This can be equivalently computed using JAX's automatic differentiation methods; here we'll compute the derivative of :func:`~jax.scipy.linalg.expm` in the direction of ``E`` using :func:`jax.jvp`, and find the same results: >>> expmA2, expm_frechet_AE2 = jax.jvp(jax.scipy.linalg.expm, (A,), (E,)) >>> jnp.allclose(expmA, expmA2) Array(True, dtype=bool) >>> jnp.allclose(expm_frechet_AE, expm_frechet_AE2) Array(True, dtype=bool) """ del method # unused A_arr = jnp.asarray(A) E_arr = jnp.asarray(E) if A_arr.ndim < 2 or A_arr.shape[-2] != A_arr.shape[1]: raise ValueError(f'expected A to be a (batched) square matrix, got A.shape={A_arr.shape}') if E_arr.ndim < 2 or E_arr.shape[-2] != E_arr.shape[-1]: raise ValueError(f'expected E to be a (batched) square matrix, got E.shape={E_arr.shape}') if A_arr.shape != E_arr.shape: raise ValueError('expected A and E to be the same shape, got ' f'A.shape={A_arr.shape} E.shape={E_arr.shape}') bound_fun = partial(expm, upper_triangular=False, max_squarings=16) expm_A, expm_frechet_AE = jvp(bound_fun, (A_arr,), (E_arr,)) if compute_expm: return expm_A, expm_frechet_AE else: return expm_frechet_AE @jit def block_diag(*arrs: ArrayLike) -> Array: """Create a block diagonal matrix from input arrays. JAX implementation of :func:`scipy.linalg.block_diag`. Args: *arrs: arrays of at most two dimensions Returns: 2D block-diagonal array constructed by placing the input arrays along the diagonal. Examples: >>> A = jnp.ones((1, 1)) >>> B = jnp.ones((2, 2)) >>> C = jnp.ones((3, 3)) >>> jax.scipy.linalg.block_diag(A, B, C) Array([[1., 0., 0., 0., 0., 0.], [0., 1., 1., 0., 0., 0.], [0., 1., 1., 0., 0., 0.], [0., 0., 0., 1., 1., 1.], [0., 0., 0., 1., 1., 1.], [0., 0., 0., 1., 1., 1.]], dtype=float32) """ if len(arrs) == 0: arrs = (jnp.zeros((1, 0)),) arrs = tuple(promote_dtypes(*arrs)) bad_shapes = [i for i, a in enumerate(arrs) if jnp.ndim(a) > 2] if bad_shapes: raise ValueError("Arguments to jax.scipy.linalg.block_diag must have at " "most 2 dimensions, got {} at argument {}." .format(arrs[bad_shapes[0]], bad_shapes[0])) converted_arrs = [jnp.atleast_2d(a) for a in arrs] acc = converted_arrs[0] dtype = lax.dtype(acc) for a in converted_arrs[1:]: _, c = a.shape a = lax.pad(a, dtype.type(0), ((0, 0, 0), (acc.shape[-1], 0, 0))) acc = lax.pad(acc, dtype.type(0), ((0, 0, 0), (0, c, 0))) acc = lax.concatenate([acc, a], dimension=0) return acc @partial(jit, static_argnames=("eigvals_only", "select", "select_range")) def eigh_tridiagonal(d: ArrayLike, e: ArrayLike, *, eigvals_only: bool = False, select: str = 'a', select_range: tuple[float, float] | None = None, tol: float | None = None) -> Array: """Solve the eigenvalue problem for a symmetric real tridiagonal matrix JAX implementation of :func:`scipy.linalg.eigh_tridiagonal`. Args: d: real-valued array of shape ``(N,)`` specifying the diagonal elements. e: real-valued array of shape ``(N - 1,)`` specifying the off-diagonal elements. eigvals_only: If True, return only the eigenvalues (default: False). Computation of eigenvectors is not yet implemented, so ``eigvals_only`` must be set to True. select: specify which eigenvalues to calculate. Supported values are: - ``'a'``: all eigenvalues - ``'i'``: eigenvalues with indices ``select_range[0] <= i <= select_range[1]`` JAX does not currently implement ``select = 'v'``. select_range: range of values used when ``select='i'``. tol: absolute tolerance to use when solving for the eigenvalues. Returns: An array of eigenvalues with shape ``(N,)``. See also: :func:`jax.scipy.linalg.eigh`: general Hermitian eigenvalue solver Examples: >>> d = jnp.array([1., 2., 3., 4.]) >>> e = jnp.array([1., 1., 1.]) >>> eigvals = jax.scipy.linalg.eigh_tridiagonal(d, e, eigvals_only=True) >>> eigvals Array([0.2547188, 1.8227171, 3.1772828, 4.745281 ], dtype=float32) For comparison, we can construct the full matrix and compute the same result using :func:`~jax.scipy.linalg.eigh`: >>> A = jnp.diag(d) + jnp.diag(e, 1) + jnp.diag(e, -1) >>> A Array([[1., 1., 0., 0.], [1., 2., 1., 0.], [0., 1., 3., 1.], [0., 0., 1., 4.]], dtype=float32) >>> eigvals_full = jax.scipy.linalg.eigh(A, eigvals_only=True) >>> jnp.allclose(eigvals, eigvals_full) Array(True, dtype=bool) """ if not eigvals_only: raise NotImplementedError("Calculation of eigenvectors is not implemented") def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x): """Implements the Sturm sequence recurrence.""" n = alpha.shape[0] zeros = jnp.zeros(x.shape, dtype=jnp.int32) ones = jnp.ones(x.shape, dtype=jnp.int32) # The first step in the Sturm sequence recurrence # requires special care if x is equal to alpha[0]. def sturm_step0(): q = alpha[0] - x count = jnp.where(q < 0, ones, zeros) q = jnp.where(alpha[0] == x, alpha0_perturbation, q) return q, count # Subsequent steps all take this form: def sturm_step(i, q, count): q = alpha[i] - beta_sq[i - 1] / q - x count = jnp.where(q <= pivmin, count + 1, count) q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q) return q, count # The first step initializes q and count. q, count = sturm_step0() # Peel off ((n-1) % blocksize) steps from the main loop, so we can run # the bulk of the iterations unrolled by a factor of blocksize. blocksize = 16 i = 1 peel = (n - 1) % blocksize unroll_cnt = peel def unrolled_steps(args): start, q, count = args for j in range(unroll_cnt): q, count = sturm_step(start + j, q, count) return start + unroll_cnt, q, count i, q, count = unrolled_steps((i, q, count)) # Run the remaining steps of the Sturm sequence using a partially # unrolled while loop. unroll_cnt = blocksize def cond(iqc): i, q, count = iqc return jnp.less(i, n) _, _, count = lax.while_loop(cond, unrolled_steps, (i, q, count)) return count alpha = jnp.asarray(d) beta = jnp.asarray(e) supported_dtypes = (jnp.float32, jnp.float64, jnp.complex64, jnp.complex128) if alpha.dtype != beta.dtype: raise TypeError("diagonal and off-diagonal values must have same dtype, " f"got {alpha.dtype} and {beta.dtype}") if alpha.dtype not in supported_dtypes or beta.dtype not in supported_dtypes: raise TypeError("Only float32 and float64 inputs are supported as inputs " "to jax.scipy.linalg.eigh_tridiagonal, got " f"{alpha.dtype} and {beta.dtype}") n = alpha.shape[0] if n <= 1: return jnp.real(alpha) if jnp.issubdtype(alpha.dtype, np.complexfloating): alpha = jnp.real(alpha) beta_sq = jnp.real(beta * jnp.conj(beta)) beta_abs = jnp.sqrt(beta_sq) else: beta_abs = jnp.abs(beta) beta_sq = jnp.square(beta) # Estimate the largest and smallest eigenvalues of T using the Gershgorin # circle theorem. off_diag_abs_row_sum = jnp.concatenate( [beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta_abs[-1:]], axis=0) lambda_est_max = jnp.amax(alpha + off_diag_abs_row_sum) lambda_est_min = jnp.amin(alpha - off_diag_abs_row_sum) # Upper bound on 2-norm of T. t_norm = jnp.maximum(jnp.abs(lambda_est_min), jnp.abs(lambda_est_max)) # Compute the smallest allowed pivot in the Sturm sequence to avoid # overflow. finfo = np.finfo(alpha.dtype) one = np.ones([], dtype=alpha.dtype) safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny) pivmin = safemin * jnp.maximum(1, jnp.amax(beta_sq)) alpha0_perturbation = jnp.square(finfo.eps * beta_abs[0]) abs_tol = finfo.eps * t_norm if tol is not None: abs_tol = jnp.maximum(tol, abs_tol) # In the worst case, when the absolute tolerance is eps*lambda_est_max and # lambda_est_max = -lambda_est_min, we have to take as many bisection steps # as there are bits in the mantissa plus 1. # The proof is left as an exercise to the reader. max_it = finfo.nmant + 1 # Determine the indices of the desired eigenvalues, based on select and # select_range. if select == 'a': target_counts = jnp.arange(n, dtype=jnp.int32) elif select == 'i': if select_range is None: raise ValueError("for select='i', select_range must be specified.") if select_range[0] > select_range[1]: raise ValueError('Got empty index range in select_range.') target_counts = jnp.arange(select_range[0], select_range[1] + 1, dtype=jnp.int32) elif select == 'v': # TODO(phawkins): requires dynamic shape support. raise NotImplementedError("eigh_tridiagonal(..., select='v') is not " "implemented") else: raise ValueError("'select must have a value in {'a', 'i', 'v'}.") # Run binary search for all desired eigenvalues in parallel, starting from # the interval lightly wider than the estimated # [lambda_est_min, lambda_est_max]. fudge = 2.1 # We widen starting interval the Gershgorin interval a bit. norm_slack = jnp.array(n, alpha.dtype) * fudge * finfo.eps * t_norm lower = lambda_est_min - norm_slack - 2 * fudge * pivmin upper = lambda_est_max + norm_slack + fudge * pivmin # Pre-broadcast the scalars used in the Sturm sequence for improved # performance. target_shape = jnp.shape(target_counts) lower = jnp.broadcast_to(lower, shape=target_shape) upper = jnp.broadcast_to(upper, shape=target_shape) mid = 0.5 * (upper + lower) pivmin = jnp.broadcast_to(pivmin, target_shape) alpha0_perturbation = jnp.broadcast_to(alpha0_perturbation, target_shape) # Start parallel binary searches. def cond(args): i, lower, _, upper = args return jnp.logical_and( jnp.less(i, max_it), jnp.less(abs_tol, jnp.amax(upper - lower))) def body(args): i, lower, mid, upper = args counts = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, mid) lower = jnp.where(counts <= target_counts, mid, lower) upper = jnp.where(counts > target_counts, mid, upper) mid = 0.5 * (lower + upper) return i + 1, lower, mid, upper _, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper)) return mid @partial(jit, static_argnames=('side', 'method')) @jax.default_matmul_precision("float32") def polar(a: ArrayLike, side: str = 'right', *, method: str = 'qdwh', eps: float | None = None, max_iterations: int | None = None) -> tuple[Array, Array]: r"""Computes the polar decomposition. Given the :math:`m \times n` matrix :math:`a`, returns the factors of the polar decomposition :math:`u` (also :math:`m \times n`) and :math:`p` such that :math:`a = up` (if side is ``"right"``; :math:`p` is :math:`n \times n`) or :math:`a = pu` (if side is ``"left"``; :math:`p` is :math:`m \times m`), where :math:`p` is positive semidefinite. If :math:`a` is nonsingular, :math:`p` is positive definite and the decomposition is unique. :math:`u` has orthonormal columns unless :math:`n > m`, in which case it has orthonormal rows. Writing the SVD of :math:`a` as :math:`a = u_\mathit{svd} \cdot s_\mathit{svd} \cdot v^h_\mathit{svd}`, we have :math:`u = u_\mathit{svd} \cdot v^h_\mathit{svd}`. Thus the unitary factor :math:`u` can be constructed as the application of the sign function to the singular values of :math:`a`; or, if :math:`a` is Hermitian, the eigenvalues. Several methods exist to compute the polar decomposition. Currently two are supported: * ``method="svd"``: Computes the SVD of :math:`a` and then forms :math:`u = u_\mathit{svd} \cdot v^h_\mathit{svd}`. * ``method="qdwh"``: Applies the `QDWH`_ (QR-based Dynamically Weighted Halley) algorithm. Args: a: The :math:`m \times n` input matrix. side: Determines whether a right or left polar decomposition is computed. If ``side`` is ``"right"`` then :math:`a = up`. If ``side`` is ``"left"`` then :math:`a = pu`. The default is ``"right"``. method: Determines the algorithm used, as described above. precision: :class:`~jax.lax.Precision` object specifying the matmul precision. eps: The final result will satisfy :math:`\left|x_k - x_{k-1}\right| < \left|x_k\right| (4\epsilon)^{\frac{1}{3}}`, where :math:`x_k` are the QDWH iterates. Ignored if ``method`` is not ``"qdwh"``. max_iterations: Iterations will terminate after this many steps even if the above is unsatisfied. Ignored if ``method`` is not ``"qdwh"``. Returns: A ``(unitary, posdef)`` tuple, where ``unitary`` is the unitary factor (:math:`m \times n`), and ``posdef`` is the positive-semidefinite factor. ``posdef`` is either :math:`n \times n` or :math:`m \times m` depending on whether ``side`` is ``"right"`` or ``"left"``, respectively. Examples: Polar decomposition of a 3x3 matrix: >>> a = jnp.array([[1., 2., 3.], ... [5., 4., 2.], ... [3., 2., 1.]]) >>> U, P = jax.scipy.linalg.polar(a) U is a Unitary Matrix: >>> jnp.round(U.T @ U) # doctest: +SKIP Array([[ 1., -0., -0.], [-0., 1., 0.], [-0., 0., 1.]], dtype=float32) P is positive-semidefinite Matrix: >>> with jnp.printoptions(precision=2, suppress=True): ... print(P) [[4.79 3.25 1.23] [3.25 3.06 2.01] [1.23 2.01 2.91]] The original matrix can be reconstructed by multiplying the U and P: >>> a_reconstructed = U @ P >>> jnp.allclose(a, a_reconstructed) Array(True, dtype=bool) .. _QDWH: https://epubs.siam.org/doi/abs/10.1137/090774999 """ arr = jnp.asarray(a) if arr.ndim != 2: raise ValueError("The input `a` must be a 2-D array.") if side not in ["right", "left"]: raise ValueError("The argument `side` must be either 'right' or 'left'.") m, n = arr.shape if method == "qdwh": # TODO(phawkins): return info also if the user opts in? if m >= n and side == "right": unitary, posdef, _, _ = qdwh.qdwh(arr, is_hermitian=False, eps=eps) elif m < n and side == "left": arr = arr.T.conj() unitary, posdef, _, _ = qdwh.qdwh(arr, is_hermitian=False, eps=eps) posdef = posdef.T.conj() unitary = unitary.T.conj() else: raise NotImplementedError("method='qdwh' only supports mxn matrices " "where m < n where side='right' and m >= n " f"side='left', got {arr.shape} with {side=}") elif method == "svd": u_svd, s_svd, vh_svd = lax_linalg.svd(arr, full_matrices=False) s_svd = s_svd.astype(u_svd.dtype) unitary = u_svd @ vh_svd if side == "right": # a = u * p posdef = (vh_svd.T.conj() * s_svd[None, :]) @ vh_svd else: # a = p * u posdef = (u_svd * s_svd[None, :]) @ (u_svd.T.conj()) else: raise ValueError(f"Unknown polar decomposition method {method}.") return unitary, posdef @jit def _sqrtm_triu(T: Array) -> Array: """ Implements Björck, Å., & Hammarling, S. (1983). "A Schur method for the square root of a matrix". Linear algebra and its applications", 52, 127-140. """ diag = jnp.sqrt(jnp.diag(T)) n = diag.size U = jnp.diag(diag) def i_loop(l, data): j, U = data i = j - 1 - l s = lax.fori_loop(i + 1, j, lambda k, val: val + U[i, k] * U[k, j], 0.0) value = jnp.where(T[i, j] == s, 0.0, (T[i, j] - s) / (diag[i] + diag[j])) return j, U.at[i, j].set(value) def j_loop(j, U): _, U = lax.fori_loop(0, j, i_loop, (j, U)) return U U = lax.fori_loop(0, n, j_loop, U) return U @jit def _sqrtm(A: ArrayLike) -> Array: T, Z = schur(A, output='complex') sqrt_T = _sqrtm_triu(T) return jnp.matmul(jnp.matmul(Z, sqrt_T, precision=lax.Precision.HIGHEST), jnp.conj(Z.T), precision=lax.Precision.HIGHEST) def sqrtm(A: ArrayLike, blocksize: int = 1) -> Array: """Compute the matrix square root JAX implementation of :func:`scipy.linalg.sqrtm`. Args: A: array of shape ``(N, N)`` blocksize: Not supported in JAX; JAX always uses ``blocksize=1``. Returns: An array of shape ``(N, N)`` containing the matrix square root of ``A`` See Also: :func:`jax.scipy.linalg.expm` Examples: >>> a = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) >>> sqrt_a = jax.scipy.linalg.sqrtm(a) >>> with jnp.printoptions(precision=2, suppress=True): ... print(sqrt_a) [[0.92+0.71j 0.54+0.j 0.92-0.71j] [0.54+0.j 1.85+0.j 0.54-0.j ] [0.92-0.71j 0.54-0.j 0.92+0.71j]] By definition, matrix multiplication of the matrix square root with itself should equal the input: >>> jnp.allclose(a, sqrt_a @ sqrt_a) Array(True, dtype=bool) Notes: This function implements the complex Schur method described in [1]_. It does not use recursive blocking to speed up computations as a Sylvester Equation solver is not yet available in JAX. References: .. [1] Björck, Å., & Hammarling, S. (1983). "A Schur method for the square root of a matrix". Linear algebra and its applications, 52, 127-140. """ if blocksize > 1: raise NotImplementedError("Blocked version is not implemented yet.") return _sqrtm(A) @partial(jit, static_argnames=('check_finite',)) def rsf2csf(T: ArrayLike, Z: ArrayLike, check_finite: bool = True) -> tuple[Array, Array]: """Convert real Schur form to complex Schur form. JAX implementation of :func:`scipy.linalg.rsf2csf`. Args: T: array of shape ``(..., N, N)`` containing the real Schur form of the input. Z: array of shape ``(..., N, N)`` containing the corresponding Schur transformation matrix. check_finite: unused by JAX Returns: A tuple of arrays ``(T, Z)`` of the same shape as the inputs, containing the Complex Schur form and the associated Schur transformation matrix. See Also: :func:`jax.scipy.linalg.schur`: Schur decomposition Examples: >>> A = jnp.array([[0., 3., 3.], ... [0., 1., 2.], ... [2., 0., 1.]]) >>> Tr, Zr = jax.scipy.linalg.schur(A) >>> Tc, Zc = jax.scipy.linalg.rsf2csf(Tr, Zr) Both the real and complex form can be used to reconstruct the input matrix to float32 precision: >>> jnp.allclose(Zr @ Tr @ Zr.T, A, atol=1E-5) Array(True, dtype=bool) >>> jnp.allclose(Zc @ Tc @ Zc.conj().T, A, atol=1E-5) Array(True, dtype=bool) The real-valued Schur form is only quasi-upper-triangular, as we can see in this case: >>> with jax.numpy.printoptions(precision=2, suppress=True): ... print(Tr) [[ 3.76 -2.17 1.38] [ 0. -0.88 -0.35] [ 0. 2.37 -0.88]] By contrast, the complex form is truly upper-triangular: >>> with jnp.printoptions(precision=2, suppress=True): ... print(Tc) [[ 3.76+0.j 1.29-0.78j 2.02-0.5j ] [ 0. +0.j -0.88+0.91j -2.02+0.j ] [ 0. +0.j 0. +0.j -0.88-0.91j]] """ del check_finite # unused T_arr = jnp.asarray(T) Z_arr = jnp.asarray(Z) if T_arr.ndim != 2 or T_arr.shape[0] != T_arr.shape[1]: raise ValueError("Input 'T' must be square.") if Z_arr.ndim != 2 or Z_arr.shape[0] != Z_arr.shape[1]: raise ValueError("Input 'Z' must be square.") if T_arr.shape[0] != Z_arr.shape[0]: raise ValueError(f"Input array shapes must match: Z: {Z_arr.shape} vs. T: {T_arr.shape}") T_arr, Z_arr = promote_dtypes_complex(T_arr, Z_arr) eps = jnp.finfo(T_arr.dtype).eps N = T_arr.shape[0] if N == 1: return T_arr, Z_arr def _update_T_Z(m, T, Z): mu = jnp.linalg.eigvals(lax.dynamic_slice(T, (m-1, m-1), (2, 2))) - T[m, m] r = jnp.linalg.norm(jnp.array([mu[0], T[m, m-1]])).astype(T.dtype) c = mu[0] / r s = T[m, m-1] / r G = jnp.array([[c.conj(), s], [-s, c]], dtype=T.dtype) # T[m-1:m+1, m-1:] = G @ T[m-1:m+1, m-1:] T_rows = lax.dynamic_slice_in_dim(T, m-1, 2, axis=0) col_mask = jnp.arange(N) >= m-1 G_dot_T_zeroed_cols = G @ jnp.where(col_mask, T_rows, 0) T_rows_new = jnp.where(~col_mask, T_rows, G_dot_T_zeroed_cols) T = lax.dynamic_update_slice_in_dim(T, T_rows_new, m-1, axis=0) # T[:m+1, m-1:m+1] = T[:m+1, m-1:m+1] @ G.conj().T T_cols = lax.dynamic_slice_in_dim(T, m-1, 2, axis=1) row_mask = jnp.arange(N)[:, jnp.newaxis] < m+1 T_zeroed_rows_dot_GH = jnp.where(row_mask, T_cols, 0) @ G.conj().T T_cols_new = jnp.where(~row_mask, T_cols, T_zeroed_rows_dot_GH) T = lax.dynamic_update_slice_in_dim(T, T_cols_new, m-1, axis=1) # Z[:, m-1:m+1] = Z[:, m-1:m+1] @ G.conj().T Z_cols = lax.dynamic_slice_in_dim(Z, m-1, 2, axis=1) Z = lax.dynamic_update_slice_in_dim(Z, Z_cols @ G.conj().T, m-1, axis=1) return T, Z def _rsf2scf_iter(i, TZ): m = N-i T, Z = TZ T, Z = lax.cond( jnp.abs(T[m, m-1]) > eps*(jnp.abs(T[m-1, m-1]) + jnp.abs(T[m, m])), _update_T_Z, lambda m, T, Z: (T, Z), m, T, Z) T = T.at[m, m-1].set(0.0) return T, Z return lax.fori_loop(1, N, _rsf2scf_iter, (T_arr, Z_arr)) @overload def hessenberg(a: ArrayLike, *, calc_q: Literal[False], overwrite_a: bool = False, check_finite: bool = True) -> Array: ... @overload def hessenberg(a: ArrayLike, *, calc_q: Literal[True], overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array]: ... @partial(jit, static_argnames=('calc_q', 'check_finite', 'overwrite_a')) def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False, check_finite: bool = True) -> Array | tuple[Array, Array]: """Compute the Hessenberg form of the matrix JAX implementation of :func:`scipy.linalg.hessenberg`. The Hessenberg form `H` of a matrix `A` satisfies: .. math:: A = Q H Q^H where `Q` is unitary and `H` is zero below the first subdiagonal. Args: a : array of shape ``(..., N, N)`` calc_q: if True, calculate the ``Q`` matrix (default: False) overwrite_a: unused by JAX check_finite: unused by JAX Returns: A tuple of arrays ``(H, Q)`` if ``calc_q`` is True, else an array ``H`` - ``H`` has shape ``(..., N, N)`` and is the Hessenberg form of ``a`` - ``Q`` has shape ``(..., N, N)`` and is the associated unitary matrix Examples: Computing the Hessenberg form of a 4x4 matrix >>> a = jnp.array([[1., 2., 3., 4.], ... [1., 4., 2., 3.], ... [3., 2., 1., 4.], ... [2., 3., 2., 2.]]) >>> H, Q = jax.scipy.linalg.hessenberg(a, calc_q=True) >>> with jnp.printoptions(suppress=True, precision=3): ... print(H) [[ 1. -5.078 1.167 1.361] [-3.742 5.786 -3.613 -1.825] [ 0. -2.992 2.493 -0.577] [ 0. 0. -0.043 -1.279]] Notice the zeros in the subdiagonal positions. The original matrix can be reconstructed using the ``Q`` vectors: >>> a_reconstructed = Q @ H @ Q.conj().T >>> jnp.allclose(a_reconstructed, a) Array(True, dtype=bool) """ del overwrite_a, check_finite # unused n = jnp.shape(a)[-1] if n == 0: if calc_q: return jnp.zeros_like(a), jnp.zeros_like(a) else: return jnp.zeros_like(a) a_out, taus = lax_linalg.hessenberg(a) h = jnp.triu(a_out, -1) if calc_q: q = lax_linalg.householder_product(a_out[..., 1:, :-1], taus) batch_dims = a_out.shape[:-2] q = jnp.block([[jnp.ones(batch_dims + (1, 1), dtype=a_out.dtype), jnp.zeros(batch_dims + (1, n - 1), dtype=a_out.dtype)], [jnp.zeros(batch_dims + (n - 1, 1), dtype=a_out.dtype), q]]) return h, q else: return h def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array: r"""Construct a Toeplitz matrix. JAX implementation of :func:`scipy.linalg.toeplitz`. A Toeplitz matrix has equal diagonals: :math:`A_{ij} = k_{i - j}` for :math:`0 \le i < n` and :math:`0 \le j < n`. This function specifies the diagonals via the first column ``c`` and the first row ``r``, such that for row `i` and column `j`: .. math:: A_{ij} = \begin{cases} c_{i - j} & i \ge j \\ r_{j - i} & i < j \end{cases} Notice this implies that :math:`r_0` is ignored. Args: c: array of shape ``(..., N)`` specifying the first column. r: (optional) array of shape ``(..., M)`` specifying the first row. Leading dimensions must be broadcast-compatible with those of ``c``. If not specified, ``r`` defaults to ``conj(c)``. Returns: A Toeplitz matrix of shape ``(... N, M)``. Examples: Specifying ``c`` only: >>> c = jnp.array([1, 2, 3]) >>> jax.scipy.linalg.toeplitz(c) Array([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=int32) Specifying ``c`` and ``r``: >>> r = jnp.array([-1, -2, -3]) >>> jax.scipy.linalg.toeplitz(c, r) # Note r[0] is ignored Array([[ 1, -2, -3], [ 2, 1, -2], [ 3, 2, 1]], dtype=int32) If specifying only complex-valued ``c``, ``r`` defaults to ``c.conj()``, resulting in a Hermitian matrix if ``c[0].imag == 0``: >>> c = jnp.array([1, 2+1j, 1+2j]) >>> M = jax.scipy.linalg.toeplitz(c) >>> M Array([[1.+0.j, 2.-1.j, 1.-2.j], [2.+1.j, 1.+0.j, 2.-1.j], [1.+2.j, 2.+1.j, 1.+0.j]], dtype=complex64) >>> print("M is Hermitian:", jnp.all(M == M.conj().T)) M is Hermitian: True For N-dimensional ``c`` and/or ``r``, the result is a batch of Toeplitz matrices: >>> c = jnp.array([[1, 2, 3], [4, 5, 6]]) >>> jax.scipy.linalg.toeplitz(c) Array([[[1, 2, 3], [2, 1, 2], [3, 2, 1]], [[4, 5, 6], [5, 4, 5], [6, 5, 4]]], dtype=int32) """ if r is None: check_arraylike("toeplitz", c) r = jnp.conjugate(jnp.asarray(c)) else: check_arraylike("toeplitz", c, r) return _toeplitz(jnp.atleast_1d(jnp.asarray(c)), jnp.atleast_1d(jnp.asarray(r))) @partial(jnp.vectorize, signature="(m),(n)->(m,n)") def _toeplitz(c: Array, r: Array) -> Array: ncols, = c.shape nrows, = r.shape if ncols == 0 or nrows == 0: return jnp.empty((ncols, nrows), dtype=jnp.promote_types(c.dtype, r.dtype)) nelems = ncols + nrows - 1 elems = jnp.concatenate((c[::-1], r[1:])) patches = lax.conv_general_dilated_patches( elems.reshape((1, nelems, 1)), (nrows,), (1,), 'VALID', dimension_numbers=('NTC', 'IOT', 'NTC'), precision=lax.Precision.HIGHEST)[0] return jnp.flip(patches, axis=0) @partial(jit, static_argnames=("n",)) def hilbert(n: int) -> Array: r"""Create a Hilbert matrix of order n. JAX implementation of :func:`scipy.linalg.hilbert`. The Hilbert matrix is defined by: .. math:: H_{ij} = \frac{1}{i + j + 1} for :math:`1 \le i \le n` and :math:`1 \le j \le n`. Args: n: the size of the matrix to create. Returns: A Hilbert matrix of shape ``(n, n)`` Examples: >>> jax.scipy.linalg.hilbert(2) Array([[1. , 0.5 ], [0.5 , 0.33333334]], dtype=float32) >>> jax.scipy.linalg.hilbert(3) Array([[1. , 0.5 , 0.33333334], [0.5 , 0.33333334, 0.25 ], [0.33333334, 0.25 , 0.2 ]], dtype=float32) """ a = lax.broadcasted_iota(jnp.float64, (n, 1), 0) return 1/(a + a.T + 1)