mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
2223 lines
76 KiB
Python
2223 lines
76 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Sequence
|
|
from functools import partial
|
|
import itertools
|
|
import math
|
|
|
|
import numpy as np
|
|
import operator
|
|
from typing import Literal, NamedTuple, overload
|
|
|
|
import jax
|
|
from jax import jit, custom_jvp
|
|
from jax import lax
|
|
|
|
from jax._src import deprecations
|
|
from jax._src.lax import lax as lax_internal
|
|
from jax._src.lax.lax import PrecisionLike
|
|
from jax._src.lax import linalg as lax_linalg
|
|
from jax._src.numpy import lax_numpy as jnp
|
|
from jax._src.numpy import reductions, ufuncs
|
|
from jax._src.numpy.util import promote_dtypes_inexact, check_arraylike
|
|
from jax._src.util import canonicalize_axis, set_module
|
|
from jax._src.typing import ArrayLike, Array, DTypeLike, DeprecatedArg
|
|
|
|
|
|
export = set_module('jax.numpy.linalg')
|
|
|
|
|
|
class EighResult(NamedTuple):
|
|
eigenvalues: jax.Array
|
|
eigenvectors: jax.Array
|
|
|
|
|
|
class QRResult(NamedTuple):
|
|
Q: jax.Array
|
|
R: jax.Array
|
|
|
|
|
|
class SlogdetResult(NamedTuple):
|
|
sign: jax.Array
|
|
logabsdet: jax.Array
|
|
|
|
|
|
class SVDResult(NamedTuple):
|
|
U: jax.Array
|
|
S: jax.Array
|
|
Vh: jax.Array
|
|
|
|
|
|
def _H(x: ArrayLike) -> Array:
|
|
return ufuncs.conjugate(jnp.matrix_transpose(x))
|
|
|
|
|
|
def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
|
|
|
|
|
|
@export
|
|
@partial(jit, static_argnames=['upper'])
|
|
def cholesky(a: ArrayLike, *, upper: bool = False) -> Array:
|
|
"""Compute the Cholesky decomposition of a matrix.
|
|
|
|
JAX implementation of :func:`numpy.linalg.cholesky`.
|
|
|
|
The Cholesky decomposition of a matrix `A` is:
|
|
|
|
.. math::
|
|
|
|
A = U^HU
|
|
|
|
or
|
|
|
|
.. math::
|
|
|
|
A = LL^H
|
|
|
|
where `U` is an upper-triangular matrix and `L` is a lower-triangular matrix, and
|
|
:math:`X^H` is the Hermitian transpose of `X`.
|
|
|
|
Args:
|
|
a: input array, representing a (batched) positive-definite hermitian matrix.
|
|
Must have shape ``(..., N, N)``.
|
|
upper: if True, compute the upper Cholesky decomposition `U`. if False
|
|
(default), compute the lower Cholesky decomposition `L`.
|
|
|
|
Returns:
|
|
array of shape ``(..., N, N)`` representing the Cholesky decomposition
|
|
of the input. If the input is not Hermitian positive-definite, The result
|
|
will contain NaN entries.
|
|
|
|
|
|
See also:
|
|
- :func:`jax.scipy.linalg.cholesky`: SciPy-style Cholesky API
|
|
- :func:`jax.lax.linalg.cholesky`: XLA-style Cholesky API
|
|
|
|
Examples:
|
|
A small real Hermitian positive-definite matrix:
|
|
|
|
>>> x = jnp.array([[2., 1.],
|
|
... [1., 2.]])
|
|
|
|
Lower Cholesky factorization:
|
|
|
|
>>> jnp.linalg.cholesky(x)
|
|
Array([[1.4142135 , 0. ],
|
|
[0.70710677, 1.2247449 ]], dtype=float32)
|
|
|
|
Upper Cholesky factorization:
|
|
|
|
>>> jnp.linalg.cholesky(x, upper=True)
|
|
Array([[1.4142135 , 0.70710677],
|
|
[0. , 1.2247449 ]], dtype=float32)
|
|
|
|
Reconstructing ``x`` from its factorization:
|
|
|
|
>>> L = jnp.linalg.cholesky(x)
|
|
>>> jnp.allclose(x, L @ L.T)
|
|
Array(True, dtype=bool)
|
|
"""
|
|
check_arraylike("jnp.linalg.cholesky", a)
|
|
a, = promote_dtypes_inexact(jnp.asarray(a))
|
|
L = lax_linalg.cholesky(a)
|
|
return L.mT.conj() if upper else L
|
|
|
|
|
|
@overload
|
|
def svd(
|
|
a: ArrayLike,
|
|
full_matrices: bool = True,
|
|
*,
|
|
compute_uv: Literal[True],
|
|
hermitian: bool = False,
|
|
subset_by_index: tuple[int, int] | None = None,
|
|
) -> SVDResult:
|
|
...
|
|
|
|
|
|
@overload
|
|
def svd(
|
|
a: ArrayLike,
|
|
full_matrices: bool,
|
|
compute_uv: Literal[True],
|
|
hermitian: bool = False,
|
|
subset_by_index: tuple[int, int] | None = None,
|
|
) -> SVDResult:
|
|
...
|
|
|
|
|
|
@overload
|
|
def svd(
|
|
a: ArrayLike,
|
|
full_matrices: bool = True,
|
|
*,
|
|
compute_uv: Literal[False],
|
|
hermitian: bool = False,
|
|
subset_by_index: tuple[int, int] | None = None,
|
|
) -> Array:
|
|
...
|
|
|
|
|
|
@overload
|
|
def svd(
|
|
a: ArrayLike,
|
|
full_matrices: bool,
|
|
compute_uv: Literal[False],
|
|
hermitian: bool = False,
|
|
subset_by_index: tuple[int, int] | None = None,
|
|
) -> Array:
|
|
...
|
|
|
|
|
|
@overload
|
|
def svd(
|
|
a: ArrayLike,
|
|
full_matrices: bool = True,
|
|
compute_uv: bool = True,
|
|
hermitian: bool = False,
|
|
subset_by_index: tuple[int, int] | None = None,
|
|
) -> Array | SVDResult:
|
|
...
|
|
|
|
|
|
@export
|
|
@partial(
|
|
jit,
|
|
static_argnames=(
|
|
"full_matrices",
|
|
"compute_uv",
|
|
"hermitian",
|
|
"subset_by_index",
|
|
),
|
|
)
|
|
def svd(
|
|
a: ArrayLike,
|
|
full_matrices: bool = True,
|
|
compute_uv: bool = True,
|
|
hermitian: bool = False,
|
|
subset_by_index: tuple[int, int] | None = None,
|
|
) -> Array | SVDResult:
|
|
r"""Compute the singular value decomposition.
|
|
|
|
JAX implementation of :func:`numpy.linalg.svd`, implemented in terms of
|
|
:func:`jax.lax.linalg.svd`.
|
|
|
|
The SVD of a matrix `A` is given by
|
|
|
|
.. math::
|
|
|
|
A = U\Sigma V^H
|
|
|
|
- :math:`U` contains the left singular vectors and satisfies :math:`U^HU=I`
|
|
- :math:`V` contains the right singular vectors and satisfies :math:`V^HV=I`
|
|
- :math:`\Sigma` is a diagonal matrix of singular values.
|
|
|
|
Args:
|
|
a: input array, of shape ``(..., N, M)``
|
|
full_matrices: if True (default) compute the full matrices; i.e. ``u`` and ``vh`` have
|
|
shape ``(..., N, N)`` and ``(..., M, M)``. If False, then the shapes are
|
|
``(..., N, K)`` and ``(..., K, M)`` with ``K = min(N, M)``.
|
|
compute_uv: if True (default), return the full SVD ``(u, s, vh)``. If False then return
|
|
only the singular values ``s``.
|
|
hermitian: if True, assume the matrix is hermitian, which allows for a more efficient
|
|
implementation (default=False)
|
|
subset_by_index: (TPU-only) Optional 2-tuple [start, end] indicating the range of
|
|
indices of singular values to compute. For example, if ``[n-2, n]`` then
|
|
``svd`` computes the two largest singular values and their singular vectors.
|
|
Only compatible with ``full_matrices=False``.
|
|
|
|
Returns:
|
|
A tuple of arrays ``(u, s, vh)`` if ``compute_uv`` is True, otherwise the array ``s``.
|
|
|
|
- ``u``: left singular vectors of shape ``(..., N, N)`` if ``full_matrices`` is True
|
|
or ``(..., N, K)`` otherwise.
|
|
- ``s``: singular values of shape ``(..., K)``
|
|
- ``vh``: conjugate-transposed right singular vectors of shape ``(..., M, M)``
|
|
if ``full_matrices`` is True or ``(..., K, M)`` otherwise.
|
|
|
|
where ``K = min(N, M)``.
|
|
|
|
See also:
|
|
- :func:`jax.scipy.linalg.svd`: SciPy-style SVD API
|
|
- :func:`jax.lax.linalg.svd`: XLA-style SVD API
|
|
|
|
Examples:
|
|
Consider the SVD of a small real-valued array:
|
|
|
|
>>> x = jnp.array([[1., 2., 3.],
|
|
... [6., 5., 4.]])
|
|
>>> u, s, vt = jnp.linalg.svd(x, full_matrices=False)
|
|
>>> s # doctest: +SKIP
|
|
Array([9.361919 , 1.8315067], dtype=float32)
|
|
|
|
The singular vectors are in the columns of ``u`` and ``v = vt.T``. These vectors are
|
|
orthonormal, which can be demonstrated by comparing the matrix product with the
|
|
identity matrix:
|
|
|
|
>>> jnp.allclose(u.T @ u, jnp.eye(2), atol=1E-5)
|
|
Array(True, dtype=bool)
|
|
>>> v = vt.T
|
|
>>> jnp.allclose(v.T @ v, jnp.eye(2), atol=1E-5)
|
|
Array(True, dtype=bool)
|
|
|
|
Given the SVD, ``x`` can be reconstructed via matrix multiplication:
|
|
|
|
>>> x_reconstructed = u @ jnp.diag(s) @ vt
|
|
>>> jnp.allclose(x_reconstructed, x)
|
|
Array(True, dtype=bool)
|
|
"""
|
|
check_arraylike("jnp.linalg.svd", a)
|
|
a, = promote_dtypes_inexact(jnp.asarray(a))
|
|
if hermitian:
|
|
w, v = lax_linalg.eigh(a, subset_by_index=subset_by_index)
|
|
s = lax.abs(v)
|
|
if compute_uv:
|
|
sign = lax.sign(v)
|
|
idxs = lax.broadcasted_iota(np.int64, s.shape, dimension=s.ndim - 1)
|
|
s, idxs, sign = lax.sort((s, idxs, sign), dimension=-1, num_keys=1)
|
|
s = lax.rev(s, dimensions=[s.ndim - 1])
|
|
idxs = lax.rev(idxs, dimensions=[s.ndim - 1])
|
|
sign = lax.rev(sign, dimensions=[s.ndim - 1])
|
|
u = jnp.take_along_axis(w, idxs[..., None, :], axis=-1)
|
|
vh = _H(u * sign[..., None, :].astype(u.dtype))
|
|
return SVDResult(u, s, vh)
|
|
else:
|
|
return lax.rev(lax.sort(s, dimension=-1), dimensions=[s.ndim-1])
|
|
|
|
if compute_uv:
|
|
u, s, vh = lax_linalg.svd(
|
|
a,
|
|
full_matrices=full_matrices,
|
|
compute_uv=True,
|
|
subset_by_index=subset_by_index,
|
|
)
|
|
return SVDResult(u, s, vh)
|
|
else:
|
|
return lax_linalg.svd(
|
|
a,
|
|
full_matrices=full_matrices,
|
|
compute_uv=False,
|
|
subset_by_index=subset_by_index,
|
|
)
|
|
|
|
|
|
@export
|
|
@partial(jit, static_argnames=('n',))
|
|
def matrix_power(a: ArrayLike, n: int) -> Array:
|
|
"""Raise a square matrix to an integer power.
|
|
|
|
JAX implementation of :func:`numpy.linalg.matrix_power`, implemented via
|
|
repeated squarings.
|
|
|
|
Args:
|
|
a: array of shape ``(..., M, M)`` to be raised to the power `n`.
|
|
n: the integer exponent to which the matrix should be raised.
|
|
|
|
Returns:
|
|
Array of shape ``(..., M, M)`` containing the matrix power of a to the n.
|
|
|
|
Examples:
|
|
>>> a = jnp.array([[1., 2.],
|
|
... [3., 4.]])
|
|
>>> jnp.linalg.matrix_power(a, 3)
|
|
Array([[ 37., 54.],
|
|
[ 81., 118.]], dtype=float32)
|
|
>>> a @ a @ a # equivalent evaluated directly
|
|
Array([[ 37., 54.],
|
|
[ 81., 118.]], dtype=float32)
|
|
|
|
This also supports zero powers:
|
|
|
|
>>> jnp.linalg.matrix_power(a, 0)
|
|
Array([[1., 0.],
|
|
[0., 1.]], dtype=float32)
|
|
|
|
and also supports negative powers:
|
|
|
|
>>> with jnp.printoptions(precision=3):
|
|
... jnp.linalg.matrix_power(a, -2)
|
|
Array([[ 5.5 , -2.5 ],
|
|
[-3.75, 1.75]], dtype=float32)
|
|
|
|
Negative powers are equivalent to matmul of the inverse:
|
|
|
|
>>> inv_a = jnp.linalg.inv(a)
|
|
>>> with jnp.printoptions(precision=3):
|
|
... inv_a @ inv_a
|
|
Array([[ 5.5 , -2.5 ],
|
|
[-3.75, 1.75]], dtype=float32)
|
|
"""
|
|
check_arraylike("jnp.linalg.matrix_power", a)
|
|
arr, = promote_dtypes_inexact(jnp.asarray(a))
|
|
|
|
if arr.ndim < 2:
|
|
raise TypeError("{}-dimensional array given. Array must be at least "
|
|
"two-dimensional".format(arr.ndim))
|
|
if arr.shape[-2] != arr.shape[-1]:
|
|
raise TypeError("Last 2 dimensions of the array must be square")
|
|
try:
|
|
n = operator.index(n)
|
|
except TypeError as err:
|
|
raise TypeError(f"exponent must be an integer, got {n}") from err
|
|
|
|
if n == 0:
|
|
return jnp.broadcast_to(jnp.eye(arr.shape[-2], dtype=arr.dtype), arr.shape)
|
|
elif n < 0:
|
|
arr = inv(arr)
|
|
n = abs(n)
|
|
|
|
if n == 1:
|
|
return arr
|
|
elif n == 2:
|
|
return arr @ arr
|
|
elif n == 3:
|
|
return (arr @ arr) @ arr
|
|
|
|
z = result = None
|
|
while n > 0:
|
|
z = arr if z is None else (z @ z) # type: ignore[operator]
|
|
n, bit = divmod(n, 2)
|
|
if bit:
|
|
result = z if result is None else (result @ z)
|
|
assert result is not None
|
|
return result
|
|
|
|
|
|
@export
|
|
@jit
|
|
def matrix_rank(
|
|
M: ArrayLike, rtol: ArrayLike | None = None, *,
|
|
tol: ArrayLike | DeprecatedArg | None = DeprecatedArg()) -> Array:
|
|
"""Compute the rank of a matrix.
|
|
|
|
JAX implementation of :func:`numpy.linalg.matrix_rank`.
|
|
|
|
The rank is calculated via the Singular Value Decomposition (SVD), and determined
|
|
by the number of singular values greater than the specified tolerance.
|
|
|
|
Args:
|
|
M: array of shape ``(..., N, K)`` whose rank is to be computed.
|
|
rtol: optional array of shape ``(...)`` specifying the tolerance. Singular values
|
|
smaller than `rtol * largest_singular_value` are considered to be zero. If
|
|
``rtol`` is None (the default), a reasonable default is chosen based the
|
|
floating point precision of the input.
|
|
tol: deprecated alias of the ``rtol`` argument. Will result in a
|
|
:class:`DeprecationWarning` if used.
|
|
|
|
Returns:
|
|
array of shape ``a.shape[-2]`` giving the matrix rank.
|
|
|
|
Notes:
|
|
The rank calculation may be inaccurate for matrices with very small singular
|
|
values or those that are numerically ill-conditioned. Consider adjusting the
|
|
``rtol`` parameter or using a more specialized rank computation method in such cases.
|
|
|
|
Examples:
|
|
>>> a = jnp.array([[1, 2],
|
|
... [3, 4]])
|
|
>>> jnp.linalg.matrix_rank(a)
|
|
Array(2, dtype=int32)
|
|
|
|
>>> b = jnp.array([[1, 0], # Rank-deficient matrix
|
|
... [0, 0]])
|
|
>>> jnp.linalg.matrix_rank(b)
|
|
Array(1, dtype=int32)
|
|
"""
|
|
check_arraylike("jnp.linalg.matrix_rank", M)
|
|
# TODO(micky774): deprecated 2024-5-14, remove after deprecation expires.
|
|
if not isinstance(tol, DeprecatedArg):
|
|
rtol = tol
|
|
del tol
|
|
deprecations.warn(
|
|
"jax-numpy-linalg-matrix_rank-tol",
|
|
("The tol argument for linalg.matrix_rank is deprecated. "
|
|
"Please use rtol instead."),
|
|
stacklevel=2
|
|
)
|
|
M, = promote_dtypes_inexact(jnp.asarray(M))
|
|
if M.ndim < 2:
|
|
return (M != 0).any().astype(jnp.int32)
|
|
S = svd(M, full_matrices=False, compute_uv=False)
|
|
if rtol is None:
|
|
rtol = S.max(-1) * np.max(M.shape[-2:]).astype(S.dtype) * jnp.finfo(S.dtype).eps
|
|
rtol = jnp.expand_dims(rtol, np.ndim(rtol))
|
|
return reductions.sum(S > rtol, axis=-1)
|
|
|
|
|
|
@custom_jvp
|
|
def _slogdet_lu(a: Array) -> tuple[Array, Array]:
|
|
dtype = lax.dtype(a)
|
|
lu, pivot, _ = lax_linalg.lu(a)
|
|
diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
|
|
is_zero = reductions.any(diag == jnp.array(0, dtype=dtype), axis=-1)
|
|
iota = lax.expand_dims(jnp.arange(a.shape[-1], dtype=pivot.dtype),
|
|
range(pivot.ndim - 1))
|
|
parity = reductions.count_nonzero(pivot != iota, axis=-1)
|
|
if jnp.iscomplexobj(a):
|
|
sign = reductions.prod(diag / ufuncs.abs(diag).astype(diag.dtype), axis=-1)
|
|
else:
|
|
sign = jnp.array(1, dtype=dtype)
|
|
parity = parity + reductions.count_nonzero(diag < 0, axis=-1)
|
|
sign = jnp.where(is_zero,
|
|
jnp.array(0, dtype=dtype),
|
|
sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype))
|
|
logdet = jnp.where(
|
|
is_zero, jnp.array(-jnp.inf, dtype=dtype),
|
|
reductions.sum(ufuncs.log(ufuncs.abs(diag)).astype(dtype), axis=-1))
|
|
return sign, ufuncs.real(logdet)
|
|
|
|
@custom_jvp
|
|
def _slogdet_qr(a: Array) -> tuple[Array, Array]:
|
|
# Implementation of slogdet using QR decomposition. One reason we might prefer
|
|
# QR decomposition is that it is more amenable to a fast batched
|
|
# implementation on TPU because of the lack of row pivoting.
|
|
if jnp.issubdtype(lax.dtype(a), jnp.complexfloating):
|
|
raise NotImplementedError("slogdet method='qr' not implemented for complex "
|
|
"inputs")
|
|
n = a.shape[-1]
|
|
a, taus = lax_linalg.geqrf(a)
|
|
# The determinant of a triangular matrix is the product of its diagonal
|
|
# elements. We are working in log space, so we compute the magnitude as the
|
|
# the trace of the log-absolute values, and we compute the sign separately.
|
|
a_diag = jnp.diagonal(a, axis1=-2, axis2=-1)
|
|
log_abs_det = reductions.sum(ufuncs.log(ufuncs.abs(a_diag)), axis=-1)
|
|
sign_diag = reductions.prod(ufuncs.sign(a_diag), axis=-1)
|
|
# The determinant of a Householder reflector is -1. So whenever we actually
|
|
# made a reflection (tau != 0), multiply the result by -1.
|
|
sign_taus = reductions.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype)
|
|
return sign_diag * sign_taus, log_abs_det
|
|
|
|
|
|
@export
|
|
@partial(jit, static_argnames=('method',))
|
|
def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult:
|
|
"""
|
|
Compute the sign and (natural) logarithm of the determinant of an array.
|
|
|
|
JAX implementation of :func:`numpy.linalg.slogdet`.
|
|
|
|
Args:
|
|
a: array of shape ``(..., M, M)`` for which to compute the sign and log determinant.
|
|
method: the method to use for determinant computation. Options are
|
|
|
|
- ``'lu'`` (default): use the LU decomposition.
|
|
- ``'qr'``: use the QR decomposition.
|
|
|
|
Returns:
|
|
A tuple of arrays ``(sign, logabsdet)``, each of shape ``a.shape[:-2]``
|
|
|
|
- ``sign`` is the sign of the determinant.
|
|
- ``logabsdet`` is the natural log of the determinant's absolute value.
|
|
|
|
See also:
|
|
:func:`jax.numpy.linalg.det`: direct computation of determinant
|
|
|
|
Examples:
|
|
>>> a = jnp.array([[1, 2],
|
|
... [3, 4]])
|
|
>>> sign, logabsdet = jnp.linalg.slogdet(a)
|
|
>>> sign # -1 indicates negative determinant
|
|
Array(-1., dtype=float32)
|
|
>>> jnp.exp(logabsdet) # Absolute value of determinant
|
|
Array(2., dtype=float32)
|
|
"""
|
|
check_arraylike("jnp.linalg.slogdet", a)
|
|
a, = promote_dtypes_inexact(jnp.asarray(a))
|
|
a_shape = jnp.shape(a)
|
|
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
|
|
raise ValueError(f"Argument to slogdet() must have shape [..., n, n], got {a_shape}")
|
|
if method is None or method == "lu":
|
|
return SlogdetResult(*_slogdet_lu(a))
|
|
elif method == "qr":
|
|
return SlogdetResult(*_slogdet_qr(a))
|
|
else:
|
|
raise ValueError(f"Unknown slogdet method '{method}'. Supported methods "
|
|
"are 'lu' (`None`), and 'qr'.")
|
|
|
|
def _slogdet_jvp(primals, tangents):
|
|
x, = primals
|
|
g, = tangents
|
|
sign, ans = slogdet(x)
|
|
ans_dot = jnp.trace(solve(x, g), axis1=-1, axis2=-2)
|
|
if jnp.issubdtype(jnp._dtype(x), jnp.complexfloating):
|
|
sign_dot = (ans_dot - ufuncs.real(ans_dot).astype(ans_dot.dtype)) * sign
|
|
ans_dot = ufuncs.real(ans_dot)
|
|
else:
|
|
sign_dot = jnp.zeros_like(sign)
|
|
return (sign, ans), (sign_dot, ans_dot)
|
|
|
|
_slogdet_lu.defjvp(_slogdet_jvp)
|
|
_slogdet_qr.defjvp(_slogdet_jvp)
|
|
|
|
def _cofactor_solve(a: ArrayLike, b: ArrayLike) -> tuple[Array, Array]:
|
|
"""Equivalent to det(a)*solve(a, b) for nonsingular mat.
|
|
|
|
Intermediate function used for jvp and vjp of det.
|
|
This function borrows heavily from jax.numpy.linalg.solve and
|
|
jax.numpy.linalg.slogdet to compute the gradient of the determinant
|
|
in a way that is well defined even for low rank matrices.
|
|
|
|
This function handles two different cases:
|
|
* rank(a) == n or n-1
|
|
* rank(a) < n-1
|
|
|
|
For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix.
|
|
Rather than computing det(a)*solve(a, b), which would return NaN, we work
|
|
directly with the LU decomposition. If a = p @ l @ u, then
|
|
det(a)*solve(a, b) =
|
|
prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b =
|
|
prod(diag(u)) * triangular_solve(u, solve(p @ l, b))
|
|
If a is rank n-1, then the lower right corner of u will be zero and the
|
|
triangular_solve will fail.
|
|
Let x = solve(p @ l, b) and y = det(a)*solve(a, b).
|
|
Then y_{n}
|
|
x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) =
|
|
x_{n} * prod_{i=1...n-1}(u_{ii})
|
|
So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1
|
|
we can avoid the triangular_solve failing.
|
|
To correctly compute the rest of y_{i} for i != n, we simply multiply
|
|
x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1.
|
|
|
|
For the second case, a check is done on the matrix to see if `solve`
|
|
returns NaN or Inf, and gives a matrix of zeros as a result, as the
|
|
gradient of the determinant of a matrix with rank less than n-1 is 0.
|
|
This will still return the correct value for rank n-1 matrices, as the check
|
|
is applied *after* the lower right corner of u has been updated.
|
|
|
|
Args:
|
|
a: A square matrix or batch of matrices, possibly singular.
|
|
b: A matrix, or batch of matrices of the same dimension as a.
|
|
|
|
Returns:
|
|
det(a) and cofactor(a)^T*b, aka adjugate(a)*b
|
|
"""
|
|
a, = promote_dtypes_inexact(jnp.asarray(a))
|
|
b, = promote_dtypes_inexact(jnp.asarray(b))
|
|
a_shape = jnp.shape(a)
|
|
b_shape = jnp.shape(b)
|
|
a_ndims = len(a_shape)
|
|
if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2]
|
|
and b_shape[-2:] == a_shape[-2:]):
|
|
msg = ("The arguments to _cofactor_solve must have shapes "
|
|
"a=[..., m, m] and b=[..., m, m]; got a={} and b={}")
|
|
raise ValueError(msg.format(a_shape, b_shape))
|
|
if a_shape[-1] == 1:
|
|
return a[..., 0, 0], b
|
|
# lu contains u in the upper triangular matrix and l in the strict lower
|
|
# triangular matrix.
|
|
# The diagonal of l is set to ones without loss of generality.
|
|
lu, pivots, permutation = lax_linalg.lu(a)
|
|
dtype = lax.dtype(a)
|
|
batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2])
|
|
x = jnp.broadcast_to(b, batch_dims + b.shape[-2:])
|
|
lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:])
|
|
# Compute (partial) determinant, ignoring last diagonal of LU
|
|
diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
|
|
iota = lax.expand_dims(jnp.arange(a_shape[-1], dtype=pivots.dtype),
|
|
range(pivots.ndim - 1))
|
|
parity = reductions.count_nonzero(pivots != iota, axis=-1)
|
|
sign = jnp.asarray(-2 * (parity % 2) + 1, dtype=dtype)
|
|
# partial_det[:, -1] contains the full determinant and
|
|
# partial_det[:, -2] contains det(u) / u_{nn}.
|
|
partial_det = reductions.cumprod(diag, axis=-1) * sign[..., None]
|
|
lu = lu.at[..., -1, -1].set(1.0 / partial_det[..., -2])
|
|
permutation = jnp.broadcast_to(permutation, (*batch_dims, a_shape[-1]))
|
|
iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in (*batch_dims, 1)))
|
|
# filter out any matrices that are not full rank
|
|
d = jnp.ones(x.shape[:-1], x.dtype)
|
|
d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False)
|
|
d = reductions.any(ufuncs.logical_or(ufuncs.isnan(d), ufuncs.isinf(d)), axis=-1)
|
|
d = jnp.tile(d[..., None, None], d.ndim*(1,) + x.shape[-2:])
|
|
x = jnp.where(d, jnp.zeros_like(x), x) # first filter
|
|
x = x[iotas[:-1] + (permutation, slice(None))]
|
|
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True,
|
|
unit_diagonal=True)
|
|
x = jnp.concatenate((x[..., :-1, :] * partial_det[..., -1, None, None],
|
|
x[..., -1:, :]), axis=-2)
|
|
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False)
|
|
x = jnp.where(d, jnp.zeros_like(x), x) # second filter
|
|
|
|
return partial_det[..., -1], x
|
|
|
|
|
|
def _det_2x2(a: Array) -> Array:
|
|
return (a[..., 0, 0] * a[..., 1, 1] -
|
|
a[..., 0, 1] * a[..., 1, 0])
|
|
|
|
|
|
def _det_3x3(a: Array) -> Array:
|
|
return (a[..., 0, 0] * a[..., 1, 1] * a[..., 2, 2] +
|
|
a[..., 0, 1] * a[..., 1, 2] * a[..., 2, 0] +
|
|
a[..., 0, 2] * a[..., 1, 0] * a[..., 2, 1] -
|
|
a[..., 0, 2] * a[..., 1, 1] * a[..., 2, 0] -
|
|
a[..., 0, 0] * a[..., 1, 2] * a[..., 2, 1] -
|
|
a[..., 0, 1] * a[..., 1, 0] * a[..., 2, 2])
|
|
|
|
|
|
@custom_jvp
|
|
def _det(a):
|
|
sign, logdet = slogdet(a)
|
|
return sign * ufuncs.exp(logdet).astype(sign.dtype)
|
|
|
|
|
|
@_det.defjvp
|
|
def _det_jvp(primals, tangents):
|
|
x, = primals
|
|
g, = tangents
|
|
y, z = _cofactor_solve(x, g)
|
|
return y, jnp.trace(z, axis1=-1, axis2=-2)
|
|
|
|
|
|
@export
|
|
@jit
|
|
def det(a: ArrayLike) -> Array:
|
|
"""
|
|
Compute the determinant of an array.
|
|
|
|
JAX implementation of :func:`numpy.linalg.det`.
|
|
|
|
Args:
|
|
a: array of shape ``(..., M, M)`` for which to compute the determinant.
|
|
|
|
Returns:
|
|
An array of determinants of shape ``a.shape[:-2]``.
|
|
|
|
See also:
|
|
:func:`jax.scipy.linalg.det`: Scipy-style API for determinant.
|
|
|
|
Examples:
|
|
>>> a = jnp.array([[1, 2],
|
|
... [3, 4]])
|
|
>>> jnp.linalg.det(a)
|
|
Array(-2., dtype=float32)
|
|
"""
|
|
check_arraylike("jnp.linalg.det", a)
|
|
a, = promote_dtypes_inexact(jnp.asarray(a))
|
|
a_shape = jnp.shape(a)
|
|
if len(a_shape) >= 2 and a_shape[-1] == 2 and a_shape[-2] == 2:
|
|
return _det_2x2(a)
|
|
elif len(a_shape) >= 2 and a_shape[-1] == 3 and a_shape[-2] == 3:
|
|
return _det_3x3(a)
|
|
elif len(a_shape) >= 2 and a_shape[-1] == a_shape[-2]:
|
|
return _det(a)
|
|
else:
|
|
msg = "Argument to _det() must have shape [..., n, n], got {}"
|
|
raise ValueError(msg.format(a_shape))
|
|
|
|
|
|
@export
|
|
def eig(a: ArrayLike) -> tuple[Array, Array]:
|
|
"""
|
|
Compute the eigenvalues and eigenvectors of a square array.
|
|
|
|
JAX implementation of :func:`numpy.linalg.eig`.
|
|
|
|
Args:
|
|
a: array of shape ``(..., M, M)`` for which to compute the eigenvalues and vectors.
|
|
|
|
Returns:
|
|
A tuple ``(eigenvalues, eigenvectors)`` with
|
|
|
|
- ``eigenvalues``: an array of shape ``(..., M)`` containing the eigenvalues.
|
|
- ``eigenvectors``: an array of shape ``(..., M, M)``, where column ``v[:, i]`` is the
|
|
eigenvector corresponding to the eigenvalue ``w[i]``.
|
|
|
|
Notes:
|
|
- This differs from :func:`numpy.linalg.eig` in that the return type of
|
|
:func:`jax.numpy.linalg.eig` is always complex64 for 32-bit input, and complex128
|
|
for 64-bit input.
|
|
- At present, non-symmetric eigendecomposition is only implemented on the CPU and
|
|
GPU backends. For more details about the GPU implementation, see the
|
|
documentation for :func:`jax.lax.linalg.eig`.
|
|
|
|
See also:
|
|
- :func:`jax.numpy.linalg.eigh`: eigenvectors and eigenvalues of a Hermitian matrix.
|
|
- :func:`jax.numpy.linalg.eigvals`: compute eigenvalues only.
|
|
|
|
Examples:
|
|
>>> a = jnp.array([[1., 2.],
|
|
... [2., 1.]])
|
|
>>> w, v = jnp.linalg.eig(a)
|
|
>>> with jax.numpy.printoptions(precision=4):
|
|
... w
|
|
Array([ 3.+0.j, -1.+0.j], dtype=complex64)
|
|
>>> v
|
|
Array([[ 0.70710677+0.j, -0.70710677+0.j],
|
|
[ 0.70710677+0.j, 0.70710677+0.j]], dtype=complex64)
|
|
"""
|
|
check_arraylike("jnp.linalg.eig", a)
|
|
a, = promote_dtypes_inexact(jnp.asarray(a))
|
|
w, v = lax_linalg.eig(a, compute_left_eigenvectors=False)
|
|
return w, v
|
|
|
|
|
|
@export
|
|
@jit
|
|
def eigvals(a: ArrayLike) -> Array:
|
|
"""
|
|
Compute the eigenvalues of a general matrix.
|
|
|
|
JAX implementation of :func:`numpy.linalg.eigvals`.
|
|
|
|
Args:
|
|
a: array of shape ``(..., M, M)`` for which to compute the eigenvalues.
|
|
|
|
Returns:
|
|
An array of shape ``(..., M)`` containing the eigenvalues.
|
|
|
|
See also:
|
|
- :func:`jax.numpy.linalg.eig`: computes eigenvalues eigenvectors of a general matrix.
|
|
- :func:`jax.numpy.linalg.eigh`: computes eigenvalues eigenvectors of a Hermitian matrix.
|
|
|
|
Notes:
|
|
- This differs from :func:`numpy.linalg.eigvals` in that the return type of
|
|
:func:`jax.numpy.linalg.eigvals` is always complex64 for 32-bit input, and
|
|
complex128 for 64-bit input.
|
|
- At present, non-symmetric eigendecomposition is only implemented on the CPU backend.
|
|
|
|
Examples:
|
|
>>> a = jnp.array([[1., 2.],
|
|
... [2., 1.]])
|
|
>>> w = jnp.linalg.eigvals(a)
|
|
>>> with jnp.printoptions(precision=2):
|
|
... w
|
|
Array([ 3.+0.j, -1.+0.j], dtype=complex64)
|
|
"""
|
|
check_arraylike("jnp.linalg.eigvals", a)
|
|
a, = promote_dtypes_inexact(jnp.asarray(a))
|
|
return lax_linalg.eig(a, compute_left_eigenvectors=False,
|
|
compute_right_eigenvectors=False)[0]
|
|
|
|
|
|
@export
|
|
@partial(jit, static_argnames=('UPLO', 'symmetrize_input'))
|
|
def eigh(a: ArrayLike, UPLO: str | None = None,
|
|
symmetrize_input: bool = True) -> EighResult:
|
|
"""
|
|
Compute the eigenvalues and eigenvectors of a Hermitian matrix.
|
|
|
|
JAX implementation of :func:`numpy.linalg.eigh`.
|
|
|
|
Args:
|
|
a: array of shape ``(..., M, M)``, containing the Hermitian (if complex)
|
|
or symmetric (if real) matrix.
|
|
UPLO: specifies whether the calculation is done with the lower triangular
|
|
part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``).
|
|
symmetrize_input: if True (default) then input is symmetrized, which leads
|
|
to better behavior under automatic differentiation.
|
|
|
|
Returns:
|
|
A namedtuple ``(eigenvalues, eigenvectors)`` where
|
|
|
|
- ``eigenvalues``: an array of shape ``(..., M)`` containing the eigenvalues,
|
|
sorted in ascending order.
|
|
- ``eigenvectors``: an array of shape ``(..., M, M)``, where column ``v[:, i]`` is the
|
|
normalized eigenvector corresponding to the eigenvalue ``w[i]``.
|
|
|
|
See also:
|
|
- :func:`jax.numpy.linalg.eig`: general eigenvalue decomposition.
|
|
- :func:`jax.numpy.linalg.eigvalsh`: compute eigenvalues only.
|
|
- :func:`jax.scipy.linalg.eigh`: SciPy API for Hermitian eigendecomposition.
|
|
- :func:`jax.lax.linalg.eigh`: XLA API for Hermitian eigendecomposition.
|
|
|
|
Examples:
|
|
>>> a = jnp.array([[1, -2j],
|
|
... [2j, 1]])
|
|
>>> w, v = jnp.linalg.eigh(a)
|
|
>>> w
|
|
Array([-1., 3.], dtype=float32)
|
|
>>> with jnp.printoptions(precision=3):
|
|
... v
|
|
Array([[-0.707+0.j , -0.707+0.j ],
|
|
[ 0. +0.707j, 0. -0.707j]], dtype=complex64)
|
|
"""
|
|
check_arraylike("jnp.linalg.eigh", a)
|
|
if UPLO is None or UPLO == "L":
|
|
lower = True
|
|
elif UPLO == "U":
|
|
lower = False
|
|
else:
|
|
msg = f"UPLO must be one of None, 'L', or 'U', got {UPLO}"
|
|
raise ValueError(msg)
|
|
|
|
a, = promote_dtypes_inexact(jnp.asarray(a))
|
|
v, w = lax_linalg.eigh(a, lower=lower, symmetrize_input=symmetrize_input)
|
|
return EighResult(w, v)
|
|
|
|
|
|
@export
|
|
@partial(jit, static_argnames=('UPLO',))
|
|
def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
|
|
"""
|
|
Compute the eigenvalues of a Hermitian matrix.
|
|
|
|
JAX implementation of :func:`numpy.linalg.eigvalsh`.
|
|
|
|
Args:
|
|
a: array of shape ``(..., M, M)``, containing the Hermitian (if complex)
|
|
or symmetric (if real) matrix.
|
|
UPLO: specifies whether the calculation is done with the lower triangular
|
|
part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``).
|
|
|
|
Returns:
|
|
An array of shape ``(..., M)`` containing the eigenvalues, sorted in
|
|
ascending order.
|
|
|
|
See also:
|
|
- :func:`jax.numpy.linalg.eig`: general eigenvalue decomposition.
|
|
- :func:`jax.numpy.linalg.eigh`: computes eigenvalues and eigenvectors of a
|
|
Hermitian matrix.
|
|
|
|
Examples:
|
|
>>> a = jnp.array([[1, -2j],
|
|
... [2j, 1]])
|
|
>>> w = jnp.linalg.eigvalsh(a)
|
|
>>> w
|
|
Array([-1., 3.], dtype=float32)
|
|
"""
|
|
check_arraylike("jnp.linalg.eigvalsh", a)
|
|
a, = promote_dtypes_inexact(jnp.asarray(a))
|
|
w, _ = eigh(a, UPLO)
|
|
return w
|
|
|
|
|
|
# TODO(micky774): deprecated 2024-5-14, remove wrapper after deprecation expires.
|
|
@export
|
|
def pinv(a: ArrayLike, rtol: ArrayLike | None = None,
|
|
hermitian: bool = False, *,
|
|
rcond: ArrayLike | DeprecatedArg | None = DeprecatedArg()) -> Array:
|
|
"""Compute the (Moore-Penrose) pseudo-inverse of a matrix.
|
|
|
|
JAX implementation of :func:`numpy.linalg.pinv`.
|
|
|
|
Args:
|
|
a: array of shape ``(..., M, N)`` containing matrices to pseudo-invert.
|
|
rtol: float or array_like of shape ``a.shape[:-2]``. Specifies the cutoff
|
|
for small singular values.of shape ``(...,)``.
|
|
Cutoff for small singular values; singular values smaller
|
|
``rtol * largest_singular_value`` are treated as zero. The default is
|
|
determined based on the floating point precision of the dtype.
|
|
hermitian: if True, then the input is assumed to be Hermitian, and a more
|
|
efficient algorithm is used (default: False)
|
|
rcond: deprecated alias of the ``rtol`` argument. Will result in a
|
|
:class:`DeprecationWarning` if used.
|
|
|
|
Returns:
|
|
An array of shape ``(..., N, M)`` containing the pseudo-inverse of ``a``.
|
|
|
|
See also:
|
|
- :func:`jax.numpy.linalg.inv`: multiplicative inverse of a square matrix.
|
|
|
|
Notes:
|
|
:func:`jax.numpy.linalg.pinv` differs from :func:`numpy.linalg.pinv` in the
|
|
default value of `rcond``: in NumPy, the default is `1e-15`. In JAX, the
|
|
default is ``10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps``.
|
|
|
|
Examples:
|
|
>>> a = jnp.array([[1, 2],
|
|
... [3, 4],
|
|
... [5, 6]])
|
|
>>> a_pinv = jnp.linalg.pinv(a)
|
|
>>> a_pinv # doctest: +SKIP
|
|
Array([[-1.333332 , -0.33333257, 0.6666657 ],
|
|
[ 1.0833322 , 0.33333272, -0.41666582]], dtype=float32)
|
|
|
|
The pseudo-inverse operates as a multiplicative inverse so long as the
|
|
output is not rank-deficient:
|
|
|
|
>>> jnp.allclose(a_pinv @ a, jnp.eye(2), atol=1E-4)
|
|
Array(True, dtype=bool)
|
|
"""
|
|
if not isinstance(rcond, DeprecatedArg):
|
|
rtol = rcond
|
|
del rcond
|
|
deprecations.warn(
|
|
"jax-numpy-linalg-pinv-rcond",
|
|
("The rcond argument for linalg.pinv is deprecated. "
|
|
"Please use rtol instead."),
|
|
stacklevel=2
|
|
)
|
|
|
|
return _pinv(a, rtol, hermitian)
|
|
|
|
|
|
@partial(custom_jvp, nondiff_argnums=(1, 2))
|
|
@partial(jit, static_argnames=('hermitian'))
|
|
def _pinv(a: ArrayLike, rtol: ArrayLike | None = None, hermitian: bool = False) -> Array:
|
|
# Uses same algorithm as
|
|
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
|
|
check_arraylike("jnp.linalg.pinv", a)
|
|
arr, = promote_dtypes_inexact(jnp.asarray(a))
|
|
m, n = arr.shape[-2:]
|
|
if m == 0 or n == 0:
|
|
return jnp.empty(arr.shape[:-2] + (n, m), arr.dtype)
|
|
arr = ufuncs.conj(arr)
|
|
if rtol is None:
|
|
max_rows_cols = max(arr.shape[-2:])
|
|
rtol = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps)
|
|
rtol = jnp.asarray(rtol)
|
|
u, s, vh = svd(arr, full_matrices=False, hermitian=hermitian)
|
|
# Singular values less than or equal to ``rtol * largest_singular_value``
|
|
# are set to zero.
|
|
rtol = lax.expand_dims(rtol[..., jnp.newaxis], range(s.ndim - rtol.ndim - 1))
|
|
cutoff = rtol * s[..., 0:1]
|
|
s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype)
|
|
res = jnp.matmul(vh.mT, ufuncs.divide(u.mT, s[..., jnp.newaxis]),
|
|
precision=lax.Precision.HIGHEST)
|
|
return lax.convert_element_type(res, arr.dtype)
|
|
|
|
|
|
@_pinv.defjvp
|
|
@jax.default_matmul_precision("float32")
|
|
def _pinv_jvp(rtol, hermitian, primals, tangents):
|
|
# The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems
|
|
# Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM
|
|
# Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432.
|
|
# (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)
|
|
a, = primals # m x n
|
|
a_dot, = tangents
|
|
p = pinv(a, rtol=rtol, hermitian=hermitian) # n x m
|
|
if hermitian:
|
|
# svd(..., hermitian=True) symmetrizes its input, and the JVP must match.
|
|
a = _symmetrize(a)
|
|
a_dot = _symmetrize(a_dot)
|
|
|
|
# TODO(phawkins): this could be simplified in the Hermitian case if we
|
|
# supported triangular matrix multiplication.
|
|
m, n = a.shape[-2:]
|
|
if m >= n:
|
|
s = (p @ _H(p)) @ _H(a_dot) # nxm
|
|
t = (_H(a_dot) @ _H(p)) @ p # nxm
|
|
p_dot = -(p @ a_dot) @ p + s - (s @ a) @ p + t - (p @ a) @ t
|
|
else: # m < n
|
|
s = p @ (_H(p) @ _H(a_dot))
|
|
t = _H(a_dot) @ (_H(p) @ p)
|
|
p_dot = -p @ (a_dot @ p) + s - s @ (a @ p) + t - p @ (a @ t)
|
|
return p, p_dot
|
|
|
|
|
|
@export
|
|
@jit
|
|
def inv(a: ArrayLike) -> Array:
|
|
"""Return the inverse of a square matrix
|
|
|
|
JAX implementation of :func:`numpy.linalg.inv`.
|
|
|
|
Args:
|
|
a: array of shape ``(..., N, N)`` specifying square array(s) to be inverted.
|
|
|
|
Returns:
|
|
Array of shape ``(..., N, N)`` containing the inverse of the input.
|
|
|
|
Notes:
|
|
In most cases, explicitly computing the inverse of a matrix is ill-advised. For
|
|
example, to compute ``x = inv(A) @ b``, it is more performant and numerically
|
|
precise to use a direct solve, such as :func:`jax.scipy.linalg.solve`.
|
|
|
|
See Also:
|
|
- :func:`jax.scipy.linalg.inv`: SciPy-style API for matrix inverse
|
|
- :func:`jax.numpy.linalg.solve`: direct linear solver
|
|
|
|
Examples:
|
|
Compute the inverse of a 3x3 matrix
|
|
|
|
>>> a = jnp.array([[1., 2., 3.],
|
|
... [2., 4., 2.],
|
|
... [3., 2., 1.]])
|
|
>>> a_inv = jnp.linalg.inv(a)
|
|
>>> a_inv # doctest: +SKIP
|
|
Array([[ 0. , -0.25 , 0.5 ],
|
|
[-0.25 , 0.5 , -0.25000003],
|
|
[ 0.5 , -0.25 , 0. ]], dtype=float32)
|
|
|
|
Check that multiplying with the inverse gives the identity:
|
|
|
|
>>> jnp.allclose(a @ a_inv, jnp.eye(3), atol=1E-5)
|
|
Array(True, dtype=bool)
|
|
|
|
Multiply the inverse by a vector ``b``, to find a solution to ``a @ x = b``
|
|
|
|
>>> b = jnp.array([1., 4., 2.])
|
|
>>> a_inv @ b
|
|
Array([ 0. , 1.25, -0.5 ], dtype=float32)
|
|
|
|
Note, however, that explicitly computing the inverse in such a case can lead
|
|
to poor performance and loss of precision as the size of the problem grows.
|
|
Instead, you should use a direct solver like :func:`jax.numpy.linalg.solve`:
|
|
|
|
>>> jnp.linalg.solve(a, b)
|
|
Array([ 0. , 1.25, -0.5 ], dtype=float32)
|
|
"""
|
|
check_arraylike("jnp.linalg.inv", a)
|
|
arr = jnp.asarray(a)
|
|
if arr.ndim < 2 or arr.shape[-1] != arr.shape[-2]:
|
|
raise ValueError(
|
|
f"Argument to inv must have shape [..., n, n], got {arr.shape}.")
|
|
return solve(
|
|
arr, lax.broadcast(jnp.eye(arr.shape[-1], dtype=arr.dtype), arr.shape[:-2]))
|
|
|
|
|
|
@export
|
|
@partial(jit, static_argnames=('ord', 'axis', 'keepdims'))
|
|
def norm(x: ArrayLike, ord: int | str | None = None,
|
|
axis: None | tuple[int, ...] | int = None,
|
|
keepdims: bool = False) -> Array:
|
|
"""Compute the norm of a matrix or vector.
|
|
|
|
JAX implementation of :func:`numpy.linalg.norm`.
|
|
|
|
Args:
|
|
x: N-dimensional array for which the norm will be computed.
|
|
ord: specify the kind of norm to take. Default is Frobenius norm for matrices,
|
|
and the 2-norm for vectors. For other options, see Notes below.
|
|
axis: integer or sequence of integers specifying the axes over which the norm
|
|
will be computed. Defaults to all axes of ``x``.
|
|
keepdims: if True, the output array will have the same number of dimensions as
|
|
the input, with the size of reduced axes replaced by ``1`` (default: False).
|
|
|
|
Returns:
|
|
array containing the specified norm of x.
|
|
|
|
Notes:
|
|
The flavor of norm computed depends on the value of ``ord`` and the number of
|
|
axes being reduced.
|
|
|
|
For **vector norms** (i.e. a single axis reduction):
|
|
|
|
- ``ord=None`` (default) computes the 2-norm
|
|
- ``ord=inf`` computes ``max(abs(x))``
|
|
- ``ord=-inf`` computes min(abs(x))``
|
|
- ``ord=0`` computes ``sum(x!=0)``
|
|
- for other numerical values, computes ``sum(abs(x) ** ord)**(1/ord)``
|
|
|
|
For **matrix norms** (i.e. two axes reductions):
|
|
|
|
- ``ord='fro'`` or ``ord=None`` (default) computes the Frobenius norm
|
|
- ``ord='nuc'`` computes the nuclear norm, or the sum of the singular values
|
|
- ``ord=1`` computes ``max(abs(x).sum(0))``
|
|
- ``ord=-1`` computes ``min(abs(x).sum(0))``
|
|
- ``ord=2`` computes the 2-norm, i.e. the largest singular value
|
|
- ``ord=-2`` computes the smallest singular value
|
|
|
|
Examples:
|
|
Vector norms:
|
|
|
|
>>> x = jnp.array([3., 4., 12.])
|
|
>>> jnp.linalg.norm(x)
|
|
Array(13., dtype=float32)
|
|
>>> jnp.linalg.norm(x, ord=1)
|
|
Array(19., dtype=float32)
|
|
>>> jnp.linalg.norm(x, ord=0)
|
|
Array(3., dtype=float32)
|
|
|
|
Matrix norms:
|
|
|
|
>>> x = jnp.array([[1., 2., 3.],
|
|
... [4., 5., 7.]])
|
|
>>> jnp.linalg.norm(x) # Frobenius norm
|
|
Array(10.198039, dtype=float32)
|
|
>>> jnp.linalg.norm(x, ord='nuc') # nuclear norm
|
|
Array(10.762535, dtype=float32)
|
|
>>> jnp.linalg.norm(x, ord=1) # 1-norm
|
|
Array(10., dtype=float32)
|
|
|
|
Batched vector norm:
|
|
|
|
>>> jnp.linalg.norm(x, axis=1)
|
|
Array([3.7416575, 9.486833 ], dtype=float32)
|
|
"""
|
|
check_arraylike("jnp.linalg.norm", x)
|
|
x, = promote_dtypes_inexact(jnp.asarray(x))
|
|
x_shape = jnp.shape(x)
|
|
ndim = len(x_shape)
|
|
|
|
if axis is None:
|
|
# NumPy has an undocumented behavior that admits arbitrary rank inputs if
|
|
# `ord` is None: https://github.com/numpy/numpy/issues/14215
|
|
if ord is None:
|
|
return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), keepdims=keepdims))
|
|
axis = tuple(range(ndim))
|
|
elif isinstance(axis, tuple):
|
|
axis = tuple(canonicalize_axis(x, ndim) for x in axis)
|
|
else:
|
|
axis = (canonicalize_axis(axis, ndim),)
|
|
|
|
num_axes = len(axis)
|
|
if num_axes == 1:
|
|
return vector_norm(x, ord=2 if ord is None else ord, axis=axis, keepdims=keepdims)
|
|
|
|
elif num_axes == 2:
|
|
row_axis, col_axis = axis # pytype: disable=bad-unpacking
|
|
if ord is None or ord in ('f', 'fro'):
|
|
return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis,
|
|
keepdims=keepdims))
|
|
elif ord == 1:
|
|
if not keepdims and col_axis > row_axis:
|
|
col_axis -= 1
|
|
return reductions.amax(reductions.sum(ufuncs.abs(x), axis=row_axis, keepdims=keepdims),
|
|
axis=col_axis, keepdims=keepdims)
|
|
elif ord == -1:
|
|
if not keepdims and col_axis > row_axis:
|
|
col_axis -= 1
|
|
return reductions.amin(reductions.sum(ufuncs.abs(x), axis=row_axis, keepdims=keepdims),
|
|
axis=col_axis, keepdims=keepdims)
|
|
elif ord == jnp.inf:
|
|
if not keepdims and row_axis > col_axis:
|
|
row_axis -= 1
|
|
return reductions.amax(reductions.sum(ufuncs.abs(x), axis=col_axis, keepdims=keepdims),
|
|
axis=row_axis, keepdims=keepdims)
|
|
elif ord == -jnp.inf:
|
|
if not keepdims and row_axis > col_axis:
|
|
row_axis -= 1
|
|
return reductions.amin(reductions.sum(ufuncs.abs(x), axis=col_axis, keepdims=keepdims),
|
|
axis=row_axis, keepdims=keepdims)
|
|
elif ord in ('nuc', 2, -2):
|
|
x = jnp.moveaxis(x, axis, (-2, -1))
|
|
if ord == 2:
|
|
reducer = reductions.amax
|
|
elif ord == -2:
|
|
reducer = reductions.amin
|
|
else:
|
|
# `sum` takes an extra dtype= argument, unlike `amax` and `amin`.
|
|
reducer = reductions.sum # type: ignore[assignment]
|
|
y = reducer(svd(x, compute_uv=False), axis=-1)
|
|
if keepdims:
|
|
y = jnp.expand_dims(y, axis)
|
|
return y
|
|
else:
|
|
raise ValueError(f"Invalid order '{ord}' for matrix norm.")
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid axis values ({axis}) for jnp.linalg.norm.")
|
|
|
|
@overload
|
|
def qr(a: ArrayLike, mode: Literal["r"]) -> Array: ...
|
|
@overload
|
|
def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: ...
|
|
|
|
@export
|
|
@partial(jit, static_argnames=('mode',))
|
|
def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult:
|
|
"""Compute the QR decomposition of an array
|
|
|
|
JAX implementation of :func:`numpy.linalg.qr`.
|
|
|
|
The QR decomposition of a matrix `A` is given by
|
|
|
|
.. math::
|
|
|
|
A = QR
|
|
|
|
Where `Q` is a unitary matrix (i.e. :math:`Q^HQ=I`) and `R` is an upper-triangular
|
|
matrix.
|
|
|
|
Args:
|
|
a: array of shape (..., M, N)
|
|
mode: Computational mode. Supported values are:
|
|
|
|
- ``"reduced"`` (default): return `Q` of shape ``(..., M, K)`` and `R` of shape
|
|
``(..., K, N)``, where ``K = min(M, N)``.
|
|
- ``"complete"``: return `Q` of shape ``(..., M, M)`` and `R` of shape ``(..., M, N)``.
|
|
- ``"raw"``: return lapack-internal representations of shape ``(..., M, N)`` and ``(..., K)``.
|
|
- ``"r"``: return `R` only.
|
|
|
|
Returns:
|
|
A tuple ``(Q, R)`` (if ``mode`` is not ``"r"``) otherwise an array ``R``,
|
|
where:
|
|
|
|
- ``Q`` is an orthogonal matrix of shape ``(..., M, K)`` (if ``mode`` is ``"reduced"``)
|
|
or ``(..., M, M)`` (if ``mode`` is ``"complete"``).
|
|
- ``R`` is an upper-triangular matrix of shape ``(..., M, N)`` (if ``mode`` is
|
|
``"r"`` or ``"complete"``) or ``(..., K, N)`` (if ``mode`` is ``"reduced"``)
|
|
|
|
with ``K = min(M, N)``.
|
|
|
|
See also:
|
|
- :func:`jax.scipy.linalg.qr`: SciPy-style QR decomposition API
|
|
- :func:`jax.lax.linalg.qr`: XLA-style QR decomposition API
|
|
|
|
Examples:
|
|
Compute the QR decomposition of a matrix:
|
|
|
|
>>> a = jnp.array([[1., 2., 3., 4.],
|
|
... [5., 4., 2., 1.],
|
|
... [6., 3., 1., 5.]])
|
|
>>> Q, R = jnp.linalg.qr(a)
|
|
>>> Q # doctest: +SKIP
|
|
Array([[-0.12700021, -0.7581426 , -0.6396022 ],
|
|
[-0.63500065, -0.43322435, 0.63960224],
|
|
[-0.7620008 , 0.48737738, -0.42640156]], dtype=float32)
|
|
>>> R # doctest: +SKIP
|
|
Array([[-7.8740077, -5.080005 , -2.4130025, -4.953006 ],
|
|
[ 0. , -1.7870499, -2.6534991, -1.028908 ],
|
|
[ 0. , 0. , -1.0660033, -4.050814 ]], dtype=float32)
|
|
|
|
Check that ``Q`` is orthonormal:
|
|
|
|
>>> jnp.allclose(Q.T @ Q, jnp.eye(3), atol=1E-5)
|
|
Array(True, dtype=bool)
|
|
|
|
Reconstruct the input:
|
|
|
|
>>> jnp.allclose(Q @ R, a)
|
|
Array(True, dtype=bool)
|
|
"""
|
|
check_arraylike("jnp.linalg.qr", a)
|
|
a, = promote_dtypes_inexact(jnp.asarray(a))
|
|
if mode == "raw":
|
|
a, taus = lax_linalg.geqrf(a)
|
|
return QRResult(a.mT, taus)
|
|
if mode in ("reduced", "r", "full"):
|
|
full_matrices = False
|
|
elif mode == "complete":
|
|
full_matrices = True
|
|
else:
|
|
raise ValueError(f"Unsupported QR decomposition mode '{mode}'")
|
|
q, r = lax_linalg.qr(a, pivoting=False, full_matrices=full_matrices)
|
|
if mode == "r":
|
|
return r
|
|
return QRResult(q, r)
|
|
|
|
|
|
@export
|
|
@jit
|
|
def solve(a: ArrayLike, b: ArrayLike) -> Array:
|
|
"""Solve a linear system of equations
|
|
|
|
JAX implementation of :func:`numpy.linalg.solve`.
|
|
|
|
This solves a (batched) linear system of equations ``a @ x = b``
|
|
for ``x`` given ``a`` and ``b``.
|
|
|
|
Args:
|
|
a: array of shape ``(..., N, N)``.
|
|
b: array of shape ``(N,)`` (for 1-dimensional right-hand-side) or
|
|
``(..., N, M)`` (for batched 2-dimensional right-hand-side).
|
|
|
|
Returns:
|
|
An array containing the result of the linear solve. The result has shape ``(..., N)``
|
|
if ``b`` is of shape ``(N,)``, and has shape ``(..., N, M)`` otherwise.
|
|
|
|
See also:
|
|
- :func:`jax.scipy.linalg.solve`: SciPy-style API for solving linear systems.
|
|
- :func:`jax.lax.custom_linear_solve`: matrix-free linear solver.
|
|
|
|
Examples:
|
|
A simple 3x3 linear system:
|
|
|
|
>>> A = jnp.array([[1., 2., 3.],
|
|
... [2., 4., 2.],
|
|
... [3., 2., 1.]])
|
|
>>> b = jnp.array([14., 16., 10.])
|
|
>>> x = jnp.linalg.solve(A, b)
|
|
>>> x
|
|
Array([1., 2., 3.], dtype=float32)
|
|
|
|
Confirming that the result solves the system:
|
|
|
|
>>> jnp.allclose(A @ x, b)
|
|
Array(True, dtype=bool)
|
|
"""
|
|
check_arraylike("jnp.linalg.solve", a, b)
|
|
a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
|
|
|
|
if a.ndim < 2:
|
|
raise ValueError(
|
|
f"left hand array must be at least two dimensional; got {a.shape=}")
|
|
|
|
# Check for invalid inputs that previously would have led to a batched 1D solve:
|
|
if (b.ndim > 1 and a.ndim == b.ndim + 1 and
|
|
a.shape[-1] == b.shape[-1] and a.shape[-1] != b.shape[-2]):
|
|
raise ValueError(
|
|
f"Invalid shapes for solve: {a.shape}, {b.shape}. Prior to JAX v0.5.0,"
|
|
" this would have been treated as a batched 1-dimensional solve."
|
|
" To recover this behavior, use solve(a, b[..., None]).squeeze(-1).")
|
|
|
|
signature = "(m,m),(m)->(m)" if b.ndim == 1 else "(m,m),(m,n)->(m,n)"
|
|
return jnp.vectorize(lax_linalg._solve, signature=signature)(a, b)
|
|
|
|
|
|
def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *,
|
|
numpy_resid: bool = False) -> tuple[Array, Array, Array, Array]:
|
|
# TODO: add lstsq to lax_linalg and implement this function via those wrappers.
|
|
# TODO: add custom jvp rule for more robust lstsq differentiation
|
|
a, b = promote_dtypes_inexact(a, b)
|
|
if a.shape[0] != b.shape[0]:
|
|
raise ValueError("Leading dimensions of input arrays must match")
|
|
b_orig_ndim = b.ndim
|
|
if b_orig_ndim == 1:
|
|
b = b[:, None]
|
|
if a.ndim != 2:
|
|
raise TypeError(
|
|
f"{a.ndim}-dimensional array given. Array must be two-dimensional")
|
|
if b.ndim != 2:
|
|
raise TypeError(
|
|
f"{b.ndim}-dimensional array given. Array must be one or two-dimensional")
|
|
m, n = a.shape
|
|
dtype = a.dtype
|
|
if a.size == 0:
|
|
s = jnp.empty(0, dtype=a.dtype)
|
|
rank = jnp.array(0, dtype=int)
|
|
x = jnp.empty((n, *b.shape[1:]), dtype=a.dtype)
|
|
else:
|
|
if rcond is None:
|
|
rcond = float(jnp.finfo(dtype).eps) * max(n, m)
|
|
else:
|
|
rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond)
|
|
u, s, vt = svd(a, full_matrices=False)
|
|
mask = s >= jnp.array(rcond, dtype=s.dtype) * s[0]
|
|
rank = mask.sum()
|
|
safe_s = jnp.where(mask, s, 1).astype(a.dtype)
|
|
s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis]
|
|
uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST)
|
|
x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST)
|
|
# Numpy returns empty residuals in some cases. To allow compilation, we
|
|
# default to returning full residuals in all cases.
|
|
if numpy_resid and (rank < n or m <= n):
|
|
resid = jnp.asarray([])
|
|
else:
|
|
b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST)
|
|
resid = norm(b - b_estimate, axis=0) ** 2
|
|
if b_orig_ndim == 1:
|
|
x = x.ravel()
|
|
return x, resid, rank, s
|
|
|
|
_jit_lstsq = jit(partial(_lstsq, numpy_resid=False))
|
|
|
|
|
|
@export
|
|
def lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None = None, *,
|
|
numpy_resid: bool = False) -> tuple[Array, Array, Array, Array]:
|
|
"""
|
|
Return the least-squares solution to a linear equation.
|
|
|
|
JAX implementation of :func:`numpy.linalg.lstsq`.
|
|
|
|
Args:
|
|
a: array of shape ``(M, N)`` representing the coefficient matrix.
|
|
b: array of shape ``(M,)`` or ``(M, K)`` representing the right-hand side.
|
|
rcond: Cut-off ratio for small singular values. Singular values smaller than
|
|
``rcond * largest_singular_value`` are treated as zero. If None (default),
|
|
the optimal value will be used to reduce floating point errors.
|
|
numpy_resid: If True, compute and return residuals in the same way as NumPy's
|
|
`linalg.lstsq`. This is necessary if you want to precisely replicate NumPy's
|
|
behavior. If False (default), a more efficient method is used to compute residuals.
|
|
|
|
Returns:
|
|
Tuple of arrays ``(x, resid, rank, s)`` where
|
|
|
|
- ``x`` is a shape ``(N,)`` or ``(N, K)`` array containing the least-squares solution.
|
|
- ``resid`` is the sum of squared residual of shape ``()`` or ``(K,)``.
|
|
- ``rank`` is the rank of the matrix ``a``.
|
|
- ``s`` is the singular values of the matrix ``a``.
|
|
|
|
Examples:
|
|
>>> a = jnp.array([[1, 2],
|
|
... [3, 4]])
|
|
>>> b = jnp.array([5, 6])
|
|
>>> x, _, _, _ = jnp.linalg.lstsq(a, b)
|
|
>>> with jnp.printoptions(precision=3):
|
|
... print(x)
|
|
[-4. 4.5]
|
|
"""
|
|
check_arraylike("jnp.linalg.lstsq", a, b)
|
|
if numpy_resid:
|
|
return _lstsq(a, b, rcond, numpy_resid=True)
|
|
return _jit_lstsq(a, b, rcond)
|
|
|
|
|
|
@export
|
|
def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1):
|
|
r"""Compute the cross-product of two 3D vectors
|
|
|
|
JAX implementation of :func:`numpy.linalg.cross`
|
|
|
|
Args:
|
|
x1: N-dimensional array, with ``x1.shape[axis] == 3``
|
|
x2: N-dimensional array, with ``x2.shape[axis] == 3``, and other axes
|
|
broadcast-compatible with ``x1``.
|
|
axis: axis along which to take the cross product (default: -1).
|
|
|
|
Returns:
|
|
array containing the result of the cross-product
|
|
|
|
See Also:
|
|
:func:`jax.numpy.cross`: more flexible cross-product API.
|
|
|
|
Examples:
|
|
|
|
Showing that :math:`\hat{x} \times \hat{y} = \hat{z}`:
|
|
|
|
>>> x = jnp.array([1., 0., 0.])
|
|
>>> y = jnp.array([0., 1., 0.])
|
|
>>> jnp.linalg.cross(x, y)
|
|
Array([0., 0., 1.], dtype=float32)
|
|
|
|
Cross product of :math:`\hat{x}` with all three standard unit vectors,
|
|
via broadcasting:
|
|
|
|
>>> xyz = jnp.eye(3)
|
|
>>> jnp.linalg.cross(x, xyz, axis=-1)
|
|
Array([[ 0., 0., 0.],
|
|
[ 0., 0., 1.],
|
|
[ 0., -1., 0.]], dtype=float32)
|
|
"""
|
|
check_arraylike("jnp.linalg.outer", x1, x2)
|
|
x1, x2 = jnp.asarray(x1), jnp.asarray(x2)
|
|
if x1.shape[axis] != 3 or x2.shape[axis] != 3:
|
|
raise ValueError(
|
|
"Both input arrays must be (arrays of) 3-dimensional vectors, "
|
|
f"but they have {x1.shape[axis]=} and {x2.shape[axis]=}"
|
|
)
|
|
return jnp.cross(x1, x2, axis=axis)
|
|
|
|
|
|
@export
|
|
def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
|
"""Compute the outer product of two 1-dimensional arrays.
|
|
|
|
JAX implementation of :func:`numpy.linalg.outer`.
|
|
|
|
Args:
|
|
x1: array
|
|
x2: array
|
|
|
|
Returns:
|
|
array containing the outer product of ``x1`` and ``x2``
|
|
|
|
See also:
|
|
:func:`jax.numpy.outer`: similar function in the main :mod:`jax.numpy` module.
|
|
|
|
Examples:
|
|
>>> x1 = jnp.array([1, 2, 3])
|
|
>>> x2 = jnp.array([4, 5, 6])
|
|
>>> jnp.linalg.outer(x1, x2)
|
|
Array([[ 4, 5, 6],
|
|
[ 8, 10, 12],
|
|
[12, 15, 18]], dtype=int32)
|
|
"""
|
|
check_arraylike("jnp.linalg.outer", x1, x2)
|
|
x1, x2 = jnp.asarray(x1), jnp.asarray(x2)
|
|
if x1.ndim != 1 or x2.ndim != 1:
|
|
raise ValueError(f"Input arrays must be one-dimensional, but they are {x1.ndim=} {x2.ndim=}")
|
|
return x1[:, None] * x2[None, :]
|
|
|
|
|
|
@export
|
|
def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str | int = 'fro') -> Array:
|
|
"""Compute the norm of a matrix or stack of matrices.
|
|
|
|
JAX implementation of :func:`numpy.linalg.matrix_norm`
|
|
|
|
Args:
|
|
x: array of shape ``(..., M, N)`` for which to take the norm.
|
|
keepdims: if True, keep the reduced dimensions in the output.
|
|
ord: A string or int specifying the type of norm; default is the Frobenius norm.
|
|
See :func:`numpy.linalg.norm` for details on available options.
|
|
|
|
Returns:
|
|
array containing the norm of ``x``. Has shape ``x.shape[:-2]`` if ``keepdims`` is
|
|
False, or shape ``(..., 1, 1)`` if ``keepdims`` is True.
|
|
|
|
See also:
|
|
- :func:`jax.numpy.linalg.vector_norm`: Norm of a vector or stack of vectors.
|
|
- :func:`jax.numpy.linalg.norm`: More general matrix or vector norm.
|
|
|
|
Examples:
|
|
>>> x = jnp.array([[1, 2, 3],
|
|
... [4, 5, 6],
|
|
... [7, 8, 9]])
|
|
>>> jnp.linalg.matrix_norm(x)
|
|
Array(16.881943, dtype=float32)
|
|
"""
|
|
check_arraylike('jnp.linalg.matrix_norm', x)
|
|
return norm(x, ord=ord, keepdims=keepdims, axis=(-2, -1))
|
|
|
|
|
|
@export
|
|
def matrix_transpose(x: ArrayLike, /) -> Array:
|
|
"""Transpose a matrix or stack of matrices.
|
|
|
|
JAX implementation of :func:`numpy.linalg.matrix_transpose`.
|
|
|
|
Args:
|
|
x: array of shape ``(..., M, N)``
|
|
|
|
Returns:
|
|
array of shape ``(..., N, M)`` containing the matrix transpose of ``x``.
|
|
|
|
See also:
|
|
:func:`jax.numpy.transpose`: more general transpose operation.
|
|
|
|
Examples:
|
|
Transpose of a single matrix:
|
|
|
|
>>> x = jnp.array([[1, 2, 3],
|
|
... [4, 5, 6]])
|
|
>>> jnp.linalg.matrix_transpose(x)
|
|
Array([[1, 4],
|
|
[2, 5],
|
|
[3, 6]], dtype=int32)
|
|
|
|
Transpose of a stack of matrices:
|
|
|
|
>>> x = jnp.array([[[1, 2],
|
|
... [3, 4]],
|
|
... [[5, 6],
|
|
... [7, 8]]])
|
|
>>> jnp.linalg.matrix_transpose(x)
|
|
Array([[[1, 3],
|
|
[2, 4]],
|
|
<BLANKLINE>
|
|
[[5, 7],
|
|
[6, 8]]], dtype=int32)
|
|
|
|
For convenience, the same computation can be done via the
|
|
:attr:`~jax.Array.mT` property of JAX array objects:
|
|
|
|
>>> x.mT
|
|
Array([[[1, 3],
|
|
[2, 4]],
|
|
<BLANKLINE>
|
|
[[5, 7],
|
|
[6, 8]]], dtype=int32)
|
|
"""
|
|
check_arraylike('jnp.linalg.matrix_transpose', x)
|
|
x_arr = jnp.asarray(x)
|
|
ndim = x_arr.ndim
|
|
if ndim < 2:
|
|
raise ValueError(f"matrix_transpose requres at least 2 dimensions; got {ndim=}")
|
|
return jax.lax.transpose(x_arr, (*range(ndim - 2), ndim - 1, ndim - 2))
|
|
|
|
|
|
@export
|
|
def vector_norm(x: ArrayLike, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False,
|
|
ord: int | str = 2) -> Array:
|
|
"""Compute the vector norm of a vector or batch of vectors.
|
|
|
|
JAX implementation of :func:`numpy.linalg.vector_norm`.
|
|
|
|
Args:
|
|
x: N-dimensional array for which to take the norm.
|
|
axis: optional axis along which to compute the vector norm. If None (default)
|
|
then ``x`` is flattened and the norm is taken over all values.
|
|
keepdims: if True, keep the reduced dimensions in the output.
|
|
ord: A string or int specifying the type of norm; default is the 2-norm.
|
|
See :func:`numpy.linalg.norm` for details on available options.
|
|
|
|
Returns:
|
|
array containing the norm of ``x``.
|
|
|
|
See also:
|
|
- :func:`jax.numpy.linalg.matrix_norm`: Norm of a matrix or stack of matrices.
|
|
- :func:`jax.numpy.linalg.norm`: More general matrix or vector norm.
|
|
|
|
Examples:
|
|
Norm of a single vector:
|
|
|
|
>>> x = jnp.array([1., 2., 3.])
|
|
>>> jnp.linalg.vector_norm(x)
|
|
Array(3.7416575, dtype=float32)
|
|
|
|
Norm of a batch of vectors:
|
|
|
|
>>> x = jnp.array([[1., 2., 3.],
|
|
... [4., 5., 7.]])
|
|
>>> jnp.linalg.vector_norm(x, axis=1)
|
|
Array([3.7416575, 9.486833 ], dtype=float32)
|
|
"""
|
|
check_arraylike('jnp.linalg.vector_norm', x)
|
|
if ord is None or ord == 2:
|
|
return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis,
|
|
keepdims=keepdims))
|
|
elif ord == jnp.inf:
|
|
return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims)
|
|
elif ord == -jnp.inf:
|
|
return reductions.amin(ufuncs.abs(x), axis=axis, keepdims=keepdims)
|
|
elif ord == 0:
|
|
return reductions.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype,
|
|
axis=axis, keepdims=keepdims)
|
|
elif ord == 1:
|
|
# Numpy has a special case for ord == 1 as an optimization. We don't
|
|
# really need the optimization (XLA could do it for us), but the Numpy
|
|
# code has slightly different type promotion semantics, so we need a
|
|
# special case too.
|
|
return reductions.sum(ufuncs.abs(x), axis=axis, keepdims=keepdims)
|
|
elif isinstance(ord, str):
|
|
msg = f"Invalid order '{ord}' for vector norm."
|
|
if ord == "inf":
|
|
msg += "Use 'jax.numpy.inf' instead."
|
|
if ord == "-inf":
|
|
msg += "Use '-jax.numpy.inf' instead."
|
|
raise ValueError(msg)
|
|
else:
|
|
abs_x = ufuncs.abs(x)
|
|
ord_arr = lax_internal._const(abs_x, ord)
|
|
ord_inv = lax_internal._const(abs_x, 1. / ord_arr)
|
|
out = reductions.sum(abs_x ** ord_arr, axis=axis, keepdims=keepdims)
|
|
return ufuncs.power(out, ord_inv)
|
|
|
|
@export
|
|
def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,
|
|
precision: PrecisionLike = None,
|
|
preferred_element_type: DTypeLike | None = None) -> Array:
|
|
"""Compute the (batched) vector conjugate dot product of two arrays.
|
|
|
|
JAX implementation of :func:`numpy.linalg.vecdot`.
|
|
|
|
Args:
|
|
x1: left-hand side array.
|
|
x2: right-hand side array. Size of ``x2[axis]`` must match size of ``x1[axis]``,
|
|
and remaining dimensions must be broadcast-compatible.
|
|
axis: axis along which to compute the dot product (default: -1)
|
|
precision: either ``None`` (default), which means the default precision for
|
|
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
|
|
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
|
|
such values indicating precision of ``x1`` and ``x2``.
|
|
preferred_element_type: either ``None`` (default), which means the default
|
|
accumulation type for the input types, or a datatype, indicating to
|
|
accumulate results to and return a result with that datatype.
|
|
|
|
Returns:
|
|
array containing the conjugate dot product of ``x1`` and ``x2`` along ``axis``.
|
|
The non-contracted dimensions are broadcast together.
|
|
|
|
See also:
|
|
- :func:`jax.numpy.vecdot`: similar API in the ``jax.numpy`` namespace.
|
|
- :func:`jax.numpy.linalg.matmul`: matrix multiplication.
|
|
- :func:`jax.numpy.linalg.tensordot`: general tensor dot product.
|
|
|
|
Examples:
|
|
Vector dot product of two 1D arrays:
|
|
|
|
>>> x1 = jnp.array([1, 2, 3])
|
|
>>> x2 = jnp.array([4, 5, 6])
|
|
>>> jnp.linalg.vecdot(x1, x2)
|
|
Array(32, dtype=int32)
|
|
|
|
Batched vector dot product of two 2D arrays:
|
|
|
|
>>> x1 = jnp.array([[1, 2, 3],
|
|
... [4, 5, 6]])
|
|
>>> x2 = jnp.array([[2, 3, 4]])
|
|
>>> jnp.linalg.vecdot(x1, x2, axis=-1)
|
|
Array([20, 47], dtype=int32)
|
|
"""
|
|
check_arraylike('jnp.linalg.vecdot', x1, x2)
|
|
return jnp.vecdot(x1, x2, axis=axis, precision=precision,
|
|
preferred_element_type=preferred_element_type)
|
|
|
|
|
|
@export
|
|
def matmul(x1: ArrayLike, x2: ArrayLike, /, *,
|
|
precision: PrecisionLike = None,
|
|
preferred_element_type: DTypeLike | None = None) -> Array:
|
|
"""Perform a matrix multiplication.
|
|
|
|
JAX implementation of :func:`numpy.linalg.matmul`.
|
|
|
|
Args:
|
|
x1: first input array, of shape ``(..., N)``.
|
|
x2: second input array. Must have shape ``(N,)`` or ``(..., N, M)``.
|
|
In the multi-dimensional case, leading dimensions must be broadcast-compatible
|
|
with the leading dimensions of ``x1``.
|
|
precision: either ``None`` (default), which means the default precision for
|
|
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
|
|
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
|
|
such values indicating precision of ``x1`` and ``x2``.
|
|
preferred_element_type: either ``None`` (default), which means the default
|
|
accumulation type for the input types, or a datatype, indicating to
|
|
accumulate results to and return a result with that datatype.
|
|
|
|
Returns:
|
|
array containing the matrix product of the inputs. Shape is ``x1.shape[:-1]``
|
|
if ``x2.ndim == 1``, otherwise the shape is ``(..., M)``.
|
|
|
|
See Also:
|
|
:func:`jax.numpy.matmul`: NumPy API for this function.
|
|
:func:`jax.numpy.linalg.vecdot`: batched vector product.
|
|
:func:`jax.numpy.linalg.tensordot`: batched tensor product.
|
|
|
|
Examples:
|
|
Vector dot products:
|
|
|
|
>>> x1 = jnp.array([1, 2, 3])
|
|
>>> x2 = jnp.array([4, 5, 6])
|
|
>>> jnp.linalg.matmul(x1, x2)
|
|
Array(32, dtype=int32)
|
|
|
|
Matrix dot product:
|
|
|
|
>>> x1 = jnp.array([[1, 2, 3],
|
|
... [4, 5, 6]])
|
|
>>> x2 = jnp.array([[1, 2],
|
|
... [3, 4],
|
|
... [5, 6]])
|
|
>>> jnp.linalg.matmul(x1, x2)
|
|
Array([[22, 28],
|
|
[49, 64]], dtype=int32)
|
|
|
|
For convenience, in all cases you can do the same computation using
|
|
the ``@`` operator:
|
|
|
|
>>> x1 @ x2
|
|
Array([[22, 28],
|
|
[49, 64]], dtype=int32)
|
|
"""
|
|
check_arraylike('jnp.linalg.matmul', x1, x2)
|
|
return jnp.matmul(x1, x2, precision=precision,
|
|
preferred_element_type=preferred_element_type)
|
|
|
|
|
|
@export
|
|
def tensordot(x1: ArrayLike, x2: ArrayLike, /, *,
|
|
axes: int | tuple[Sequence[int], Sequence[int]] = 2,
|
|
precision: PrecisionLike = None,
|
|
preferred_element_type: DTypeLike | None = None) -> Array:
|
|
"""Compute the tensor dot product of two N-dimensional arrays.
|
|
|
|
JAX implementation of :func:`numpy.linalg.tensordot`.
|
|
|
|
Args:
|
|
x1: N-dimensional array
|
|
x2: M-dimensional array
|
|
axes: integer or tuple of sequences of integers. If an integer `k`, then
|
|
sum over the last `k` axes of ``x1`` and the first `k` axes of ``x2``,
|
|
in order. If a tuple, then ``axes[0]`` specifies the axes of ``x1`` and
|
|
``axes[1]`` specifies the axes of ``x2``.
|
|
precision: either ``None`` (default), which means the default precision for
|
|
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
|
|
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
|
|
such values indicating precision of ``x1`` and ``x2``.
|
|
preferred_element_type: either ``None`` (default), which means the default
|
|
accumulation type for the input types, or a datatype, indicating to
|
|
accumulate results to and return a result with that datatype.
|
|
|
|
Returns:
|
|
array containing the tensor dot product of the inputs
|
|
|
|
See also:
|
|
- :func:`jax.numpy.tensordot`: equivalent API in the :mod:`jax.numpy` namespace.
|
|
- :func:`jax.numpy.einsum`: NumPy API for more general tensor contractions.
|
|
- :func:`jax.lax.dot_general`: XLA API for more general tensor contractions.
|
|
|
|
Examples:
|
|
>>> x1 = jnp.arange(24.).reshape(2, 3, 4)
|
|
>>> x2 = jnp.ones((3, 4, 5))
|
|
>>> jnp.linalg.tensordot(x1, x2)
|
|
Array([[ 66., 66., 66., 66., 66.],
|
|
[210., 210., 210., 210., 210.]], dtype=float32)
|
|
|
|
Equivalent result when specifying the axes as explicit sequences:
|
|
|
|
>>> jnp.linalg.tensordot(x1, x2, axes=([1, 2], [0, 1]))
|
|
Array([[ 66., 66., 66., 66., 66.],
|
|
[210., 210., 210., 210., 210.]], dtype=float32)
|
|
|
|
Equivalent result via :func:`~jax.numpy.einsum`:
|
|
|
|
>>> jnp.einsum('ijk,jkm->im', x1, x2)
|
|
Array([[ 66., 66., 66., 66., 66.],
|
|
[210., 210., 210., 210., 210.]], dtype=float32)
|
|
|
|
Setting ``axes=1`` for two-dimensional inputs is equivalent to a matrix
|
|
multiplication:
|
|
|
|
>>> x1 = jnp.array([[1, 2],
|
|
... [3, 4]])
|
|
>>> x2 = jnp.array([[1, 2, 3],
|
|
... [4, 5, 6]])
|
|
>>> jnp.linalg.tensordot(x1, x2, axes=1)
|
|
Array([[ 9, 12, 15],
|
|
[19, 26, 33]], dtype=int32)
|
|
>>> x1 @ x2
|
|
Array([[ 9, 12, 15],
|
|
[19, 26, 33]], dtype=int32)
|
|
|
|
Setting ``axes=0`` for one-dimensional inputs is equivalent to
|
|
:func:`jax.numpy.linalg.outer`:
|
|
|
|
>>> x1 = jnp.array([1, 2])
|
|
>>> x2 = jnp.array([1, 2, 3])
|
|
>>> jnp.linalg.tensordot(x1, x2, axes=0)
|
|
Array([[1, 2, 3],
|
|
[2, 4, 6]], dtype=int32)
|
|
>>> jnp.linalg.outer(x1, x2)
|
|
Array([[1, 2, 3],
|
|
[2, 4, 6]], dtype=int32)
|
|
"""
|
|
check_arraylike('jnp.linalg.tensordot', x1, x2)
|
|
return jnp.tensordot(x1, x2, axes=axes, precision=precision,
|
|
preferred_element_type=preferred_element_type)
|
|
|
|
|
|
@export
|
|
def svdvals(x: ArrayLike, /) -> Array:
|
|
"""Compute the singular values of a matrix.
|
|
|
|
JAX implementation of :func:`numpy.linalg.svdvals`.
|
|
|
|
Args:
|
|
x: array of shape ``(..., M, N)`` for which singular values will be computed.
|
|
|
|
Returns:
|
|
array of singular values of shape ``(..., K)`` with ``K = min(M, N)``.
|
|
|
|
See also:
|
|
:func:`jax.numpy.linalg.svd`: compute singular values and singular vectors
|
|
|
|
Examples:
|
|
>>> x = jnp.array([[1, 2, 3],
|
|
... [4, 5, 6]])
|
|
>>> jnp.linalg.svdvals(x)
|
|
Array([9.508031 , 0.7728694], dtype=float32)
|
|
"""
|
|
check_arraylike('jnp.linalg.svdvals', x)
|
|
return svd(x, compute_uv=False, hermitian=False)
|
|
|
|
|
|
@export
|
|
def diagonal(x: ArrayLike, /, *, offset: int = 0) -> Array:
|
|
"""Extract the diagonal of an matrix or stack of matrices.
|
|
|
|
JAX implementation of :func:`numpy.linalg.diagonal`.
|
|
|
|
Args:
|
|
x: array of shape ``(..., M, N)`` from which the diagonal will be extracted.
|
|
offset: positive or negative offset from the main diagonal.
|
|
|
|
Returns:
|
|
Array of shape ``(..., K)`` where ``K`` is the length of the specified diagonal.
|
|
|
|
See Also:
|
|
- :func:`jax.numpy.diagonal`: more general functionality for extracting diagonals.
|
|
- :func:`jax.numpy.diag`: create a diagonal matrix from values.
|
|
|
|
Examples:
|
|
Diagonals of a single matrix:
|
|
|
|
>>> x = jnp.array([[1, 2, 3, 4],
|
|
... [5, 6, 7, 8],
|
|
... [9, 10, 11, 12]])
|
|
>>> jnp.linalg.diagonal(x)
|
|
Array([ 1, 6, 11], dtype=int32)
|
|
>>> jnp.linalg.diagonal(x, offset=1)
|
|
Array([ 2, 7, 12], dtype=int32)
|
|
>>> jnp.linalg.diagonal(x, offset=-1)
|
|
Array([ 5, 10], dtype=int32)
|
|
|
|
Batched diagonals:
|
|
|
|
>>> x = jnp.arange(24).reshape(2, 3, 4)
|
|
>>> jnp.linalg.diagonal(x)
|
|
Array([[ 0, 5, 10],
|
|
[12, 17, 22]], dtype=int32)
|
|
"""
|
|
check_arraylike('jnp.linalg.diagonal', x)
|
|
return jnp.diagonal(x, offset=offset, axis1=-2, axis2=-1)
|
|
|
|
|
|
@export
|
|
def tensorinv(a: ArrayLike, ind: int = 2) -> Array:
|
|
"""Compute the tensor inverse of an array.
|
|
|
|
JAX implementation of :func:`numpy.linalg.tensorinv`.
|
|
|
|
This computes the inverse of the :func:`~jax.numpy.linalg.tensordot`
|
|
operation with the same ``ind`` value.
|
|
|
|
Args:
|
|
a: array to be inverted. Must have ``prod(a.shape[:ind]) == prod(a.shape[ind:])``
|
|
ind: positive integer specifying the number of indices in the tensor product.
|
|
|
|
Returns:
|
|
array of shape ``(*a.shape[ind:], *a.shape[:ind])`` containing the
|
|
tensor inverse of ``a``.
|
|
|
|
See also:
|
|
- :func:`jax.numpy.linalg.tensordot`
|
|
- :func:`jax.numpy.linalg.tensorsolve`
|
|
|
|
Examples:
|
|
>>> key = jax.random.key(1337)
|
|
>>> x = jax.random.normal(key, shape=(2, 2, 4))
|
|
>>> xinv = jnp.linalg.tensorinv(x, 2)
|
|
>>> xinv_x = jnp.linalg.tensordot(xinv, x, axes=2)
|
|
>>> jnp.allclose(xinv_x, jnp.eye(4), atol=1E-4)
|
|
Array(True, dtype=bool)
|
|
"""
|
|
check_arraylike("tensorinv", a)
|
|
arr = jnp.asarray(a)
|
|
ind = operator.index(ind)
|
|
if ind <= 0:
|
|
raise ValueError(f"ind must be a positive integer; got {ind=}")
|
|
contracting_shape, batch_shape = arr.shape[:ind], arr.shape[ind:]
|
|
flatshape = (math.prod(contracting_shape), math.prod(batch_shape))
|
|
if flatshape[0] != flatshape[1]:
|
|
raise ValueError("tensorinv is only possible when the product of the first"
|
|
" `ind` dimensions equals that of the remaining dimensions."
|
|
f" got {arr.shape=} with {ind=}.")
|
|
return inv(arr.reshape(flatshape)).reshape(*batch_shape, *contracting_shape)
|
|
|
|
|
|
@export
|
|
def tensorsolve(a: ArrayLike, b: ArrayLike, axes: tuple[int, ...] | None = None) -> Array:
|
|
"""Solve the tensor equation a x = b for x.
|
|
|
|
JAX implementation of :func:`numpy.linalg.tensorsolve`.
|
|
|
|
Args:
|
|
a: input array. After reordering via ``axes`` (see below), shape must be
|
|
``(*b.shape, *x.shape)``.
|
|
b: right-hand-side array.
|
|
axes: optional tuple specifying axes of ``a`` that should be moved to the end
|
|
|
|
Returns:
|
|
array x such that after reordering of axes of ``a``, ``tensordot(a, x, x.ndim)``
|
|
is equivalent to ``b``.
|
|
|
|
See also:
|
|
- :func:`jax.numpy.linalg.tensordot`
|
|
- :func:`jax.numpy.linalg.tensorinv`
|
|
|
|
Examples:
|
|
>>> key1, key2 = jax.random.split(jax.random.key(8675309))
|
|
>>> a = jax.random.normal(key1, shape=(2, 2, 4))
|
|
>>> b = jax.random.normal(key2, shape=(2, 2))
|
|
>>> x = jnp.linalg.tensorsolve(a, b)
|
|
>>> x.shape
|
|
(4,)
|
|
|
|
Now show that ``x`` can be used to reconstruct ``b`` using
|
|
:func:`~jax.numpy.linalg.tensordot`:
|
|
|
|
>>> b_reconstructed = jnp.linalg.tensordot(a, x, axes=x.ndim)
|
|
>>> jnp.allclose(b, b_reconstructed)
|
|
Array(True, dtype=bool)
|
|
"""
|
|
check_arraylike("tensorsolve", a, b)
|
|
a_arr, b_arr = jnp.asarray(a), jnp.asarray(b)
|
|
if axes is not None:
|
|
a_arr = jnp.moveaxis(a_arr, axes, len(axes) * (a_arr.ndim - 1,))
|
|
out_shape = a_arr.shape[b_arr.ndim:]
|
|
if a_arr.shape[:b_arr.ndim] != b_arr.shape:
|
|
raise ValueError("After moving axes to end, leading shape of a must match shape of b."
|
|
f" got a.shape={a_arr.shape}, b.shape={b_arr.shape}")
|
|
if b_arr.size != math.prod(out_shape):
|
|
raise ValueError("Input arrays must have prod(a.shape[:b.ndim]) == prod(a.shape[b.ndim:]);"
|
|
f" got a.shape={a_arr.shape}, b.ndim={b_arr.ndim}.")
|
|
a_arr = a_arr.reshape(b_arr.size, math.prod(out_shape))
|
|
return solve(a_arr, b_arr.ravel()).reshape(out_shape)
|
|
|
|
|
|
@export
|
|
def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) -> Array:
|
|
"""Efficiently compute matrix products between a sequence of arrays.
|
|
|
|
JAX implementation of :func:`numpy.linalg.multi_dot`.
|
|
|
|
JAX internally uses the opt_einsum library to compute the most efficient
|
|
operation order.
|
|
|
|
Args:
|
|
arrays: sequence of arrays. All must be two-dimensional, except the first
|
|
and last which may be one-dimensional.
|
|
precision: either ``None`` (default), which means the default precision for
|
|
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
|
|
``Precision.HIGH`` or ``Precision.HIGHEST``).
|
|
|
|
Returns:
|
|
an array representing the equivalent of ``reduce(jnp.matmul, arrays)``, but
|
|
evaluated in the optimal order.
|
|
|
|
This function exists because the cost of computing sequences of matmul operations
|
|
can differ vastly depending on the order in which the operations are evaluated.
|
|
For a single matmul, the number of floating point operations (flops) required to
|
|
compute a matrix product can be approximated this way:
|
|
|
|
>>> def approx_flops(x, y):
|
|
... # for 2D x and y, with x.shape[1] == y.shape[0]
|
|
... return 2 * x.shape[0] * x.shape[1] * y.shape[1]
|
|
|
|
Suppose we have three matrices that we'd like to multiply in sequence:
|
|
|
|
>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3)
|
|
>>> x = jax.random.normal(key1, shape=(200, 5))
|
|
>>> y = jax.random.normal(key2, shape=(5, 100))
|
|
>>> z = jax.random.normal(key3, shape=(100, 10))
|
|
|
|
Because of associativity of matrix products, there are two orders in which we might
|
|
evaluate the product ``x @ y @ z``, and both produce equivalent outputs up to floating
|
|
point precision:
|
|
|
|
>>> result1 = (x @ y) @ z
|
|
>>> result2 = x @ (y @ z)
|
|
>>> jnp.allclose(result1, result2, atol=1E-4)
|
|
Array(True, dtype=bool)
|
|
|
|
But the computational cost of these differ greatly:
|
|
|
|
>>> print("(x @ y) @ z flops:", approx_flops(x, y) + approx_flops(x @ y, z))
|
|
(x @ y) @ z flops: 600000
|
|
>>> print("x @ (y @ z) flops:", approx_flops(y, z) + approx_flops(x, y @ z))
|
|
x @ (y @ z) flops: 30000
|
|
|
|
The second approach is about 20x more efficient in terms of estimated flops!
|
|
|
|
``multi_dot`` is a function that will automatically choose the fastest
|
|
computational path for such problems:
|
|
|
|
>>> result3 = jnp.linalg.multi_dot([x, y, z])
|
|
>>> jnp.allclose(result1, result3, atol=1E-4)
|
|
Array(True, dtype=bool)
|
|
|
|
We can use JAX's :ref:`ahead-of-time-lowering` tools to estimate the total flops
|
|
of each approach, and confirm that ``multi_dot`` is choosing the more efficient
|
|
option:
|
|
|
|
>>> jax.jit(lambda x, y, z: (x @ y) @ z).lower(x, y, z).cost_analysis()['flops']
|
|
600000.0
|
|
>>> jax.jit(lambda x, y, z: x @ (y @ z)).lower(x, y, z).cost_analysis()['flops']
|
|
30000.0
|
|
>>> jax.jit(jnp.linalg.multi_dot).lower([x, y, z]).cost_analysis()['flops']
|
|
30000.0
|
|
"""
|
|
check_arraylike('jnp.linalg.multi_dot', *arrays)
|
|
arrs: list[Array] = list(map(jnp.asarray, arrays))
|
|
if len(arrs) < 2:
|
|
raise ValueError(f"multi_dot requires at least two arrays; got len(arrays)={len(arrs)}")
|
|
if not (arrs[0].ndim in (1, 2) and arrs[-1].ndim in (1, 2) and
|
|
all(a.ndim == 2 for a in arrs[1:-1])):
|
|
raise ValueError("multi_dot: input arrays must all be two-dimensional, except for"
|
|
" the first and last array which may be 1 or 2 dimensional."
|
|
f" Got array shapes {[a.shape for a in arrs]}")
|
|
if any(a.shape[-1] != b.shape[0] for a, b in zip(arrs[:-1], arrs[1:])):
|
|
raise ValueError("multi_dot: last dimension of each array must match first dimension"
|
|
f" of following array. Got array shapes {[a.shape for a in arrs]}")
|
|
einsum_axes: list[tuple[int, ...]] = [(i, i+1) for i in range(len(arrs))]
|
|
if arrs[0].ndim == 1:
|
|
einsum_axes[0] = einsum_axes[0][1:]
|
|
if arrs[-1].ndim == 1:
|
|
einsum_axes[-1] = einsum_axes[-1][:1]
|
|
return jnp.einsum(*itertools.chain(*zip(arrs, einsum_axes)), # type: ignore[call-overload]
|
|
optimize='auto', precision=precision)
|
|
|
|
|
|
@export
|
|
@partial(jit, static_argnames=['p'])
|
|
def cond(x: ArrayLike, p=None):
|
|
"""Compute the condition number of a matrix.
|
|
|
|
JAX implementation of :func:`numpy.linalg.cond`.
|
|
|
|
The condition number is defined as ``norm(x, p) * norm(inv(x), p)``. For ``p = 2``
|
|
(the default), the condition number is the ratio of the largest to the smallest
|
|
singular value.
|
|
|
|
Args:
|
|
x: array of shape ``(..., M, N)`` for which to compute the condition number.
|
|
p: the order of the norm to use. One of ``{None, 1, -1, 2, -2, inf, -inf, 'fro'}``;
|
|
see :func:`jax.numpy.linalg.norm` for the meaning of these. The default is ``p = None``,
|
|
which is equivalent to ``p = 2``. If not in ``{None, 2, -2}`` then ``x`` must be square,
|
|
i.e. ``M = N``.
|
|
|
|
Returns:
|
|
array of shape ``x.shape[:-2]`` containing the condition number.
|
|
|
|
See also:
|
|
:func:`jax.numpy.linalg.norm`
|
|
|
|
Examples:
|
|
|
|
Well-conditioned matrix:
|
|
|
|
>>> x = jnp.array([[1, 2],
|
|
... [2, 1]])
|
|
>>> jnp.linalg.cond(x)
|
|
Array(3., dtype=float32)
|
|
|
|
Ill-conditioned matrix:
|
|
|
|
>>> x = jnp.array([[1, 2],
|
|
... [0, 0]])
|
|
>>> jnp.linalg.cond(x)
|
|
Array(inf, dtype=float32)
|
|
"""
|
|
check_arraylike("cond", x)
|
|
arr = jnp.asarray(x)
|
|
if arr.ndim < 2:
|
|
raise ValueError(f"jnp.linalg.cond: input array must be at least 2D; got {arr.shape=}")
|
|
if arr.shape[-1] == 0 or arr.shape[-2] == 0:
|
|
raise ValueError(f"jnp.linalg.cond: input array must not be empty; got {arr.shape=}")
|
|
if p is None or p == 2:
|
|
s = svdvals(x)
|
|
return s[..., 0] / s[..., -1]
|
|
elif p == -2:
|
|
s = svdvals(x)
|
|
r = s[..., -1] / s[..., 0]
|
|
else:
|
|
if arr.shape[-2] != arr.shape[-1]:
|
|
raise ValueError(f"jnp.linalg.cond: for {p=}, array must be square; got {arr.shape=}")
|
|
r = norm(x, ord=p, axis=(-2, -1)) * norm(inv(x), ord=p, axis=(-2, -1))
|
|
# Convert NaNs to infs where original array has no NaNs.
|
|
return jnp.where(ufuncs.isnan(r) & ~ufuncs.isnan(x).any(axis=(-2, -1)), jnp.inf, r)
|
|
|
|
|
|
@export
|
|
def trace(x: ArrayLike, /, *,
|
|
offset: int = 0, dtype: DTypeLike | None = None) -> Array:
|
|
"""Compute the trace of a matrix.
|
|
|
|
JAX implementation of :func:`numpy.linalg.trace`.
|
|
|
|
Args:
|
|
x: array of shape ``(..., M, N)`` and whose innermost two
|
|
dimensions form MxN matrices for which to take the trace.
|
|
offset: positive or negative offset from the main diagonal
|
|
(default: 0).
|
|
dtype: data type of the returned array (default: ``None``). If ``None``,
|
|
then output dtype will match the dtype of ``x``, promoted to default
|
|
precision in the case of integer types.
|
|
|
|
Returns:
|
|
array of batched traces with shape ``x.shape[:-2]``
|
|
|
|
See also:
|
|
- :func:`jax.numpy.trace`: similar API in the ``jax.numpy`` namespace.
|
|
|
|
Examples:
|
|
Trace of a single matrix:
|
|
|
|
>>> x = jnp.array([[1, 2, 3, 4],
|
|
... [5, 6, 7, 8],
|
|
... [9, 10, 11, 12]])
|
|
>>> jnp.linalg.trace(x)
|
|
Array(18, dtype=int32)
|
|
>>> jnp.linalg.trace(x, offset=1)
|
|
Array(21, dtype=int32)
|
|
>>> jnp.linalg.trace(x, offset=-1, dtype="float32")
|
|
Array(15., dtype=float32)
|
|
|
|
Batched traces:
|
|
|
|
>>> x = jnp.arange(24).reshape(2, 3, 4)
|
|
>>> jnp.linalg.trace(x)
|
|
Array([15, 51], dtype=int32)
|
|
"""
|
|
check_arraylike('jnp.linalg.trace', x)
|
|
return jnp.trace(x, offset=offset, axis1=-2, axis2=-1, dtype=dtype)
|