mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fixup complex values and tol in tests for jax.scipy.linalg.sparse.cg (#2717)
* Fixup complex values and tol in tests for jax.scipy.linalg.sparse.cg The tests for CG were failing on TPUs: - `test_cg_pytree` is fixed by requiring slightly less precision than the unit-test default. - `test_cg_against_scipy` is fixed for complex values in two independent ways: 1. We don't set both `tol=0` and `atol=0`, which made the termination behavior of CG (convergence or NaN) dependent on exactly how XLA handles arithmetic with denormals. 2. We make use of *real valued* inner products inside `cg`, even for complex values. It turns that all these inner products are mathematically guaranteed to yield a real number anyways, so we can save some flops and avoid ill-defined comparisons of complex-values (see https://github.com/numpy/numpy/issues/15981) by ignoring the complex part of the result from `jnp.vdot`. (Real numbers also happen to have the desired rounding behavior for denormals on TPUs, so this on its own would also fix these failures.) * comment fixup * fix my comment
This commit is contained in:
parent
5baa59fe71
commit
8fa707af98
@ -23,10 +23,23 @@ from jax import lax, device_put
|
||||
from jax.tree_util import tree_leaves, tree_map, tree_multimap
|
||||
|
||||
|
||||
def _vdot_real_part(x, y):
|
||||
"""Vector dot-product guaranteed to have a real valued result."""
|
||||
# all our uses of vdot() in CG are for computing an operator of the form
|
||||
# `z^T M z` where `M` is positive definite and Hermitian, so the result is
|
||||
# real valued:
|
||||
# https://en.wikipedia.org/wiki/Definiteness_of_a_matrix#Definitions_for_complex_matrices
|
||||
vdot = partial(jnp.vdot, precision=lax.Precision.HIGHEST)
|
||||
result = vdot(x.real, y.real)
|
||||
if jnp.iscomplexobj(x) or jnp.iscomplexobj(y):
|
||||
result += vdot(x.imag, y.imag)
|
||||
return result
|
||||
|
||||
|
||||
# aliases for working with pytrees
|
||||
def _vdot(x, y):
|
||||
f = partial(jnp.vdot, precision=lax.Precision.HIGHEST)
|
||||
return sum(tree_leaves(tree_multimap(f, x, y)))
|
||||
|
||||
def _vdot_tree(x, y):
|
||||
return sum(tree_leaves(tree_multimap(_vdot_real_part, x, y)))
|
||||
|
||||
def _mul(scalar, tree):
|
||||
return tree_map(partial(operator.mul, scalar), tree)
|
||||
@ -42,31 +55,31 @@ def _identity(x):
|
||||
def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):
|
||||
|
||||
# tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg
|
||||
bs = _vdot(b, b)
|
||||
bs = _vdot_tree(b, b)
|
||||
atol2 = jnp.maximum(tol ** 2 * bs, atol ** 2)
|
||||
|
||||
# https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method
|
||||
|
||||
def cond_fun(value):
|
||||
x, r, gamma, p, k = value
|
||||
rs = gamma if M is _identity else _vdot(r, r)
|
||||
rs = gamma if M is _identity else _vdot_tree(r, r)
|
||||
return (rs > atol2) & (k < maxiter)
|
||||
|
||||
def body_fun(value):
|
||||
x, r, gamma, p, k = value
|
||||
Ap = A(p)
|
||||
alpha = gamma / _vdot(p, Ap)
|
||||
alpha = gamma / _vdot_tree(p, Ap)
|
||||
x_ = _add(x, _mul(alpha, p))
|
||||
r_ = _sub(r, _mul(alpha, Ap))
|
||||
z_ = M(r_)
|
||||
gamma_ = _vdot(r_, z_)
|
||||
gamma_ = _vdot_tree(r_, z_)
|
||||
beta_ = gamma_ / gamma
|
||||
p_ = _add(z_, _mul(beta_, p))
|
||||
return x_, r_, gamma_, p_, k + 1
|
||||
|
||||
r0 = _sub(b, A(x0))
|
||||
p0 = z0 = M(r0)
|
||||
gamma0 = _vdot(r0, z0)
|
||||
gamma0 = _vdot_tree(r0, z0)
|
||||
initial_value = (x0, r0, gamma0, p0, 0)
|
||||
|
||||
x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
|
||||
|
@ -13,16 +13,21 @@
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
|
||||
from absl.testing import parameterized
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
import scipy.sparse.linalg
|
||||
|
||||
from jax import jit
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import scipy.sparse.linalg
|
||||
from jax import lax
|
||||
from jax import test_util as jtu
|
||||
import jax.scipy.sparse.linalg
|
||||
from jax.config import config
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -40,16 +45,16 @@ def posify(matrix):
|
||||
return matmul_high_precision(matrix, matrix.T.conj())
|
||||
|
||||
|
||||
def lax_cg(A, b, M=None, tol=0.0, atol=0.0, **kwargs):
|
||||
def lax_cg(A, b, M=None, atol=0.0, **kwargs):
|
||||
A = partial(matmul_high_precision, A)
|
||||
if M is not None:
|
||||
M = partial(matmul_high_precision, M)
|
||||
x, _ = jax.scipy.sparse.linalg.cg(A, b, tol=tol, atol=atol, M=M, **kwargs)
|
||||
x, _ = jax.scipy.sparse.linalg.cg(A, b, atol=atol, M=M, **kwargs)
|
||||
return x
|
||||
|
||||
|
||||
def scipy_cg(A, b, tol=0.0, atol=0.0, **kwargs):
|
||||
x, _ = scipy.sparse.linalg.cg(A, b, tol=tol, atol=atol, **kwargs)
|
||||
def scipy_cg(A, b, atol=0.0, **kwargs):
|
||||
x, _ = scipy.sparse.linalg.cg(A, b, atol=atol, **kwargs)
|
||||
return x
|
||||
|
||||
|
||||
@ -149,8 +154,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
expected = {"a": 4.0, "b": -6.0}
|
||||
actual, _ = jax.scipy.sparse.linalg.cg(A, b)
|
||||
self.assertEqual(expected.keys(), actual.keys())
|
||||
self.assertAlmostEqual(expected["a"], actual["a"])
|
||||
self.assertAlmostEqual(expected["b"], actual["b"])
|
||||
self.assertAlmostEqual(expected["a"], actual["a"], places=6)
|
||||
self.assertAlmostEqual(expected["b"], actual["b"], places=6)
|
||||
|
||||
def test_cg_errors(self):
|
||||
A = lambda x: x
|
||||
|
Loading…
x
Reference in New Issue
Block a user