More unit-tests + mark gmres as internal for now

This commit is contained in:
Stephan Hoyer 2020-11-09 16:02:18 -08:00
parent c5c71a0a37
commit 7e62270e5a
2 changed files with 95 additions and 40 deletions

View File

@ -26,6 +26,8 @@ from jax.util import safe_map as map
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
_vdot = partial(jnp.vdot, precision=lax.Precision.HIGHEST)
_einsum = partial(jnp.einsum, precision=lax.Precision.HIGHEST)
# aliases for working with pytrees
def _vdot_real_part(x, y):
@ -216,7 +218,6 @@ def _safe_normalize(x, thresh=None):
thresh = thresh.astype(dtype).real
use_norm = norm > thresh
# where_normalizing = partial(tree_map, partial(jnp.where, norm > thresh))
normalized_x = tree_map(lambda y: jnp.where(use_norm, y / norm, 0.0), x)
norm = jnp.where(use_norm, norm, 0.0)
return normalized_x, norm
@ -227,9 +228,7 @@ def _project_on_columns(A, v):
Returns A.T.conj() @ v.
"""
v_proj = tree_multimap(
lambda X, y: jnp.einsum("...n,...->n", X.conj(), y),
A,
v,
lambda X, y: _einsum("...n,...->n", X.conj(), y), A, v,
)
return tree_reduce(operator.add, v_proj)
@ -299,7 +298,7 @@ def _iterative_classical_gram_schmidt(Q, x, max_iterations=2):
return q, r
def kth_arnoldi_iteration(k, A, M, V, H, tol):
def _kth_arnoldi_iteration(k, A, M, V, H, tol):
"""
Performs a single (the k'th) step of the Arnoldi process. Thus,
adds a new orthonormalized Krylov vector A(M(V[:, k])) to V[:, k+1],
@ -321,7 +320,7 @@ def kth_arnoldi_iteration(k, A, M, V, H, tol):
return V, H, breakdown
def apply_givens_rotations(H_row, givens, k):
def _apply_givens_rotations(H_row, givens, k):
"""
Applies the Givens rotations stored in the vectors cs and sn to the vector
H_row. Then constructs and applies a new Givens rotation that eliminates
@ -386,8 +385,8 @@ def _gmres_qr(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):
def arnoldi_qr_step(carry):
k, _, V, R, beta_vec, givens = carry
V, H, _ = kth_arnoldi_iteration(k, A, M, V, R, inner_tol)
R_row, givens = apply_givens_rotations(H[k, :], givens, k)
V, H, _ = _kth_arnoldi_iteration(k, A, M, V, R, inner_tol)
R_row, givens = _apply_givens_rotations(H[k, :], givens, k)
R = R.at[k, :].set(R_row[:])
cs, sn = givens[k, :] * beta_vec[k]
beta_vec = beta_vec.at[k].set(cs)
@ -434,7 +433,7 @@ def _gmres_plain(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):
def arnoldi_process(carry):
V, H, _, k = carry
V, H, breakdown = kth_arnoldi_iteration(k, A, M, V, H, inner_tol)
V, H, breakdown = _kth_arnoldi_iteration(k, A, M, V, H, inner_tol)
return V, H, breakdown, k + 1
carry = (V, H, False, 0)
@ -486,11 +485,10 @@ def _gmres_solve(A, b, x0, outer_tol, inner_tol, restart, maxiter, M,
x_final, k, _, err = lax.while_loop(cond_fun, body_fun, initialization)
_ = k # Until we can pass this out
_ = err
# info = lax.cond(converged, lambda y: 0, lambda y: k, 0)
return x_final # , info
def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
def _gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
M=None, qr_mode=False):
"""
GMRES solves the linear system A x = b for x, given A and b.

View File

@ -13,6 +13,7 @@
# limitations under the License.
from functools import partial
import unittest
from absl.testing import parameterized
from absl.testing import absltest
@ -28,7 +29,7 @@ import jax.scipy.sparse.linalg
from jax.config import config
config.parse_flags_with_absl()
config.update('jax_enable_x64', True)
float_types = [np.float32, np.float64]
complex_types = [np.complex64, np.complex128]
@ -43,19 +44,29 @@ def posify(matrix):
return matmul_high_precision(matrix, matrix.T.conj())
def lax_cg(A, b, M=None, atol=0.0, **kwargs):
def lax_solver(solver_name, A, b, M=None, atol=0.0, **kwargs):
A = partial(matmul_high_precision, A)
if M is not None:
M = partial(matmul_high_precision, M)
x, _ = jax.scipy.sparse.linalg.cg(A, b, atol=atol, M=M, **kwargs)
func = getattr(jax.scipy.sparse.linalg, solver_name)
x, _ = func(A, b, atol=atol, M=M, **kwargs)
return x
def scipy_cg(A, b, atol=0.0, **kwargs):
x, _ = scipy.sparse.linalg.cg(A, b, atol=atol, **kwargs)
lax_cg = partial(lax_solver, 'cg')
lax_gmres = partial(lax_solver, '_gmres')
def scipy_solver(solver_name, A, b, atol=0.0, **kwargs):
func = getattr(scipy.sparse.linalg, solver_name)
x, _ = func(A, b, atol=atol, **kwargs)
return x
scipy_cg = partial(scipy_solver, 'cg')
scipy_gmres = partial(scipy_solver, 'gmres')
def rand_sym_pos_def(rng, shape, dtype):
matrix = np.eye(N=shape[0], dtype=dtype) + rng(shape, dtype)
return matrix @ matrix.T.conj()
@ -219,11 +230,59 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
def A_mv(x):
return matmul_high_precision(A, x)
tol = A.size * jnp.finfo(dtype).eps
x, _ = jax.scipy.sparse.linalg.gmres(A_mv, b, x0=x0, tol=tol, atol=tol,
restart=restart, maxiter=maxiter)
x, _ = jax.scipy.sparse.linalg._gmres(
A_mv, b, x0=x0, tol=tol, atol=tol, restart=restart, maxiter=maxiter)
solution = jnp.array([2., 1.], dtype=dtype)
self.assertAllClose(solution, x)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}_preconditioner={}_qr_mode={}".format(
jtu.format_shape_dtype_string(shape, dtype),
preconditioner,
qr_mode),
"shape": shape, "dtype": dtype, "preconditioner": preconditioner,
"qr_mode": qr_mode}
for shape in [(3, 3)]
# TODO(shoyer): get working for np.complex128 and qr_mode=True
for dtype in [np.float64]
for preconditioner in [None, 'identity', 'exact']
for qr_mode in [False]))
def test_gmres_against_scipy(self, shape, dtype, preconditioner, qr_mode):
if not config.FLAGS.jax_enable_x64:
raise unittest.SkipTest("requires x64 mode")
rng = jtu.rand_default(self.rng())
A = rng(shape, dtype)
b = rng(shape[:1], dtype)
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
def args_maker():
return A, b
self._CheckAgainstNumpy(
partial(scipy_gmres, M=M, restart=1, maxiter=1),
partial(lax_gmres, M=M, restart=1, maxiter=1, qr_mode=qr_mode),
args_maker,
tol=1e-3)
self._CheckAgainstNumpy(
partial(scipy_gmres, M=M, restart=1, maxiter=2),
partial(lax_gmres, M=M, restart=1, maxiter=2, qr_mode=qr_mode),
args_maker,
tol=1e-3)
self._CheckAgainstNumpy(
partial(scipy_gmres, M=M, restart=3, maxiter=1),
partial(lax_gmres, M=M, restart=3, maxiter=1, qr_mode=qr_mode),
args_maker,
tol=3e-3)
self._CheckAgainstNumpy(
np.linalg.solve,
partial(lax_gmres, M=M, atol=1e-6, qr_mode=qr_mode),
args_maker,
tol=2e-2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -252,14 +311,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
b = A_mv(solution)
restart = shape[-1]
tol = shape[0] * jnp.finfo(dtype).eps
x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol,
x, info = jax.scipy.sparse.linalg._gmres(A_mv, b, tol=tol, atol=tol,
restart=restart,
M=M,
qr_mode=qr_mode)
err = jnp.linalg.norm(solution - x) / jnp.linalg.norm(b)
rtol = tol*jnp.linalg.norm(b)
true_tol = max(rtol, tol)
self.assertLessEqual(err, true_tol)
self.assertAllClose(x, solution, atol=300*jnp.finfo(x.dtype).eps)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -269,7 +325,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
qr_mode),
"shape": shape, "dtype": dtype, "preconditioner": preconditioner,
"qr_mode": qr_mode}
for shape in [(2, 2), (7, 7)]
for shape in [(2, 2), (4, 4)]
for dtype in float_types + complex_types
for preconditioner in [None, 'identity', 'exact']
for qr_mode in [True, False]
@ -287,21 +343,19 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
return_function=True)
b = A_mv(solution)
restart = shape[-1]
tol = shape[0] * jnp.finfo(dtype).eps
x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol,
tol = shape[0] * jnp.finfo(A.dtype).eps
x, info = jax.scipy.sparse.linalg._gmres(A_mv, b, tol=tol, atol=tol,
restart=restart,
M=M,
qr_mode=qr_mode)
err = jnp.linalg.norm(solution - x) / jnp.linalg.norm(b)
rtol = tol*jnp.linalg.norm(b)
true_tol = max(rtol, tol)
self.assertLessEqual(err, true_tol)
solution_tol = 300*jnp.finfo(x.dtype).eps
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
def test_gmres_pytree(self):
A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
b = {"a": 1.0, "b": -4.0}
expected = {"a": 4.0, "b": -6.0}
actual, _ = jax.scipy.sparse.linalg.gmres(A, b)
actual, _ = jax.scipy.sparse.linalg._gmres(A, b)
self.assertEqual(expected.keys(), actual.keys())
self.assertAlmostEqual(expected["a"], actual["a"], places=6)
self.assertAlmostEqual(expected["b"], actual["b"], places=6)
@ -312,13 +366,16 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
jtu.format_shape_dtype_string(shape, dtype),
preconditioner),
"shape": shape, "dtype": dtype, "preconditioner": preconditioner}
for shape in [(2, 2), (7, 7), (32, 32)]
for shape in [(2, 2), (3, 3)]
for dtype in float_types + complex_types
for preconditioner in [None, 'identity']))
def test_gmres_arnoldi_step(self, shape, dtype, preconditioner):
"""
The Arnoldi decomposition within GMRES is correct.
"""
if not config.FLAGS.jax_enable_x64:
raise unittest.SkipTest("requires x64 mode")
rng = jtu.rand_default(self.rng())
A = rng(shape, dtype)
if preconditioner is None:
@ -339,10 +396,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
def A_mv(x):
return matmul_high_precision(A, x)
for k in range(n):
Q, H, _ = jax.scipy.sparse.linalg.kth_arnoldi_iteration(k, A_mv, M, Q, H,
tol)
QAQ = matmul_high_precision(Q[:, :n].conj().T, A)
QAQ = matmul_high_precision(QAQ, Q[:, :n])
Q, H, _ = jax.scipy.sparse.linalg._kth_arnoldi_iteration(
k, A_mv, M, Q, H, tol)
QA = matmul_high_precision(Q[:, :n].conj().T, A)
QAQ = matmul_high_precision(QA, Q[:, :n])
self.assertAllClose(QAQ, H.T[:n, :], rtol=tol, atol=tol)