mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Corrections to GMRES - now gives correct result.
Co-authored-by: gehring <clement.gehring@gmail.com> Co-authored-by: Stephan Hoyer <shoyer@google.com>
This commit is contained in:
parent
342cc36051
commit
7ed9fe70ea
@ -18,42 +18,59 @@ import operator
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
from jax import scipy as jsp
|
||||
from jax import lax, device_put, jit
|
||||
from jax.tree_util import tree_leaves, tree_map, tree_multimap, tree_structure, tree_reduce
|
||||
from jax import lax, device_put
|
||||
from jax.tree_util import (tree_leaves, tree_map, tree_multimap, tree_structure,
|
||||
tree_reduce, Partial)
|
||||
from jax.util import safe_map as map
|
||||
|
||||
|
||||
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
|
||||
_vdot = partial(jnp.vdot, precision=lax.Precision.HIGHEST)
|
||||
|
||||
# aliases for working with pytrees
|
||||
def _vdot_real_part(x, y):
|
||||
"""Vector dot-product guaranteed to have a real valued result."""
|
||||
"""Vector dot-product guaranteed to have a real valued result despite
|
||||
possibly complex input. Thus neglects the real-imaginary cross-terms.
|
||||
The result is a real float.
|
||||
"""
|
||||
# all our uses of vdot() in CG are for computing an operator of the form
|
||||
# `z^T M z` where `M` is positive definite and Hermitian, so the result is
|
||||
# z^H M z
|
||||
# where M is positive definite and Hermitian, so the result is
|
||||
# real valued:
|
||||
# https://en.wikipedia.org/wiki/Definiteness_of_a_matrix#Definitions_for_complex_matrices
|
||||
result = _vdot(x.real, y.real)
|
||||
vdot = partial(jnp.vdot, precision=lax.Precision.HIGHEST)
|
||||
result = vdot(x.real, y.real)
|
||||
if jnp.iscomplexobj(x) or jnp.iscomplexobj(y):
|
||||
result += _vdot(x.imag, y.imag)
|
||||
result += vdot(x.imag, y.imag)
|
||||
return result
|
||||
|
||||
|
||||
# aliases for working with pytrees
|
||||
def _vdot_real_tree(x, y):
|
||||
return sum(tree_leaves(tree_multimap(_vdot_real_part, x, y)))
|
||||
|
||||
|
||||
def _norm_tree(x):
|
||||
return jnp.sqrt(_vdot_real_tree(x, x))
|
||||
|
||||
|
||||
def _vdot_tree(x, y):
|
||||
return sum(tree_leaves(tree_multimap(_vdot, x, y)))
|
||||
|
||||
def _vdot_tree(x, y, assume_real=True):
|
||||
if assume_real:
|
||||
return sum(tree_leaves(tree_multimap(_vdot_real_part, x, y)))
|
||||
else:
|
||||
return sum(tree_leaves(tree_multimap(_vdot, x, y)))
|
||||
|
||||
def _mul(scalar, tree):
|
||||
return tree_map(partial(operator.mul, scalar), tree)
|
||||
|
||||
|
||||
def _div(tree, scalar):
|
||||
return tree_map(partial(lambda v: v / scalar), tree)
|
||||
|
||||
|
||||
_add = partial(tree_multimap, operator.add)
|
||||
_sub = partial(tree_multimap, operator.sub)
|
||||
_dot_tree = partial(tree_multimap, _dot)
|
||||
|
||||
|
||||
@Partial
|
||||
def _identity(x):
|
||||
return x
|
||||
|
||||
@ -61,31 +78,31 @@ def _identity(x):
|
||||
def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):
|
||||
|
||||
# tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg
|
||||
bs = _vdot_tree(b, b)
|
||||
bs = _vdot_real_tree(b, b)
|
||||
atol2 = jnp.maximum(jnp.square(tol) * bs, jnp.square(atol))
|
||||
|
||||
# https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method
|
||||
|
||||
def cond_fun(value):
|
||||
x, r, gamma, p, k = value
|
||||
rs = gamma if M is _identity else _vdot_tree(r, r)
|
||||
_, r, gamma, _, k = value
|
||||
rs = gamma if M is _identity else _vdot_real_tree(r, r)
|
||||
return (rs > atol2) & (k < maxiter)
|
||||
|
||||
def body_fun(value):
|
||||
x, r, gamma, p, k = value
|
||||
Ap = A(p)
|
||||
alpha = gamma / _vdot_tree(p, Ap)
|
||||
alpha = gamma / _vdot_real_tree(p, Ap)
|
||||
x_ = _add(x, _mul(alpha, p))
|
||||
r_ = _sub(r, _mul(alpha, Ap))
|
||||
z_ = M(r_)
|
||||
gamma_ = _vdot_tree(r_, z_)
|
||||
gamma_ = _vdot_real_tree(r_, z_)
|
||||
beta_ = gamma_ / gamma
|
||||
p_ = _add(z_, _mul(beta_, p))
|
||||
return x_, r_, gamma_, p_, k + 1
|
||||
|
||||
r0 = _sub(b, A(x0))
|
||||
p0 = z0 = M(r0)
|
||||
gamma0 = _vdot_tree(r0, z0)
|
||||
gamma0 = _vdot_real_tree(r0, z0)
|
||||
initial_value = (x0, r0, gamma0, p0, 0)
|
||||
|
||||
x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
|
||||
@ -174,8 +191,10 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
|
||||
|
||||
cg_solve = partial(
|
||||
_cg_solve, x0=x0, tol=tol, atol=atol, maxiter=maxiter, M=M)
|
||||
|
||||
# real-valued positive-definite linear operators are symmetric
|
||||
real_valued = lambda x: not issubclass(x.dtype.type, np.complexfloating)
|
||||
def real_valued(x):
|
||||
return not issubclass(x.dtype.type, np.complexfloating)
|
||||
symmetric = all(map(real_valued, tree_leaves(b)))
|
||||
x = lax.custom_linear_solve(
|
||||
A, b, solve=cg_solve, transpose_solve=cg_solve, symmetric=symmetric)
|
||||
@ -183,23 +202,25 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
|
||||
return x, info
|
||||
|
||||
|
||||
def _project_on_columns(A, v):
|
||||
v_proj = tree_multimap(
|
||||
lambda X, y: jnp.einsum("...n,...->n", X.conj(), y),
|
||||
A,
|
||||
v,
|
||||
)
|
||||
return tree_reduce(operator.add, v_proj)
|
||||
|
||||
|
||||
def _safe_normalize(x, return_norm=False):
|
||||
norm = jnp.sqrt(_vdot_tree(x, x, assume_real=False))
|
||||
def _safe_normalize(x, return_norm=False, thresh=None):
|
||||
"""
|
||||
Returns the L2-normalized vector (which can be a pytree) x, and optionally
|
||||
the computed norm. If the computed norm is less than the threshold `thresh`,
|
||||
which by default is the machine precision of x's dtype, it will be
|
||||
taken to be 0, and the normalized x to be the zero vector.
|
||||
"""
|
||||
norm = _norm_tree(x)
|
||||
dtype = jnp.result_type(*tree_leaves(x))
|
||||
if thresh is None:
|
||||
thresh = jnp.finfo(norm.dtype).eps
|
||||
thresh = thresh.astype(dtype).real
|
||||
|
||||
normalized_x, norm = lax.cond(
|
||||
norm > 1e-12,
|
||||
lambda y: (tree_map(lambda v: v / norm, y), norm),
|
||||
lambda y: (y, 0.),
|
||||
x,
|
||||
norm > thresh,
|
||||
lambda y: (_div(y, norm), norm),
|
||||
lambda y: (tree_map(jnp.zeros_like, y), jnp.zeros((),
|
||||
dtype=thresh.dtype)),
|
||||
x,
|
||||
)
|
||||
if return_norm:
|
||||
return normalized_x, norm
|
||||
@ -207,8 +228,41 @@ def _safe_normalize(x, return_norm=False):
|
||||
return normalized_x
|
||||
|
||||
|
||||
def _iterative_classical_gram_schmidt(Q, x, iterations=2):
|
||||
"""Orthogonalize x against the columns of Q."""
|
||||
def _project_on_columns(A, v):
|
||||
"""
|
||||
Returns A.T.conj() @ v.
|
||||
"""
|
||||
v_proj = tree_multimap(
|
||||
lambda X, y: jnp.einsum("...n,...->n", X.conj(), y),
|
||||
A,
|
||||
v,
|
||||
)
|
||||
return tree_reduce(operator.add, v_proj)
|
||||
|
||||
|
||||
def _iterative_classical_gram_schmidt(Q, x, max_iterations=2):
|
||||
"""
|
||||
Orthogonalize x against the columns of Q. The process is repeated
|
||||
up to `max_iterations` times, or fewer if the condition
|
||||
||r|| < (1/sqrt(2)) ||x|| is met earlier (see below for the meaning
|
||||
of r and x).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
Q : array or tree of arrays
|
||||
A matrix of orthonormal columns.
|
||||
x : array or tree of arrays
|
||||
A vector. It will be replaced with a new vector q which is orthonormal
|
||||
to the columns of Q, such that x in span(col(Q), q).
|
||||
|
||||
Returns
|
||||
-------
|
||||
q : array or tree of arrays
|
||||
A unit vector, orthonormal to each column of Q, such that
|
||||
x in span(col(Q), q).
|
||||
r : array
|
||||
Stores the overlaps of x with each vector in Q.
|
||||
"""
|
||||
# "twice is enough"
|
||||
# http://slepc.upv.es/documentation/reports/str1.pdf
|
||||
|
||||
@ -216,122 +270,340 @@ def _iterative_classical_gram_schmidt(Q, x, iterations=2):
|
||||
# axis.
|
||||
r = jnp.zeros((tree_leaves(Q)[0].shape[-1]))
|
||||
q = x
|
||||
_, xnorm = _safe_normalize(x, return_norm=True)
|
||||
xnorm_scaled = xnorm / jnp.sqrt(2)
|
||||
|
||||
for _ in range(iterations):
|
||||
def body_function(carry):
|
||||
k, q, r, qnorm_scaled = carry
|
||||
h = _project_on_columns(Q, q)
|
||||
q = _sub(q, tree_map(lambda X: jnp.dot(X, h), Q))
|
||||
Qh = tree_map(lambda X: _dot_tree(X, h), Q)
|
||||
q = _sub(q, Qh)
|
||||
r = _add(r, h)
|
||||
|
||||
def qnorm_cond(carry):
|
||||
k, not_done, _, _ = carry
|
||||
return jnp.logical_and(not_done, k < (max_iterations - 1))
|
||||
|
||||
def qnorm(carry):
|
||||
k, _, q, qnorm_scaled = carry
|
||||
_, qnorm = _safe_normalize(q, return_norm=True)
|
||||
qnorm_scaled = qnorm / jnp.sqrt(2)
|
||||
return (k, False, q, qnorm_scaled)
|
||||
|
||||
init = (k, True, q, qnorm_scaled)
|
||||
_, _, q, qnorm_scaled = lax.while_loop(qnorm_cond, qnorm, init)
|
||||
return (k + 1, q, r, qnorm_scaled)
|
||||
|
||||
def cond_function(carry):
|
||||
k, _, r, qnorm_scaled = carry
|
||||
_, rnorm = _safe_normalize(r, return_norm=True)
|
||||
return jnp.logical_and(k < (max_iterations - 1), rnorm < qnorm_scaled)
|
||||
|
||||
k, q, r, qnorm_scaled = body_function((0, q, r, xnorm_scaled))
|
||||
k, q, r, _ = lax.while_loop(cond_function, body_function,
|
||||
(k, q, r, qnorm_scaled))
|
||||
return q, r
|
||||
|
||||
|
||||
def arnoldi_iteration(A, b, n, M=None):
|
||||
# https://en.wikipedia.org/wiki/Arnoldi_iteration#The_Arnoldi_iteration
|
||||
if M is None:
|
||||
M = _identity
|
||||
q = _safe_normalize(b)
|
||||
Q = tree_map(
|
||||
lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, n),)),
|
||||
q,
|
||||
)
|
||||
H = jnp.zeros((n, n + 1), jnp.result_type(*tree_leaves(b)))
|
||||
def kth_arnoldi_iteration(k, A, M, V, H, tol):
|
||||
"""
|
||||
Performs a single (the k'th) step of the Arnoldi process. Thus,
|
||||
adds a new orthonormalized Krylov vector A(M(V[:, k])) to V[:, k+1],
|
||||
and that vectors overlaps with the existing Krylov vectors to
|
||||
H[k, :]. The tolerance 'tol' sets the threshold at which an invariant
|
||||
subspace is declared to have been found, in which case in which case the new
|
||||
vector is taken to be the zero vector.
|
||||
"""
|
||||
|
||||
def step(carry, k):
|
||||
Q, H = carry
|
||||
q = tree_map(lambda x: x[..., k], Q)
|
||||
v = A(M(q))
|
||||
v, h = _iterative_classical_gram_schmidt(Q, v, iterations=1)
|
||||
v, v_norm = _safe_normalize(v, return_norm=True)
|
||||
Q = tree_multimap(lambda X, y: X.at[..., k + 1].set(y), Q, v)
|
||||
h = h.at[k + 1].set(v_norm)
|
||||
H = H.at[k, :].set(h)
|
||||
return (Q, H), None
|
||||
v = tree_map(lambda x: x[..., k], V) # Gets V[:, k]
|
||||
v = A(M(v))
|
||||
v, h = _iterative_classical_gram_schmidt(V, v, max_iterations=2)
|
||||
unit_v, v_norm = _safe_normalize(v, return_norm=True, thresh=tol)
|
||||
V = tree_multimap(lambda X, y: X.at[..., k + 1].set(y), V, unit_v)
|
||||
|
||||
(Q, H), _ = lax.scan(step, (Q, H), jnp.arange(n))
|
||||
return Q, H
|
||||
h = h.at[k + 1].set(v_norm)
|
||||
H = H.at[k, :].set(h)
|
||||
breakdown = v_norm == 0.
|
||||
return V, H, breakdown
|
||||
|
||||
|
||||
@jit
|
||||
def _lstsq(a, b):
|
||||
return jnp.linalg.lstsq(a, b)[0]
|
||||
def apply_givens_rotations(H_row, givens, k):
|
||||
"""
|
||||
Applies the Givens rotations stored in the vectors cs and sn to the vector
|
||||
H_row. Then constructs and applies a new Givens rotation that eliminates
|
||||
H_row's k'th element.
|
||||
"""
|
||||
# This call successively applies each of the
|
||||
# Givens rotations stored in givens[:, :k] to H_col.
|
||||
|
||||
def apply_ith_rotation(i, H_row):
|
||||
cs, sn = givens[i, :]
|
||||
H_i = cs * H_row[i] - sn * H_row[i + 1]
|
||||
H_ip1 = sn * H_row[i] + cs * H_row[i + 1]
|
||||
H_row = H_row.at[i].set(H_i)
|
||||
H_row = H_row.at[i + 1].set(H_ip1)
|
||||
return H_row
|
||||
|
||||
R_row = lax.fori_loop(0, k, apply_ith_rotation, H_row)
|
||||
|
||||
def givens_rotation(v1, v2):
|
||||
t = jnp.sqrt(v1**2 + v2**2)
|
||||
cs = v1 / t
|
||||
sn = -v2 / t
|
||||
return cs, sn
|
||||
givens_factors = givens_rotation(R_row[k], R_row[k + 1])
|
||||
givens = givens.at[k, :].set(givens_factors)
|
||||
cs_k, sn_k = givens_factors
|
||||
|
||||
R_row = R_row.at[k].set(cs_k * R_row[k] - sn_k * R_row[k + 1])
|
||||
R_row = R_row.at[k + 1].set(0.)
|
||||
return R_row, givens
|
||||
|
||||
|
||||
def _gmres(A, b, x0, n, M, residual=None):
|
||||
def _gmres_qr(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):
|
||||
"""
|
||||
Implements a single restart of GMRES. The restart-dimensional Krylov subspace
|
||||
K(A, x0) = span(A(x0), A@x0, A@A@x0, ..., A^restart @ x0) is built, and the
|
||||
projection of the true solution into this subspace is returned.
|
||||
|
||||
This implementation builds the QR factorization during the Arnoldi process.
|
||||
"""
|
||||
# https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf
|
||||
Q, H = arnoldi_iteration(A, b, n, M)
|
||||
if residual is None:
|
||||
residual = _sub(b, A(x0))
|
||||
# residual = _sub(b, A(x0))
|
||||
# unit_residual, beta = _safe_normalize(residual, return_norm=True)
|
||||
|
||||
beta = jnp.sqrt(_vdot_tree(residual, residual, assume_real=False))
|
||||
dtype = beta.dtype
|
||||
e1 = jnp.concatenate([jnp.ones((1,), dtype), jnp.zeros((n,), dtype)])
|
||||
y = _lstsq(H.T, beta * e1)
|
||||
V = tree_map(
|
||||
lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)),
|
||||
unit_residual,
|
||||
)
|
||||
dtype = jnp.result_type(*tree_leaves(b))
|
||||
R = jnp.eye(restart, restart + 1, dtype=dtype) # eye to avoid constructing
|
||||
# a singular matrix in case
|
||||
# of early termination.
|
||||
b_norm = _norm_tree(b)
|
||||
|
||||
givens = jnp.zeros((restart, 2), dtype=dtype)
|
||||
beta_vec = jnp.zeros((restart + 1), dtype=dtype)
|
||||
beta_vec = beta_vec.at[0].set(residual_norm)
|
||||
|
||||
def loop_cond(carry):
|
||||
k, err, _, _, _, _ = carry
|
||||
return jnp.logical_and(k < restart, err > inner_tol)
|
||||
|
||||
def arnoldi_qr_step(carry):
|
||||
k, _, V, R, beta_vec, givens = carry
|
||||
V, H, _ = kth_arnoldi_iteration(k, A, M, V, R, inner_tol)
|
||||
R_row, givens = apply_givens_rotations(H[k, :], givens, k)
|
||||
R = R.at[k, :].set(R_row[:])
|
||||
cs, sn = givens[k, :] * beta_vec[k]
|
||||
beta_vec = beta_vec.at[k].set(cs)
|
||||
beta_vec = beta_vec.at[k + 1].set(sn)
|
||||
err = jnp.abs(sn) / b_norm
|
||||
return k + 1, err, V, R, beta_vec, givens
|
||||
|
||||
carry = (0, residual_norm, V, R, beta_vec, givens)
|
||||
carry = lax.while_loop(loop_cond, arnoldi_qr_step, carry)
|
||||
k, residual_norm, V, R, beta_vec, _ = carry
|
||||
del k # Until we figure out how to pass this to the user.
|
||||
|
||||
y = jsp.linalg.solve_triangular(R[:, :-1].T, beta_vec[:-1])
|
||||
Vy = tree_map(lambda X: _dot(X[..., :-1], y), V)
|
||||
dx = M(Vy)
|
||||
|
||||
dx = M(tree_map(lambda X: jnp.dot(X[..., :-1], y), Q))
|
||||
x = _add(x0, dx)
|
||||
return x
|
||||
residual = _sub(b, A(x))
|
||||
unit_residual, residual_norm = _safe_normalize(residual, return_norm=True)
|
||||
return x, unit_residual, residual_norm
|
||||
|
||||
|
||||
def _gmres_solve(A, b, x0, *, tol, atol, restart, maxiter, M):
|
||||
bs = _vdot_tree(b, b, assume_real=False)
|
||||
atol2 = jnp.maximum(jnp.square(tol) * bs, jnp.square(atol))
|
||||
num_restarts = maxiter // restart
|
||||
def _gmres_plain(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):
|
||||
"""
|
||||
Implements a single restart of GMRES. The ``restart``-dimensional Krylov
|
||||
subspace
|
||||
K(A, x0) = span(A(x0), A@x0, A@A@x0, ..., A^restart @ x0) is built, and the
|
||||
projection of the true solution into this subspace is returned.
|
||||
|
||||
This implementation solves a dense linear problem instead of building
|
||||
a QR factorization during the Arnoldi process.
|
||||
"""
|
||||
# https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf
|
||||
V = tree_map(
|
||||
lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)),
|
||||
unit_residual,
|
||||
)
|
||||
dtype = jnp.result_type(*tree_leaves(b))
|
||||
H = jnp.eye(restart, restart + 1, dtype=dtype)
|
||||
|
||||
def loop_cond(carry):
|
||||
_, _, breakdown, k = carry
|
||||
return jnp.logical_and(k < restart, jnp.logical_not(breakdown))
|
||||
# return lax.cond(k < restart,
|
||||
# lambda x: ~x,
|
||||
# lambda x: False,
|
||||
# breakdown)
|
||||
|
||||
def arnoldi_process(carry):
|
||||
V, H, _, k = carry
|
||||
V, H, breakdown = kth_arnoldi_iteration(k, A, M, V, H, inner_tol)
|
||||
return V, H, breakdown, k + 1
|
||||
|
||||
carry = (V, H, False, 0)
|
||||
V, H, _, _ = lax.while_loop(loop_cond, arnoldi_process, carry)
|
||||
|
||||
# The following is equivalent to:
|
||||
beta_vec = jnp.zeros((restart,), dtype=dtype)
|
||||
beta_vec = beta_vec.at[0].set(residual_norm) # it really is the original value
|
||||
y = jsp.linalg.solve(H[:, :-1].T, beta_vec)
|
||||
Vy = tree_map(lambda X: _dot(X[..., :-1], y), V)
|
||||
|
||||
dx = M(Vy)
|
||||
x = _add(x0, dx)
|
||||
|
||||
residual = _sub(b, A(x))
|
||||
unit_residual, residual_norm = _safe_normalize(residual, return_norm=True)
|
||||
return x, unit_residual, residual_norm
|
||||
|
||||
|
||||
def _gmres_solve(A, b, x0, outer_tol, inner_tol, restart, maxiter, M,
|
||||
gmres_func):
|
||||
"""
|
||||
The main function call wrapped by custom_linear_solve. Repeatedly calls GMRES
|
||||
to find the projected solution within the order-``restart``
|
||||
Krylov space K(A, x0, restart), using the result of the previous projection
|
||||
in place of x0 each time. Parameters are the same as in ``gmres`` except:
|
||||
|
||||
outer_tol: Tolerance to be used between restarts.
|
||||
inner_tol: Tolerance used within a restart.
|
||||
gmres_func: A function performing a single GMRES restart.
|
||||
|
||||
Returns: The solution.
|
||||
"""
|
||||
residual = _sub(b, A(x0))
|
||||
unit_residual, residual_norm = _safe_normalize(residual, return_norm=True)
|
||||
|
||||
def cond_fun(value):
|
||||
x, residual, k = value
|
||||
sqr_error = _vdot_tree(residual, residual, assume_real=False)
|
||||
return (sqr_error > atol2) & (k < num_restarts) & ~jnp.isnan(sqr_error)
|
||||
_, k, _, residual_norm = value
|
||||
return jnp.logical_and(k < maxiter, residual_norm > outer_tol)
|
||||
|
||||
def body_fun(value):
|
||||
x, residual, k = value
|
||||
x = _gmres(A, b, x, restart, M, residual)
|
||||
residual = _sub(b, A(x))
|
||||
return x, residual, k + 1
|
||||
x, k, unit_residual, residual_norm = value
|
||||
x, unit_residual, residual_norm = gmres_func(A, b, x, unit_residual,
|
||||
residual_norm, inner_tol,
|
||||
restart, M)
|
||||
return x, k + 1, unit_residual, residual_norm
|
||||
|
||||
residual = _sub(b, A(x0))
|
||||
if num_restarts:
|
||||
x, residual, _ = lax.while_loop(
|
||||
cond_fun, body_fun, (x0, residual, 0))
|
||||
else:
|
||||
x = x0
|
||||
|
||||
iters = maxiter % restart
|
||||
sqr_error = _vdot_tree(residual, residual)
|
||||
if iters > 0:
|
||||
x_final = lax.cond(
|
||||
sqr_error > atol2,
|
||||
true_fun=lambda values: _gmres(A, b, values[0], iters, M, values[1]),
|
||||
false_fun=lambda values: values[0],
|
||||
operand=(x, residual),
|
||||
)
|
||||
else:
|
||||
x_final = x
|
||||
return x_final
|
||||
initialization = (x0, 0, unit_residual, residual_norm)
|
||||
x_final, k, _, err = lax.while_loop(cond_fun, body_fun, initialization)
|
||||
_ = k # Until we can pass this out
|
||||
_ = err
|
||||
# info = lax.cond(converged, lambda y: 0, lambda y: k, 0)
|
||||
return x_final # , info
|
||||
|
||||
|
||||
def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
|
||||
M=None):
|
||||
M=None, qr_mode=False):
|
||||
"""
|
||||
GMRES solves the linear system A x = b for x, given A and b.
|
||||
|
||||
A is specified as a function performing A(vi) -> vf = A @ vi, and in principle
|
||||
need not have any particular special properties, such as symmetry. However,
|
||||
convergence is often slow for nearly symmetric operators.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
A: function
|
||||
Function that calculates the linear map (matrix-vector product)
|
||||
``Ax`` when called like ``A(x)``. ``A`` must return array(s) with the same
|
||||
structure and shape as its argument.
|
||||
b : array or tree of arrays
|
||||
Right hand side of the linear system representing a single vector. Can be
|
||||
stored as an array or Python container of array(s) with any shape.
|
||||
|
||||
Returns
|
||||
-------
|
||||
x : array or tree of arrays
|
||||
The converged solution. Has the same structure as ``b``.
|
||||
info : None
|
||||
Placeholder for convergence information. In the future, JAX will report
|
||||
the number of iterations when convergence is not achieved, like SciPy.
|
||||
|
||||
Other Parameters
|
||||
----------------
|
||||
x0 : array, optional
|
||||
Starting guess for the solution. Must have the same structure as ``b``.
|
||||
If this is unspecified, zeroes are used.
|
||||
tol, atol : float, optional
|
||||
Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
|
||||
We do not implement SciPy's "legacy" behavior, so JAX's tolerance will
|
||||
differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``gmres``.
|
||||
restart : integer, optional
|
||||
Size of the Krylov subspace ("number of iterations") built between
|
||||
restarts. GMRES works by approximating the true solution x as its
|
||||
projection into a Krylov space of this dimension - this parameter
|
||||
therefore bounds the maximum accuracy achievable from any guess
|
||||
solution. Larger values increase both number of iterations and iteration
|
||||
cost, but may be necessary for convergence. The algorithm terminates
|
||||
early if convergence is achieved before the full subspace is built.
|
||||
Default is 20.
|
||||
maxiter : integer
|
||||
Maximum number of times to rebuild the size-``restart`` Krylov space
|
||||
starting from the solution found at the last iteration. If GMRES
|
||||
halts or is very slow, decreasing this parameter may help.
|
||||
Default is infinite.
|
||||
M : function
|
||||
Preconditioner for A. The preconditioner should approximate the
|
||||
inverse of A. Effective preconditioning dramatically improves the
|
||||
rate of convergence, which implies that fewer iterations are needed
|
||||
to reach a given error tolerance.
|
||||
qr_mode : bool
|
||||
If True, the algorithm builds an internal Krylov subspace using a QR
|
||||
based algorithm, which reduces overhead and improved stability. However,
|
||||
it may degrade performance significantly on GPUs or TPUs, in which case
|
||||
this flag should be set False.
|
||||
|
||||
See also
|
||||
--------
|
||||
scipy.sparse.linalg.gmres
|
||||
jax.lax.custom_linear_solve
|
||||
"""
|
||||
|
||||
if x0 is None:
|
||||
x0 = tree_map(jnp.zeros_like, b)
|
||||
if M is None:
|
||||
M = _identity
|
||||
|
||||
b, x0 = device_put((b, x0))
|
||||
size = sum(bi.size for bi in tree_leaves(b))
|
||||
|
||||
if maxiter is None:
|
||||
maxiter = 10 * size # copied from scipy
|
||||
if restart > size:
|
||||
restart = size
|
||||
restart = min(restart, size)
|
||||
|
||||
if tree_structure(x0) != tree_structure(b):
|
||||
raise ValueError(
|
||||
'x0 and b must have matching tree structure: '
|
||||
f'{tree_structure(x0)} vs {tree_structure(b)}')
|
||||
'x0 and b must have matching tree structure: '
|
||||
f'{tree_structure(x0)} vs {tree_structure(b)}')
|
||||
|
||||
b, x0 = device_put((b, x0))
|
||||
b_norm = _norm_tree(b)
|
||||
if b_norm == 0:
|
||||
return b, 0
|
||||
outer_tol = jnp.maximum(tol * b_norm, atol)
|
||||
|
||||
def _solve(A, b):
|
||||
return _gmres_solve(A, b, x0, tol=tol, atol=atol, maxiter=maxiter,
|
||||
restart=restart, M=M)
|
||||
Mb = M(b)
|
||||
Mb_norm = _norm_tree(Mb)
|
||||
inner_tol = Mb_norm * min(1.0, outer_tol / b_norm)
|
||||
|
||||
if qr_mode:
|
||||
def _solve(A, b):
|
||||
return _gmres_solve(A, b, x0, outer_tol, inner_tol, restart, maxiter, M,
|
||||
_gmres_plain)
|
||||
else:
|
||||
def _solve(A, b):
|
||||
return _gmres_solve(A, b, x0, outer_tol, inner_tol, restart, maxiter, M,
|
||||
_gmres_qr)
|
||||
|
||||
x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve)
|
||||
info = None
|
||||
|
||||
failed = jnp.isnan(_norm_tree(x))
|
||||
info = jnp.where(failed, x=-1, y=0)
|
||||
return x, info
|
||||
|
@ -28,6 +28,7 @@ import jax.scipy.sparse.linalg
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
config.update('jax_enable_x64', True)
|
||||
|
||||
float_types = [np.float32, np.float64]
|
||||
complex_types = [np.complex64, np.complex128]
|
||||
@ -60,7 +61,34 @@ def rand_sym_pos_def(rng, shape, dtype):
|
||||
return matrix @ matrix.T.conj()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
def _fetch_preconditioner(self, preconditioner, A, rng=None,
|
||||
return_function=False):
|
||||
"""
|
||||
Returns one of various preconditioning matrices depending on the identifier
|
||||
`preconditioner' and the input matrix A whose inverse it supposedly
|
||||
approximates.
|
||||
"""
|
||||
if preconditioner == 'identity':
|
||||
M = np.eye(A.shape[0], dtype=A.dtype)
|
||||
elif preconditioner == 'random':
|
||||
if rng is None:
|
||||
rng = jtu.rand_default(self.rng())
|
||||
M = np.linalg.inv(rand_sym_pos_def(rng, A.shape, A.dtype))
|
||||
elif preconditioner == 'exact':
|
||||
M = np.linalg.inv(A)
|
||||
else:
|
||||
M = None
|
||||
|
||||
if M is None or not return_function:
|
||||
return M
|
||||
else:
|
||||
return lambda x: jnp.dot(M, x, precision=lax.Precision.HIGHEST)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_preconditioner={}".format(
|
||||
@ -76,15 +104,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
A = rand_sym_pos_def(rng, shape, dtype)
|
||||
b = rng(shape[:1], dtype)
|
||||
|
||||
if preconditioner == 'identity':
|
||||
M = np.eye(shape[0], dtype=dtype)
|
||||
elif preconditioner == 'random':
|
||||
M = np.linalg.inv(rand_sym_pos_def(rng, shape, dtype))
|
||||
elif preconditioner == 'exact':
|
||||
M = np.linalg.inv(A)
|
||||
else:
|
||||
M = None
|
||||
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
|
||||
|
||||
def args_maker():
|
||||
return A, b
|
||||
@ -178,6 +198,153 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
actual, _ = jax.scipy.sparse.linalg.cg(A, b)
|
||||
self.assertAllClose(expected, actual.value)
|
||||
|
||||
# GMRES
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in [(2, 2)]
|
||||
for dtype in float_types + complex_types))
|
||||
def test_gmres_on_small_fixed_problem(self, shape, dtype):
|
||||
"""
|
||||
GMRES gives the right answer for a small fixed system.
|
||||
"""
|
||||
A = jnp.array(([[1, 1], [3, -4]]), dtype=dtype)
|
||||
b = jnp.array([3, 2], dtype=dtype)
|
||||
x0 = jnp.ones(2, dtype=dtype)
|
||||
restart = 2
|
||||
maxiter = 1
|
||||
|
||||
@jax.tree_util.Partial
|
||||
def A_mv(x):
|
||||
return matmul_high_precision(A, x)
|
||||
tol = A.size * jnp.finfo(dtype).eps
|
||||
x, _ = jax.scipy.sparse.linalg.gmres(A_mv, b, x0=x0, tol=tol, atol=tol,
|
||||
restart=restart, maxiter=maxiter)
|
||||
solution = jnp.array([2., 1.], dtype=dtype)
|
||||
self.assertAllClose(solution, x)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_preconditioner={}_qr_mode={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
preconditioner,
|
||||
qr_mode),
|
||||
"shape": shape, "dtype": dtype, "preconditioner": preconditioner,
|
||||
"qr_mode": qr_mode}
|
||||
for shape in [(2, 2), (7, 7)]
|
||||
for dtype in float_types + complex_types
|
||||
for preconditioner in [None, 'identity', 'exact']
|
||||
for qr_mode in [True, False]
|
||||
))
|
||||
def test_gmres_on_identity_system(self, shape, dtype, preconditioner,
|
||||
qr_mode):
|
||||
A = jnp.eye(shape[1], dtype=dtype)
|
||||
|
||||
solution = jnp.ones(shape[1], dtype=dtype)
|
||||
@jax.tree_util.Partial
|
||||
def A_mv(x):
|
||||
return matmul_high_precision(A, x)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
M = self._fetch_preconditioner(preconditioner, A, rng=rng,
|
||||
return_function=True)
|
||||
b = A_mv(solution)
|
||||
restart = shape[-1]
|
||||
tol = shape[0] * jnp.finfo(dtype).eps
|
||||
x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol,
|
||||
restart=restart,
|
||||
M=M,
|
||||
qr_mode=qr_mode)
|
||||
err = jnp.linalg.norm(solution - x) / jnp.linalg.norm(b)
|
||||
rtol = tol*jnp.linalg.norm(b)
|
||||
true_tol = max(rtol, tol)
|
||||
self.assertLessEqual(err, true_tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_preconditioner={}_qr_mode={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
preconditioner,
|
||||
qr_mode),
|
||||
"shape": shape, "dtype": dtype, "preconditioner": preconditioner,
|
||||
"qr_mode": qr_mode}
|
||||
for shape in [(2, 2), (7, 7)]
|
||||
for dtype in float_types + complex_types
|
||||
for preconditioner in [None, 'identity', 'exact']
|
||||
for qr_mode in [True, False]
|
||||
))
|
||||
def test_gmres_on_random_system(self, shape, dtype, preconditioner,
|
||||
qr_mode):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
A = rng(shape, dtype)
|
||||
|
||||
solution = rng(shape[1:], dtype)
|
||||
@jax.tree_util.Partial
|
||||
def A_mv(x):
|
||||
return matmul_high_precision(A, x)
|
||||
M = self._fetch_preconditioner(preconditioner, A, rng=rng,
|
||||
return_function=True)
|
||||
b = A_mv(solution)
|
||||
restart = shape[-1]
|
||||
tol = shape[0] * jnp.finfo(dtype).eps
|
||||
x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol,
|
||||
restart=restart,
|
||||
M=M,
|
||||
qr_mode=qr_mode)
|
||||
err = jnp.linalg.norm(solution - x) / jnp.linalg.norm(b)
|
||||
rtol = tol*jnp.linalg.norm(b)
|
||||
true_tol = max(rtol, tol)
|
||||
self.assertLessEqual(err, true_tol)
|
||||
|
||||
def test_gmres_pytree(self):
|
||||
A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
|
||||
b = {"a": 1.0, "b": -4.0}
|
||||
expected = {"a": 4.0, "b": -6.0}
|
||||
actual, _ = jax.scipy.sparse.linalg.gmres(A, b)
|
||||
self.assertEqual(expected.keys(), actual.keys())
|
||||
self.assertAlmostEqual(expected["a"], actual["a"], places=6)
|
||||
self.assertAlmostEqual(expected["b"], actual["b"], places=6)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_preconditioner={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
preconditioner),
|
||||
"shape": shape, "dtype": dtype, "preconditioner": preconditioner}
|
||||
for shape in [(2, 2), (7, 7), (32, 32)]
|
||||
for dtype in float_types + complex_types
|
||||
for preconditioner in [None, 'identity']))
|
||||
def test_gmres_arnoldi_step(self, shape, dtype, preconditioner):
|
||||
"""
|
||||
The Arnoldi decomposition within GMRES is correct.
|
||||
"""
|
||||
rng = jtu.rand_default(self.rng())
|
||||
A = rng(shape, dtype)
|
||||
if preconditioner is None:
|
||||
M = lambda x: x
|
||||
else:
|
||||
M = self._fetch_preconditioner(preconditioner, A, rng=rng,
|
||||
return_function=True)
|
||||
|
||||
n = shape[0]
|
||||
x0 = rng(shape[:1], dtype)
|
||||
Q = np.zeros((n, n + 1), dtype=dtype)
|
||||
Q[:, 0] = x0/jax.numpy.linalg.norm(x0)
|
||||
Q = jnp.array(Q)
|
||||
H = jax.numpy.eye(n, n + 1, dtype=dtype)
|
||||
tol = A.size*A.size*jax.numpy.finfo(dtype).eps
|
||||
|
||||
@jax.tree_util.Partial
|
||||
def A_mv(x):
|
||||
return matmul_high_precision(A, x)
|
||||
for k in range(n):
|
||||
Q, H, _ = jax.scipy.sparse.linalg.kth_arnoldi_iteration(k, A_mv, M, Q, H,
|
||||
tol)
|
||||
QAQ = matmul_high_precision(Q[:, :n].conj().T, A)
|
||||
QAQ = matmul_high_precision(QAQ, Q[:, :n])
|
||||
self.assertAllClose(QAQ, H.T[:n, :], rtol=tol, atol=tol)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user