Corrections to GMRES - now gives correct result.

Co-authored-by: gehring <clement.gehring@gmail.com>

Co-authored-by: Stephan Hoyer <shoyer@google.com>
This commit is contained in:
Adam GM Lewis 2020-08-28 14:40:31 -04:00 committed by Stephan Hoyer
parent 342cc36051
commit 7ed9fe70ea
2 changed files with 561 additions and 122 deletions

View File

@ -18,42 +18,59 @@ import operator
import numpy as np
import jax.numpy as jnp
from jax import scipy as jsp
from jax import lax, device_put, jit
from jax.tree_util import tree_leaves, tree_map, tree_multimap, tree_structure, tree_reduce
from jax import lax, device_put
from jax.tree_util import (tree_leaves, tree_map, tree_multimap, tree_structure,
tree_reduce, Partial)
from jax.util import safe_map as map
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
_vdot = partial(jnp.vdot, precision=lax.Precision.HIGHEST)
# aliases for working with pytrees
def _vdot_real_part(x, y):
"""Vector dot-product guaranteed to have a real valued result."""
"""Vector dot-product guaranteed to have a real valued result despite
possibly complex input. Thus neglects the real-imaginary cross-terms.
The result is a real float.
"""
# all our uses of vdot() in CG are for computing an operator of the form
# `z^T M z` where `M` is positive definite and Hermitian, so the result is
# z^H M z
# where M is positive definite and Hermitian, so the result is
# real valued:
# https://en.wikipedia.org/wiki/Definiteness_of_a_matrix#Definitions_for_complex_matrices
result = _vdot(x.real, y.real)
vdot = partial(jnp.vdot, precision=lax.Precision.HIGHEST)
result = vdot(x.real, y.real)
if jnp.iscomplexobj(x) or jnp.iscomplexobj(y):
result += _vdot(x.imag, y.imag)
result += vdot(x.imag, y.imag)
return result
# aliases for working with pytrees
def _vdot_real_tree(x, y):
return sum(tree_leaves(tree_multimap(_vdot_real_part, x, y)))
def _norm_tree(x):
return jnp.sqrt(_vdot_real_tree(x, x))
def _vdot_tree(x, y):
return sum(tree_leaves(tree_multimap(_vdot, x, y)))
def _vdot_tree(x, y, assume_real=True):
if assume_real:
return sum(tree_leaves(tree_multimap(_vdot_real_part, x, y)))
else:
return sum(tree_leaves(tree_multimap(_vdot, x, y)))
def _mul(scalar, tree):
return tree_map(partial(operator.mul, scalar), tree)
def _div(tree, scalar):
return tree_map(partial(lambda v: v / scalar), tree)
_add = partial(tree_multimap, operator.add)
_sub = partial(tree_multimap, operator.sub)
_dot_tree = partial(tree_multimap, _dot)
@Partial
def _identity(x):
return x
@ -61,31 +78,31 @@ def _identity(x):
def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):
# tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg
bs = _vdot_tree(b, b)
bs = _vdot_real_tree(b, b)
atol2 = jnp.maximum(jnp.square(tol) * bs, jnp.square(atol))
# https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method
def cond_fun(value):
x, r, gamma, p, k = value
rs = gamma if M is _identity else _vdot_tree(r, r)
_, r, gamma, _, k = value
rs = gamma if M is _identity else _vdot_real_tree(r, r)
return (rs > atol2) & (k < maxiter)
def body_fun(value):
x, r, gamma, p, k = value
Ap = A(p)
alpha = gamma / _vdot_tree(p, Ap)
alpha = gamma / _vdot_real_tree(p, Ap)
x_ = _add(x, _mul(alpha, p))
r_ = _sub(r, _mul(alpha, Ap))
z_ = M(r_)
gamma_ = _vdot_tree(r_, z_)
gamma_ = _vdot_real_tree(r_, z_)
beta_ = gamma_ / gamma
p_ = _add(z_, _mul(beta_, p))
return x_, r_, gamma_, p_, k + 1
r0 = _sub(b, A(x0))
p0 = z0 = M(r0)
gamma0 = _vdot_tree(r0, z0)
gamma0 = _vdot_real_tree(r0, z0)
initial_value = (x0, r0, gamma0, p0, 0)
x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
@ -174,8 +191,10 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
cg_solve = partial(
_cg_solve, x0=x0, tol=tol, atol=atol, maxiter=maxiter, M=M)
# real-valued positive-definite linear operators are symmetric
real_valued = lambda x: not issubclass(x.dtype.type, np.complexfloating)
def real_valued(x):
return not issubclass(x.dtype.type, np.complexfloating)
symmetric = all(map(real_valued, tree_leaves(b)))
x = lax.custom_linear_solve(
A, b, solve=cg_solve, transpose_solve=cg_solve, symmetric=symmetric)
@ -183,23 +202,25 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
return x, info
def _project_on_columns(A, v):
v_proj = tree_multimap(
lambda X, y: jnp.einsum("...n,...->n", X.conj(), y),
A,
v,
)
return tree_reduce(operator.add, v_proj)
def _safe_normalize(x, return_norm=False):
norm = jnp.sqrt(_vdot_tree(x, x, assume_real=False))
def _safe_normalize(x, return_norm=False, thresh=None):
"""
Returns the L2-normalized vector (which can be a pytree) x, and optionally
the computed norm. If the computed norm is less than the threshold `thresh`,
which by default is the machine precision of x's dtype, it will be
taken to be 0, and the normalized x to be the zero vector.
"""
norm = _norm_tree(x)
dtype = jnp.result_type(*tree_leaves(x))
if thresh is None:
thresh = jnp.finfo(norm.dtype).eps
thresh = thresh.astype(dtype).real
normalized_x, norm = lax.cond(
norm > 1e-12,
lambda y: (tree_map(lambda v: v / norm, y), norm),
lambda y: (y, 0.),
x,
norm > thresh,
lambda y: (_div(y, norm), norm),
lambda y: (tree_map(jnp.zeros_like, y), jnp.zeros((),
dtype=thresh.dtype)),
x,
)
if return_norm:
return normalized_x, norm
@ -207,8 +228,41 @@ def _safe_normalize(x, return_norm=False):
return normalized_x
def _iterative_classical_gram_schmidt(Q, x, iterations=2):
"""Orthogonalize x against the columns of Q."""
def _project_on_columns(A, v):
"""
Returns A.T.conj() @ v.
"""
v_proj = tree_multimap(
lambda X, y: jnp.einsum("...n,...->n", X.conj(), y),
A,
v,
)
return tree_reduce(operator.add, v_proj)
def _iterative_classical_gram_schmidt(Q, x, max_iterations=2):
"""
Orthogonalize x against the columns of Q. The process is repeated
up to `max_iterations` times, or fewer if the condition
||r|| < (1/sqrt(2)) ||x|| is met earlier (see below for the meaning
of r and x).
Parameters
----------
Q : array or tree of arrays
A matrix of orthonormal columns.
x : array or tree of arrays
A vector. It will be replaced with a new vector q which is orthonormal
to the columns of Q, such that x in span(col(Q), q).
Returns
-------
q : array or tree of arrays
A unit vector, orthonormal to each column of Q, such that
x in span(col(Q), q).
r : array
Stores the overlaps of x with each vector in Q.
"""
# "twice is enough"
# http://slepc.upv.es/documentation/reports/str1.pdf
@ -216,122 +270,340 @@ def _iterative_classical_gram_schmidt(Q, x, iterations=2):
# axis.
r = jnp.zeros((tree_leaves(Q)[0].shape[-1]))
q = x
_, xnorm = _safe_normalize(x, return_norm=True)
xnorm_scaled = xnorm / jnp.sqrt(2)
for _ in range(iterations):
def body_function(carry):
k, q, r, qnorm_scaled = carry
h = _project_on_columns(Q, q)
q = _sub(q, tree_map(lambda X: jnp.dot(X, h), Q))
Qh = tree_map(lambda X: _dot_tree(X, h), Q)
q = _sub(q, Qh)
r = _add(r, h)
def qnorm_cond(carry):
k, not_done, _, _ = carry
return jnp.logical_and(not_done, k < (max_iterations - 1))
def qnorm(carry):
k, _, q, qnorm_scaled = carry
_, qnorm = _safe_normalize(q, return_norm=True)
qnorm_scaled = qnorm / jnp.sqrt(2)
return (k, False, q, qnorm_scaled)
init = (k, True, q, qnorm_scaled)
_, _, q, qnorm_scaled = lax.while_loop(qnorm_cond, qnorm, init)
return (k + 1, q, r, qnorm_scaled)
def cond_function(carry):
k, _, r, qnorm_scaled = carry
_, rnorm = _safe_normalize(r, return_norm=True)
return jnp.logical_and(k < (max_iterations - 1), rnorm < qnorm_scaled)
k, q, r, qnorm_scaled = body_function((0, q, r, xnorm_scaled))
k, q, r, _ = lax.while_loop(cond_function, body_function,
(k, q, r, qnorm_scaled))
return q, r
def arnoldi_iteration(A, b, n, M=None):
# https://en.wikipedia.org/wiki/Arnoldi_iteration#The_Arnoldi_iteration
if M is None:
M = _identity
q = _safe_normalize(b)
Q = tree_map(
lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, n),)),
q,
)
H = jnp.zeros((n, n + 1), jnp.result_type(*tree_leaves(b)))
def kth_arnoldi_iteration(k, A, M, V, H, tol):
"""
Performs a single (the k'th) step of the Arnoldi process. Thus,
adds a new orthonormalized Krylov vector A(M(V[:, k])) to V[:, k+1],
and that vectors overlaps with the existing Krylov vectors to
H[k, :]. The tolerance 'tol' sets the threshold at which an invariant
subspace is declared to have been found, in which case in which case the new
vector is taken to be the zero vector.
"""
def step(carry, k):
Q, H = carry
q = tree_map(lambda x: x[..., k], Q)
v = A(M(q))
v, h = _iterative_classical_gram_schmidt(Q, v, iterations=1)
v, v_norm = _safe_normalize(v, return_norm=True)
Q = tree_multimap(lambda X, y: X.at[..., k + 1].set(y), Q, v)
h = h.at[k + 1].set(v_norm)
H = H.at[k, :].set(h)
return (Q, H), None
v = tree_map(lambda x: x[..., k], V) # Gets V[:, k]
v = A(M(v))
v, h = _iterative_classical_gram_schmidt(V, v, max_iterations=2)
unit_v, v_norm = _safe_normalize(v, return_norm=True, thresh=tol)
V = tree_multimap(lambda X, y: X.at[..., k + 1].set(y), V, unit_v)
(Q, H), _ = lax.scan(step, (Q, H), jnp.arange(n))
return Q, H
h = h.at[k + 1].set(v_norm)
H = H.at[k, :].set(h)
breakdown = v_norm == 0.
return V, H, breakdown
@jit
def _lstsq(a, b):
return jnp.linalg.lstsq(a, b)[0]
def apply_givens_rotations(H_row, givens, k):
"""
Applies the Givens rotations stored in the vectors cs and sn to the vector
H_row. Then constructs and applies a new Givens rotation that eliminates
H_row's k'th element.
"""
# This call successively applies each of the
# Givens rotations stored in givens[:, :k] to H_col.
def apply_ith_rotation(i, H_row):
cs, sn = givens[i, :]
H_i = cs * H_row[i] - sn * H_row[i + 1]
H_ip1 = sn * H_row[i] + cs * H_row[i + 1]
H_row = H_row.at[i].set(H_i)
H_row = H_row.at[i + 1].set(H_ip1)
return H_row
R_row = lax.fori_loop(0, k, apply_ith_rotation, H_row)
def givens_rotation(v1, v2):
t = jnp.sqrt(v1**2 + v2**2)
cs = v1 / t
sn = -v2 / t
return cs, sn
givens_factors = givens_rotation(R_row[k], R_row[k + 1])
givens = givens.at[k, :].set(givens_factors)
cs_k, sn_k = givens_factors
R_row = R_row.at[k].set(cs_k * R_row[k] - sn_k * R_row[k + 1])
R_row = R_row.at[k + 1].set(0.)
return R_row, givens
def _gmres(A, b, x0, n, M, residual=None):
def _gmres_qr(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):
"""
Implements a single restart of GMRES. The restart-dimensional Krylov subspace
K(A, x0) = span(A(x0), A@x0, A@A@x0, ..., A^restart @ x0) is built, and the
projection of the true solution into this subspace is returned.
This implementation builds the QR factorization during the Arnoldi process.
"""
# https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf
Q, H = arnoldi_iteration(A, b, n, M)
if residual is None:
residual = _sub(b, A(x0))
# residual = _sub(b, A(x0))
# unit_residual, beta = _safe_normalize(residual, return_norm=True)
beta = jnp.sqrt(_vdot_tree(residual, residual, assume_real=False))
dtype = beta.dtype
e1 = jnp.concatenate([jnp.ones((1,), dtype), jnp.zeros((n,), dtype)])
y = _lstsq(H.T, beta * e1)
V = tree_map(
lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)),
unit_residual,
)
dtype = jnp.result_type(*tree_leaves(b))
R = jnp.eye(restart, restart + 1, dtype=dtype) # eye to avoid constructing
# a singular matrix in case
# of early termination.
b_norm = _norm_tree(b)
givens = jnp.zeros((restart, 2), dtype=dtype)
beta_vec = jnp.zeros((restart + 1), dtype=dtype)
beta_vec = beta_vec.at[0].set(residual_norm)
def loop_cond(carry):
k, err, _, _, _, _ = carry
return jnp.logical_and(k < restart, err > inner_tol)
def arnoldi_qr_step(carry):
k, _, V, R, beta_vec, givens = carry
V, H, _ = kth_arnoldi_iteration(k, A, M, V, R, inner_tol)
R_row, givens = apply_givens_rotations(H[k, :], givens, k)
R = R.at[k, :].set(R_row[:])
cs, sn = givens[k, :] * beta_vec[k]
beta_vec = beta_vec.at[k].set(cs)
beta_vec = beta_vec.at[k + 1].set(sn)
err = jnp.abs(sn) / b_norm
return k + 1, err, V, R, beta_vec, givens
carry = (0, residual_norm, V, R, beta_vec, givens)
carry = lax.while_loop(loop_cond, arnoldi_qr_step, carry)
k, residual_norm, V, R, beta_vec, _ = carry
del k # Until we figure out how to pass this to the user.
y = jsp.linalg.solve_triangular(R[:, :-1].T, beta_vec[:-1])
Vy = tree_map(lambda X: _dot(X[..., :-1], y), V)
dx = M(Vy)
dx = M(tree_map(lambda X: jnp.dot(X[..., :-1], y), Q))
x = _add(x0, dx)
return x
residual = _sub(b, A(x))
unit_residual, residual_norm = _safe_normalize(residual, return_norm=True)
return x, unit_residual, residual_norm
def _gmres_solve(A, b, x0, *, tol, atol, restart, maxiter, M):
bs = _vdot_tree(b, b, assume_real=False)
atol2 = jnp.maximum(jnp.square(tol) * bs, jnp.square(atol))
num_restarts = maxiter // restart
def _gmres_plain(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):
"""
Implements a single restart of GMRES. The ``restart``-dimensional Krylov
subspace
K(A, x0) = span(A(x0), A@x0, A@A@x0, ..., A^restart @ x0) is built, and the
projection of the true solution into this subspace is returned.
This implementation solves a dense linear problem instead of building
a QR factorization during the Arnoldi process.
"""
# https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf
V = tree_map(
lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)),
unit_residual,
)
dtype = jnp.result_type(*tree_leaves(b))
H = jnp.eye(restart, restart + 1, dtype=dtype)
def loop_cond(carry):
_, _, breakdown, k = carry
return jnp.logical_and(k < restart, jnp.logical_not(breakdown))
# return lax.cond(k < restart,
# lambda x: ~x,
# lambda x: False,
# breakdown)
def arnoldi_process(carry):
V, H, _, k = carry
V, H, breakdown = kth_arnoldi_iteration(k, A, M, V, H, inner_tol)
return V, H, breakdown, k + 1
carry = (V, H, False, 0)
V, H, _, _ = lax.while_loop(loop_cond, arnoldi_process, carry)
# The following is equivalent to:
beta_vec = jnp.zeros((restart,), dtype=dtype)
beta_vec = beta_vec.at[0].set(residual_norm) # it really is the original value
y = jsp.linalg.solve(H[:, :-1].T, beta_vec)
Vy = tree_map(lambda X: _dot(X[..., :-1], y), V)
dx = M(Vy)
x = _add(x0, dx)
residual = _sub(b, A(x))
unit_residual, residual_norm = _safe_normalize(residual, return_norm=True)
return x, unit_residual, residual_norm
def _gmres_solve(A, b, x0, outer_tol, inner_tol, restart, maxiter, M,
gmres_func):
"""
The main function call wrapped by custom_linear_solve. Repeatedly calls GMRES
to find the projected solution within the order-``restart``
Krylov space K(A, x0, restart), using the result of the previous projection
in place of x0 each time. Parameters are the same as in ``gmres`` except:
outer_tol: Tolerance to be used between restarts.
inner_tol: Tolerance used within a restart.
gmres_func: A function performing a single GMRES restart.
Returns: The solution.
"""
residual = _sub(b, A(x0))
unit_residual, residual_norm = _safe_normalize(residual, return_norm=True)
def cond_fun(value):
x, residual, k = value
sqr_error = _vdot_tree(residual, residual, assume_real=False)
return (sqr_error > atol2) & (k < num_restarts) & ~jnp.isnan(sqr_error)
_, k, _, residual_norm = value
return jnp.logical_and(k < maxiter, residual_norm > outer_tol)
def body_fun(value):
x, residual, k = value
x = _gmres(A, b, x, restart, M, residual)
residual = _sub(b, A(x))
return x, residual, k + 1
x, k, unit_residual, residual_norm = value
x, unit_residual, residual_norm = gmres_func(A, b, x, unit_residual,
residual_norm, inner_tol,
restart, M)
return x, k + 1, unit_residual, residual_norm
residual = _sub(b, A(x0))
if num_restarts:
x, residual, _ = lax.while_loop(
cond_fun, body_fun, (x0, residual, 0))
else:
x = x0
iters = maxiter % restart
sqr_error = _vdot_tree(residual, residual)
if iters > 0:
x_final = lax.cond(
sqr_error > atol2,
true_fun=lambda values: _gmres(A, b, values[0], iters, M, values[1]),
false_fun=lambda values: values[0],
operand=(x, residual),
)
else:
x_final = x
return x_final
initialization = (x0, 0, unit_residual, residual_norm)
x_final, k, _, err = lax.while_loop(cond_fun, body_fun, initialization)
_ = k # Until we can pass this out
_ = err
# info = lax.cond(converged, lambda y: 0, lambda y: k, 0)
return x_final # , info
def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
M=None):
M=None, qr_mode=False):
"""
GMRES solves the linear system A x = b for x, given A and b.
A is specified as a function performing A(vi) -> vf = A @ vi, and in principle
need not have any particular special properties, such as symmetry. However,
convergence is often slow for nearly symmetric operators.
Parameters
----------
A: function
Function that calculates the linear map (matrix-vector product)
``Ax`` when called like ``A(x)``. ``A`` must return array(s) with the same
structure and shape as its argument.
b : array or tree of arrays
Right hand side of the linear system representing a single vector. Can be
stored as an array or Python container of array(s) with any shape.
Returns
-------
x : array or tree of arrays
The converged solution. Has the same structure as ``b``.
info : None
Placeholder for convergence information. In the future, JAX will report
the number of iterations when convergence is not achieved, like SciPy.
Other Parameters
----------------
x0 : array, optional
Starting guess for the solution. Must have the same structure as ``b``.
If this is unspecified, zeroes are used.
tol, atol : float, optional
Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
We do not implement SciPy's "legacy" behavior, so JAX's tolerance will
differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``gmres``.
restart : integer, optional
Size of the Krylov subspace ("number of iterations") built between
restarts. GMRES works by approximating the true solution x as its
projection into a Krylov space of this dimension - this parameter
therefore bounds the maximum accuracy achievable from any guess
solution. Larger values increase both number of iterations and iteration
cost, but may be necessary for convergence. The algorithm terminates
early if convergence is achieved before the full subspace is built.
Default is 20.
maxiter : integer
Maximum number of times to rebuild the size-``restart`` Krylov space
starting from the solution found at the last iteration. If GMRES
halts or is very slow, decreasing this parameter may help.
Default is infinite.
M : function
Preconditioner for A. The preconditioner should approximate the
inverse of A. Effective preconditioning dramatically improves the
rate of convergence, which implies that fewer iterations are needed
to reach a given error tolerance.
qr_mode : bool
If True, the algorithm builds an internal Krylov subspace using a QR
based algorithm, which reduces overhead and improved stability. However,
it may degrade performance significantly on GPUs or TPUs, in which case
this flag should be set False.
See also
--------
scipy.sparse.linalg.gmres
jax.lax.custom_linear_solve
"""
if x0 is None:
x0 = tree_map(jnp.zeros_like, b)
if M is None:
M = _identity
b, x0 = device_put((b, x0))
size = sum(bi.size for bi in tree_leaves(b))
if maxiter is None:
maxiter = 10 * size # copied from scipy
if restart > size:
restart = size
restart = min(restart, size)
if tree_structure(x0) != tree_structure(b):
raise ValueError(
'x0 and b must have matching tree structure: '
f'{tree_structure(x0)} vs {tree_structure(b)}')
'x0 and b must have matching tree structure: '
f'{tree_structure(x0)} vs {tree_structure(b)}')
b, x0 = device_put((b, x0))
b_norm = _norm_tree(b)
if b_norm == 0:
return b, 0
outer_tol = jnp.maximum(tol * b_norm, atol)
def _solve(A, b):
return _gmres_solve(A, b, x0, tol=tol, atol=atol, maxiter=maxiter,
restart=restart, M=M)
Mb = M(b)
Mb_norm = _norm_tree(Mb)
inner_tol = Mb_norm * min(1.0, outer_tol / b_norm)
if qr_mode:
def _solve(A, b):
return _gmres_solve(A, b, x0, outer_tol, inner_tol, restart, maxiter, M,
_gmres_plain)
else:
def _solve(A, b):
return _gmres_solve(A, b, x0, outer_tol, inner_tol, restart, maxiter, M,
_gmres_qr)
x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve)
info = None
failed = jnp.isnan(_norm_tree(x))
info = jnp.where(failed, x=-1, y=0)
return x, info

