diff --git a/jax/scipy/sparse/linalg.py b/jax/scipy/sparse/linalg.py index 878505cd8..5a64bc717 100644 --- a/jax/scipy/sparse/linalg.py +++ b/jax/scipy/sparse/linalg.py @@ -26,6 +26,8 @@ from jax.util import safe_map as map _dot = partial(jnp.dot, precision=lax.Precision.HIGHEST) _vdot = partial(jnp.vdot, precision=lax.Precision.HIGHEST) +_einsum = partial(jnp.einsum, precision=lax.Precision.HIGHEST) + # aliases for working with pytrees def _vdot_real_part(x, y): @@ -216,7 +218,6 @@ def _safe_normalize(x, thresh=None): thresh = thresh.astype(dtype).real use_norm = norm > thresh - # where_normalizing = partial(tree_map, partial(jnp.where, norm > thresh)) normalized_x = tree_map(lambda y: jnp.where(use_norm, y / norm, 0.0), x) norm = jnp.where(use_norm, norm, 0.0) return normalized_x, norm @@ -227,9 +228,7 @@ def _project_on_columns(A, v): Returns A.T.conj() @ v. """ v_proj = tree_multimap( - lambda X, y: jnp.einsum("...n,...->n", X.conj(), y), - A, - v, + lambda X, y: _einsum("...n,...->n", X.conj(), y), A, v, ) return tree_reduce(operator.add, v_proj) @@ -299,7 +298,7 @@ def _iterative_classical_gram_schmidt(Q, x, max_iterations=2): return q, r -def kth_arnoldi_iteration(k, A, M, V, H, tol): +def _kth_arnoldi_iteration(k, A, M, V, H, tol): """ Performs a single (the k'th) step of the Arnoldi process. Thus, adds a new orthonormalized Krylov vector A(M(V[:, k])) to V[:, k+1], @@ -321,7 +320,7 @@ def kth_arnoldi_iteration(k, A, M, V, H, tol): return V, H, breakdown -def apply_givens_rotations(H_row, givens, k): +def _apply_givens_rotations(H_row, givens, k): """ Applies the Givens rotations stored in the vectors cs and sn to the vector H_row. Then constructs and applies a new Givens rotation that eliminates @@ -386,8 +385,8 @@ def _gmres_qr(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M): def arnoldi_qr_step(carry): k, _, V, R, beta_vec, givens = carry - V, H, _ = kth_arnoldi_iteration(k, A, M, V, R, inner_tol) - R_row, givens = apply_givens_rotations(H[k, :], givens, k) + V, H, _ = _kth_arnoldi_iteration(k, A, M, V, R, inner_tol) + R_row, givens = _apply_givens_rotations(H[k, :], givens, k) R = R.at[k, :].set(R_row[:]) cs, sn = givens[k, :] * beta_vec[k] beta_vec = beta_vec.at[k].set(cs) @@ -434,7 +433,7 @@ def _gmres_plain(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M): def arnoldi_process(carry): V, H, _, k = carry - V, H, breakdown = kth_arnoldi_iteration(k, A, M, V, H, inner_tol) + V, H, breakdown = _kth_arnoldi_iteration(k, A, M, V, H, inner_tol) return V, H, breakdown, k + 1 carry = (V, H, False, 0) @@ -486,12 +485,11 @@ def _gmres_solve(A, b, x0, outer_tol, inner_tol, restart, maxiter, M, x_final, k, _, err = lax.while_loop(cond_fun, body_fun, initialization) _ = k # Until we can pass this out _ = err - # info = lax.cond(converged, lambda y: 0, lambda y: k, 0) return x_final # , info -def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None, - M=None, qr_mode=False): +def _gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None, + M=None, qr_mode=False): """ GMRES solves the linear system A x = b for x, given A and b. diff --git a/tests/lax_scipy_sparse_test.py b/tests/lax_scipy_sparse_test.py index dca995e76..ab55eb5a7 100644 --- a/tests/lax_scipy_sparse_test.py +++ b/tests/lax_scipy_sparse_test.py @@ -13,6 +13,7 @@ # limitations under the License. from functools import partial +import unittest from absl.testing import parameterized from absl.testing import absltest @@ -28,7 +29,7 @@ import jax.scipy.sparse.linalg from jax.config import config config.parse_flags_with_absl() -config.update('jax_enable_x64', True) + float_types = [np.float32, np.float64] complex_types = [np.complex64, np.complex128] @@ -43,19 +44,29 @@ def posify(matrix): return matmul_high_precision(matrix, matrix.T.conj()) -def lax_cg(A, b, M=None, atol=0.0, **kwargs): +def lax_solver(solver_name, A, b, M=None, atol=0.0, **kwargs): A = partial(matmul_high_precision, A) if M is not None: M = partial(matmul_high_precision, M) - x, _ = jax.scipy.sparse.linalg.cg(A, b, atol=atol, M=M, **kwargs) + func = getattr(jax.scipy.sparse.linalg, solver_name) + x, _ = func(A, b, atol=atol, M=M, **kwargs) return x -def scipy_cg(A, b, atol=0.0, **kwargs): - x, _ = scipy.sparse.linalg.cg(A, b, atol=atol, **kwargs) +lax_cg = partial(lax_solver, 'cg') +lax_gmres = partial(lax_solver, '_gmres') + + +def scipy_solver(solver_name, A, b, atol=0.0, **kwargs): + func = getattr(scipy.sparse.linalg, solver_name) + x, _ = func(A, b, atol=atol, **kwargs) return x +scipy_cg = partial(scipy_solver, 'cg') +scipy_gmres = partial(scipy_solver, 'gmres') + + def rand_sym_pos_def(rng, shape, dtype): matrix = np.eye(N=shape[0], dtype=dtype) + rng(shape, dtype) return matrix @ matrix.T.conj() @@ -219,11 +230,59 @@ class LaxBackedScipyTests(jtu.JaxTestCase): def A_mv(x): return matmul_high_precision(A, x) tol = A.size * jnp.finfo(dtype).eps - x, _ = jax.scipy.sparse.linalg.gmres(A_mv, b, x0=x0, tol=tol, atol=tol, - restart=restart, maxiter=maxiter) + x, _ = jax.scipy.sparse.linalg._gmres( + A_mv, b, x0=x0, tol=tol, atol=tol, restart=restart, maxiter=maxiter) solution = jnp.array([2., 1.], dtype=dtype) self.assertAllClose(solution, x) + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": + "_shape={}_preconditioner={}_qr_mode={}".format( + jtu.format_shape_dtype_string(shape, dtype), + preconditioner, + qr_mode), + "shape": shape, "dtype": dtype, "preconditioner": preconditioner, + "qr_mode": qr_mode} + for shape in [(3, 3)] + # TODO(shoyer): get working for np.complex128 and qr_mode=True + for dtype in [np.float64] + for preconditioner in [None, 'identity', 'exact'] + for qr_mode in [False])) + def test_gmres_against_scipy(self, shape, dtype, preconditioner, qr_mode): + if not config.FLAGS.jax_enable_x64: + raise unittest.SkipTest("requires x64 mode") + + rng = jtu.rand_default(self.rng()) + A = rng(shape, dtype) + b = rng(shape[:1], dtype) + M = self._fetch_preconditioner(preconditioner, A, rng=rng) + + def args_maker(): + return A, b + + self._CheckAgainstNumpy( + partial(scipy_gmres, M=M, restart=1, maxiter=1), + partial(lax_gmres, M=M, restart=1, maxiter=1, qr_mode=qr_mode), + args_maker, + tol=1e-3) + + self._CheckAgainstNumpy( + partial(scipy_gmres, M=M, restart=1, maxiter=2), + partial(lax_gmres, M=M, restart=1, maxiter=2, qr_mode=qr_mode), + args_maker, + tol=1e-3) + + self._CheckAgainstNumpy( + partial(scipy_gmres, M=M, restart=3, maxiter=1), + partial(lax_gmres, M=M, restart=3, maxiter=1, qr_mode=qr_mode), + args_maker, + tol=3e-3) + + self._CheckAgainstNumpy( + np.linalg.solve, + partial(lax_gmres, M=M, atol=1e-6, qr_mode=qr_mode), + args_maker, + tol=2e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": @@ -252,14 +311,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase): b = A_mv(solution) restart = shape[-1] tol = shape[0] * jnp.finfo(dtype).eps - x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol, - restart=restart, - M=M, - qr_mode=qr_mode) - err = jnp.linalg.norm(solution - x) / jnp.linalg.norm(b) - rtol = tol*jnp.linalg.norm(b) - true_tol = max(rtol, tol) - self.assertLessEqual(err, true_tol) + x, info = jax.scipy.sparse.linalg._gmres(A_mv, b, tol=tol, atol=tol, + restart=restart, + M=M, + qr_mode=qr_mode) + self.assertAllClose(x, solution, atol=300*jnp.finfo(x.dtype).eps) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": @@ -269,7 +325,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase): qr_mode), "shape": shape, "dtype": dtype, "preconditioner": preconditioner, "qr_mode": qr_mode} - for shape in [(2, 2), (7, 7)] + for shape in [(2, 2), (4, 4)] for dtype in float_types + complex_types for preconditioner in [None, 'identity', 'exact'] for qr_mode in [True, False] @@ -287,21 +343,19 @@ class LaxBackedScipyTests(jtu.JaxTestCase): return_function=True) b = A_mv(solution) restart = shape[-1] - tol = shape[0] * jnp.finfo(dtype).eps - x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol, + tol = shape[0] * jnp.finfo(A.dtype).eps + x, info = jax.scipy.sparse.linalg._gmres(A_mv, b, tol=tol, atol=tol, restart=restart, M=M, qr_mode=qr_mode) - err = jnp.linalg.norm(solution - x) / jnp.linalg.norm(b) - rtol = tol*jnp.linalg.norm(b) - true_tol = max(rtol, tol) - self.assertLessEqual(err, true_tol) + solution_tol = 300*jnp.finfo(x.dtype).eps + self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol) def test_gmres_pytree(self): A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]} b = {"a": 1.0, "b": -4.0} expected = {"a": 4.0, "b": -6.0} - actual, _ = jax.scipy.sparse.linalg.gmres(A, b) + actual, _ = jax.scipy.sparse.linalg._gmres(A, b) self.assertEqual(expected.keys(), actual.keys()) self.assertAlmostEqual(expected["a"], actual["a"], places=6) self.assertAlmostEqual(expected["b"], actual["b"], places=6) @@ -312,13 +366,16 @@ class LaxBackedScipyTests(jtu.JaxTestCase): jtu.format_shape_dtype_string(shape, dtype), preconditioner), "shape": shape, "dtype": dtype, "preconditioner": preconditioner} - for shape in [(2, 2), (7, 7), (32, 32)] + for shape in [(2, 2), (3, 3)] for dtype in float_types + complex_types for preconditioner in [None, 'identity'])) def test_gmres_arnoldi_step(self, shape, dtype, preconditioner): """ The Arnoldi decomposition within GMRES is correct. """ + if not config.FLAGS.jax_enable_x64: + raise unittest.SkipTest("requires x64 mode") + rng = jtu.rand_default(self.rng()) A = rng(shape, dtype) if preconditioner is None: @@ -339,10 +396,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase): def A_mv(x): return matmul_high_precision(A, x) for k in range(n): - Q, H, _ = jax.scipy.sparse.linalg.kth_arnoldi_iteration(k, A_mv, M, Q, H, - tol) - QAQ = matmul_high_precision(Q[:, :n].conj().T, A) - QAQ = matmul_high_precision(QAQ, Q[:, :n]) + Q, H, _ = jax.scipy.sparse.linalg._kth_arnoldi_iteration( + k, A_mv, M, Q, H, tol) + QA = matmul_high_precision(Q[:, :n].conj().T, A) + QAQ = matmul_high_precision(QA, Q[:, :n]) self.assertAllClose(QAQ, H.T[:n, :], rtol=tol, atol=tol)