Cleanup/fixup jax.scipy.sparse.linalg.gmres and expose it publicly.

This commit is contained in:
Stephan Hoyer 2020-11-27 14:44:06 -08:00
parent 29370c4d50
commit 6cc5b28327
4 changed files with 155 additions and 175 deletions

View File

@ -60,6 +60,7 @@ jax.scipy.sparse.linalg
:toctree: _autosummary
cg
gmres
jax.scipy.special
-----------------

View File

@ -40,10 +40,9 @@ def _vdot_real_part(x, y):
# where M is positive definite and Hermitian, so the result is
# real valued:
# https://en.wikipedia.org/wiki/Definiteness_of_a_matrix#Definitions_for_complex_matrices
vdot = partial(jnp.vdot, precision=lax.Precision.HIGHEST)
result = vdot(x.real, y.real)
result = _vdot(x.real, y.real)
if jnp.iscomplexobj(x) or jnp.iscomplexobj(y):
result += vdot(x.imag, y.imag)
result += _vdot(x.imag, y.imag)
return result
@ -51,12 +50,9 @@ def _vdot_real_tree(x, y):
return sum(tree_leaves(tree_multimap(_vdot_real_part, x, y)))
def _norm_tree(x):
return jnp.sqrt(_vdot_real_tree(x, x))
def _vdot_tree(x, y):
return sum(tree_leaves(tree_multimap(_vdot, x, y)))
def _norm(x):
xs = tree_leaves(x)
return jnp.sqrt(sum(map(_vdot_real_part, xs, xs)))
def _mul(scalar, tree):
@ -211,7 +207,7 @@ def _safe_normalize(x, thresh=None):
which by default is the machine precision of x's dtype, it will be
taken to be 0, and the normalized x to be the zero vector.
"""
norm = _norm_tree(x)
norm = _norm(x)
dtype = jnp.result_type(*tree_leaves(x))
if thresh is None:
thresh = jnp.finfo(norm.dtype).eps
@ -233,7 +229,7 @@ def _project_on_columns(A, v):
return tree_reduce(operator.add, v_proj)
def _iterative_classical_gram_schmidt(Q, x, max_iterations=2):
def _iterative_classical_gram_schmidt(Q, x, xnorm, max_iterations=2):
"""
Orthogonalize x against the columns of Q. The process is repeated
up to `max_iterations` times, or fewer if the condition
@ -247,6 +243,8 @@ def _iterative_classical_gram_schmidt(Q, x, max_iterations=2):
x : array or tree of arrays
A vector. It will be replaced with a new vector q which is orthonormal
to the columns of Q, such that x in span(col(Q), q).
xnorm : float
Norm of x.
Returns
-------
@ -259,17 +257,18 @@ def _iterative_classical_gram_schmidt(Q, x, max_iterations=2):
# "twice is enough"
# http://slepc.upv.es/documentation/reports/str1.pdf
# TODO(shoyer): consider switching to only one iteration, like SciPy?
# This assumes that Q's leaves all have the same dimension in the last
# axis.
r = jnp.zeros((tree_leaves(Q)[0].shape[-1]))
q = x
_, xnorm = _safe_normalize(x)
xnorm_scaled = xnorm / jnp.sqrt(2)
def body_function(carry):
k, q, r, qnorm_scaled = carry
h = _project_on_columns(Q, q)
Qh = tree_map(lambda X: _dot_tree(X, h), Q)
Qh = tree_map(lambda X: _dot(X, h), Q)
q = _sub(q, Qh)
r = _add(r, h)
@ -298,7 +297,7 @@ def _iterative_classical_gram_schmidt(Q, x, max_iterations=2):
return q, r
def _kth_arnoldi_iteration(k, A, M, V, H, tol):
def _kth_arnoldi_iteration(k, A, M, V, H):
"""
Performs a single (the k'th) step of the Arnoldi process. Thus,
adds a new orthonormalized Krylov vector A(M(V[:, k])) to V[:, k+1],
@ -307,19 +306,43 @@ def _kth_arnoldi_iteration(k, A, M, V, H, tol):
subspace is declared to have been found, in which case in which case the new
vector is taken to be the zero vector.
"""
eps = jnp.finfo(jnp.result_type(*tree_leaves(V))).eps
v = tree_map(lambda x: x[..., k], V) # Gets V[:, k]
v = A(M(v))
v, h = _iterative_classical_gram_schmidt(V, v, max_iterations=2)
unit_v, v_norm = _safe_normalize(v, thresh=tol)
v = M(A(v))
_, v_norm_0 = _safe_normalize(v)
v, h = _iterative_classical_gram_schmidt(V, v, v_norm_0, max_iterations=2)
tol = eps * v_norm_0
unit_v, v_norm_1 = _safe_normalize(v, thresh=tol)
V = tree_multimap(lambda X, y: X.at[..., k + 1].set(y), V, unit_v)
h = h.at[k + 1].set(v_norm)
h = h.at[k + 1].set(v_norm_1)
H = H.at[k, :].set(h)
breakdown = v_norm == 0.
breakdown = v_norm_1 == 0.
return V, H, breakdown
def _rotate_vectors(H, i, cs, sn):
x1 = H[i]
y1 = H[i + 1]
x2 = cs.conj() * x1 - sn.conj() * y1
y2 = sn * x1 + cs * y1
H = H.at[i].set(x2)
H = H.at[i + 1].set(y2)
return H
def _givens_rotation(a, b):
b_zero = abs(b) == 0
a_lt_b = abs(a) < abs(b)
t = -jnp.where(a_lt_b, a, b) / jnp.where(a_lt_b, b, a)
r = lax.rsqrt(1 + abs(t) ** 2)
cs = jnp.where(b_zero, 1, jnp.where(a_lt_b, r * t, r))
sn = jnp.where(b_zero, 0, jnp.where(a_lt_b, r, r * t))
return cs, sn
def _apply_givens_rotations(H_row, givens, k):
"""
Applies the Givens rotations stored in the vectors cs and sn to the vector
@ -330,30 +353,16 @@ def _apply_givens_rotations(H_row, givens, k):
# Givens rotations stored in givens[:, :k] to H_col.
def apply_ith_rotation(i, H_row):
cs, sn = givens[i, :]
H_i = cs * H_row[i] - sn * H_row[i + 1]
H_ip1 = sn * H_row[i] + cs * H_row[i + 1]
H_row = H_row.at[i].set(H_i)
H_row = H_row.at[i + 1].set(H_ip1)
return H_row
return _rotate_vectors(H_row, i, *givens[i, :])
R_row = lax.fori_loop(0, k, apply_ith_rotation, H_row)
def givens_rotation(v1, v2):
t = jnp.sqrt(v1**2 + v2**2)
cs = v1 / t
sn = -v2 / t
return cs, sn
givens_factors = givens_rotation(R_row[k], R_row[k + 1])
givens_factors = _givens_rotation(R_row[k], R_row[k + 1])
givens = givens.at[k, :].set(givens_factors)
cs_k, sn_k = givens_factors
R_row = R_row.at[k].set(cs_k * R_row[k] - sn_k * R_row[k + 1])
R_row = R_row.at[k + 1].set(0.)
R_row = _rotate_vectors(R_row, k, *givens_factors)
return R_row, givens
def _gmres_qr(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):
def _gmres_incremental(A, b, x0, unit_residual, residual_norm, ptol, restart, M):
"""
Implements a single restart of GMRES. The restart-dimensional Krylov subspace
K(A, x0) = span(A(x0), A@x0, A@A@x0, ..., A^restart @ x0) is built, and the
@ -362,18 +371,15 @@ def _gmres_qr(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):
This implementation builds the QR factorization during the Arnoldi process.
"""
# https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf
# residual = _sub(b, A(x0))
# unit_residual, beta = _safe_normalize(residual)
V = tree_map(
lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)),
unit_residual,
)
dtype = jnp.result_type(*tree_leaves(b))
R = jnp.eye(restart, restart + 1, dtype=dtype) # eye to avoid constructing
# a singular matrix in case
# of early termination.
b_norm = _norm_tree(b)
# use eye() to avoid constructing a singular matrix in case of early
# termination
R = jnp.eye(restart, restart + 1, dtype=dtype)
givens = jnp.zeros((restart, 2), dtype=dtype)
beta_vec = jnp.zeros((restart + 1), dtype=dtype)
@ -381,17 +387,15 @@ def _gmres_qr(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):
def loop_cond(carry):
k, err, _, _, _, _ = carry
return jnp.logical_and(k < restart, err > inner_tol)
return jnp.logical_and(k < restart, err > ptol)
def arnoldi_qr_step(carry):
k, _, V, R, beta_vec, givens = carry
V, H, _ = _kth_arnoldi_iteration(k, A, M, V, R, inner_tol)
V, H, _ = _kth_arnoldi_iteration(k, A, M, V, R)
R_row, givens = _apply_givens_rotations(H[k, :], givens, k)
R = R.at[k, :].set(R_row[:])
cs, sn = givens[k, :] * beta_vec[k]
beta_vec = beta_vec.at[k].set(cs)
beta_vec = beta_vec.at[k + 1].set(sn)
err = jnp.abs(sn) / b_norm
R = R.at[k, :].set(R_row)
beta_vec = _rotate_vectors(beta_vec, k, *givens[k, :])
err = abs(beta_vec[k + 1])
return k + 1, err, V, R, beta_vec, givens
carry = (0, residual_norm, V, R, beta_vec, givens)
@ -400,16 +404,23 @@ def _gmres_qr(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):
del k # Until we figure out how to pass this to the user.
y = jsp.linalg.solve_triangular(R[:, :-1].T, beta_vec[:-1])
Vy = tree_map(lambda X: _dot(X[..., :-1], y), V)
dx = M(Vy)
dx = tree_map(lambda X: _dot(X[..., :-1], y), V)
x = _add(x0, dx)
residual = _sub(b, A(x))
residual = M(_sub(b, A(x)))
unit_residual, residual_norm = _safe_normalize(residual)
# TODO(shoyer): "Inner loop tolerance control" on ptol, like SciPy
return x, unit_residual, residual_norm
def _gmres_plain(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):
def _lstsq(a, b):
# faster than jsp.linalg.lstsq
a2 = _dot(a.T.conj(), a)
b2 = _dot(a.T.conj(), b)
return jsp.linalg.solve(a2, b2, sym_pos=True)
def _gmres_batched(A, b, x0, unit_residual, residual_norm, ptol, restart, M):
"""
Implements a single restart of GMRES. The ``restart``-dimensional Krylov
subspace
@ -419,6 +430,7 @@ def _gmres_plain(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):
This implementation solves a dense linear problem instead of building
a QR factorization during the Arnoldi process.
"""
del ptol # unused
# https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf
V = tree_map(
lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)),
@ -433,63 +445,59 @@ def _gmres_plain(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):
def arnoldi_process(carry):
V, H, _, k = carry
V, H, breakdown = _kth_arnoldi_iteration(k, A, M, V, H, inner_tol)
V, H, breakdown = _kth_arnoldi_iteration(k, A, M, V, H)
return V, H, breakdown, k + 1
carry = (V, H, False, 0)
V, H, _, _ = lax.while_loop(loop_cond, arnoldi_process, carry)
# The following is equivalent to:
beta_vec = jnp.zeros((restart,), dtype=dtype)
beta_vec = beta_vec.at[0].set(residual_norm) # it really is the original value
y = jsp.linalg.solve(H[:, :-1].T, beta_vec)
Vy = tree_map(lambda X: _dot(X[..., :-1], y), V)
beta_vec = jnp.zeros((restart + 1,), dtype=dtype)
beta_vec = beta_vec.at[0].set(residual_norm)
y = _lstsq(H.T, beta_vec)
dx = tree_map(lambda X: _dot(X[..., :-1], y), V)
dx = M(Vy)
x = _add(x0, dx)
residual = _sub(b, A(x))
residual = M(_sub(b, A(x)))
unit_residual, residual_norm = _safe_normalize(residual)
return x, unit_residual, residual_norm
def _gmres_solve(A, b, x0, outer_tol, inner_tol, restart, maxiter, M,
gmres_func):
def _gmres_solve(A, b, x0, atol, ptol, restart, maxiter, M, gmres_func):
"""
The main function call wrapped by custom_linear_solve. Repeatedly calls GMRES
to find the projected solution within the order-``restart``
Krylov space K(A, x0, restart), using the result of the previous projection
in place of x0 each time. Parameters are the same as in ``gmres`` except:
outer_tol: Tolerance to be used between restarts.
inner_tol: Tolerance used within a restart.
atol: Tolerance for norm(A(x) - b), used between restarts.
ptol: Tolerance for norm(M(A(x) - b)), used within a restart.
gmres_func: A function performing a single GMRES restart.
Returns: The solution.
"""
residual = _sub(b, A(x0))
residual = M(_sub(b, A(x0)))
unit_residual, residual_norm = _safe_normalize(residual)
def cond_fun(value):
_, k, _, residual_norm = value
return jnp.logical_and(k < maxiter, residual_norm > outer_tol)
return jnp.logical_and(k < maxiter, residual_norm > atol)
def body_fun(value):
x, k, unit_residual, residual_norm = value
x, unit_residual, residual_norm = gmres_func(A, b, x, unit_residual,
residual_norm, inner_tol,
restart, M)
x, unit_residual, residual_norm = gmres_func(
A, b, x, unit_residual, residual_norm, ptol, restart, M)
return x, k + 1, unit_residual, residual_norm
initialization = (x0, 0, unit_residual, residual_norm)
x_final, k, _, err = lax.while_loop(cond_fun, body_fun, initialization)
_ = k # Until we can pass this out
_ = k # Until we can pass this out
_ = err
return x_final # , info
def _gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
M=None, qr_mode=False):
def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
M=None, solve_method='batched'):
"""
GMRES solves the linear system A x = b for x, given A and b.
@ -543,11 +551,14 @@ def _gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
inverse of A. Effective preconditioning dramatically improves the
rate of convergence, which implies that fewer iterations are needed
to reach a given error tolerance.
qr_mode : bool
If True, the algorithm builds an internal Krylov subspace using a QR
based algorithm, which reduces overhead and improved stability. However,
it may degrade performance significantly on GPUs or TPUs, in which case
this flag should be set False.
solve_method : 'incremental' or 'batched'
The 'incremental' solve method builds a QR decomposition for the Krylov
subspace incrementally during the GMRES process using Givens rotations.
This improves numerical stability and gives a free estimate of the
residual norm that allows for early termination within a single "restart".
In contrast, the 'batched' solve method solves the least squares problem
from scratch at the end of each GMRES iteration. It does not allow for
early termination, but has much less overhead on GPUs.
See also
--------
@ -572,26 +583,25 @@ def _gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
'x0 and b must have matching tree structure: '
f'{tree_structure(x0)} vs {tree_structure(b)}')
b_norm = _norm_tree(b)
if b_norm == 0:
return b, 0
outer_tol = jnp.maximum(tol * b_norm, atol)
b_norm = _norm(b)
atol = jnp.maximum(tol * b_norm, atol)
Mb = M(b)
Mb_norm = _norm_tree(Mb)
inner_tol = Mb_norm * min(1.0, outer_tol / b_norm)
Mb_norm = _norm(Mb)
ptol = Mb_norm * jnp.minimum(1.0, atol / b_norm)
if qr_mode:
def _solve(A, b):
return _gmres_solve(A, b, x0, outer_tol, inner_tol, restart, maxiter, M,
_gmres_plain)
if solve_method == 'incremental':
gmres_func = _gmres_incremental
elif solve_method == 'batched':
gmres_func = _gmres_batched
else:
def _solve(A, b):
return _gmres_solve(A, b, x0, outer_tol, inner_tol, restart, maxiter, M,
_gmres_qr)
raise ValueError(f"invalid solve_method {solve_method}, must be either "
"'incremental' or 'batched'")
def _solve(A, b):
return _gmres_solve(A, b, x0, atol, ptol, restart, maxiter, M, gmres_func)
x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve)
failed = jnp.isnan(_norm_tree(x))
failed = jnp.isnan(_norm(x))
info = jnp.where(failed, x=-1, y=0)
return x, info

