Fixup complex values and tol in tests for jax.scipy.linalg.sparse.cg (#2717)

* Fixup complex values and tol in tests for jax.scipy.linalg.sparse.cg

The tests for CG were failing on TPUs:

- `test_cg_pytree` is fixed by requiring slightly less precision than the
  unit-test default.
- `test_cg_against_scipy` is fixed for complex values in two independent ways:
  1. We don't set both `tol=0` and `atol=0`, which made the termination
     behavior of CG (convergence or NaN) dependent on exactly how XLA handles
     arithmetic with denormals.
  2. We make use of *real valued* inner products inside `cg`, even for complex
     values. It turns that all these inner products are mathematically
     guaranteed to yield a real number anyways, so we can save some flops and
     avoid ill-defined comparisons of complex-values (see
     https://github.com/numpy/numpy/issues/15981) by ignoring the complex part
     of the result from `jnp.vdot`. (Real numbers also happen to have the
     desired rounding behavior for denormals on TPUs, so this on its own would
     also fix these failures.)

* comment fixup

* fix my comment
This commit is contained in:
Stephan Hoyer 2020-04-14 22:35:48 -07:00 committed by GitHub
parent 5baa59fe71
commit 8fa707af98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 16 deletions

View File

@ -23,10 +23,23 @@ from jax import lax, device_put
from jax.tree_util import tree_leaves, tree_map, tree_multimap
def _vdot_real_part(x, y):
"""Vector dot-product guaranteed to have a real valued result."""
# all our uses of vdot() in CG are for computing an operator of the form
# `z^T M z` where `M` is positive definite and Hermitian, so the result is
# real valued:
# https://en.wikipedia.org/wiki/Definiteness_of_a_matrix#Definitions_for_complex_matrices
vdot = partial(jnp.vdot, precision=lax.Precision.HIGHEST)
result = vdot(x.real, y.real)
if jnp.iscomplexobj(x) or jnp.iscomplexobj(y):
result += vdot(x.imag, y.imag)
return result
# aliases for working with pytrees
def _vdot(x, y):
f = partial(jnp.vdot, precision=lax.Precision.HIGHEST)
return sum(tree_leaves(tree_multimap(f, x, y)))
def _vdot_tree(x, y):
return sum(tree_leaves(tree_multimap(_vdot_real_part, x, y)))
def _mul(scalar, tree):
return tree_map(partial(operator.mul, scalar), tree)
@ -42,31 +55,31 @@ def _identity(x):
def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):
# tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg
bs = _vdot(b, b)
bs = _vdot_tree(b, b)
atol2 = jnp.maximum(tol ** 2 * bs, atol ** 2)
# https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method
def cond_fun(value):
x, r, gamma, p, k = value
rs = gamma if M is _identity else _vdot(r, r)
rs = gamma if M is _identity else _vdot_tree(r, r)
return (rs > atol2) & (k < maxiter)
def body_fun(value):
x, r, gamma, p, k = value
Ap = A(p)
alpha = gamma / _vdot(p, Ap)
alpha = gamma / _vdot_tree(p, Ap)
x_ = _add(x, _mul(alpha, p))
r_ = _sub(r, _mul(alpha, Ap))
z_ = M(r_)
gamma_ = _vdot(r_, z_)
gamma_ = _vdot_tree(r_, z_)
beta_ = gamma_ / gamma
p_ = _add(z_, _mul(beta_, p))
return x_, r_, gamma_, p_, k + 1
r0 = _sub(b, A(x0))
p0 = z0 = M(r0)
gamma0 = _vdot(r0, z0)
gamma0 = _vdot_tree(r0, z0)
initial_value = (x0, r0, gamma0, p0, 0)
x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)

View File

@ -13,16 +13,21 @@
# limitations under the License.
from functools import partial
from absl.testing import parameterized
from absl.testing import absltest
import numpy as np
import scipy.sparse.linalg
from jax import jit
import jax.numpy as jnp
import numpy as np
import scipy.sparse.linalg
from jax import lax
from jax import test_util as jtu
import jax.scipy.sparse.linalg
from jax.config import config
config.parse_flags_with_absl()
from jax.config import config
config.parse_flags_with_absl()
@ -40,16 +45,16 @@ def posify(matrix):
return matmul_high_precision(matrix, matrix.T.conj())
def lax_cg(A, b, M=None, tol=0.0, atol=0.0, **kwargs):
def lax_cg(A, b, M=None, atol=0.0, **kwargs):
A = partial(matmul_high_precision, A)
if M is not None:
M = partial(matmul_high_precision, M)
x, _ = jax.scipy.sparse.linalg.cg(A, b, tol=tol, atol=atol, M=M, **kwargs)
x, _ = jax.scipy.sparse.linalg.cg(A, b, atol=atol, M=M, **kwargs)
return x
def scipy_cg(A, b, tol=0.0, atol=0.0, **kwargs):
x, _ = scipy.sparse.linalg.cg(A, b, tol=tol, atol=atol, **kwargs)
def scipy_cg(A, b, atol=0.0, **kwargs):
x, _ = scipy.sparse.linalg.cg(A, b, atol=atol, **kwargs)
return x
@ -149,8 +154,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
expected = {"a": 4.0, "b": -6.0}
actual, _ = jax.scipy.sparse.linalg.cg(A, b)
self.assertEqual(expected.keys(), actual.keys())
self.assertAlmostEqual(expected["a"], actual["a"])
self.assertAlmostEqual(expected["b"], actual["b"])
self.assertAlmostEqual(expected["a"], actual["a"], places=6)
self.assertAlmostEqual(expected["b"], actual["b"], places=6)
def test_cg_errors(self):
A = lambda x: x