Reorder top-level functions in lax.linalg, and add/expand docstrings.

PiperOrigin-RevId: 726603731
This commit is contained in:
Dan Foreman-Mackey 2025-02-13 12:57:11 -08:00 committed by jax authors
parent 5ebb7eb55d
commit c6c38fb852
3 changed files with 326 additions and 237 deletions

View File

@ -238,19 +238,24 @@ Linear algebra operators (jax.lax.linalg)
:toctree: _autosummary
cholesky
cholesky_update
eig
eigh
hessenberg
lu
householder_product
lu
lu_pivots_to_permutation
qdwh
qr
schur
svd
SvdAlgorithm
symmetric_product
triangular_solve
tridiagonal
tridiagonal_solve
Argument classes
----------------

View File

@ -18,7 +18,7 @@ from collections.abc import Callable
import enum
from functools import partial
import math
from typing import Any, Literal, TypeVar, overload
from typing import Any, Literal, overload
import numpy as np
@ -57,41 +57,11 @@ from jax._src.lib import gpu_solver # pylint:disable=unused-import # noqa: F40
from jax._src.lib import gpu_sparse # pylint:disable=unused-import # noqa: F401
from jax._src.lib import lapack # pylint:disable=unused-import # noqa: F401
TFun = TypeVar('TFun', bound=Callable[..., Any])
def _broadcasted_iotas(*sizes):
ones = (1,) * (len(sizes) - 1)
shapes = (util.tuple_insert(ones, i, s) for i, s in enumerate(sizes))
return [lax.broadcasted_iota('int32', shape, i) for i, shape in enumerate(shapes)]
def _tril(m: Array, k:int = 0) -> Array:
*_, N, M = m.shape
mask = lax_internal._tri(bool, (N, M), k)
return lax.select(lax.broadcast(mask, m.shape[:-2]), m, lax.zeros_like_array(m))
def _triu(m: Array, k:int = 0) -> Array:
*_, N, M = m.shape
mask = lax_internal._tri(bool, (N, M), k - 1)
return lax.select(lax.broadcast(mask, m.shape[:-2]), lax.zeros_like_array(m), m)
def _construct_diagonal(s: Array) -> Array:
"""Construct a (batched) diagonal matrix"""
i = lax.iota('int32', s.shape[-1])
return lax.full((*s.shape, s.shape[-1]), 0, s.dtype).at[..., i, i].set(s)
def _extract_diagonal(s: Array) -> Array:
"""Extract the diagonal from a batched matrix"""
i = lax.iota('int32', min(s.shape[-2], s.shape[-1]))
return s[..., i, i]
def _broadcast_to(x: Array, shape: tuple[int, ...]) -> Array:
assert x.ndim <= len(shape)
return lax.broadcast_in_dim(x, shape, range(len(shape) - x.ndim, len(shape)))
# traceables
# Top-level functions in alphabetical order.
def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array:
"""Cholesky decomposition.
r"""Cholesky decomposition.
Computes the Cholesky decomposition
@ -106,7 +76,7 @@ def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array:
x: A batch of square Hermitian (symmetric if real) positive-definite
matrices with shape ``[..., n, n]``.
symmetrize_input: If ``True``, the matrix is symmetrized before Cholesky
decomposition by computing :math:`\\frac{1}{2}(x + x^H)`. If ``False``,
decomposition by computing :math:`\frac{1}{2}(x + x^H)`. If ``False``,
only the lower triangle of ``x`` is used; the upper triangle is ignored
and not accessed.
@ -120,9 +90,31 @@ def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array:
return _tril(cholesky_p.bind(x))
def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True,
compute_right_eigenvectors: bool = True,
use_magma: bool | None = None) -> list[Array]:
def cholesky_update(r_matrix: ArrayLike, w_vector: ArrayLike) -> Array:
r"""Cholesky rank-1 update.
Given a Cholesky decomposition :math:`A = R.T \, R` and a vector :math:`w`,
computes the Cholesky decomposition of :math:`A + w \, w.T` in :math:`O(N^2)`
time.
Args:
r_matrix: An upper-triangular matrix (R) such that :math:`A = R^T \, R`.
w_vector: A vector :math:`w` for rank-1 update.
Returns:
A new upper-triangular matrix :math:`R` defining the Cholesky decomposition
of :math:`A + w \, w^T`.
"""
return cholesky_update_p.bind(r_matrix, w_vector)
def eig(
x: ArrayLike,
*,
compute_left_eigenvectors: bool = True,
compute_right_eigenvectors: bool = True,
use_magma: bool | None = None,
) -> list[Array]:
"""Eigendecomposition of a general matrix.
Nonsymmetric eigendecomposition is only implemented on CPU and GPU. On GPU,
@ -201,10 +193,10 @@ def eigh(
sort_eigenvalues: If ``True``, the eigenvalues will be sorted in ascending
order. If ``False`` the eigenvalues are returned in an
implementation-defined order.
subset_by_index: Optional 2-tuple [start, end] indicating the range of
indices of eigenvalues to compute. For example, is ``range_select`` =
[n-2,n], then ``eigh`` computes the two largest eigenvalues and their
eigenvectors.
subset_by_index: Optional 2-tuple [start, end] indicating the range of
indices of eigenvalues to compute. For example, is ``range_select`` =
[n-2,n], then ``eigh`` computes the two largest eigenvalues and their
eigenvectors.
Returns:
A tuple ``(v, w)``.
@ -229,59 +221,47 @@ def eigh(
return v, w
def cholesky_update(r_matrix: ArrayLike, w_vector: ArrayLike) -> Array:
"""Given a Cholesky decomposition A = R.T @ R and a vector w,
computes the Cholesky decomposition of A + w @ w.T in O(N^2) time.
def hessenberg(a: ArrayLike) -> tuple[Array, Array]:
"""Reduces a square matrix to upper Hessenberg form.
Currently implemented on CPU only.
Args:
r_matrix: An upper-triangular matrix (R) such that A = R.T @ R.
w_vector: A vector (w) for rank-1 update.
a: A floating point or complex square matrix or batch of matrices.
Returns:
A new R' matrix being the Cholesky decomposition of A + w @ w.T.
A ``(a, taus)`` pair, where the upper triangle and first subdiagonal of
``a`` contain the upper Hessenberg matrix, and the elements below the first
subdiagonal contain the Householder reflectors. For each Householder
reflector ``taus`` contains the scalar factors of the elementary Householder
reflectors.
"""
return cholesky_update_p.bind(r_matrix, w_vector)
return hessenberg_p.bind(a)
def symmetric_product(
a_matrix: ArrayLike, c_matrix: ArrayLike,
alpha: float = 1., beta: float = 0.,
symmetrize_output=False):
"""Computes C = alpha * A @ A.T + beta * C (where C is symmetric)."""
result = symmetric_product_p.bind(a_matrix, c_matrix, alpha=alpha, beta=beta)
if symmetrize_output:
upper_half = lax.transpose(
_tril(result, k=-1),
(*range(result.ndim - 2), result.ndim - 1, result.ndim - 2))
result = _tril(result, k=0) + upper_half
return result
def lu_pivots_to_permutation(pivots: ArrayLike, permutation_size: int) -> Array:
"""Converts the pivots (row swaps) returned by LU to a permutation.
We build a permutation rather than applying `pivots` directly to the rows
of a matrix because lax loops aren't differentiable.
def householder_product(a: ArrayLike, taus: ArrayLike) -> Array:
"""Product of elementary Householder reflectors.
Args:
pivots: an int32 array of shape (..., k) of row swaps to perform
permutation_size: the size of the output permutation. Has to be >= k.
a: A matrix with shape ``[..., m, n]``, whose lower triangle contains
elementary Householder reflectors.
taus: A vector with shape ``[..., k]``, where ``k < min(m, n)``, containing
the scalar factors of the elementary Householder reflectors.
Returns:
An int32 array of shape (..., permutation_size).
A batch of orthogonal (unitary) matrices with the same shape as ``a``,
containing the products of the elementary Householder reflectors.
"""
permutation = lu_pivots_to_permutation_p.bind(
pivots, permutation_size=permutation_size)
return permutation
return householder_product_p.bind(a, taus)
def lu(x: ArrayLike) -> tuple[Array, Array, Array]:
"""LU decomposition with partial pivoting.
r"""LU decomposition with partial pivoting.
Computes the matrix decomposition:
.. math::
P.A = L.U
P \, A = L \, U
where :math:`P` is a permutation of the rows of :math:`A`, :math:`L` is a
lower-triangular matrix with unit-diagonal elements, and :math:`U` is an
@ -305,8 +285,24 @@ def lu(x: ArrayLike) -> tuple[Array, Array, Array]:
swaps as a permutation, represented as an int32 array with shape
``[..., m]``.
"""
lu, pivots, permutation = lu_p.bind(x)
return lu, pivots, permutation
return lu_p.bind(x)
def lu_pivots_to_permutation(pivots: ArrayLike, permutation_size: int) -> Array:
"""Converts the pivots (row swaps) returned by LU to a permutation.
We build a permutation rather than applying `pivots` directly to the rows
of a matrix because lax loops aren't differentiable.
Args:
pivots: an int32 array of shape (..., k) of row swaps to perform
permutation_size: the size of the output permutation. Has to be >= k.
Returns:
An int32 array of shape (..., permutation_size).
"""
return lu_pivots_to_permutation_p.bind(
pivots, permutation_size=permutation_size)
@overload
@ -328,12 +324,12 @@ def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
use_magma: bool | None = None
) -> tuple[Array, Array] | tuple[Array, Array, Array]:
"""QR decomposition.
r"""QR decomposition.
Computes the QR decomposition
.. math::
A = Q . R
A = Q \, R
of matrices :math:`A`, such that :math:`Q` is a unitary (orthogonal) matrix,
and :math:`R` is an upper-triangular matrix.
@ -379,6 +375,42 @@ def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
return q, r
def schur(
x: ArrayLike,
*,
compute_schur_vectors: bool = True,
sort_eig_vals: bool = False,
select_callable: Callable[..., Any] | None = None,
) -> tuple[Array, Array]:
r"""Schur decomposition.
Only implemented on CPU.
Computes the Schur decomposition:
.. math::
A = Q \, U \, Q^{-H}
for a square matrix :math:`A`.
Args:
x: A batch of square matrices with shape ``[..., m, m]``.
compute_schur_vectors: If ``True``, compute the Schur vectors ::math:`Q`,
otherwise only :math:`U` is computed.
sort_eig_vals: Unused.
select_callable: Unused.
Returns:
A pair of arrays ``U, Q``, if ``compute_schur_vectors=True``, otherwise
only ``U`` is returned.
"""
return schur_p.bind(
x,
compute_schur_vectors=compute_schur_vectors,
sort_eig_vals=sort_eig_vals,
select_callable=select_callable)
class SvdAlgorithm(enum.Enum):
"""Enum for SVD algorithm."""
DEFAULT = "default"
@ -433,9 +465,23 @@ def svd(
) -> Array | tuple[Array, Array, Array]:
"""Singular value decomposition.
Returns the singular values if compute_uv is False, otherwise returns a triple
containing the left singular vectors, the singular values and the adjoint of
the right singular vectors.
Computes the singular value decomposition of an input matrix.
Args:
x: A batch of matrices with shape ``[..., m, n]``.
full_matrices: Determines if full or reduced matrices are returned.
compute_uv: If ``True``, returns the left singular vectors, the singular
values and the adjoint of the right singular vectors. Otherwise, only
the singular values are returned.
subset_by_index: If ``None``, the entire matrix is returned. Otherwise,
returns the singular values and vectors for the given range of indices.
algorithm: The SVD algorithm to use. Must be ``None`` or a value from
:class:`~jax.lax.linalg.SvdAlgorithm`.
Returns:
The singular values if ``compute_uv`` is ``False``, otherwise returns a
triple containing the left singular vectors, the singular values, and the
adjoint of the right singular vectors.
"""
result = svd_p.bind(
x,
@ -452,10 +498,56 @@ def svd(
return s
def triangular_solve(a: ArrayLike, b: ArrayLike, *,
left_side: bool = False, lower: bool = False,
transpose_a: bool = False, conjugate_a: bool = False,
unit_diagonal: bool = False) -> Array:
def symmetric_product(
a_matrix: ArrayLike,
c_matrix: ArrayLike,
*,
alpha: float = 1.,
beta: float = 0.,
symmetrize_output: bool = False
):
r"""Symmetric product.
Computes the symmetric product
..math::
\alpha \, A \, A^T + \beta \, C
where :math:`A` is a rectangular matrix and :math:`C` is a symmetric matrix.
Args:
a_matrix: A batch of matrices with shape ``[..., m, n]``.
c_matrix: A batch of matrices with shape ``[..., m, m]``.
alpha: A scalar.
beta: A scalar.
symmetrize_output: If ``True``, the upper triangle of the output is
replaced with its transpose.
Returns:
A batch of matrices with shape ``[..., m, m]`` where only the lower
triangle is guaranteed to include the correct values on all platforms. If
``symmetrize_output`` is ``True``, the upper triangle is filled with the
transpose of the lower triangle, and the whole matrix is valid.
"""
result = symmetric_product_p.bind(a_matrix, c_matrix, alpha=alpha, beta=beta)
if symmetrize_output:
upper_half = lax.transpose(
_tril(result, k=-1),
(*range(result.ndim - 2), result.ndim - 1, result.ndim - 2))
result = _tril(result, k=0) + upper_half
return result
def triangular_solve(
a: ArrayLike,
b: ArrayLike,
*,
left_side: bool = False,
lower: bool = False,
transpose_a: bool = False,
conjugate_a: bool = False,
unit_diagonal: bool = False,
) -> Array:
r"""Triangular solve.
Solves either the matrix equation
@ -502,55 +594,69 @@ def triangular_solve(a: ArrayLike, b: ArrayLike, *,
return out
# utilities
def _broadcasted_matvec(a: Array, b: Array) -> Array:
# This is a broadcasted dot_general with signature (...,n,m),(...,m)->(...,n)
assert a.ndim >= 2
assert b.ndim >= 1
batch_shape = lax.broadcast_shapes(a.shape[:-2], b.shape[:-1])
n_batch = len(batch_shape)
a = _broadcast_to(a, (*batch_shape, *a.shape[-2:]))
b = _broadcast_to(b, (*batch_shape, b.shape[-1]))
def tridiagonal(
a: ArrayLike, *, lower: bool=True
) -> tuple[Array, Array, Array, Array]:
"""Reduces a symmetric/Hermitian matrix to tridiagonal form.
dimension_numbers = (([a.ndim - 1], [b.ndim - 1]), (list(range(n_batch)), list(range(n_batch))))
return lax.dot_general(a, b, dimension_numbers=dimension_numbers, precision=lax.Precision.HIGHEST)
Currently implemented on CPU and GPU only.
def _check_solve_shapes(a: Array, b: Array):
if not (a.ndim >= 2 and b.ndim in [a.ndim, a.ndim - 1] and
a.shape[-1] == a.shape[-2] == b.shape[a.ndim - 2]):
raise ValueError(
"The arguments to solve must have shapes a=[..., m, m] and "
f"b=[..., m, k] or b=[..., m]; got a={a.shape} and b={b.shape}")
Args:
a: A floating point or complex matrix or batch of matrices.
lower: Describes which triangle of the input matrices to use.
The other triangle is ignored and not accessed.
def _solve(a: Array, b: Array) -> Array:
_check_solve_shapes(a, b)
Returns:
A ``(a, d, e, taus)`` tuple. If ``lower=True``, the diagonal and first
subdiagonal of matrix (or batch of matrices) ``a`` contain the tridiagonal
representation, and elements below the first subdiagonal contain the
elementary Householder reflectors, where additionally ``d`` contains the
diagonal of the matrix and ``e`` contains the first subdiagonal. If
``lower=False`` the diagonal and first superdiagonal of the matrix contains
the tridiagonal representation, and elements above the first superdiagonal
contain the elementary Householder reflectors, where additionally ``d``
contains the diagonal of the matrix and ``e`` contains the first
superdiagonal. ``taus`` contains the scalar factors of the elementary
Householder reflectors.
"""
arr, d, e, taus, info = tridiagonal_p.bind(lax_internal.asarray(a), lower=lower)
def nans_like(arr):
if dtypes.issubdtype(arr.dtype, np.complexfloating):
return lax.full_like(arr, np.nan + 1j * np.nan)
return lax.full_like(arr, np.nan)
mask = lambda x: lax.broadcast_in_dim(info == 0, x.shape, range(info.ndim))
arr = lax.select(mask(arr), arr, nans_like(arr))
d = lax.select(mask(d), d, nans_like(d))
e = lax.select(mask(e), e, nans_like(e))
taus = lax.select(mask(taus), taus, nans_like(taus))
return arr, d, e, taus
# Broadcast leading dimensions of b to the shape of a, as is required by
# custom_linear_solve.
out_shape = tuple(d_a if d_b == 1 else d_b
for d_a, d_b in zip(a.shape[:-1] + (1,), b.shape))
b = lax.broadcast_in_dim(b, out_shape, range(b.ndim))
# With custom_linear_solve, we can reuse the same factorization when
# computing sensitivities. This is considerably faster.
lu_, _, permutation = lu(lax.stop_gradient(a))
custom_solve = partial(
lax.custom_linear_solve,
lambda x: _broadcasted_matvec(a, x),
solve=lambda _, x: lu_solve(lu_, permutation, x, trans=0),
transpose_solve=lambda _, x: lu_solve(lu_, permutation, x, trans=1))
if a.ndim == b.ndim + 1:
# b.shape == [..., m]
return custom_solve(b)
else:
# b.shape == [..., m, k]
return api.vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array:
r"""Computes the solution of a tridiagonal linear system.
This function computes the solution of a tridiagonal linear system:
.. math::
A \, X = B
Args:
dl: A batch of vectors with shape ``[..., m]``.
The lower diagonal of A: ``dl[i] := A[i, i-1]`` for i in ``[0,m)``.
Note that ``dl[0] = 0``.
d: A batch of vectors with shape ``[..., m]``.
The middle diagonal of A: ``d[i] := A[i, i]`` for i in ``[0,m)``.
du: A batch of vectors with shape ``[..., m]``.
The upper diagonal of A: ``du[i] := A[i, i+1]`` for i in ``[0,m)``.
Note that ``dl[m - 1] = 0``.
b: Right hand side matrix.
Returns:
Solution ``X`` of tridiagonal system.
"""
return tridiagonal_solve_p.bind(dl, d, du, b)
def _T(x: Array) -> Array:
return lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2))
def _H(x: Array) -> Array:
return _T(x).conj()
def symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
# primitives
@ -1941,22 +2047,6 @@ mlir.register_lowering(geqp3_p, partial(_geqp3_gpu_lowering, 'hip'), platform="r
# householder_product: product of elementary Householder reflectors
def householder_product(a: ArrayLike, taus: ArrayLike) -> Array:
"""Product of elementary Householder reflectors.
Args:
a: A matrix with shape ``[..., m, n]``, whose lower triangle contains
elementary Householder reflectors.
taus: A vector with shape ``[..., k]``, where ``k < min(m, n)``, containing
the scalar factors of the elementary Householder reflectors.
Returns:
A batch of orthogonal (unitary) matrices with the same shape as ``a``,
containing the products of the elementary Householder reflectors.
"""
return householder_product_p.bind(a, taus)
def _householder_product_abstract_eval(a, taus):
if not isinstance(a, ShapedArray) or not isinstance(taus, ShapedArray):
raise NotImplementedError("Unsupported aval in householder_product_abstract_eval: "
@ -2681,46 +2771,8 @@ mlir.register_lowering(tridiagonal_solve_p, mlir.lower_fun(
_tridiagonal_solve_jax, multiple_results=False))
def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array:
r"""Computes the solution of a tridiagonal linear system.
This function computes the solution of a tridiagonal linear system:
.. math::
A . X = B
Args:
dl: A batch of vectors with shape ``[..., m]``.
The lower diagonal of A: ``dl[i] := A[i, i-1]`` for i in ``[0,m)``.
Note that ``dl[0] = 0``.
d: A batch of vectors with shape ``[..., m]``.
The middle diagonal of A: ``d[i] := A[i, i]`` for i in ``[0,m)``.
du: A batch of vectors with shape ``[..., m]``.
The upper diagonal of A: ``du[i] := A[i, i+1]`` for i in ``[0,m)``.
Note that ``dl[m - 1] = 0``.
b: Right hand side matrix.
Returns:
Solution ``X`` of tridiagonal system.
"""
return tridiagonal_solve_p.bind(dl, d, du, b)
# Schur Decomposition
def schur(x: ArrayLike, *,
compute_schur_vectors: bool = True,
sort_eig_vals: bool = False,
select_callable: Callable[..., Any] | None = None) -> tuple[Array, Array]:
return schur_p.bind(
x,
compute_schur_vectors=compute_schur_vectors,
sort_eig_vals=sort_eig_vals,
select_callable=select_callable)
def _schur_impl(operand, *, compute_schur_vectors, sort_eig_vals,
select_callable):
return dispatch.apply_primitive(
@ -2827,23 +2879,6 @@ ad.primitive_jvps[schur_p] = _schur_jvp_rule
# hessenberg: Upper Hessenberg reduction
def hessenberg(a: ArrayLike) -> tuple[Array, Array]:
"""Reduces a square matrix to upper Hessenberg form.
Currently implemented on CPU only.
Args:
a: A floating point or complex square matrix or batch of matrices.
Returns:
A ``(a, taus)`` pair, where the upper triangle and first subdiagonal of ``a``
contain the upper Hessenberg matrix, and the elements below the first
subdiagonal contain the Householder reflectors. For each Householder
reflector ``taus`` contains the scalar factors of the elementary Householder
reflectors.
"""
return hessenberg_p.bind(a)
def _hessenberg_abstract_eval(a):
if a.dtype not in (np.float32, np.float64, np.complex64, np.complex128):
raise TypeError("hessenberg requires a.dtype to be float32, float64, "
@ -2895,41 +2930,6 @@ mlir.register_lowering(hessenberg_p, _hessenberg_cpu_hlo, platform='cpu')
# tridiagonal: Upper Hessenberg reduction
def tridiagonal(a: ArrayLike, *, lower=True
) -> tuple[Array, Array, Array, Array]:
"""Reduces a symmetric/Hermitian matrix to tridiagonal form.
Currently implemented on CPU and GPU only.
Args:
a: A floating point or complex matrix or batch of matrices.
lower: Describes which triangle of the input matrices to use.
The other triangle is ignored and not accessed.
Returns:
A ``(a, d, e, taus)`` pair. If ``lower=True``, the diagonal and first subdiagonal of
matrix (or batch of matrices) ``a`` contain the tridiagonal representation,
and elements below the first subdiagonal contain the elementary Householder
reflectors, where additionally ``d`` contains the diagonal of the matrix and ``e`` contains
the first subdiagonal.If ``lower=False`` the diagonal and first superdiagonal of the
matrix contains the tridiagonal representation, and elements above the first
superdiagonal contain the elementary Householder reflectors, where
additionally ``d`` contains the diagonal of the matrix and ``e`` contains the
first superdiagonal. ``taus`` contains the scalar factors of the elementary
Householder reflectors.
"""
arr, d, e, taus, info = tridiagonal_p.bind(lax_internal.asarray(a), lower=lower)
def nans_like(arr):
if dtypes.issubdtype(arr.dtype, np.complexfloating):
return lax.full_like(arr, np.nan + 1j * np.nan)
return lax.full_like(arr, np.nan)
mask = lambda x: lax.broadcast_in_dim(info == 0, x.shape, range(info.ndim))
arr = lax.select(mask(arr), arr, nans_like(arr))
d = lax.select(mask(d), d, nans_like(d))
e = lax.select(mask(e), e, nans_like(e))
taus = lax.select(mask(taus), taus, nans_like(taus))
return arr, d, e, taus
def _tridiagonal_abstract_eval(a, *, lower):
if a.dtype not in (np.float32, np.float64, np.complex64, np.complex128):
raise TypeError("tridiagonal requires a.dtype to be float32, float64, "
@ -2994,6 +2994,86 @@ mlir.register_lowering(
# Utilities
def _broadcasted_matvec(a: Array, b: Array) -> Array:
# This is a broadcasted dot_general with signature (...,n,m),(...,m)->(...,n)
assert a.ndim >= 2
assert b.ndim >= 1
batch_shape = lax.broadcast_shapes(a.shape[:-2], b.shape[:-1])
n_batch = len(batch_shape)
a = _broadcast_to(a, (*batch_shape, *a.shape[-2:]))
b = _broadcast_to(b, (*batch_shape, b.shape[-1]))
dimension_numbers = (([a.ndim - 1], [b.ndim - 1]), (list(range(n_batch)), list(range(n_batch))))
return lax.dot_general(a, b, dimension_numbers=dimension_numbers, precision=lax.Precision.HIGHEST)
def _check_solve_shapes(a: Array, b: Array):
if not (a.ndim >= 2 and b.ndim in [a.ndim, a.ndim - 1] and
a.shape[-1] == a.shape[-2] == b.shape[a.ndim - 2]):
raise ValueError(
"The arguments to solve must have shapes a=[..., m, m] and "
f"b=[..., m, k] or b=[..., m]; got a={a.shape} and b={b.shape}")
def _solve(a: Array, b: Array) -> Array:
_check_solve_shapes(a, b)
# Broadcast leading dimensions of b to the shape of a, as is required by
# custom_linear_solve.
out_shape = tuple(d_a if d_b == 1 else d_b
for d_a, d_b in zip(a.shape[:-1] + (1,), b.shape))
b = lax.broadcast_in_dim(b, out_shape, range(b.ndim))
# With custom_linear_solve, we can reuse the same factorization when
# computing sensitivities. This is considerably faster.
lu_, _, permutation = lu(lax.stop_gradient(a))
custom_solve = partial(
lax.custom_linear_solve,
lambda x: _broadcasted_matvec(a, x),
solve=lambda _, x: lu_solve(lu_, permutation, x, trans=0),
transpose_solve=lambda _, x: lu_solve(lu_, permutation, x, trans=1))
if a.ndim == b.ndim + 1:
# b.shape == [..., m]
return custom_solve(b)
else:
# b.shape == [..., m, k]
return api.vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
def _T(x: Array) -> Array:
return lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2))
def _H(x: Array) -> Array:
return _T(x).conj()
def symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
def _broadcasted_iotas(*sizes):
ones = (1,) * (len(sizes) - 1)
shapes = (util.tuple_insert(ones, i, s) for i, s in enumerate(sizes))
return [lax.broadcasted_iota('int32', shape, i) for i, shape in enumerate(shapes)]
def _tril(m: Array, k:int = 0) -> Array:
*_, N, M = m.shape
mask = lax_internal._tri(bool, (N, M), k)
return lax.select(lax.broadcast(mask, m.shape[:-2]), m, lax.zeros_like_array(m))
def _triu(m: Array, k:int = 0) -> Array:
*_, N, M = m.shape
mask = lax_internal._tri(bool, (N, M), k - 1)
return lax.select(lax.broadcast(mask, m.shape[:-2]), lax.zeros_like_array(m), m)
def _construct_diagonal(s: Array) -> Array:
"""Construct a (batched) diagonal matrix"""
i = lax.iota('int32', s.shape[-1])
return lax.full((*s.shape, s.shape[-1]), 0, s.dtype).at[..., i, i].set(s)
def _extract_diagonal(s: Array) -> Array:
"""Extract the diagonal from a batched matrix"""
i = lax.iota('int32', min(s.shape[-2], s.shape[-1]))
return s[..., i, i]
def _broadcast_to(x: Array, shape: tuple[int, ...]) -> Array:
assert x.ndim <= len(shape)
return lax.broadcast_in_dim(x, shape, range(len(shape) - x.ndim, len(shape)))
def _nan_like_hlo(ctx: mlir.LoweringRuleContext, aval) -> ir.Value:
if dtypes.issubdtype(aval.dtype, np.complexfloating):
return mlir.full_like_aval(ctx, np.nan + np.nan * 1j, aval)

View File

@ -15,6 +15,8 @@
from jax._src.lax.linalg import (
cholesky as cholesky,
cholesky_p as cholesky_p,
cholesky_update as cholesky_update,
cholesky_update_p as cholesky_update_p,
eig as eig,
eig_p as eig_p,
eigh as eigh,
@ -24,6 +26,7 @@ from jax._src.lax.linalg import (
lu as lu,
lu_p as lu_p,
lu_pivots_to_permutation as lu_pivots_to_permutation,
lu_pivots_to_permutation_p as lu_pivots_to_permutation_p,
householder_product as householder_product,
householder_product_p as householder_product_p,
qr as qr,
@ -39,9 +42,10 @@ from jax._src.lax.linalg import (
tridiagonal_solve_p as tridiagonal_solve_p,
schur as schur,
schur_p as schur_p,
symmetric_product as symmetric_product,
symmetric_product_p as symmetric_product_p,
)
from jax._src.lax.qdwh import (
qdwh as qdwh
)