mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Cleanup/fixup jax.scipy.sparse.linalg.gmres and expose it publicly.
This commit is contained in:
parent
29370c4d50
commit
6cc5b28327
@ -60,6 +60,7 @@ jax.scipy.sparse.linalg
|
||||
:toctree: _autosummary
|
||||
|
||||
cg
|
||||
gmres
|
||||
|
||||
jax.scipy.special
|
||||
-----------------
|
||||
|
@ -40,10 +40,9 @@ def _vdot_real_part(x, y):
|
||||
# where M is positive definite and Hermitian, so the result is
|
||||
# real valued:
|
||||
# https://en.wikipedia.org/wiki/Definiteness_of_a_matrix#Definitions_for_complex_matrices
|
||||
vdot = partial(jnp.vdot, precision=lax.Precision.HIGHEST)
|
||||
result = vdot(x.real, y.real)
|
||||
result = _vdot(x.real, y.real)
|
||||
if jnp.iscomplexobj(x) or jnp.iscomplexobj(y):
|
||||
result += vdot(x.imag, y.imag)
|
||||
result += _vdot(x.imag, y.imag)
|
||||
return result
|
||||
|
||||
|
||||
@ -51,12 +50,9 @@ def _vdot_real_tree(x, y):
|
||||
return sum(tree_leaves(tree_multimap(_vdot_real_part, x, y)))
|
||||
|
||||
|
||||
def _norm_tree(x):
|
||||
return jnp.sqrt(_vdot_real_tree(x, x))
|
||||
|
||||
|
||||
def _vdot_tree(x, y):
|
||||
return sum(tree_leaves(tree_multimap(_vdot, x, y)))
|
||||
def _norm(x):
|
||||
xs = tree_leaves(x)
|
||||
return jnp.sqrt(sum(map(_vdot_real_part, xs, xs)))
|
||||
|
||||
|
||||
def _mul(scalar, tree):
|
||||
@ -211,7 +207,7 @@ def _safe_normalize(x, thresh=None):
|
||||
which by default is the machine precision of x's dtype, it will be
|
||||
taken to be 0, and the normalized x to be the zero vector.
|
||||
"""
|
||||
norm = _norm_tree(x)
|
||||
norm = _norm(x)
|
||||
dtype = jnp.result_type(*tree_leaves(x))
|
||||
if thresh is None:
|
||||
thresh = jnp.finfo(norm.dtype).eps
|
||||
@ -233,7 +229,7 @@ def _project_on_columns(A, v):
|
||||
return tree_reduce(operator.add, v_proj)
|
||||
|
||||
|
||||
def _iterative_classical_gram_schmidt(Q, x, max_iterations=2):
|
||||
def _iterative_classical_gram_schmidt(Q, x, xnorm, max_iterations=2):
|
||||
"""
|
||||
Orthogonalize x against the columns of Q. The process is repeated
|
||||
up to `max_iterations` times, or fewer if the condition
|
||||
@ -247,6 +243,8 @@ def _iterative_classical_gram_schmidt(Q, x, max_iterations=2):
|
||||
x : array or tree of arrays
|
||||
A vector. It will be replaced with a new vector q which is orthonormal
|
||||
to the columns of Q, such that x in span(col(Q), q).
|
||||
xnorm : float
|
||||
Norm of x.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -259,17 +257,18 @@ def _iterative_classical_gram_schmidt(Q, x, max_iterations=2):
|
||||
# "twice is enough"
|
||||
# http://slepc.upv.es/documentation/reports/str1.pdf
|
||||
|
||||
# TODO(shoyer): consider switching to only one iteration, like SciPy?
|
||||
|
||||
# This assumes that Q's leaves all have the same dimension in the last
|
||||
# axis.
|
||||
r = jnp.zeros((tree_leaves(Q)[0].shape[-1]))
|
||||
q = x
|
||||
_, xnorm = _safe_normalize(x)
|
||||
xnorm_scaled = xnorm / jnp.sqrt(2)
|
||||
|
||||
def body_function(carry):
|
||||
k, q, r, qnorm_scaled = carry
|
||||
h = _project_on_columns(Q, q)
|
||||
Qh = tree_map(lambda X: _dot_tree(X, h), Q)
|
||||
Qh = tree_map(lambda X: _dot(X, h), Q)
|
||||
q = _sub(q, Qh)
|
||||
r = _add(r, h)
|
||||
|
||||
@ -298,7 +297,7 @@ def _iterative_classical_gram_schmidt(Q, x, max_iterations=2):
|
||||
return q, r
|
||||
|
||||
|
||||
def _kth_arnoldi_iteration(k, A, M, V, H, tol):
|
||||
def _kth_arnoldi_iteration(k, A, M, V, H):
|
||||
"""
|
||||
Performs a single (the k'th) step of the Arnoldi process. Thus,
|
||||
adds a new orthonormalized Krylov vector A(M(V[:, k])) to V[:, k+1],
|
||||
@ -307,19 +306,43 @@ def _kth_arnoldi_iteration(k, A, M, V, H, tol):
|
||||
subspace is declared to have been found, in which case in which case the new
|
||||
vector is taken to be the zero vector.
|
||||
"""
|
||||
eps = jnp.finfo(jnp.result_type(*tree_leaves(V))).eps
|
||||
|
||||
v = tree_map(lambda x: x[..., k], V) # Gets V[:, k]
|
||||
v = A(M(v))
|
||||
v, h = _iterative_classical_gram_schmidt(V, v, max_iterations=2)
|
||||
unit_v, v_norm = _safe_normalize(v, thresh=tol)
|
||||
v = M(A(v))
|
||||
_, v_norm_0 = _safe_normalize(v)
|
||||
v, h = _iterative_classical_gram_schmidt(V, v, v_norm_0, max_iterations=2)
|
||||
|
||||
tol = eps * v_norm_0
|
||||
unit_v, v_norm_1 = _safe_normalize(v, thresh=tol)
|
||||
V = tree_multimap(lambda X, y: X.at[..., k + 1].set(y), V, unit_v)
|
||||
|
||||
h = h.at[k + 1].set(v_norm)
|
||||
h = h.at[k + 1].set(v_norm_1)
|
||||
H = H.at[k, :].set(h)
|
||||
breakdown = v_norm == 0.
|
||||
breakdown = v_norm_1 == 0.
|
||||
return V, H, breakdown
|
||||
|
||||
|
||||
def _rotate_vectors(H, i, cs, sn):
|
||||
x1 = H[i]
|
||||
y1 = H[i + 1]
|
||||
x2 = cs.conj() * x1 - sn.conj() * y1
|
||||
y2 = sn * x1 + cs * y1
|
||||
H = H.at[i].set(x2)
|
||||
H = H.at[i + 1].set(y2)
|
||||
return H
|
||||
|
||||
|
||||
def _givens_rotation(a, b):
|
||||
b_zero = abs(b) == 0
|
||||
a_lt_b = abs(a) < abs(b)
|
||||
t = -jnp.where(a_lt_b, a, b) / jnp.where(a_lt_b, b, a)
|
||||
r = lax.rsqrt(1 + abs(t) ** 2)
|
||||
cs = jnp.where(b_zero, 1, jnp.where(a_lt_b, r * t, r))
|
||||
sn = jnp.where(b_zero, 0, jnp.where(a_lt_b, r, r * t))
|
||||
return cs, sn
|
||||
|
||||
|
||||
def _apply_givens_rotations(H_row, givens, k):
|
||||
"""
|
||||
Applies the Givens rotations stored in the vectors cs and sn to the vector
|
||||
@ -330,30 +353,16 @@ def _apply_givens_rotations(H_row, givens, k):
|
||||
# Givens rotations stored in givens[:, :k] to H_col.
|
||||
|
||||
def apply_ith_rotation(i, H_row):
|
||||
cs, sn = givens[i, :]
|
||||
H_i = cs * H_row[i] - sn * H_row[i + 1]
|
||||
H_ip1 = sn * H_row[i] + cs * H_row[i + 1]
|
||||
H_row = H_row.at[i].set(H_i)
|
||||
H_row = H_row.at[i + 1].set(H_ip1)
|
||||
return H_row
|
||||
|
||||
return _rotate_vectors(H_row, i, *givens[i, :])
|
||||
R_row = lax.fori_loop(0, k, apply_ith_rotation, H_row)
|
||||
|
||||
def givens_rotation(v1, v2):
|
||||
t = jnp.sqrt(v1**2 + v2**2)
|
||||
cs = v1 / t
|
||||
sn = -v2 / t
|
||||
return cs, sn
|
||||
givens_factors = givens_rotation(R_row[k], R_row[k + 1])
|
||||
givens_factors = _givens_rotation(R_row[k], R_row[k + 1])
|
||||
givens = givens.at[k, :].set(givens_factors)
|
||||
cs_k, sn_k = givens_factors
|
||||
|
||||
R_row = R_row.at[k].set(cs_k * R_row[k] - sn_k * R_row[k + 1])
|
||||
R_row = R_row.at[k + 1].set(0.)
|
||||
R_row = _rotate_vectors(R_row, k, *givens_factors)
|
||||
return R_row, givens
|
||||
|
||||
|
||||
def _gmres_qr(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):
|
||||
def _gmres_incremental(A, b, x0, unit_residual, residual_norm, ptol, restart, M):
|
||||
"""
|
||||
Implements a single restart of GMRES. The restart-dimensional Krylov subspace
|
||||
K(A, x0) = span(A(x0), A@x0, A@A@x0, ..., A^restart @ x0) is built, and the
|
||||
@ -362,18 +371,15 @@ def _gmres_qr(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):
|
||||
This implementation builds the QR factorization during the Arnoldi process.
|
||||
"""
|
||||
# https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf
|
||||
# residual = _sub(b, A(x0))
|
||||
# unit_residual, beta = _safe_normalize(residual)
|
||||
|
||||
V = tree_map(
|
||||
lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)),
|
||||
unit_residual,
|
||||
)
|
||||
dtype = jnp.result_type(*tree_leaves(b))
|
||||
R = jnp.eye(restart, restart + 1, dtype=dtype) # eye to avoid constructing
|
||||
# a singular matrix in case
|
||||
# of early termination.
|
||||
b_norm = _norm_tree(b)
|
||||
# use eye() to avoid constructing a singular matrix in case of early
|
||||
# termination
|
||||
R = jnp.eye(restart, restart + 1, dtype=dtype)
|
||||
|
||||
givens = jnp.zeros((restart, 2), dtype=dtype)
|
||||
beta_vec = jnp.zeros((restart + 1), dtype=dtype)
|
||||
@ -381,17 +387,15 @@ def _gmres_qr(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):
|
||||
|
||||
def loop_cond(carry):
|
||||
k, err, _, _, _, _ = carry
|
||||
return jnp.logical_and(k < restart, err > inner_tol)
|
||||
return jnp.logical_and(k < restart, err > ptol)
|
||||
|
||||
def arnoldi_qr_step(carry):
|
||||
k, _, V, R, beta_vec, givens = carry
|
||||
V, H, _ = _kth_arnoldi_iteration(k, A, M, V, R, inner_tol)
|
||||
V, H, _ = _kth_arnoldi_iteration(k, A, M, V, R)
|
||||
R_row, givens = _apply_givens_rotations(H[k, :], givens, k)
|
||||
R = R.at[k, :].set(R_row[:])
|
||||
cs, sn = givens[k, :] * beta_vec[k]
|
||||
beta_vec = beta_vec.at[k].set(cs)
|
||||
beta_vec = beta_vec.at[k + 1].set(sn)
|
||||
err = jnp.abs(sn) / b_norm
|
||||
R = R.at[k, :].set(R_row)
|
||||
beta_vec = _rotate_vectors(beta_vec, k, *givens[k, :])
|
||||
err = abs(beta_vec[k + 1])
|
||||
return k + 1, err, V, R, beta_vec, givens
|
||||
|
||||
carry = (0, residual_norm, V, R, beta_vec, givens)
|
||||
@ -400,16 +404,23 @@ def _gmres_qr(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):
|
||||
del k # Until we figure out how to pass this to the user.
|
||||
|
||||
y = jsp.linalg.solve_triangular(R[:, :-1].T, beta_vec[:-1])
|
||||
Vy = tree_map(lambda X: _dot(X[..., :-1], y), V)
|
||||
dx = M(Vy)
|
||||
dx = tree_map(lambda X: _dot(X[..., :-1], y), V)
|
||||
|
||||
x = _add(x0, dx)
|
||||
residual = _sub(b, A(x))
|
||||
residual = M(_sub(b, A(x)))
|
||||
unit_residual, residual_norm = _safe_normalize(residual)
|
||||
# TODO(shoyer): "Inner loop tolerance control" on ptol, like SciPy
|
||||
return x, unit_residual, residual_norm
|
||||
|
||||
|
||||
def _gmres_plain(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):
|
||||
def _lstsq(a, b):
|
||||
# faster than jsp.linalg.lstsq
|
||||
a2 = _dot(a.T.conj(), a)
|
||||
b2 = _dot(a.T.conj(), b)
|
||||
return jsp.linalg.solve(a2, b2, sym_pos=True)
|
||||
|
||||
|
||||
def _gmres_batched(A, b, x0, unit_residual, residual_norm, ptol, restart, M):
|
||||
"""
|
||||
Implements a single restart of GMRES. The ``restart``-dimensional Krylov
|
||||
subspace
|
||||
@ -419,6 +430,7 @@ def _gmres_plain(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):
|
||||
This implementation solves a dense linear problem instead of building
|
||||
a QR factorization during the Arnoldi process.
|
||||
"""
|
||||
del ptol # unused
|
||||
# https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf
|
||||
V = tree_map(
|
||||
lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)),
|
||||
@ -433,63 +445,59 @@ def _gmres_plain(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):
|
||||
|
||||
def arnoldi_process(carry):
|
||||
V, H, _, k = carry
|
||||
V, H, breakdown = _kth_arnoldi_iteration(k, A, M, V, H, inner_tol)
|
||||
V, H, breakdown = _kth_arnoldi_iteration(k, A, M, V, H)
|
||||
return V, H, breakdown, k + 1
|
||||
|
||||
carry = (V, H, False, 0)
|
||||
V, H, _, _ = lax.while_loop(loop_cond, arnoldi_process, carry)
|
||||
|
||||
# The following is equivalent to:
|
||||
beta_vec = jnp.zeros((restart,), dtype=dtype)
|
||||
beta_vec = beta_vec.at[0].set(residual_norm) # it really is the original value
|
||||
y = jsp.linalg.solve(H[:, :-1].T, beta_vec)
|
||||
Vy = tree_map(lambda X: _dot(X[..., :-1], y), V)
|
||||
beta_vec = jnp.zeros((restart + 1,), dtype=dtype)
|
||||
beta_vec = beta_vec.at[0].set(residual_norm)
|
||||
y = _lstsq(H.T, beta_vec)
|
||||
dx = tree_map(lambda X: _dot(X[..., :-1], y), V)
|
||||
|
||||
dx = M(Vy)
|
||||
x = _add(x0, dx)
|
||||
|
||||
residual = _sub(b, A(x))
|
||||
residual = M(_sub(b, A(x)))
|
||||
unit_residual, residual_norm = _safe_normalize(residual)
|
||||
return x, unit_residual, residual_norm
|
||||
|
||||
|
||||
def _gmres_solve(A, b, x0, outer_tol, inner_tol, restart, maxiter, M,
|
||||
gmres_func):
|
||||
def _gmres_solve(A, b, x0, atol, ptol, restart, maxiter, M, gmres_func):
|
||||
"""
|
||||
The main function call wrapped by custom_linear_solve. Repeatedly calls GMRES
|
||||
to find the projected solution within the order-``restart``
|
||||
Krylov space K(A, x0, restart), using the result of the previous projection
|
||||
in place of x0 each time. Parameters are the same as in ``gmres`` except:
|
||||
|
||||
outer_tol: Tolerance to be used between restarts.
|
||||
inner_tol: Tolerance used within a restart.
|
||||
atol: Tolerance for norm(A(x) - b), used between restarts.
|
||||
ptol: Tolerance for norm(M(A(x) - b)), used within a restart.
|
||||
gmres_func: A function performing a single GMRES restart.
|
||||
|
||||
Returns: The solution.
|
||||
"""
|
||||
residual = _sub(b, A(x0))
|
||||
residual = M(_sub(b, A(x0)))
|
||||
unit_residual, residual_norm = _safe_normalize(residual)
|
||||
|
||||
def cond_fun(value):
|
||||
_, k, _, residual_norm = value
|
||||
return jnp.logical_and(k < maxiter, residual_norm > outer_tol)
|
||||
return jnp.logical_and(k < maxiter, residual_norm > atol)
|
||||
|
||||
def body_fun(value):
|
||||
x, k, unit_residual, residual_norm = value
|
||||
x, unit_residual, residual_norm = gmres_func(A, b, x, unit_residual,
|
||||
residual_norm, inner_tol,
|
||||
restart, M)
|
||||
x, unit_residual, residual_norm = gmres_func(
|
||||
A, b, x, unit_residual, residual_norm, ptol, restart, M)
|
||||
return x, k + 1, unit_residual, residual_norm
|
||||
|
||||
initialization = (x0, 0, unit_residual, residual_norm)
|
||||
x_final, k, _, err = lax.while_loop(cond_fun, body_fun, initialization)
|
||||
_ = k # Until we can pass this out
|
||||
_ = k # Until we can pass this out
|
||||
_ = err
|
||||
return x_final # , info
|
||||
|
||||
|
||||
def _gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
|
||||
M=None, qr_mode=False):
|
||||
def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
|
||||
M=None, solve_method='batched'):
|
||||
"""
|
||||
GMRES solves the linear system A x = b for x, given A and b.
|
||||
|
||||
@ -543,11 +551,14 @@ def _gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
|
||||
inverse of A. Effective preconditioning dramatically improves the
|
||||
rate of convergence, which implies that fewer iterations are needed
|
||||
to reach a given error tolerance.
|
||||
qr_mode : bool
|
||||
If True, the algorithm builds an internal Krylov subspace using a QR
|
||||
based algorithm, which reduces overhead and improved stability. However,
|
||||
it may degrade performance significantly on GPUs or TPUs, in which case
|
||||
this flag should be set False.
|
||||
solve_method : 'incremental' or 'batched'
|
||||
The 'incremental' solve method builds a QR decomposition for the Krylov
|
||||
subspace incrementally during the GMRES process using Givens rotations.
|
||||
This improves numerical stability and gives a free estimate of the
|
||||
residual norm that allows for early termination within a single "restart".
|
||||
In contrast, the 'batched' solve method solves the least squares problem
|
||||
from scratch at the end of each GMRES iteration. It does not allow for
|
||||
early termination, but has much less overhead on GPUs.
|
||||
|
||||
See also
|
||||
--------
|
||||
@ -572,26 +583,25 @@ def _gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
|
||||
'x0 and b must have matching tree structure: '
|
||||
f'{tree_structure(x0)} vs {tree_structure(b)}')
|
||||
|
||||
b_norm = _norm_tree(b)
|
||||
if b_norm == 0:
|
||||
return b, 0
|
||||
outer_tol = jnp.maximum(tol * b_norm, atol)
|
||||
b_norm = _norm(b)
|
||||
atol = jnp.maximum(tol * b_norm, atol)
|
||||
|
||||
Mb = M(b)
|
||||
Mb_norm = _norm_tree(Mb)
|
||||
inner_tol = Mb_norm * min(1.0, outer_tol / b_norm)
|
||||
Mb_norm = _norm(Mb)
|
||||
ptol = Mb_norm * jnp.minimum(1.0, atol / b_norm)
|
||||
|
||||
if qr_mode:
|
||||
def _solve(A, b):
|
||||
return _gmres_solve(A, b, x0, outer_tol, inner_tol, restart, maxiter, M,
|
||||
_gmres_plain)
|
||||
if solve_method == 'incremental':
|
||||
gmres_func = _gmres_incremental
|
||||
elif solve_method == 'batched':
|
||||
gmres_func = _gmres_batched
|
||||
else:
|
||||
def _solve(A, b):
|
||||
return _gmres_solve(A, b, x0, outer_tol, inner_tol, restart, maxiter, M,
|
||||
_gmres_qr)
|
||||
raise ValueError(f"invalid solve_method {solve_method}, must be either "
|
||||
"'incremental' or 'batched'")
|
||||
|
||||
def _solve(A, b):
|
||||
return _gmres_solve(A, b, x0, atol, ptol, restart, maxiter, M, gmres_func)
|
||||
x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve)
|
||||
|
||||
failed = jnp.isnan(_norm_tree(x))
|
||||
failed = jnp.isnan(_norm(x))
|
||||
info = jnp.where(failed, x=-1, y=0)
|
||||
return x, info
|
||||
|
@ -15,5 +15,5 @@
|
||||
# flake8: noqa: F401
|
||||
from jax._src.scipy.sparse.linalg import (
|
||||
cg,
|
||||
_gmres,
|
||||
gmres,
|
||||
)
|
||||
|
@ -55,7 +55,7 @@ def lax_solver(solver_name, A, b, M=None, atol=0.0, **kwargs):
|
||||
|
||||
|
||||
lax_cg = partial(lax_solver, 'cg')
|
||||
lax_gmres = partial(lax_solver, '_gmres')
|
||||
lax_gmres = partial(lax_solver, 'gmres')
|
||||
|
||||
|
||||
def scipy_solver(solver_name, A, b, atol=0.0, **kwargs):
|
||||
@ -73,10 +73,6 @@ def rand_sym_pos_def(rng, shape, dtype):
|
||||
return matrix @ matrix.T.conj()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
def _fetch_preconditioner(self, preconditioner, A, rng=None,
|
||||
return_function=False):
|
||||
@ -107,11 +103,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
preconditioner),
|
||||
"shape": shape, "dtype": dtype, "preconditioner": preconditioner}
|
||||
for shape in [(4, 4), (7, 7), (32, 32)]
|
||||
for dtype in float_types + complex_types
|
||||
for preconditioner in [None, 'identity', 'exact']))
|
||||
# TODO(#2951): reenable 'random' preconditioner.
|
||||
for shape in [(4, 4), (7, 7)]
|
||||
for dtype in [np.float64, np.complex128]
|
||||
for preconditioner in [None, 'identity', 'exact', 'random']))
|
||||
def test_cg_against_scipy(self, shape, dtype, preconditioner):
|
||||
if not config.FLAGS.jax_enable_x64:
|
||||
raise unittest.SkipTest("requires x64 mode")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
A = rand_sym_pos_def(rng, shape, dtype)
|
||||
@ -125,21 +122,19 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
partial(scipy_cg, M=M, maxiter=1),
|
||||
partial(lax_cg, M=M, maxiter=1),
|
||||
args_maker,
|
||||
tol=1e-3)
|
||||
tol=1e-12)
|
||||
|
||||
# TODO(shoyer,mattjj): I had to loosen the tolerance for complex64[7,7]
|
||||
# with preconditioner=random
|
||||
self._CheckAgainstNumpy(
|
||||
partial(scipy_cg, M=M, maxiter=3),
|
||||
partial(lax_cg, M=M, maxiter=3),
|
||||
args_maker,
|
||||
tol=3e-3)
|
||||
tol=1e-12)
|
||||
|
||||
self._CheckAgainstNumpy(
|
||||
np.linalg.solve,
|
||||
partial(lax_cg, M=M, atol=1e-6),
|
||||
partial(lax_cg, M=M, atol=1e-10),
|
||||
args_maker,
|
||||
tol=2e-2)
|
||||
tol=1e-6)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -213,43 +208,18 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
# GMRES
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in [(2, 2)]
|
||||
for dtype in float_types + complex_types))
|
||||
def test_gmres_on_small_fixed_problem(self, shape, dtype):
|
||||
"""
|
||||
GMRES gives the right answer for a small fixed system.
|
||||
"""
|
||||
A = jnp.array(([[1, 1], [3, -4]]), dtype=dtype)
|
||||
b = jnp.array([3, 2], dtype=dtype)
|
||||
x0 = jnp.ones(2, dtype=dtype)
|
||||
restart = 2
|
||||
maxiter = 1
|
||||
|
||||
@jax.tree_util.Partial
|
||||
def A_mv(x):
|
||||
return matmul_high_precision(A, x)
|
||||
tol = A.size * jnp.finfo(dtype).eps
|
||||
x, _ = jax.scipy.sparse.linalg._gmres(
|
||||
A_mv, b, x0=x0, tol=tol, atol=tol, restart=restart, maxiter=maxiter)
|
||||
solution = jnp.array([2., 1.], dtype=dtype)
|
||||
self.assertAllClose(solution, x)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_preconditioner={}_qr_mode={}".format(
|
||||
"_shape={}_preconditioner={}_solve_method={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
preconditioner,
|
||||
qr_mode),
|
||||
solve_method),
|
||||
"shape": shape, "dtype": dtype, "preconditioner": preconditioner,
|
||||
"qr_mode": qr_mode}
|
||||
"solve_method": solve_method}
|
||||
for shape in [(3, 3)]
|
||||
# TODO(shoyer): get working for np.complex128 and qr_mode=True
|
||||
for dtype in [np.float64]
|
||||
for preconditioner in [None, 'identity', 'exact']
|
||||
for qr_mode in [False]))
|
||||
def test_gmres_against_scipy(self, shape, dtype, preconditioner, qr_mode):
|
||||
for dtype in [np.float64, np.complex128]
|
||||
for preconditioner in [None, 'identity', 'exact', 'random']
|
||||
for solve_method in ['incremental', 'batched']))
|
||||
def test_gmres_against_scipy(
|
||||
self, shape, dtype, preconditioner, solve_method):
|
||||
if not config.FLAGS.jax_enable_x64:
|
||||
raise unittest.SkipTest("requires x64 mode")
|
||||
|
||||
@ -263,43 +233,43 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(
|
||||
partial(scipy_gmres, M=M, restart=1, maxiter=1),
|
||||
partial(lax_gmres, M=M, restart=1, maxiter=1, qr_mode=qr_mode),
|
||||
partial(lax_gmres, M=M, restart=1, maxiter=1, solve_method=solve_method),
|
||||
args_maker,
|
||||
tol=1e-3)
|
||||
tol=1e-10)
|
||||
|
||||
self._CheckAgainstNumpy(
|
||||
partial(scipy_gmres, M=M, restart=1, maxiter=2),
|
||||
partial(lax_gmres, M=M, restart=1, maxiter=2, qr_mode=qr_mode),
|
||||
partial(lax_gmres, M=M, restart=1, maxiter=2, solve_method=solve_method),
|
||||
args_maker,
|
||||
tol=1e-3)
|
||||
tol=1e-10)
|
||||
|
||||
self._CheckAgainstNumpy(
|
||||
partial(scipy_gmres, M=M, restart=3, maxiter=1),
|
||||
partial(lax_gmres, M=M, restart=3, maxiter=1, qr_mode=qr_mode),
|
||||
partial(scipy_gmres, M=M, restart=2, maxiter=1),
|
||||
partial(lax_gmres, M=M, restart=2, maxiter=1, solve_method=solve_method),
|
||||
args_maker,
|
||||
tol=3e-3)
|
||||
tol=1e-10)
|
||||
|
||||
self._CheckAgainstNumpy(
|
||||
np.linalg.solve,
|
||||
partial(lax_gmres, M=M, atol=1e-6, qr_mode=qr_mode),
|
||||
partial(lax_gmres, M=M, atol=1e-6, solve_method=solve_method),
|
||||
args_maker,
|
||||
tol=2e-2)
|
||||
tol=1e-10)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_preconditioner={}_qr_mode={}".format(
|
||||
"_shape={}_preconditioner={}_solve_method={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
preconditioner,
|
||||
qr_mode),
|
||||
solve_method),
|
||||
"shape": shape, "dtype": dtype, "preconditioner": preconditioner,
|
||||
"qr_mode": qr_mode}
|
||||
"solve_method": solve_method}
|
||||
for shape in [(2, 2), (7, 7)]
|
||||
for dtype in float_types + complex_types
|
||||
for preconditioner in [None, 'identity', 'exact']
|
||||
for qr_mode in [True, False]
|
||||
for solve_method in ['batched', 'incremental']
|
||||
))
|
||||
def test_gmres_on_identity_system(self, shape, dtype, preconditioner,
|
||||
qr_mode):
|
||||
solve_method):
|
||||
A = jnp.eye(shape[1], dtype=dtype)
|
||||
|
||||
solution = jnp.ones(shape[1], dtype=dtype)
|
||||
@ -312,28 +282,28 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
b = A_mv(solution)
|
||||
restart = shape[-1]
|
||||
tol = shape[0] * jnp.finfo(dtype).eps
|
||||
x, info = jax.scipy.sparse.linalg._gmres(A_mv, b, tol=tol, atol=tol,
|
||||
restart=restart,
|
||||
M=M, qr_mode=qr_mode)
|
||||
x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol,
|
||||
restart=restart,
|
||||
M=M, solve_method=solve_method)
|
||||
using_x64 = solution.dtype.kind in {np.float64, np.complex128}
|
||||
solution_tol = 1e-8 if using_x64 else 1e-4
|
||||
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_preconditioner={}_qr_mode={}".format(
|
||||
"_shape={}_preconditioner={}_solve_method={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
preconditioner,
|
||||
qr_mode),
|
||||
solve_method),
|
||||
"shape": shape, "dtype": dtype, "preconditioner": preconditioner,
|
||||
"qr_mode": qr_mode}
|
||||
"solve_method": solve_method}
|
||||
for shape in [(2, 2), (4, 4)]
|
||||
for dtype in float_types + complex_types
|
||||
for preconditioner in [None, 'identity', 'exact']
|
||||
for qr_mode in [True, False]
|
||||
for solve_method in ['incremental', 'batched']
|
||||
))
|
||||
def test_gmres_on_random_system(self, shape, dtype, preconditioner,
|
||||
qr_mode):
|
||||
solve_method):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
A = rng(shape, dtype)
|
||||
|
||||
@ -346,9 +316,9 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
b = A_mv(solution)
|
||||
restart = shape[-1]
|
||||
tol = shape[0] * jnp.finfo(A.dtype).eps
|
||||
x, info = jax.scipy.sparse.linalg._gmres(A_mv, b, tol=tol, atol=tol,
|
||||
restart=restart,
|
||||
M=M, qr_mode=qr_mode)
|
||||
x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol,
|
||||
restart=restart,
|
||||
M=M, solve_method=solve_method)
|
||||
using_x64 = solution.dtype.kind in {np.float64, np.complex128}
|
||||
solution_tol = 1e-8 if using_x64 else 1e-4
|
||||
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
|
||||
@ -357,10 +327,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
|
||||
b = {"a": 1.0, "b": -4.0}
|
||||
expected = {"a": 4.0, "b": -6.0}
|
||||
actual, _ = jax.scipy.sparse.linalg._gmres(A, b)
|
||||
actual, _ = jax.scipy.sparse.linalg.gmres(A, b)
|
||||
self.assertEqual(expected.keys(), actual.keys())
|
||||
self.assertAlmostEqual(expected["a"], actual["a"], places=6)
|
||||
self.assertAlmostEqual(expected["b"], actual["b"], places=6)
|
||||
self.assertAlmostEqual(expected["a"], actual["a"], places=5)
|
||||
self.assertAlmostEqual(expected["b"], actual["b"], places=5)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -389,20 +359,19 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
n = shape[0]
|
||||
x0 = rng(shape[:1], dtype)
|
||||
Q = np.zeros((n, n + 1), dtype=dtype)
|
||||
Q[:, 0] = x0/jax.numpy.linalg.norm(x0)
|
||||
Q[:, 0] = x0/jnp.linalg.norm(x0)
|
||||
Q = jnp.array(Q)
|
||||
H = jax.numpy.eye(n, n + 1, dtype=dtype)
|
||||
tol = A.size*A.size*jax.numpy.finfo(dtype).eps
|
||||
H = jnp.eye(n, n + 1, dtype=dtype)
|
||||
|
||||
@jax.tree_util.Partial
|
||||
def A_mv(x):
|
||||
return matmul_high_precision(A, x)
|
||||
for k in range(n):
|
||||
Q, H, _ = jax._src.scipy.sparse.linalg._kth_arnoldi_iteration(
|
||||
k, A_mv, M, Q, H, tol)
|
||||
k, A_mv, M, Q, H)
|
||||
QA = matmul_high_precision(Q[:, :n].conj().T, A)
|
||||
QAQ = matmul_high_precision(QA, Q[:, :n])
|
||||
self.assertAllClose(QAQ, H.T[:n, :], rtol=tol, atol=tol)
|
||||
self.assertAllClose(QAQ, H.T[:n, :], rtol=1e-5, atol=1e-5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user