mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Renable custom_linear_solve and cg with complex values
This commit is contained in:
parent
7cf5a94bba
commit
9cc5e9018c
@ -1603,11 +1603,6 @@ def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs, tree):
|
||||
kwargs = dict(const_lengths=const_lengths, jaxprs=jaxprs, tree=tree)
|
||||
x = linear_solve_p.bind(*primals, **kwargs)
|
||||
|
||||
if any(issubclass(dtypes.dtype(xi).type, onp.complexfloating) for xi in x):
|
||||
raise NotImplementedError(
|
||||
"gradients of complex values are not yet supported in "
|
||||
"custom_linear_solve: https://github.com/google/jax/issues/2572")
|
||||
|
||||
params, _ = _split_linear_solve_args(primals, const_lengths)
|
||||
params_dot, b_dot = _split_linear_solve_args(tangents, const_lengths)
|
||||
|
||||
|
@ -146,6 +146,10 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
|
||||
|
||||
cg_solve = partial(
|
||||
_cg_solve, x0=x0, tol=tol, atol=atol, maxiter=maxiter, M=M)
|
||||
x = lax.custom_linear_solve(A, b, cg_solve, symmetric=True)
|
||||
# real-valued positive-definite linear operators are symmetric
|
||||
real_valued = lambda x: not issubclass(x.dtype.type, np.complexfloating)
|
||||
symmetric = all(map(real_valued, tree_leaves(b)))
|
||||
x = lax.custom_linear_solve(
|
||||
A, b, solve=cg_solve, transpose_solve=cg_solve, symmetric=symmetric)
|
||||
info = None # TODO(shoyer): return the real iteration count here
|
||||
return x, info
|
||||
|
@ -1655,25 +1655,18 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
def test_custom_linear_solve_complex(self):
|
||||
|
||||
def positive_definite_solve(a, b):
|
||||
def solve(a, b):
|
||||
def solve(matvec, x):
|
||||
return jsp.linalg.solve(a, x)
|
||||
def tr_solve(matvec, x):
|
||||
return jsp.linalg.solve(a.T, x)
|
||||
matvec = partial(high_precision_dot, a)
|
||||
return lax.custom_linear_solve(matvec, b, solve, symmetric=True)
|
||||
return lax.custom_linear_solve(matvec, b, solve, tr_solve)
|
||||
|
||||
rng = onp.random.RandomState(0)
|
||||
a = 0.5 * rng.randn(2, 2) + 0.5j * rng.randn(2, 2)
|
||||
b = 0.5 * rng.randn(2) + 0.5j * rng.randn(2)
|
||||
|
||||
expected = np.linalg.solve(posify(a), b)
|
||||
actual = positive_definite_solve(posify(a), b)
|
||||
self.assertAllClose(expected, actual, check_dtypes=True)
|
||||
|
||||
# TODO(shoyer): remove this error when complex values work
|
||||
with self.assertRaises(NotImplementedError):
|
||||
jtu.check_grads(
|
||||
lambda x, y: positive_definite_solve(posify(x), y),
|
||||
(a, b), order=2, rtol=1e-2)
|
||||
jtu.check_grads(solve, (a, b), order=2, rtol=1e-2)
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def test_custom_linear_solve_lu(self):
|
||||
|
@ -113,7 +113,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype, "rng_factory": rng_factory}
|
||||
for shape in [(2, 2)]
|
||||
for dtype in float_types
|
||||
for dtype in float_types + complex_types
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def test_cg_as_solve(self, shape, dtype, rng_factory):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user