diff --git a/jax/_src/scipy/sparse/linalg.py b/jax/_src/scipy/sparse/linalg.py index 385d5b62d..a0263c94a 100644 --- a/jax/_src/scipy/sparse/linalg.py +++ b/jax/_src/scipy/sparse/linalg.py @@ -73,6 +73,21 @@ def _identity(x): return x +def _normalize_matvec(f): + """Normalize an argument for computing matrix-vector products.""" + if callable(f): + return f + elif isinstance(f, (np.ndarray, jnp.ndarray)): + if f.ndim != 2 or f.shape[0] != f.shape[1]: + raise ValueError( + f'linear operator must be a square matrix, but has shape: {f.shape}') + return partial(_dot, f) + else: + # TODO(shoyer): handle sparse arrays? + raise TypeError( + f'linear operator must be either a function or ndarray: {f}') + + def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity): # tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg @@ -126,11 +141,11 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None): Parameters ---------- - A : function - Function that calculates the matrix-vector product ``Ax`` when called - like ``A(x)``. ``A`` must represent a hermitian, positive definite - matrix, and must return array(s) with the same structure and shape as its - argument. + A: ndarray or function + 2D array or function that calculates the linear map (matrix-vector + product) ``Ax`` when called like ``A(x)``. ``A`` must represent a + hermitian, positive definite matrix, and must return array(s) with the + same structure and shape as its argument. b : array or tree of arrays Right hand side of the linear system representing a single vector. Can be stored as an array or Python container of array(s) with any shape. @@ -154,7 +169,7 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None): maxiter : integer Maximum number of iterations. Iteration will stop after maxiter steps even if the specified tolerance has not been achieved. - M : function + M : ndarray or function Preconditioner for A. The preconditioner should approximate the inverse of A. Effective preconditioning dramatically improves the rate of convergence, which implies that fewer iterations are needed @@ -176,6 +191,8 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None): if M is None: M = _identity + A = _normalize_matvec(A) + M = _normalize_matvec(M) if tree_structure(x0) != tree_structure(b): raise ValueError( @@ -507,10 +524,10 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None, Parameters ---------- - A: function - Function that calculates the linear map (matrix-vector product) - ``Ax`` when called like ``A(x)``. ``A`` must return array(s) with the same - structure and shape as its argument. + A: ndarray or function + 2D array or function that calculates the linear map (matrix-vector + product) ``Ax`` when called like ``A(x)``. ``A`` must return array(s) with + the same structure and shape as its argument. b : array or tree of arrays Right hand side of the linear system representing a single vector. Can be stored as an array or Python container of array(s) with any shape. @@ -526,8 +543,8 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None, Other Parameters ---------------- x0 : array, optional - Starting guess for the solution. Must have the same structure as ``b``. - If this is unspecified, zeroes are used. + Starting guess for the solution. Must have the same structure as ``b``. + If this is unspecified, zeroes are used. tol, atol : float, optional Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``. We do not implement SciPy's "legacy" behavior, so JAX's tolerance will @@ -546,7 +563,7 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None, starting from the solution found at the last iteration. If GMRES halts or is very slow, decreasing this parameter may help. Default is infinite. - M : function + M : ndarray or function Preconditioner for A. The preconditioner should approximate the inverse of A. Effective preconditioning dramatically improves the rate of convergence, which implies that fewer iterations are needed @@ -570,6 +587,8 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None, x0 = tree_map(jnp.zeros_like, b) if M is None: M = _identity + A = _normalize_matvec(A) + M = _normalize_matvec(M) b, x0 = device_put((b, x0)) size = sum(bi.size for bi in tree_leaves(b)) diff --git a/tests/lax_scipy_sparse_test.py b/tests/lax_scipy_sparse_test.py index fc9a6a38a..7367db6ea 100644 --- a/tests/lax_scipy_sparse_test.py +++ b/tests/lax_scipy_sparse_test.py @@ -45,27 +45,15 @@ def posify(matrix): return matmul_high_precision(matrix, matrix.T.conj()) -def lax_solver(solver_name, A, b, M=None, atol=0.0, **kwargs): - A = partial(matmul_high_precision, A) - if M is not None: - M = partial(matmul_high_precision, M) - func = getattr(jax.scipy.sparse.linalg, solver_name) +def solver(func, A, b, M=None, atol=0.0, **kwargs): x, _ = func(A, b, atol=atol, M=M, **kwargs) return x -lax_cg = partial(lax_solver, 'cg') -lax_gmres = partial(lax_solver, 'gmres') - - -def scipy_solver(solver_name, A, b, atol=0.0, **kwargs): - func = getattr(scipy.sparse.linalg, solver_name) - x, _ = func(A, b, atol=atol, **kwargs) - return x - - -scipy_cg = partial(scipy_solver, 'cg') -scipy_gmres = partial(scipy_solver, 'gmres') +lax_cg = partial(solver, jax.scipy.sparse.linalg.cg) +lax_gmres = partial(solver, jax.scipy.sparse.linalg.gmres) +scipy_cg = partial(solver, scipy.sparse.linalg.cg) +scipy_gmres = partial(solver, scipy.sparse.linalg.gmres) def rand_sym_pos_def(rng, shape, dtype): @@ -74,8 +62,7 @@ def rand_sym_pos_def(rng, shape, dtype): class LaxBackedScipyTests(jtu.JaxTestCase): - def _fetch_preconditioner(self, preconditioner, A, rng=None, - return_function=False): + def _fetch_preconditioner(self, preconditioner, A, rng=None): """ Returns one of various preconditioning matrices depending on the identifier `preconditioner' and the input matrix A whose inverse it supposedly @@ -91,11 +78,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase): M = np.linalg.inv(A) else: M = None - - if M is None or not return_function: - return M - else: - return lambda x: jnp.dot(M, x, precision=lax.Precision.HIGHEST) + return M @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": @@ -186,6 +169,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase): with self.assertRaisesRegex( ValueError, "x0 and b must have matching shape"): jax.scipy.sparse.linalg.cg(A, b, b[:, np.newaxis]) + with self.assertRaisesRegex(ValueError, "must be a square matrix"): + jax.scipy.sparse.linalg.cg(jnp.zeros((3, 2)), jnp.zeros((2,))) + with self.assertRaisesRegex( + TypeError, "linear operator must be either a function or ndarray"): + jax.scipy.sparse.linalg.cg([[1]], jnp.zeros((1,))) def test_cg_without_pytree_equality(self): @@ -273,16 +261,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase): A = jnp.eye(shape[1], dtype=dtype) solution = jnp.ones(shape[1], dtype=dtype) - @jax.tree_util.Partial - def A_mv(x): - return matmul_high_precision(A, x) rng = jtu.rand_default(self.rng()) - M = self._fetch_preconditioner(preconditioner, A, rng=rng, - return_function=True) - b = A_mv(solution) + M = self._fetch_preconditioner(preconditioner, A, rng=rng) + b = matmul_high_precision(A, solution) restart = shape[-1] tol = shape[0] * jnp.finfo(dtype).eps - x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol, + x, info = jax.scipy.sparse.linalg.gmres(A, b, tol=tol, atol=tol, restart=restart, M=M, solve_method=solve_method) using_x64 = solution.dtype.kind in {np.float64, np.complex128} @@ -308,15 +292,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase): A = rng(shape, dtype) solution = rng(shape[1:], dtype) - @jax.tree_util.Partial - def A_mv(x): - return matmul_high_precision(A, x) - M = self._fetch_preconditioner(preconditioner, A, rng=rng, - return_function=True) - b = A_mv(solution) + M = self._fetch_preconditioner(preconditioner, A, rng=rng) + b = matmul_high_precision(A, solution) restart = shape[-1] tol = shape[0] * jnp.finfo(A.dtype).eps - x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol, + x, info = jax.scipy.sparse.linalg.gmres(A, b, tol=tol, atol=tol, restart=restart, M=M, solve_method=solve_method) using_x64 = solution.dtype.kind in {np.float64, np.complex128} @@ -350,12 +330,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase): rng = jtu.rand_default(self.rng()) A = rng(shape, dtype) + M = self._fetch_preconditioner(preconditioner, A, rng=rng) if preconditioner is None: M = lambda x: x else: - M = self._fetch_preconditioner(preconditioner, A, rng=rng, - return_function=True) - + M = partial(matmul_high_precision, M) n = shape[0] x0 = rng(shape[:1], dtype) Q = np.zeros((n, n + 1), dtype=dtype)