mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Reorder top-level functions in lax.linalg, and add/expand docstrings.
PiperOrigin-RevId: 726603731
This commit is contained in:
parent
5ebb7eb55d
commit
c6c38fb852
@ -238,19 +238,24 @@ Linear algebra operators (jax.lax.linalg)
|
||||
:toctree: _autosummary
|
||||
|
||||
cholesky
|
||||
cholesky_update
|
||||
eig
|
||||
eigh
|
||||
hessenberg
|
||||
lu
|
||||
householder_product
|
||||
lu
|
||||
lu_pivots_to_permutation
|
||||
qdwh
|
||||
qr
|
||||
schur
|
||||
svd
|
||||
SvdAlgorithm
|
||||
symmetric_product
|
||||
triangular_solve
|
||||
tridiagonal
|
||||
tridiagonal_solve
|
||||
|
||||
|
||||
Argument classes
|
||||
----------------
|
||||
|
||||
|
@ -18,7 +18,7 @@ from collections.abc import Callable
|
||||
import enum
|
||||
from functools import partial
|
||||
import math
|
||||
from typing import Any, Literal, TypeVar, overload
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -57,41 +57,11 @@ from jax._src.lib import gpu_solver # pylint:disable=unused-import # noqa: F40
|
||||
from jax._src.lib import gpu_sparse # pylint:disable=unused-import # noqa: F401
|
||||
from jax._src.lib import lapack # pylint:disable=unused-import # noqa: F401
|
||||
|
||||
TFun = TypeVar('TFun', bound=Callable[..., Any])
|
||||
|
||||
def _broadcasted_iotas(*sizes):
|
||||
ones = (1,) * (len(sizes) - 1)
|
||||
shapes = (util.tuple_insert(ones, i, s) for i, s in enumerate(sizes))
|
||||
return [lax.broadcasted_iota('int32', shape, i) for i, shape in enumerate(shapes)]
|
||||
|
||||
def _tril(m: Array, k:int = 0) -> Array:
|
||||
*_, N, M = m.shape
|
||||
mask = lax_internal._tri(bool, (N, M), k)
|
||||
return lax.select(lax.broadcast(mask, m.shape[:-2]), m, lax.zeros_like_array(m))
|
||||
|
||||
def _triu(m: Array, k:int = 0) -> Array:
|
||||
*_, N, M = m.shape
|
||||
mask = lax_internal._tri(bool, (N, M), k - 1)
|
||||
return lax.select(lax.broadcast(mask, m.shape[:-2]), lax.zeros_like_array(m), m)
|
||||
|
||||
def _construct_diagonal(s: Array) -> Array:
|
||||
"""Construct a (batched) diagonal matrix"""
|
||||
i = lax.iota('int32', s.shape[-1])
|
||||
return lax.full((*s.shape, s.shape[-1]), 0, s.dtype).at[..., i, i].set(s)
|
||||
|
||||
def _extract_diagonal(s: Array) -> Array:
|
||||
"""Extract the diagonal from a batched matrix"""
|
||||
i = lax.iota('int32', min(s.shape[-2], s.shape[-1]))
|
||||
return s[..., i, i]
|
||||
|
||||
def _broadcast_to(x: Array, shape: tuple[int, ...]) -> Array:
|
||||
assert x.ndim <= len(shape)
|
||||
return lax.broadcast_in_dim(x, shape, range(len(shape) - x.ndim, len(shape)))
|
||||
|
||||
# traceables
|
||||
# Top-level functions in alphabetical order.
|
||||
|
||||
def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array:
|
||||
"""Cholesky decomposition.
|
||||
r"""Cholesky decomposition.
|
||||
|
||||
Computes the Cholesky decomposition
|
||||
|
||||
@ -106,7 +76,7 @@ def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array:
|
||||
x: A batch of square Hermitian (symmetric if real) positive-definite
|
||||
matrices with shape ``[..., n, n]``.
|
||||
symmetrize_input: If ``True``, the matrix is symmetrized before Cholesky
|
||||
decomposition by computing :math:`\\frac{1}{2}(x + x^H)`. If ``False``,
|
||||
decomposition by computing :math:`\frac{1}{2}(x + x^H)`. If ``False``,
|
||||
only the lower triangle of ``x`` is used; the upper triangle is ignored
|
||||
and not accessed.
|
||||
|
||||
@ -120,9 +90,31 @@ def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array:
|
||||
return _tril(cholesky_p.bind(x))
|
||||
|
||||
|
||||
def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True,
|
||||
compute_right_eigenvectors: bool = True,
|
||||
use_magma: bool | None = None) -> list[Array]:
|
||||
def cholesky_update(r_matrix: ArrayLike, w_vector: ArrayLike) -> Array:
|
||||
r"""Cholesky rank-1 update.
|
||||
|
||||
Given a Cholesky decomposition :math:`A = R.T \, R` and a vector :math:`w`,
|
||||
computes the Cholesky decomposition of :math:`A + w \, w.T` in :math:`O(N^2)`
|
||||
time.
|
||||
|
||||
Args:
|
||||
r_matrix: An upper-triangular matrix (R) such that :math:`A = R^T \, R`.
|
||||
w_vector: A vector :math:`w` for rank-1 update.
|
||||
|
||||
Returns:
|
||||
A new upper-triangular matrix :math:`R` defining the Cholesky decomposition
|
||||
of :math:`A + w \, w^T`.
|
||||
"""
|
||||
return cholesky_update_p.bind(r_matrix, w_vector)
|
||||
|
||||
|
||||
def eig(
|
||||
x: ArrayLike,
|
||||
*,
|
||||
compute_left_eigenvectors: bool = True,
|
||||
compute_right_eigenvectors: bool = True,
|
||||
use_magma: bool | None = None,
|
||||
) -> list[Array]:
|
||||
"""Eigendecomposition of a general matrix.
|
||||
|
||||
Nonsymmetric eigendecomposition is only implemented on CPU and GPU. On GPU,
|
||||
@ -201,10 +193,10 @@ def eigh(
|
||||
sort_eigenvalues: If ``True``, the eigenvalues will be sorted in ascending
|
||||
order. If ``False`` the eigenvalues are returned in an
|
||||
implementation-defined order.
|
||||
subset_by_index: Optional 2-tuple [start, end] indicating the range of
|
||||
indices of eigenvalues to compute. For example, is ``range_select`` =
|
||||
[n-2,n], then ``eigh`` computes the two largest eigenvalues and their
|
||||
eigenvectors.
|
||||
subset_by_index: Optional 2-tuple [start, end] indicating the range of
|
||||
indices of eigenvalues to compute. For example, is ``range_select`` =
|
||||
[n-2,n], then ``eigh`` computes the two largest eigenvalues and their
|
||||
eigenvectors.
|
||||
|
||||
Returns:
|
||||
A tuple ``(v, w)``.
|
||||
@ -229,59 +221,47 @@ def eigh(
|
||||
return v, w
|
||||
|
||||
|
||||
def cholesky_update(r_matrix: ArrayLike, w_vector: ArrayLike) -> Array:
|
||||
"""Given a Cholesky decomposition A = R.T @ R and a vector w,
|
||||
computes the Cholesky decomposition of A + w @ w.T in O(N^2) time.
|
||||
def hessenberg(a: ArrayLike) -> tuple[Array, Array]:
|
||||
"""Reduces a square matrix to upper Hessenberg form.
|
||||
|
||||
Currently implemented on CPU only.
|
||||
|
||||
Args:
|
||||
r_matrix: An upper-triangular matrix (R) such that A = R.T @ R.
|
||||
w_vector: A vector (w) for rank-1 update.
|
||||
a: A floating point or complex square matrix or batch of matrices.
|
||||
|
||||
Returns:
|
||||
A new R' matrix being the Cholesky decomposition of A + w @ w.T.
|
||||
A ``(a, taus)`` pair, where the upper triangle and first subdiagonal of
|
||||
``a`` contain the upper Hessenberg matrix, and the elements below the first
|
||||
subdiagonal contain the Householder reflectors. For each Householder
|
||||
reflector ``taus`` contains the scalar factors of the elementary Householder
|
||||
reflectors.
|
||||
"""
|
||||
return cholesky_update_p.bind(r_matrix, w_vector)
|
||||
return hessenberg_p.bind(a)
|
||||
|
||||
|
||||
def symmetric_product(
|
||||
a_matrix: ArrayLike, c_matrix: ArrayLike,
|
||||
alpha: float = 1., beta: float = 0.,
|
||||
symmetrize_output=False):
|
||||
"""Computes C = alpha * A @ A.T + beta * C (where C is symmetric)."""
|
||||
result = symmetric_product_p.bind(a_matrix, c_matrix, alpha=alpha, beta=beta)
|
||||
if symmetrize_output:
|
||||
upper_half = lax.transpose(
|
||||
_tril(result, k=-1),
|
||||
(*range(result.ndim - 2), result.ndim - 1, result.ndim - 2))
|
||||
result = _tril(result, k=0) + upper_half
|
||||
return result
|
||||
|
||||
|
||||
def lu_pivots_to_permutation(pivots: ArrayLike, permutation_size: int) -> Array:
|
||||
"""Converts the pivots (row swaps) returned by LU to a permutation.
|
||||
|
||||
We build a permutation rather than applying `pivots` directly to the rows
|
||||
of a matrix because lax loops aren't differentiable.
|
||||
def householder_product(a: ArrayLike, taus: ArrayLike) -> Array:
|
||||
"""Product of elementary Householder reflectors.
|
||||
|
||||
Args:
|
||||
pivots: an int32 array of shape (..., k) of row swaps to perform
|
||||
permutation_size: the size of the output permutation. Has to be >= k.
|
||||
a: A matrix with shape ``[..., m, n]``, whose lower triangle contains
|
||||
elementary Householder reflectors.
|
||||
taus: A vector with shape ``[..., k]``, where ``k < min(m, n)``, containing
|
||||
the scalar factors of the elementary Householder reflectors.
|
||||
|
||||
Returns:
|
||||
An int32 array of shape (..., permutation_size).
|
||||
A batch of orthogonal (unitary) matrices with the same shape as ``a``,
|
||||
containing the products of the elementary Householder reflectors.
|
||||
"""
|
||||
permutation = lu_pivots_to_permutation_p.bind(
|
||||
pivots, permutation_size=permutation_size)
|
||||
return permutation
|
||||
return householder_product_p.bind(a, taus)
|
||||
|
||||
|
||||
def lu(x: ArrayLike) -> tuple[Array, Array, Array]:
|
||||
"""LU decomposition with partial pivoting.
|
||||
r"""LU decomposition with partial pivoting.
|
||||
|
||||
Computes the matrix decomposition:
|
||||
|
||||
.. math::
|
||||
P.A = L.U
|
||||
P \, A = L \, U
|
||||
|
||||
where :math:`P` is a permutation of the rows of :math:`A`, :math:`L` is a
|
||||
lower-triangular matrix with unit-diagonal elements, and :math:`U` is an
|
||||
@ -305,8 +285,24 @@ def lu(x: ArrayLike) -> tuple[Array, Array, Array]:
|
||||
swaps as a permutation, represented as an int32 array with shape
|
||||
``[..., m]``.
|
||||
"""
|
||||
lu, pivots, permutation = lu_p.bind(x)
|
||||
return lu, pivots, permutation
|
||||
return lu_p.bind(x)
|
||||
|
||||
|
||||
def lu_pivots_to_permutation(pivots: ArrayLike, permutation_size: int) -> Array:
|
||||
"""Converts the pivots (row swaps) returned by LU to a permutation.
|
||||
|
||||
We build a permutation rather than applying `pivots` directly to the rows
|
||||
of a matrix because lax loops aren't differentiable.
|
||||
|
||||
Args:
|
||||
pivots: an int32 array of shape (..., k) of row swaps to perform
|
||||
permutation_size: the size of the output permutation. Has to be >= k.
|
||||
|
||||
Returns:
|
||||
An int32 array of shape (..., permutation_size).
|
||||
"""
|
||||
return lu_pivots_to_permutation_p.bind(
|
||||
pivots, permutation_size=permutation_size)
|
||||
|
||||
|
||||
@overload
|
||||
@ -328,12 +324,12 @@ def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
|
||||
def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
|
||||
use_magma: bool | None = None
|
||||
) -> tuple[Array, Array] | tuple[Array, Array, Array]:
|
||||
"""QR decomposition.
|
||||
r"""QR decomposition.
|
||||
|
||||
Computes the QR decomposition
|
||||
|
||||
.. math::
|
||||
A = Q . R
|
||||
A = Q \, R
|
||||
|
||||
of matrices :math:`A`, such that :math:`Q` is a unitary (orthogonal) matrix,
|
||||
and :math:`R` is an upper-triangular matrix.
|
||||
@ -379,6 +375,42 @@ def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
|
||||
return q, r
|
||||
|
||||
|
||||
def schur(
|
||||
x: ArrayLike,
|
||||
*,
|
||||
compute_schur_vectors: bool = True,
|
||||
sort_eig_vals: bool = False,
|
||||
select_callable: Callable[..., Any] | None = None,
|
||||
) -> tuple[Array, Array]:
|
||||
r"""Schur decomposition.
|
||||
|
||||
Only implemented on CPU.
|
||||
|
||||
Computes the Schur decomposition:
|
||||
|
||||
.. math::
|
||||
A = Q \, U \, Q^{-H}
|
||||
|
||||
for a square matrix :math:`A`.
|
||||
|
||||
Args:
|
||||
x: A batch of square matrices with shape ``[..., m, m]``.
|
||||
compute_schur_vectors: If ``True``, compute the Schur vectors ::math:`Q`,
|
||||
otherwise only :math:`U` is computed.
|
||||
sort_eig_vals: Unused.
|
||||
select_callable: Unused.
|
||||
|
||||
Returns:
|
||||
A pair of arrays ``U, Q``, if ``compute_schur_vectors=True``, otherwise
|
||||
only ``U`` is returned.
|
||||
"""
|
||||
return schur_p.bind(
|
||||
x,
|
||||
compute_schur_vectors=compute_schur_vectors,
|
||||
sort_eig_vals=sort_eig_vals,
|
||||
select_callable=select_callable)
|
||||
|
||||
|
||||
class SvdAlgorithm(enum.Enum):
|
||||
"""Enum for SVD algorithm."""
|
||||
DEFAULT = "default"
|
||||
@ -433,9 +465,23 @@ def svd(
|
||||
) -> Array | tuple[Array, Array, Array]:
|
||||
"""Singular value decomposition.
|
||||
|
||||
Returns the singular values if compute_uv is False, otherwise returns a triple
|
||||
containing the left singular vectors, the singular values and the adjoint of
|
||||
the right singular vectors.
|
||||
Computes the singular value decomposition of an input matrix.
|
||||
|
||||
Args:
|
||||
x: A batch of matrices with shape ``[..., m, n]``.
|
||||
full_matrices: Determines if full or reduced matrices are returned.
|
||||
compute_uv: If ``True``, returns the left singular vectors, the singular
|
||||
values and the adjoint of the right singular vectors. Otherwise, only
|
||||
the singular values are returned.
|
||||
subset_by_index: If ``None``, the entire matrix is returned. Otherwise,
|
||||
returns the singular values and vectors for the given range of indices.
|
||||
algorithm: The SVD algorithm to use. Must be ``None`` or a value from
|
||||
:class:`~jax.lax.linalg.SvdAlgorithm`.
|
||||
|
||||
Returns:
|
||||
The singular values if ``compute_uv`` is ``False``, otherwise returns a
|
||||
triple containing the left singular vectors, the singular values, and the
|
||||
adjoint of the right singular vectors.
|
||||
"""
|
||||
result = svd_p.bind(
|
||||
x,
|
||||
@ -452,10 +498,56 @@ def svd(
|
||||
return s
|
||||
|
||||
|
||||
def triangular_solve(a: ArrayLike, b: ArrayLike, *,
|
||||
left_side: bool = False, lower: bool = False,
|
||||
transpose_a: bool = False, conjugate_a: bool = False,
|
||||
unit_diagonal: bool = False) -> Array:
|
||||
def symmetric_product(
|
||||
a_matrix: ArrayLike,
|
||||
c_matrix: ArrayLike,
|
||||
*,
|
||||
alpha: float = 1.,
|
||||
beta: float = 0.,
|
||||
symmetrize_output: bool = False
|
||||
):
|
||||
r"""Symmetric product.
|
||||
|
||||
Computes the symmetric product
|
||||
|
||||
..math::
|
||||
\alpha \, A \, A^T + \beta \, C
|
||||
|
||||
where :math:`A` is a rectangular matrix and :math:`C` is a symmetric matrix.
|
||||
|
||||
Args:
|
||||
a_matrix: A batch of matrices with shape ``[..., m, n]``.
|
||||
c_matrix: A batch of matrices with shape ``[..., m, m]``.
|
||||
alpha: A scalar.
|
||||
beta: A scalar.
|
||||
symmetrize_output: If ``True``, the upper triangle of the output is
|
||||
replaced with its transpose.
|
||||
|
||||
Returns:
|
||||
A batch of matrices with shape ``[..., m, m]`` where only the lower
|
||||
triangle is guaranteed to include the correct values on all platforms. If
|
||||
``symmetrize_output`` is ``True``, the upper triangle is filled with the
|
||||
transpose of the lower triangle, and the whole matrix is valid.
|
||||
"""
|
||||
result = symmetric_product_p.bind(a_matrix, c_matrix, alpha=alpha, beta=beta)
|
||||
if symmetrize_output:
|
||||
upper_half = lax.transpose(
|
||||
_tril(result, k=-1),
|
||||
(*range(result.ndim - 2), result.ndim - 1, result.ndim - 2))
|
||||
result = _tril(result, k=0) + upper_half
|
||||
return result
|
||||
|
||||
|
||||
def triangular_solve(
|
||||
a: ArrayLike,
|
||||
b: ArrayLike,
|
||||
*,
|
||||
left_side: bool = False,
|
||||
lower: bool = False,
|
||||
transpose_a: bool = False,
|
||||
conjugate_a: bool = False,
|
||||
unit_diagonal: bool = False,
|
||||
) -> Array:
|
||||
r"""Triangular solve.
|
||||
|
||||
Solves either the matrix equation
|
||||
@ -502,55 +594,69 @@ def triangular_solve(a: ArrayLike, b: ArrayLike, *,
|
||||
return out
|
||||
|
||||
|
||||
# utilities
|
||||
def _broadcasted_matvec(a: Array, b: Array) -> Array:
|
||||
# This is a broadcasted dot_general with signature (...,n,m),(...,m)->(...,n)
|
||||
assert a.ndim >= 2
|
||||
assert b.ndim >= 1
|
||||
batch_shape = lax.broadcast_shapes(a.shape[:-2], b.shape[:-1])
|
||||
n_batch = len(batch_shape)
|
||||
a = _broadcast_to(a, (*batch_shape, *a.shape[-2:]))
|
||||
b = _broadcast_to(b, (*batch_shape, b.shape[-1]))
|
||||
def tridiagonal(
|
||||
a: ArrayLike, *, lower: bool=True
|
||||
) -> tuple[Array, Array, Array, Array]:
|
||||
"""Reduces a symmetric/Hermitian matrix to tridiagonal form.
|
||||
|
||||
dimension_numbers = (([a.ndim - 1], [b.ndim - 1]), (list(range(n_batch)), list(range(n_batch))))
|
||||
return lax.dot_general(a, b, dimension_numbers=dimension_numbers, precision=lax.Precision.HIGHEST)
|
||||
Currently implemented on CPU and GPU only.
|
||||
|
||||
def _check_solve_shapes(a: Array, b: Array):
|
||||
if not (a.ndim >= 2 and b.ndim in [a.ndim, a.ndim - 1] and
|
||||
a.shape[-1] == a.shape[-2] == b.shape[a.ndim - 2]):
|
||||
raise ValueError(
|
||||
"The arguments to solve must have shapes a=[..., m, m] and "
|
||||
f"b=[..., m, k] or b=[..., m]; got a={a.shape} and b={b.shape}")
|
||||
Args:
|
||||
a: A floating point or complex matrix or batch of matrices.
|
||||
lower: Describes which triangle of the input matrices to use.
|
||||
The other triangle is ignored and not accessed.
|
||||
|
||||
def _solve(a: Array, b: Array) -> Array:
|
||||
_check_solve_shapes(a, b)
|
||||
Returns:
|
||||
A ``(a, d, e, taus)`` tuple. If ``lower=True``, the diagonal and first
|
||||
subdiagonal of matrix (or batch of matrices) ``a`` contain the tridiagonal
|
||||
representation, and elements below the first subdiagonal contain the
|
||||
elementary Householder reflectors, where additionally ``d`` contains the
|
||||
diagonal of the matrix and ``e`` contains the first subdiagonal. If
|
||||
``lower=False`` the diagonal and first superdiagonal of the matrix contains
|
||||
the tridiagonal representation, and elements above the first superdiagonal
|
||||
contain the elementary Householder reflectors, where additionally ``d``
|
||||
contains the diagonal of the matrix and ``e`` contains the first
|
||||
superdiagonal. ``taus`` contains the scalar factors of the elementary
|
||||
Householder reflectors.
|
||||
"""
|
||||
arr, d, e, taus, info = tridiagonal_p.bind(lax_internal.asarray(a), lower=lower)
|
||||
def nans_like(arr):
|
||||
if dtypes.issubdtype(arr.dtype, np.complexfloating):
|
||||
return lax.full_like(arr, np.nan + 1j * np.nan)
|
||||
return lax.full_like(arr, np.nan)
|
||||
mask = lambda x: lax.broadcast_in_dim(info == 0, x.shape, range(info.ndim))
|
||||
arr = lax.select(mask(arr), arr, nans_like(arr))
|
||||
d = lax.select(mask(d), d, nans_like(d))
|
||||
e = lax.select(mask(e), e, nans_like(e))
|
||||
taus = lax.select(mask(taus), taus, nans_like(taus))
|
||||
return arr, d, e, taus
|
||||
|
||||
# Broadcast leading dimensions of b to the shape of a, as is required by
|
||||
# custom_linear_solve.
|
||||
out_shape = tuple(d_a if d_b == 1 else d_b
|
||||
for d_a, d_b in zip(a.shape[:-1] + (1,), b.shape))
|
||||
b = lax.broadcast_in_dim(b, out_shape, range(b.ndim))
|
||||
|
||||
# With custom_linear_solve, we can reuse the same factorization when
|
||||
# computing sensitivities. This is considerably faster.
|
||||
lu_, _, permutation = lu(lax.stop_gradient(a))
|
||||
custom_solve = partial(
|
||||
lax.custom_linear_solve,
|
||||
lambda x: _broadcasted_matvec(a, x),
|
||||
solve=lambda _, x: lu_solve(lu_, permutation, x, trans=0),
|
||||
transpose_solve=lambda _, x: lu_solve(lu_, permutation, x, trans=1))
|
||||
if a.ndim == b.ndim + 1:
|
||||
# b.shape == [..., m]
|
||||
return custom_solve(b)
|
||||
else:
|
||||
# b.shape == [..., m, k]
|
||||
return api.vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
|
||||
def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array:
|
||||
r"""Computes the solution of a tridiagonal linear system.
|
||||
|
||||
This function computes the solution of a tridiagonal linear system:
|
||||
|
||||
.. math::
|
||||
A \, X = B
|
||||
|
||||
Args:
|
||||
|
||||
dl: A batch of vectors with shape ``[..., m]``.
|
||||
The lower diagonal of A: ``dl[i] := A[i, i-1]`` for i in ``[0,m)``.
|
||||
Note that ``dl[0] = 0``.
|
||||
d: A batch of vectors with shape ``[..., m]``.
|
||||
The middle diagonal of A: ``d[i] := A[i, i]`` for i in ``[0,m)``.
|
||||
du: A batch of vectors with shape ``[..., m]``.
|
||||
The upper diagonal of A: ``du[i] := A[i, i+1]`` for i in ``[0,m)``.
|
||||
Note that ``dl[m - 1] = 0``.
|
||||
b: Right hand side matrix.
|
||||
|
||||
Returns:
|
||||
Solution ``X`` of tridiagonal system.
|
||||
"""
|
||||
return tridiagonal_solve_p.bind(dl, d, du, b)
|
||||
|
||||
def _T(x: Array) -> Array:
|
||||
return lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2))
|
||||
def _H(x: Array) -> Array:
|
||||
return _T(x).conj()
|
||||
def symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
|
||||
|
||||
# primitives
|
||||
|
||||
@ -1941,22 +2047,6 @@ mlir.register_lowering(geqp3_p, partial(_geqp3_gpu_lowering, 'hip'), platform="r
|
||||
|
||||
# householder_product: product of elementary Householder reflectors
|
||||
|
||||
def householder_product(a: ArrayLike, taus: ArrayLike) -> Array:
|
||||
"""Product of elementary Householder reflectors.
|
||||
|
||||
Args:
|
||||
a: A matrix with shape ``[..., m, n]``, whose lower triangle contains
|
||||
elementary Householder reflectors.
|
||||
taus: A vector with shape ``[..., k]``, where ``k < min(m, n)``, containing
|
||||
the scalar factors of the elementary Householder reflectors.
|
||||
|
||||
Returns:
|
||||
A batch of orthogonal (unitary) matrices with the same shape as ``a``,
|
||||
containing the products of the elementary Householder reflectors.
|
||||
"""
|
||||
return householder_product_p.bind(a, taus)
|
||||
|
||||
|
||||
def _householder_product_abstract_eval(a, taus):
|
||||
if not isinstance(a, ShapedArray) or not isinstance(taus, ShapedArray):
|
||||
raise NotImplementedError("Unsupported aval in householder_product_abstract_eval: "
|
||||
@ -2681,46 +2771,8 @@ mlir.register_lowering(tridiagonal_solve_p, mlir.lower_fun(
|
||||
_tridiagonal_solve_jax, multiple_results=False))
|
||||
|
||||
|
||||
def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array:
|
||||
r"""Computes the solution of a tridiagonal linear system.
|
||||
|
||||
This function computes the solution of a tridiagonal linear system:
|
||||
|
||||
.. math::
|
||||
A . X = B
|
||||
|
||||
Args:
|
||||
|
||||
dl: A batch of vectors with shape ``[..., m]``.
|
||||
The lower diagonal of A: ``dl[i] := A[i, i-1]`` for i in ``[0,m)``.
|
||||
Note that ``dl[0] = 0``.
|
||||
d: A batch of vectors with shape ``[..., m]``.
|
||||
The middle diagonal of A: ``d[i] := A[i, i]`` for i in ``[0,m)``.
|
||||
du: A batch of vectors with shape ``[..., m]``.
|
||||
The upper diagonal of A: ``du[i] := A[i, i+1]`` for i in ``[0,m)``.
|
||||
Note that ``dl[m - 1] = 0``.
|
||||
b: Right hand side matrix.
|
||||
|
||||
Returns:
|
||||
Solution ``X`` of tridiagonal system.
|
||||
"""
|
||||
return tridiagonal_solve_p.bind(dl, d, du, b)
|
||||
|
||||
|
||||
# Schur Decomposition
|
||||
|
||||
|
||||
def schur(x: ArrayLike, *,
|
||||
compute_schur_vectors: bool = True,
|
||||
sort_eig_vals: bool = False,
|
||||
select_callable: Callable[..., Any] | None = None) -> tuple[Array, Array]:
|
||||
return schur_p.bind(
|
||||
x,
|
||||
compute_schur_vectors=compute_schur_vectors,
|
||||
sort_eig_vals=sort_eig_vals,
|
||||
select_callable=select_callable)
|
||||
|
||||
|
||||
def _schur_impl(operand, *, compute_schur_vectors, sort_eig_vals,
|
||||
select_callable):
|
||||
return dispatch.apply_primitive(
|
||||
@ -2827,23 +2879,6 @@ ad.primitive_jvps[schur_p] = _schur_jvp_rule
|
||||
|
||||
# hessenberg: Upper Hessenberg reduction
|
||||
|
||||
def hessenberg(a: ArrayLike) -> tuple[Array, Array]:
|
||||
"""Reduces a square matrix to upper Hessenberg form.
|
||||
|
||||
Currently implemented on CPU only.
|
||||
|
||||
Args:
|
||||
a: A floating point or complex square matrix or batch of matrices.
|
||||
|
||||
Returns:
|
||||
A ``(a, taus)`` pair, where the upper triangle and first subdiagonal of ``a``
|
||||
contain the upper Hessenberg matrix, and the elements below the first
|
||||
subdiagonal contain the Householder reflectors. For each Householder
|
||||
reflector ``taus`` contains the scalar factors of the elementary Householder
|
||||
reflectors.
|
||||
"""
|
||||
return hessenberg_p.bind(a)
|
||||
|
||||
def _hessenberg_abstract_eval(a):
|
||||
if a.dtype not in (np.float32, np.float64, np.complex64, np.complex128):
|
||||
raise TypeError("hessenberg requires a.dtype to be float32, float64, "
|
||||
@ -2895,41 +2930,6 @@ mlir.register_lowering(hessenberg_p, _hessenberg_cpu_hlo, platform='cpu')
|
||||
|
||||
# tridiagonal: Upper Hessenberg reduction
|
||||
|
||||
def tridiagonal(a: ArrayLike, *, lower=True
|
||||
) -> tuple[Array, Array, Array, Array]:
|
||||
"""Reduces a symmetric/Hermitian matrix to tridiagonal form.
|
||||
|
||||
Currently implemented on CPU and GPU only.
|
||||
|
||||
Args:
|
||||
a: A floating point or complex matrix or batch of matrices.
|
||||
lower: Describes which triangle of the input matrices to use.
|
||||
The other triangle is ignored and not accessed.
|
||||
|
||||
Returns:
|
||||
A ``(a, d, e, taus)`` pair. If ``lower=True``, the diagonal and first subdiagonal of
|
||||
matrix (or batch of matrices) ``a`` contain the tridiagonal representation,
|
||||
and elements below the first subdiagonal contain the elementary Householder
|
||||
reflectors, where additionally ``d`` contains the diagonal of the matrix and ``e`` contains
|
||||
the first subdiagonal.If ``lower=False`` the diagonal and first superdiagonal of the
|
||||
matrix contains the tridiagonal representation, and elements above the first
|
||||
superdiagonal contain the elementary Householder reflectors, where
|
||||
additionally ``d`` contains the diagonal of the matrix and ``e`` contains the
|
||||
first superdiagonal. ``taus`` contains the scalar factors of the elementary
|
||||
Householder reflectors.
|
||||
"""
|
||||
arr, d, e, taus, info = tridiagonal_p.bind(lax_internal.asarray(a), lower=lower)
|
||||
def nans_like(arr):
|
||||
if dtypes.issubdtype(arr.dtype, np.complexfloating):
|
||||
return lax.full_like(arr, np.nan + 1j * np.nan)
|
||||
return lax.full_like(arr, np.nan)
|
||||
mask = lambda x: lax.broadcast_in_dim(info == 0, x.shape, range(info.ndim))
|
||||
arr = lax.select(mask(arr), arr, nans_like(arr))
|
||||
d = lax.select(mask(d), d, nans_like(d))
|
||||
e = lax.select(mask(e), e, nans_like(e))
|
||||
taus = lax.select(mask(taus), taus, nans_like(taus))
|
||||
return arr, d, e, taus
|
||||
|
||||
def _tridiagonal_abstract_eval(a, *, lower):
|
||||
if a.dtype not in (np.float32, np.float64, np.complex64, np.complex128):
|
||||
raise TypeError("tridiagonal requires a.dtype to be float32, float64, "
|
||||
@ -2994,6 +2994,86 @@ mlir.register_lowering(
|
||||
|
||||
# Utilities
|
||||
|
||||
def _broadcasted_matvec(a: Array, b: Array) -> Array:
|
||||
# This is a broadcasted dot_general with signature (...,n,m),(...,m)->(...,n)
|
||||
assert a.ndim >= 2
|
||||
assert b.ndim >= 1
|
||||
batch_shape = lax.broadcast_shapes(a.shape[:-2], b.shape[:-1])
|
||||
n_batch = len(batch_shape)
|
||||
a = _broadcast_to(a, (*batch_shape, *a.shape[-2:]))
|
||||
b = _broadcast_to(b, (*batch_shape, b.shape[-1]))
|
||||
|
||||
dimension_numbers = (([a.ndim - 1], [b.ndim - 1]), (list(range(n_batch)), list(range(n_batch))))
|
||||
return lax.dot_general(a, b, dimension_numbers=dimension_numbers, precision=lax.Precision.HIGHEST)
|
||||
|
||||
def _check_solve_shapes(a: Array, b: Array):
|
||||
if not (a.ndim >= 2 and b.ndim in [a.ndim, a.ndim - 1] and
|
||||
a.shape[-1] == a.shape[-2] == b.shape[a.ndim - 2]):
|
||||
raise ValueError(
|
||||
"The arguments to solve must have shapes a=[..., m, m] and "
|
||||
f"b=[..., m, k] or b=[..., m]; got a={a.shape} and b={b.shape}")
|
||||
|
||||
def _solve(a: Array, b: Array) -> Array:
|
||||
_check_solve_shapes(a, b)
|
||||
|
||||
# Broadcast leading dimensions of b to the shape of a, as is required by
|
||||
# custom_linear_solve.
|
||||
out_shape = tuple(d_a if d_b == 1 else d_b
|
||||
for d_a, d_b in zip(a.shape[:-1] + (1,), b.shape))
|
||||
b = lax.broadcast_in_dim(b, out_shape, range(b.ndim))
|
||||
|
||||
# With custom_linear_solve, we can reuse the same factorization when
|
||||
# computing sensitivities. This is considerably faster.
|
||||
lu_, _, permutation = lu(lax.stop_gradient(a))
|
||||
custom_solve = partial(
|
||||
lax.custom_linear_solve,
|
||||
lambda x: _broadcasted_matvec(a, x),
|
||||
solve=lambda _, x: lu_solve(lu_, permutation, x, trans=0),
|
||||
transpose_solve=lambda _, x: lu_solve(lu_, permutation, x, trans=1))
|
||||
if a.ndim == b.ndim + 1:
|
||||
# b.shape == [..., m]
|
||||
return custom_solve(b)
|
||||
else:
|
||||
# b.shape == [..., m, k]
|
||||
return api.vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
|
||||
|
||||
def _T(x: Array) -> Array:
|
||||
return lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2))
|
||||
|
||||
def _H(x: Array) -> Array:
|
||||
return _T(x).conj()
|
||||
|
||||
def symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
|
||||
|
||||
def _broadcasted_iotas(*sizes):
|
||||
ones = (1,) * (len(sizes) - 1)
|
||||
shapes = (util.tuple_insert(ones, i, s) for i, s in enumerate(sizes))
|
||||
return [lax.broadcasted_iota('int32', shape, i) for i, shape in enumerate(shapes)]
|
||||
|
||||
def _tril(m: Array, k:int = 0) -> Array:
|
||||
*_, N, M = m.shape
|
||||
mask = lax_internal._tri(bool, (N, M), k)
|
||||
return lax.select(lax.broadcast(mask, m.shape[:-2]), m, lax.zeros_like_array(m))
|
||||
|
||||
def _triu(m: Array, k:int = 0) -> Array:
|
||||
*_, N, M = m.shape
|
||||
mask = lax_internal._tri(bool, (N, M), k - 1)
|
||||
return lax.select(lax.broadcast(mask, m.shape[:-2]), lax.zeros_like_array(m), m)
|
||||
|
||||
def _construct_diagonal(s: Array) -> Array:
|
||||
"""Construct a (batched) diagonal matrix"""
|
||||
i = lax.iota('int32', s.shape[-1])
|
||||
return lax.full((*s.shape, s.shape[-1]), 0, s.dtype).at[..., i, i].set(s)
|
||||
|
||||
def _extract_diagonal(s: Array) -> Array:
|
||||
"""Extract the diagonal from a batched matrix"""
|
||||
i = lax.iota('int32', min(s.shape[-2], s.shape[-1]))
|
||||
return s[..., i, i]
|
||||
|
||||
def _broadcast_to(x: Array, shape: tuple[int, ...]) -> Array:
|
||||
assert x.ndim <= len(shape)
|
||||
return lax.broadcast_in_dim(x, shape, range(len(shape) - x.ndim, len(shape)))
|
||||
|
||||
def _nan_like_hlo(ctx: mlir.LoweringRuleContext, aval) -> ir.Value:
|
||||
if dtypes.issubdtype(aval.dtype, np.complexfloating):
|
||||
return mlir.full_like_aval(ctx, np.nan + np.nan * 1j, aval)
|
||||
|
@ -15,6 +15,8 @@
|
||||
from jax._src.lax.linalg import (
|
||||
cholesky as cholesky,
|
||||
cholesky_p as cholesky_p,
|
||||
cholesky_update as cholesky_update,
|
||||
cholesky_update_p as cholesky_update_p,
|
||||
eig as eig,
|
||||
eig_p as eig_p,
|
||||
eigh as eigh,
|
||||
@ -24,6 +26,7 @@ from jax._src.lax.linalg import (
|
||||
lu as lu,
|
||||
lu_p as lu_p,
|
||||
lu_pivots_to_permutation as lu_pivots_to_permutation,
|
||||
lu_pivots_to_permutation_p as lu_pivots_to_permutation_p,
|
||||
householder_product as householder_product,
|
||||
householder_product_p as householder_product_p,
|
||||
qr as qr,
|
||||
@ -39,9 +42,10 @@ from jax._src.lax.linalg import (
|
||||
tridiagonal_solve_p as tridiagonal_solve_p,
|
||||
schur as schur,
|
||||
schur_p as schur_p,
|
||||
symmetric_product as symmetric_product,
|
||||
symmetric_product_p as symmetric_product_p,
|
||||
)
|
||||
|
||||
|
||||
from jax._src.lax.qdwh import (
|
||||
qdwh as qdwh
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user