Support ndarrays as arguments to cg and gmres

This is consistent with SciPy, and makes things a little bit less
surprising for users.
This commit is contained in:
Stephan Hoyer 2020-12-04 12:00:59 -08:00
parent 1464142b1f
commit cd9f6cccbf
2 changed files with 52 additions and 54 deletions

View File

@ -73,6 +73,21 @@ def _identity(x):
return x
def _normalize_matvec(f):
"""Normalize an argument for computing matrix-vector products."""
if callable(f):
return f
elif isinstance(f, (np.ndarray, jnp.ndarray)):
if f.ndim != 2 or f.shape[0] != f.shape[1]:
raise ValueError(
f'linear operator must be a square matrix, but has shape: {f.shape}')
return partial(_dot, f)
else:
# TODO(shoyer): handle sparse arrays?
raise TypeError(
f'linear operator must be either a function or ndarray: {f}')
def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):
# tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg
@ -126,11 +141,11 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
Parameters
----------
A : function
Function that calculates the matrix-vector product ``Ax`` when called
like ``A(x)``. ``A`` must represent a hermitian, positive definite
matrix, and must return array(s) with the same structure and shape as its
argument.
A: ndarray or function
2D array or function that calculates the linear map (matrix-vector
product) ``Ax`` when called like ``A(x)``. ``A`` must represent a
hermitian, positive definite matrix, and must return array(s) with the
same structure and shape as its argument.
b : array or tree of arrays
Right hand side of the linear system representing a single vector. Can be
stored as an array or Python container of array(s) with any shape.
@ -154,7 +169,7 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
maxiter : integer
Maximum number of iterations. Iteration will stop after maxiter
steps even if the specified tolerance has not been achieved.
M : function
M : ndarray or function
Preconditioner for A. The preconditioner should approximate the
inverse of A. Effective preconditioning dramatically improves the
rate of convergence, which implies that fewer iterations are needed
@ -176,6 +191,8 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
if M is None:
M = _identity
A = _normalize_matvec(A)
M = _normalize_matvec(M)
if tree_structure(x0) != tree_structure(b):
raise ValueError(
@ -507,10 +524,10 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
Parameters
----------
A: function
Function that calculates the linear map (matrix-vector product)
``Ax`` when called like ``A(x)``. ``A`` must return array(s) with the same
structure and shape as its argument.
A: ndarray or function
2D array or function that calculates the linear map (matrix-vector
product) ``Ax`` when called like ``A(x)``. ``A`` must return array(s) with
the same structure and shape as its argument.
b : array or tree of arrays
Right hand side of the linear system representing a single vector. Can be
stored as an array or Python container of array(s) with any shape.
@ -526,8 +543,8 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
Other Parameters
----------------
x0 : array, optional
Starting guess for the solution. Must have the same structure as ``b``.
If this is unspecified, zeroes are used.
Starting guess for the solution. Must have the same structure as ``b``.
If this is unspecified, zeroes are used.
tol, atol : float, optional
Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
We do not implement SciPy's "legacy" behavior, so JAX's tolerance will
@ -546,7 +563,7 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
starting from the solution found at the last iteration. If GMRES
halts or is very slow, decreasing this parameter may help.
Default is infinite.
M : function
M : ndarray or function
Preconditioner for A. The preconditioner should approximate the
inverse of A. Effective preconditioning dramatically improves the
rate of convergence, which implies that fewer iterations are needed
@ -570,6 +587,8 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
x0 = tree_map(jnp.zeros_like, b)
if M is None:
M = _identity
A = _normalize_matvec(A)
M = _normalize_matvec(M)
b, x0 = device_put((b, x0))
size = sum(bi.size for bi in tree_leaves(b))

View File

@ -45,27 +45,15 @@ def posify(matrix):
return matmul_high_precision(matrix, matrix.T.conj())
def lax_solver(solver_name, A, b, M=None, atol=0.0, **kwargs):
A = partial(matmul_high_precision, A)
if M is not None:
M = partial(matmul_high_precision, M)
func = getattr(jax.scipy.sparse.linalg, solver_name)
def solver(func, A, b, M=None, atol=0.0, **kwargs):
x, _ = func(A, b, atol=atol, M=M, **kwargs)
return x
lax_cg = partial(lax_solver, 'cg')
lax_gmres = partial(lax_solver, 'gmres')
def scipy_solver(solver_name, A, b, atol=0.0, **kwargs):
func = getattr(scipy.sparse.linalg, solver_name)
x, _ = func(A, b, atol=atol, **kwargs)
return x
scipy_cg = partial(scipy_solver, 'cg')
scipy_gmres = partial(scipy_solver, 'gmres')
lax_cg = partial(solver, jax.scipy.sparse.linalg.cg)
lax_gmres = partial(solver, jax.scipy.sparse.linalg.gmres)
scipy_cg = partial(solver, scipy.sparse.linalg.cg)
scipy_gmres = partial(solver, scipy.sparse.linalg.gmres)
def rand_sym_pos_def(rng, shape, dtype):
@ -74,8 +62,7 @@ def rand_sym_pos_def(rng, shape, dtype):
class LaxBackedScipyTests(jtu.JaxTestCase):
def _fetch_preconditioner(self, preconditioner, A, rng=None,
return_function=False):
def _fetch_preconditioner(self, preconditioner, A, rng=None):
"""
Returns one of various preconditioning matrices depending on the identifier
`preconditioner' and the input matrix A whose inverse it supposedly
@ -91,11 +78,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
M = np.linalg.inv(A)
else:
M = None
if M is None or not return_function:
return M
else:
return lambda x: jnp.dot(M, x, precision=lax.Precision.HIGHEST)
return M
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -186,6 +169,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError, "x0 and b must have matching shape"):
jax.scipy.sparse.linalg.cg(A, b, b[:, np.newaxis])
with self.assertRaisesRegex(ValueError, "must be a square matrix"):
jax.scipy.sparse.linalg.cg(jnp.zeros((3, 2)), jnp.zeros((2,)))
with self.assertRaisesRegex(
TypeError, "linear operator must be either a function or ndarray"):
jax.scipy.sparse.linalg.cg([[1]], jnp.zeros((1,)))
def test_cg_without_pytree_equality(self):
@ -273,16 +261,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
A = jnp.eye(shape[1], dtype=dtype)
solution = jnp.ones(shape[1], dtype=dtype)
@jax.tree_util.Partial
def A_mv(x):
return matmul_high_precision(A, x)
rng = jtu.rand_default(self.rng())
M = self._fetch_preconditioner(preconditioner, A, rng=rng,
return_function=True)
b = A_mv(solution)
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
b = matmul_high_precision(A, solution)
restart = shape[-1]
tol = shape[0] * jnp.finfo(dtype).eps
x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol,
x, info = jax.scipy.sparse.linalg.gmres(A, b, tol=tol, atol=tol,
restart=restart,
M=M, solve_method=solve_method)
using_x64 = solution.dtype.kind in {np.float64, np.complex128}
@ -308,15 +292,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
A = rng(shape, dtype)
solution = rng(shape[1:], dtype)
@jax.tree_util.Partial
def A_mv(x):
return matmul_high_precision(A, x)
M = self._fetch_preconditioner(preconditioner, A, rng=rng,
return_function=True)
b = A_mv(solution)
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
b = matmul_high_precision(A, solution)
restart = shape[-1]
tol = shape[0] * jnp.finfo(A.dtype).eps
x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol,
x, info = jax.scipy.sparse.linalg.gmres(A, b, tol=tol, atol=tol,
restart=restart,
M=M, solve_method=solve_method)
using_x64 = solution.dtype.kind in {np.float64, np.complex128}
@ -350,12 +330,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
rng = jtu.rand_default(self.rng())
A = rng(shape, dtype)
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
if preconditioner is None:
M = lambda x: x
else:
M = self._fetch_preconditioner(preconditioner, A, rng=rng,
return_function=True)
M = partial(matmul_high_precision, M)
n = shape[0]
x0 = rng(shape[:1], dtype)
Q = np.zeros((n, n + 1), dtype=dtype)