mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00
More unit-tests + mark gmres as internal for now
This commit is contained in:
parent
c5c71a0a37
commit
7e62270e5a
@ -26,6 +26,8 @@ from jax.util import safe_map as map
|
||||
|
||||
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
|
||||
_vdot = partial(jnp.vdot, precision=lax.Precision.HIGHEST)
|
||||
_einsum = partial(jnp.einsum, precision=lax.Precision.HIGHEST)
|
||||
|
||||
|
||||
# aliases for working with pytrees
|
||||
def _vdot_real_part(x, y):
|
||||
@ -216,7 +218,6 @@ def _safe_normalize(x, thresh=None):
|
||||
thresh = thresh.astype(dtype).real
|
||||
|
||||
use_norm = norm > thresh
|
||||
# where_normalizing = partial(tree_map, partial(jnp.where, norm > thresh))
|
||||
normalized_x = tree_map(lambda y: jnp.where(use_norm, y / norm, 0.0), x)
|
||||
norm = jnp.where(use_norm, norm, 0.0)
|
||||
return normalized_x, norm
|
||||
@ -227,9 +228,7 @@ def _project_on_columns(A, v):
|
||||
Returns A.T.conj() @ v.
|
||||
"""
|
||||
v_proj = tree_multimap(
|
||||
lambda X, y: jnp.einsum("...n,...->n", X.conj(), y),
|
||||
A,
|
||||
v,
|
||||
lambda X, y: _einsum("...n,...->n", X.conj(), y), A, v,
|
||||
)
|
||||
return tree_reduce(operator.add, v_proj)
|
||||
|
||||
@ -299,7 +298,7 @@ def _iterative_classical_gram_schmidt(Q, x, max_iterations=2):
|
||||
return q, r
|
||||
|
||||
|
||||
def kth_arnoldi_iteration(k, A, M, V, H, tol):
|
||||
def _kth_arnoldi_iteration(k, A, M, V, H, tol):
|
||||
"""
|
||||
Performs a single (the k'th) step of the Arnoldi process. Thus,
|
||||
adds a new orthonormalized Krylov vector A(M(V[:, k])) to V[:, k+1],
|
||||
@ -321,7 +320,7 @@ def kth_arnoldi_iteration(k, A, M, V, H, tol):
|
||||
return V, H, breakdown
|
||||
|
||||
|
||||
def apply_givens_rotations(H_row, givens, k):
|
||||
def _apply_givens_rotations(H_row, givens, k):
|
||||
"""
|
||||
Applies the Givens rotations stored in the vectors cs and sn to the vector
|
||||
H_row. Then constructs and applies a new Givens rotation that eliminates
|
||||
@ -386,8 +385,8 @@ def _gmres_qr(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):
|
||||
|
||||
def arnoldi_qr_step(carry):
|
||||
k, _, V, R, beta_vec, givens = carry
|
||||
V, H, _ = kth_arnoldi_iteration(k, A, M, V, R, inner_tol)
|
||||
R_row, givens = apply_givens_rotations(H[k, :], givens, k)
|
||||
V, H, _ = _kth_arnoldi_iteration(k, A, M, V, R, inner_tol)
|
||||
R_row, givens = _apply_givens_rotations(H[k, :], givens, k)
|
||||
R = R.at[k, :].set(R_row[:])
|
||||
cs, sn = givens[k, :] * beta_vec[k]
|
||||
beta_vec = beta_vec.at[k].set(cs)
|
||||
@ -434,7 +433,7 @@ def _gmres_plain(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):
|
||||
|
||||
def arnoldi_process(carry):
|
||||
V, H, _, k = carry
|
||||
V, H, breakdown = kth_arnoldi_iteration(k, A, M, V, H, inner_tol)
|
||||
V, H, breakdown = _kth_arnoldi_iteration(k, A, M, V, H, inner_tol)
|
||||
return V, H, breakdown, k + 1
|
||||
|
||||
carry = (V, H, False, 0)
|
||||
@ -486,11 +485,10 @@ def _gmres_solve(A, b, x0, outer_tol, inner_tol, restart, maxiter, M,
|
||||
x_final, k, _, err = lax.while_loop(cond_fun, body_fun, initialization)
|
||||
_ = k # Until we can pass this out
|
||||
_ = err
|
||||
# info = lax.cond(converged, lambda y: 0, lambda y: k, 0)
|
||||
return x_final # , info
|
||||
|
||||
|
||||
def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
|
||||
def _gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
|
||||
M=None, qr_mode=False):
|
||||
"""
|
||||
GMRES solves the linear system A x = b for x, given A and b.
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import unittest
|
||||
|
||||
from absl.testing import parameterized
|
||||
from absl.testing import absltest
|
||||
@ -28,7 +29,7 @@ import jax.scipy.sparse.linalg
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
config.update('jax_enable_x64', True)
|
||||
|
||||
|
||||
float_types = [np.float32, np.float64]
|
||||
complex_types = [np.complex64, np.complex128]
|
||||
@ -43,19 +44,29 @@ def posify(matrix):
|
||||
return matmul_high_precision(matrix, matrix.T.conj())
|
||||
|
||||
|
||||
def lax_cg(A, b, M=None, atol=0.0, **kwargs):
|
||||
def lax_solver(solver_name, A, b, M=None, atol=0.0, **kwargs):
|
||||
A = partial(matmul_high_precision, A)
|
||||
if M is not None:
|
||||
M = partial(matmul_high_precision, M)
|
||||
x, _ = jax.scipy.sparse.linalg.cg(A, b, atol=atol, M=M, **kwargs)
|
||||
func = getattr(jax.scipy.sparse.linalg, solver_name)
|
||||
x, _ = func(A, b, atol=atol, M=M, **kwargs)
|
||||
return x
|
||||
|
||||
|
||||
def scipy_cg(A, b, atol=0.0, **kwargs):
|
||||
x, _ = scipy.sparse.linalg.cg(A, b, atol=atol, **kwargs)
|
||||
lax_cg = partial(lax_solver, 'cg')
|
||||
lax_gmres = partial(lax_solver, '_gmres')
|
||||
|
||||
|
||||
def scipy_solver(solver_name, A, b, atol=0.0, **kwargs):
|
||||
func = getattr(scipy.sparse.linalg, solver_name)
|
||||
x, _ = func(A, b, atol=atol, **kwargs)
|
||||
return x
|
||||
|
||||
|
||||
scipy_cg = partial(scipy_solver, 'cg')
|
||||
scipy_gmres = partial(scipy_solver, 'gmres')
|
||||
|
||||
|
||||
def rand_sym_pos_def(rng, shape, dtype):
|
||||
matrix = np.eye(N=shape[0], dtype=dtype) + rng(shape, dtype)
|
||||
return matrix @ matrix.T.conj()
|
||||
@ -219,11 +230,59 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
def A_mv(x):
|
||||
return matmul_high_precision(A, x)
|
||||
tol = A.size * jnp.finfo(dtype).eps
|
||||
x, _ = jax.scipy.sparse.linalg.gmres(A_mv, b, x0=x0, tol=tol, atol=tol,
|
||||
restart=restart, maxiter=maxiter)
|
||||
x, _ = jax.scipy.sparse.linalg._gmres(
|
||||
A_mv, b, x0=x0, tol=tol, atol=tol, restart=restart, maxiter=maxiter)
|
||||
solution = jnp.array([2., 1.], dtype=dtype)
|
||||
self.assertAllClose(solution, x)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_preconditioner={}_qr_mode={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
preconditioner,
|
||||
qr_mode),
|
||||
"shape": shape, "dtype": dtype, "preconditioner": preconditioner,
|
||||
"qr_mode": qr_mode}
|
||||
for shape in [(3, 3)]
|
||||
# TODO(shoyer): get working for np.complex128 and qr_mode=True
|
||||
for dtype in [np.float64]
|
||||
for preconditioner in [None, 'identity', 'exact']
|
||||
for qr_mode in [False]))
|
||||
def test_gmres_against_scipy(self, shape, dtype, preconditioner, qr_mode):
|
||||
if not config.FLAGS.jax_enable_x64:
|
||||
raise unittest.SkipTest("requires x64 mode")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
A = rng(shape, dtype)
|
||||
b = rng(shape[:1], dtype)
|
||||
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
|
||||
|
||||
def args_maker():
|
||||
return A, b
|
||||
|
||||
self._CheckAgainstNumpy(
|
||||
partial(scipy_gmres, M=M, restart=1, maxiter=1),
|
||||
partial(lax_gmres, M=M, restart=1, maxiter=1, qr_mode=qr_mode),
|
||||
args_maker,
|
||||
tol=1e-3)
|
||||
|
||||
self._CheckAgainstNumpy(
|
||||
partial(scipy_gmres, M=M, restart=1, maxiter=2),
|
||||
partial(lax_gmres, M=M, restart=1, maxiter=2, qr_mode=qr_mode),
|
||||
args_maker,
|
||||
tol=1e-3)
|
||||
|
||||
self._CheckAgainstNumpy(
|
||||
partial(scipy_gmres, M=M, restart=3, maxiter=1),
|
||||
partial(lax_gmres, M=M, restart=3, maxiter=1, qr_mode=qr_mode),
|
||||
args_maker,
|
||||
tol=3e-3)
|
||||
|
||||
self._CheckAgainstNumpy(
|
||||
np.linalg.solve,
|
||||
partial(lax_gmres, M=M, atol=1e-6, qr_mode=qr_mode),
|
||||
args_maker,
|
||||
tol=2e-2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -252,14 +311,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
b = A_mv(solution)
|
||||
restart = shape[-1]
|
||||
tol = shape[0] * jnp.finfo(dtype).eps
|
||||
x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol,
|
||||
x, info = jax.scipy.sparse.linalg._gmres(A_mv, b, tol=tol, atol=tol,
|
||||
restart=restart,
|
||||
M=M,
|
||||
qr_mode=qr_mode)
|
||||
err = jnp.linalg.norm(solution - x) / jnp.linalg.norm(b)
|
||||
rtol = tol*jnp.linalg.norm(b)
|
||||
true_tol = max(rtol, tol)
|
||||
self.assertLessEqual(err, true_tol)
|
||||
self.assertAllClose(x, solution, atol=300*jnp.finfo(x.dtype).eps)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -269,7 +325,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
qr_mode),
|
||||
"shape": shape, "dtype": dtype, "preconditioner": preconditioner,
|
||||
"qr_mode": qr_mode}
|
||||
for shape in [(2, 2), (7, 7)]
|
||||
for shape in [(2, 2), (4, 4)]
|
||||
for dtype in float_types + complex_types
|
||||
for preconditioner in [None, 'identity', 'exact']
|
||||
for qr_mode in [True, False]
|
||||
@ -287,21 +343,19 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
return_function=True)
|
||||
b = A_mv(solution)
|
||||
restart = shape[-1]
|
||||
tol = shape[0] * jnp.finfo(dtype).eps
|
||||
x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol,
|
||||
tol = shape[0] * jnp.finfo(A.dtype).eps
|
||||
x, info = jax.scipy.sparse.linalg._gmres(A_mv, b, tol=tol, atol=tol,
|
||||
restart=restart,
|
||||
M=M,
|
||||
qr_mode=qr_mode)
|
||||
err = jnp.linalg.norm(solution - x) / jnp.linalg.norm(b)
|
||||
rtol = tol*jnp.linalg.norm(b)
|
||||
true_tol = max(rtol, tol)
|
||||
self.assertLessEqual(err, true_tol)
|
||||
solution_tol = 300*jnp.finfo(x.dtype).eps
|
||||
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
|
||||
|
||||
def test_gmres_pytree(self):
|
||||
A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
|
||||
b = {"a": 1.0, "b": -4.0}
|
||||
expected = {"a": 4.0, "b": -6.0}
|
||||
actual, _ = jax.scipy.sparse.linalg.gmres(A, b)
|
||||
actual, _ = jax.scipy.sparse.linalg._gmres(A, b)
|
||||
self.assertEqual(expected.keys(), actual.keys())
|
||||
self.assertAlmostEqual(expected["a"], actual["a"], places=6)
|
||||
self.assertAlmostEqual(expected["b"], actual["b"], places=6)
|
||||
@ -312,13 +366,16 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
preconditioner),
|
||||
"shape": shape, "dtype": dtype, "preconditioner": preconditioner}
|
||||
for shape in [(2, 2), (7, 7), (32, 32)]
|
||||
for shape in [(2, 2), (3, 3)]
|
||||
for dtype in float_types + complex_types
|
||||
for preconditioner in [None, 'identity']))
|
||||
def test_gmres_arnoldi_step(self, shape, dtype, preconditioner):
|
||||
"""
|
||||
The Arnoldi decomposition within GMRES is correct.
|
||||
"""
|
||||
if not config.FLAGS.jax_enable_x64:
|
||||
raise unittest.SkipTest("requires x64 mode")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
A = rng(shape, dtype)
|
||||
if preconditioner is None:
|
||||
@ -339,10 +396,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
def A_mv(x):
|
||||
return matmul_high_precision(A, x)
|
||||
for k in range(n):
|
||||
Q, H, _ = jax.scipy.sparse.linalg.kth_arnoldi_iteration(k, A_mv, M, Q, H,
|
||||
tol)
|
||||
QAQ = matmul_high_precision(Q[:, :n].conj().T, A)
|
||||
QAQ = matmul_high_precision(QAQ, Q[:, :n])
|
||||
Q, H, _ = jax.scipy.sparse.linalg._kth_arnoldi_iteration(
|
||||
k, A_mv, M, Q, H, tol)
|
||||
QA = matmul_high_precision(Q[:, :n].conj().T, A)
|
||||
QAQ = matmul_high_precision(QA, Q[:, :n])
|
||||
self.assertAllClose(QAQ, H.T[:n, :], rtol=tol, atol=tol)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user