Renable custom_linear_solve and cg with complex values

This commit is contained in:
Stephan Hoyer 2020-04-09 00:46:10 -07:00
parent 7cf5a94bba
commit 9cc5e9018c
4 changed files with 11 additions and 19 deletions

View File

@ -1603,11 +1603,6 @@ def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs, tree):
kwargs = dict(const_lengths=const_lengths, jaxprs=jaxprs, tree=tree)
x = linear_solve_p.bind(*primals, **kwargs)
if any(issubclass(dtypes.dtype(xi).type, onp.complexfloating) for xi in x):
raise NotImplementedError(
"gradients of complex values are not yet supported in "
"custom_linear_solve: https://github.com/google/jax/issues/2572")
params, _ = _split_linear_solve_args(primals, const_lengths)
params_dot, b_dot = _split_linear_solve_args(tangents, const_lengths)

View File

@ -146,6 +146,10 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
cg_solve = partial(
_cg_solve, x0=x0, tol=tol, atol=atol, maxiter=maxiter, M=M)
x = lax.custom_linear_solve(A, b, cg_solve, symmetric=True)
# real-valued positive-definite linear operators are symmetric
real_valued = lambda x: not issubclass(x.dtype.type, np.complexfloating)
symmetric = all(map(real_valued, tree_leaves(b)))
x = lax.custom_linear_solve(
A, b, solve=cg_solve, transpose_solve=cg_solve, symmetric=symmetric)
info = None # TODO(shoyer): return the real iteration count here
return x, info

View File

@ -1655,25 +1655,18 @@ class LaxControlFlowTest(jtu.JaxTestCase):
def test_custom_linear_solve_complex(self):
def positive_definite_solve(a, b):
def solve(a, b):
def solve(matvec, x):
return jsp.linalg.solve(a, x)
def tr_solve(matvec, x):
return jsp.linalg.solve(a.T, x)
matvec = partial(high_precision_dot, a)
return lax.custom_linear_solve(matvec, b, solve, symmetric=True)
return lax.custom_linear_solve(matvec, b, solve, tr_solve)
rng = onp.random.RandomState(0)
a = 0.5 * rng.randn(2, 2) + 0.5j * rng.randn(2, 2)
b = 0.5 * rng.randn(2) + 0.5j * rng.randn(2)
expected = np.linalg.solve(posify(a), b)
actual = positive_definite_solve(posify(a), b)
self.assertAllClose(expected, actual, check_dtypes=True)
# TODO(shoyer): remove this error when complex values work
with self.assertRaises(NotImplementedError):
jtu.check_grads(
lambda x, y: positive_definite_solve(posify(x), y),
(a, b), order=2, rtol=1e-2)
jtu.check_grads(solve, (a, b), order=2, rtol=1e-2)
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_custom_linear_solve_lu(self):

View File

@ -113,7 +113,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype, "rng_factory": rng_factory}
for shape in [(2, 2)]
for dtype in float_types
for dtype in float_types + complex_types
for rng_factory in [jtu.rand_default]))
def test_cg_as_solve(self, shape, dtype, rng_factory):