View File

@ -15,5 +15,5 @@
# flake8: noqa: F401
from jax._src.scipy.sparse.linalg import (
cg,
_gmres,
gmres,
)

View File

@ -55,7 +55,7 @@ def lax_solver(solver_name, A, b, M=None, atol=0.0, **kwargs):
lax_cg = partial(lax_solver, 'cg')
lax_gmres = partial(lax_solver, '_gmres')
lax_gmres = partial(lax_solver, 'gmres')
def scipy_solver(solver_name, A, b, atol=0.0, **kwargs):
@ -73,10 +73,6 @@ def rand_sym_pos_def(rng, shape, dtype):
return matrix @ matrix.T.conj()
class LaxBackedScipyTests(jtu.JaxTestCase):
def _fetch_preconditioner(self, preconditioner, A, rng=None,
return_function=False):
@ -107,11 +103,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
jtu.format_shape_dtype_string(shape, dtype),
preconditioner),
"shape": shape, "dtype": dtype, "preconditioner": preconditioner}
for shape in [(4, 4), (7, 7), (32, 32)]
for dtype in float_types + complex_types
for preconditioner in [None, 'identity', 'exact']))
# TODO(#2951): reenable 'random' preconditioner.
for shape in [(4, 4), (7, 7)]
for dtype in [np.float64, np.complex128]
for preconditioner in [None, 'identity', 'exact', 'random']))
def test_cg_against_scipy(self, shape, dtype, preconditioner):
if not config.FLAGS.jax_enable_x64:
raise unittest.SkipTest("requires x64 mode")
rng = jtu.rand_default(self.rng())
A = rand_sym_pos_def(rng, shape, dtype)
@ -125,21 +122,19 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
partial(scipy_cg, M=M, maxiter=1),
partial(lax_cg, M=M, maxiter=1),
args_maker,
tol=1e-3)
tol=1e-12)
# TODO(shoyer,mattjj): I had to loosen the tolerance for complex64[7,7]
# with preconditioner=random
self._CheckAgainstNumpy(
partial(scipy_cg, M=M, maxiter=3),
partial(lax_cg, M=M, maxiter=3),
args_maker,
tol=3e-3)
tol=1e-12)
self._CheckAgainstNumpy(
np.linalg.solve,
partial(lax_cg, M=M, atol=1e-6),
partial(lax_cg, M=M, atol=1e-10),
args_maker,
tol=2e-2)
tol=1e-6)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -213,43 +208,18 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
# GMRES
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in [(2, 2)]
for dtype in float_types + complex_types))
def test_gmres_on_small_fixed_problem(self, shape, dtype):
"""
GMRES gives the right answer for a small fixed system.
"""
A = jnp.array(([[1, 1], [3, -4]]), dtype=dtype)
b = jnp.array([3, 2], dtype=dtype)
x0 = jnp.ones(2, dtype=dtype)
restart = 2
maxiter = 1
@jax.tree_util.Partial
def A_mv(x):
return matmul_high_precision(A, x)
tol = A.size * jnp.finfo(dtype).eps
x, _ = jax.scipy.sparse.linalg._gmres(
A_mv, b, x0=x0, tol=tol, atol=tol, restart=restart, maxiter=maxiter)
solution = jnp.array([2., 1.], dtype=dtype)
self.assertAllClose(solution, x)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}_preconditioner={}_qr_mode={}".format(
"_shape={}_preconditioner={}_solve_method={}".format(
jtu.format_shape_dtype_string(shape, dtype),
preconditioner,
qr_mode),
solve_method),
"shape": shape, "dtype": dtype, "preconditioner": preconditioner,
"qr_mode": qr_mode}
"solve_method": solve_method}
for shape in [(3, 3)]
# TODO(shoyer): get working for np.complex128 and qr_mode=True
for dtype in [np.float64]
for preconditioner in [None, 'identity', 'exact']
for qr_mode in [False]))
def test_gmres_against_scipy(self, shape, dtype, preconditioner, qr_mode):
for dtype in [np.float64, np.complex128]
for preconditioner in [None, 'identity', 'exact', 'random']
for solve_method in ['incremental', 'batched']))
def test_gmres_against_scipy(
self, shape, dtype, preconditioner, solve_method):
if not config.FLAGS.jax_enable_x64:
raise unittest.SkipTest("requires x64 mode")
@ -263,43 +233,43 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(
partial(scipy_gmres, M=M, restart=1, maxiter=1),
partial(lax_gmres, M=M, restart=1, maxiter=1, qr_mode=qr_mode),
partial(lax_gmres, M=M, restart=1, maxiter=1, solve_method=solve_method),
args_maker,
tol=1e-3)
tol=1e-10)
self._CheckAgainstNumpy(
partial(scipy_gmres, M=M, restart=1, maxiter=2),
partial(lax_gmres, M=M, restart=1, maxiter=2, qr_mode=qr_mode),
partial(lax_gmres, M=M, restart=1, maxiter=2, solve_method=solve_method),
args_maker,
tol=1e-3)
tol=1e-10)
self._CheckAgainstNumpy(
partial(scipy_gmres, M=M, restart=3, maxiter=1),
partial(lax_gmres, M=M, restart=3, maxiter=1, qr_mode=qr_mode),
partial(scipy_gmres, M=M, restart=2, maxiter=1),
partial(lax_gmres, M=M, restart=2, maxiter=1, solve_method=solve_method),
args_maker,
tol=3e-3)
tol=1e-10)
self._CheckAgainstNumpy(
np.linalg.solve,
partial(lax_gmres, M=M, atol=1e-6, qr_mode=qr_mode),
partial(lax_gmres, M=M, atol=1e-6, solve_method=solve_method),
args_maker,
tol=2e-2)
tol=1e-10)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}_preconditioner={}_qr_mode={}".format(
"_shape={}_preconditioner={}_solve_method={}".format(
jtu.format_shape_dtype_string(shape, dtype),
preconditioner,
qr_mode),
solve_method),
"shape": shape, "dtype": dtype, "preconditioner": preconditioner,
"qr_mode": qr_mode}
"solve_method": solve_method}
for shape in [(2, 2), (7, 7)]
for dtype in float_types + complex_types
for preconditioner in [None, 'identity', 'exact']
for qr_mode in [True, False]
for solve_method in ['batched', 'incremental']
))
def test_gmres_on_identity_system(self, shape, dtype, preconditioner,
qr_mode):
solve_method):
A = jnp.eye(shape[1], dtype=dtype)
solution = jnp.ones(shape[1], dtype=dtype)
@ -312,28 +282,28 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
b = A_mv(solution)
restart = shape[-1]
tol = shape[0] * jnp.finfo(dtype).eps
x, info = jax.scipy.sparse.linalg._gmres(A_mv, b, tol=tol, atol=tol,
restart=restart,
M=M, qr_mode=qr_mode)
x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol,
restart=restart,
M=M, solve_method=solve_method)
using_x64 = solution.dtype.kind in {np.float64, np.complex128}
solution_tol = 1e-8 if using_x64 else 1e-4
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}_preconditioner={}_qr_mode={}".format(
"_shape={}_preconditioner={}_solve_method={}".format(
jtu.format_shape_dtype_string(shape, dtype),
preconditioner,
qr_mode),
solve_method),
"shape": shape, "dtype": dtype, "preconditioner": preconditioner,
"qr_mode": qr_mode}
"solve_method": solve_method}
for shape in [(2, 2), (4, 4)]
for dtype in float_types + complex_types
for preconditioner in [None, 'identity', 'exact']
for qr_mode in [True, False]
for solve_method in ['incremental', 'batched']
))
def test_gmres_on_random_system(self, shape, dtype, preconditioner,
qr_mode):
solve_method):
rng = jtu.rand_default(self.rng())
A = rng(shape, dtype)
@ -346,9 +316,9 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
b = A_mv(solution)
restart = shape[-1]
tol = shape[0] * jnp.finfo(A.dtype).eps
x, info = jax.scipy.sparse.linalg._gmres(A_mv, b, tol=tol, atol=tol,
restart=restart,
M=M, qr_mode=qr_mode)
x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol,
restart=restart,
M=M, solve_method=solve_method)
using_x64 = solution.dtype.kind in {np.float64, np.complex128}
solution_tol = 1e-8 if using_x64 else 1e-4
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
@ -357,10 +327,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
b = {"a": 1.0, "b": -4.0}
expected = {"a": 4.0, "b": -6.0}
actual, _ = jax.scipy.sparse.linalg._gmres(A, b)
actual, _ = jax.scipy.sparse.linalg.gmres(A, b)
self.assertEqual(expected.keys(), actual.keys())
self.assertAlmostEqual(expected["a"], actual["a"], places=6)
self.assertAlmostEqual(expected["b"], actual["b"], places=6)
self.assertAlmostEqual(expected["a"], actual["a"], places=5)
self.assertAlmostEqual(expected["b"], actual["b"], places=5)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -389,20 +359,19 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
n = shape[0]
x0 = rng(shape[:1], dtype)
Q = np.zeros((n, n + 1), dtype=dtype)
Q[:, 0] = x0/jax.numpy.linalg.norm(x0)
Q[:, 0] = x0/jnp.linalg.norm(x0)
Q = jnp.array(Q)
H = jax.numpy.eye(n, n + 1, dtype=dtype)
tol = A.size*A.size*jax.numpy.finfo(dtype).eps
H = jnp.eye(n, n + 1, dtype=dtype)
@jax.tree_util.Partial
def A_mv(x):
return matmul_high_precision(A, x)
for k in range(n):
Q, H, _ = jax._src.scipy.sparse.linalg._kth_arnoldi_iteration(
k, A_mv, M, Q, H, tol)
k, A_mv, M, Q, H)
QA = matmul_high_precision(Q[:, :n].conj().T, A)
QAQ = matmul_high_precision(QA, Q[:, :n])
self.assertAllClose(QAQ, H.T[:n, :], rtol=tol, atol=tol)
self.assertAllClose(QAQ, H.T[:n, :], rtol=1e-5, atol=1e-5)
if __name__ == "__main__":