View File

@ -28,6 +28,7 @@ import jax.scipy.sparse.linalg
from jax.config import config
config.parse_flags_with_absl()
config.update('jax_enable_x64', True)
float_types = [np.float32, np.float64]
complex_types = [np.complex64, np.complex128]
@ -60,7 +61,34 @@ def rand_sym_pos_def(rng, shape, dtype):
return matrix @ matrix.T.conj()
class LaxBackedScipyTests(jtu.JaxTestCase):
def _fetch_preconditioner(self, preconditioner, A, rng=None,
return_function=False):
"""
Returns one of various preconditioning matrices depending on the identifier
`preconditioner' and the input matrix A whose inverse it supposedly
approximates.
"""
if preconditioner == 'identity':
M = np.eye(A.shape[0], dtype=A.dtype)
elif preconditioner == 'random':
if rng is None:
rng = jtu.rand_default(self.rng())
M = np.linalg.inv(rand_sym_pos_def(rng, A.shape, A.dtype))
elif preconditioner == 'exact':
M = np.linalg.inv(A)
else:
M = None
if M is None or not return_function:
return M
else:
return lambda x: jnp.dot(M, x, precision=lax.Precision.HIGHEST)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}_preconditioner={}".format(
@ -76,15 +104,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
rng = jtu.rand_default(self.rng())
A = rand_sym_pos_def(rng, shape, dtype)
b = rng(shape[:1], dtype)
if preconditioner == 'identity':
M = np.eye(shape[0], dtype=dtype)
elif preconditioner == 'random':
M = np.linalg.inv(rand_sym_pos_def(rng, shape, dtype))
elif preconditioner == 'exact':
M = np.linalg.inv(A)
else:
M = None
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
def args_maker():
return A, b
@ -178,6 +198,153 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
actual, _ = jax.scipy.sparse.linalg.cg(A, b)
self.assertAllClose(expected, actual.value)
# GMRES
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in [(2, 2)]
for dtype in float_types + complex_types))
def test_gmres_on_small_fixed_problem(self, shape, dtype):
"""
GMRES gives the right answer for a small fixed system.
"""
A = jnp.array(([[1, 1], [3, -4]]), dtype=dtype)
b = jnp.array([3, 2], dtype=dtype)
x0 = jnp.ones(2, dtype=dtype)
restart = 2
maxiter = 1
@jax.tree_util.Partial
def A_mv(x):
return matmul_high_precision(A, x)
tol = A.size * jnp.finfo(dtype).eps
x, _ = jax.scipy.sparse.linalg.gmres(A_mv, b, x0=x0, tol=tol, atol=tol,
restart=restart, maxiter=maxiter)
solution = jnp.array([2., 1.], dtype=dtype)
self.assertAllClose(solution, x)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}_preconditioner={}_qr_mode={}".format(
jtu.format_shape_dtype_string(shape, dtype),
preconditioner,
qr_mode),
"shape": shape, "dtype": dtype, "preconditioner": preconditioner,
"qr_mode": qr_mode}
for shape in [(2, 2), (7, 7)]
for dtype in float_types + complex_types
for preconditioner in [None, 'identity', 'exact']
for qr_mode in [True, False]
))
def test_gmres_on_identity_system(self, shape, dtype, preconditioner,
qr_mode):
A = jnp.eye(shape[1], dtype=dtype)
solution = jnp.ones(shape[1], dtype=dtype)
@jax.tree_util.Partial
def A_mv(x):
return matmul_high_precision(A, x)
rng = jtu.rand_default(self.rng())
M = self._fetch_preconditioner(preconditioner, A, rng=rng,
return_function=True)
b = A_mv(solution)
restart = shape[-1]
tol = shape[0] * jnp.finfo(dtype).eps
x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol,
restart=restart,
M=M,
qr_mode=qr_mode)
err = jnp.linalg.norm(solution - x) / jnp.linalg.norm(b)
rtol = tol*jnp.linalg.norm(b)
true_tol = max(rtol, tol)
self.assertLessEqual(err, true_tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}_preconditioner={}_qr_mode={}".format(
jtu.format_shape_dtype_string(shape, dtype),
preconditioner,
qr_mode),
"shape": shape, "dtype": dtype, "preconditioner": preconditioner,
"qr_mode": qr_mode}
for shape in [(2, 2), (7, 7)]
for dtype in float_types + complex_types
for preconditioner in [None, 'identity', 'exact']
for qr_mode in [True, False]
))
def test_gmres_on_random_system(self, shape, dtype, preconditioner,
qr_mode):
rng = jtu.rand_default(self.rng())
A = rng(shape, dtype)
solution = rng(shape[1:], dtype)
@jax.tree_util.Partial
def A_mv(x):
return matmul_high_precision(A, x)
M = self._fetch_preconditioner(preconditioner, A, rng=rng,
return_function=True)
b = A_mv(solution)
restart = shape[-1]
tol = shape[0] * jnp.finfo(dtype).eps
x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol,
restart=restart,
M=M,
qr_mode=qr_mode)
err = jnp.linalg.norm(solution - x) / jnp.linalg.norm(b)
rtol = tol*jnp.linalg.norm(b)
true_tol = max(rtol, tol)
self.assertLessEqual(err, true_tol)
def test_gmres_pytree(self):
A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
b = {"a": 1.0, "b": -4.0}
expected = {"a": 4.0, "b": -6.0}
actual, _ = jax.scipy.sparse.linalg.gmres(A, b)
self.assertEqual(expected.keys(), actual.keys())
self.assertAlmostEqual(expected["a"], actual["a"], places=6)
self.assertAlmostEqual(expected["b"], actual["b"], places=6)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}_preconditioner={}".format(
jtu.format_shape_dtype_string(shape, dtype),
preconditioner),
"shape": shape, "dtype": dtype, "preconditioner": preconditioner}
for shape in [(2, 2), (7, 7), (32, 32)]
for dtype in float_types + complex_types
for preconditioner in [None, 'identity']))
def test_gmres_arnoldi_step(self, shape, dtype, preconditioner):
"""
The Arnoldi decomposition within GMRES is correct.
"""
rng = jtu.rand_default(self.rng())
A = rng(shape, dtype)
if preconditioner is None:
M = lambda x: x
else:
M = self._fetch_preconditioner(preconditioner, A, rng=rng,
return_function=True)
n = shape[0]
x0 = rng(shape[:1], dtype)
Q = np.zeros((n, n + 1), dtype=dtype)
Q[:, 0] = x0/jax.numpy.linalg.norm(x0)
Q = jnp.array(Q)
H = jax.numpy.eye(n, n + 1, dtype=dtype)
tol = A.size*A.size*jax.numpy.finfo(dtype).eps
@jax.tree_util.Partial
def A_mv(x):
return matmul_high_precision(A, x)
for k in range(n):
Q, H, _ = jax.scipy.sparse.linalg.kth_arnoldi_iteration(k, A_mv, M, Q, H,
tol)
QAQ = matmul_high_precision(Q[:, :n].conj().T, A)
QAQ = matmul_high_precision(QAQ, Q[:, :n])
self.assertAllClose(QAQ, H.T[:n, :], rtol=tol, atol=tol)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())