mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Support ndarrays as arguments to cg and gmres
This is consistent with SciPy, and makes things a little bit less surprising for users.
This commit is contained in:
parent
1464142b1f
commit
cd9f6cccbf
@ -73,6 +73,21 @@ def _identity(x):
|
||||
return x
|
||||
|
||||
|
||||
def _normalize_matvec(f):
|
||||
"""Normalize an argument for computing matrix-vector products."""
|
||||
if callable(f):
|
||||
return f
|
||||
elif isinstance(f, (np.ndarray, jnp.ndarray)):
|
||||
if f.ndim != 2 or f.shape[0] != f.shape[1]:
|
||||
raise ValueError(
|
||||
f'linear operator must be a square matrix, but has shape: {f.shape}')
|
||||
return partial(_dot, f)
|
||||
else:
|
||||
# TODO(shoyer): handle sparse arrays?
|
||||
raise TypeError(
|
||||
f'linear operator must be either a function or ndarray: {f}')
|
||||
|
||||
|
||||
def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):
|
||||
|
||||
# tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg
|
||||
@ -126,11 +141,11 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
A : function
|
||||
Function that calculates the matrix-vector product ``Ax`` when called
|
||||
like ``A(x)``. ``A`` must represent a hermitian, positive definite
|
||||
matrix, and must return array(s) with the same structure and shape as its
|
||||
argument.
|
||||
A: ndarray or function
|
||||
2D array or function that calculates the linear map (matrix-vector
|
||||
product) ``Ax`` when called like ``A(x)``. ``A`` must represent a
|
||||
hermitian, positive definite matrix, and must return array(s) with the
|
||||
same structure and shape as its argument.
|
||||
b : array or tree of arrays
|
||||
Right hand side of the linear system representing a single vector. Can be
|
||||
stored as an array or Python container of array(s) with any shape.
|
||||
@ -154,7 +169,7 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
|
||||
maxiter : integer
|
||||
Maximum number of iterations. Iteration will stop after maxiter
|
||||
steps even if the specified tolerance has not been achieved.
|
||||
M : function
|
||||
M : ndarray or function
|
||||
Preconditioner for A. The preconditioner should approximate the
|
||||
inverse of A. Effective preconditioning dramatically improves the
|
||||
rate of convergence, which implies that fewer iterations are needed
|
||||
@ -176,6 +191,8 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
|
||||
|
||||
if M is None:
|
||||
M = _identity
|
||||
A = _normalize_matvec(A)
|
||||
M = _normalize_matvec(M)
|
||||
|
||||
if tree_structure(x0) != tree_structure(b):
|
||||
raise ValueError(
|
||||
@ -507,10 +524,10 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
|
||||
|
||||
Parameters
|
||||
----------
|
||||
A: function
|
||||
Function that calculates the linear map (matrix-vector product)
|
||||
``Ax`` when called like ``A(x)``. ``A`` must return array(s) with the same
|
||||
structure and shape as its argument.
|
||||
A: ndarray or function
|
||||
2D array or function that calculates the linear map (matrix-vector
|
||||
product) ``Ax`` when called like ``A(x)``. ``A`` must return array(s) with
|
||||
the same structure and shape as its argument.
|
||||
b : array or tree of arrays
|
||||
Right hand side of the linear system representing a single vector. Can be
|
||||
stored as an array or Python container of array(s) with any shape.
|
||||
@ -526,8 +543,8 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
|
||||
Other Parameters
|
||||
----------------
|
||||
x0 : array, optional
|
||||
Starting guess for the solution. Must have the same structure as ``b``.
|
||||
If this is unspecified, zeroes are used.
|
||||
Starting guess for the solution. Must have the same structure as ``b``.
|
||||
If this is unspecified, zeroes are used.
|
||||
tol, atol : float, optional
|
||||
Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
|
||||
We do not implement SciPy's "legacy" behavior, so JAX's tolerance will
|
||||
@ -546,7 +563,7 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
|
||||
starting from the solution found at the last iteration. If GMRES
|
||||
halts or is very slow, decreasing this parameter may help.
|
||||
Default is infinite.
|
||||
M : function
|
||||
M : ndarray or function
|
||||
Preconditioner for A. The preconditioner should approximate the
|
||||
inverse of A. Effective preconditioning dramatically improves the
|
||||
rate of convergence, which implies that fewer iterations are needed
|
||||
@ -570,6 +587,8 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
|
||||
x0 = tree_map(jnp.zeros_like, b)
|
||||
if M is None:
|
||||
M = _identity
|
||||
A = _normalize_matvec(A)
|
||||
M = _normalize_matvec(M)
|
||||
|
||||
b, x0 = device_put((b, x0))
|
||||
size = sum(bi.size for bi in tree_leaves(b))
|
||||
|
@ -45,27 +45,15 @@ def posify(matrix):
|
||||
return matmul_high_precision(matrix, matrix.T.conj())
|
||||
|
||||
|
||||
def lax_solver(solver_name, A, b, M=None, atol=0.0, **kwargs):
|
||||
A = partial(matmul_high_precision, A)
|
||||
if M is not None:
|
||||
M = partial(matmul_high_precision, M)
|
||||
func = getattr(jax.scipy.sparse.linalg, solver_name)
|
||||
def solver(func, A, b, M=None, atol=0.0, **kwargs):
|
||||
x, _ = func(A, b, atol=atol, M=M, **kwargs)
|
||||
return x
|
||||
|
||||
|
||||
lax_cg = partial(lax_solver, 'cg')
|
||||
lax_gmres = partial(lax_solver, 'gmres')
|
||||
|
||||
|
||||
def scipy_solver(solver_name, A, b, atol=0.0, **kwargs):
|
||||
func = getattr(scipy.sparse.linalg, solver_name)
|
||||
x, _ = func(A, b, atol=atol, **kwargs)
|
||||
return x
|
||||
|
||||
|
||||
scipy_cg = partial(scipy_solver, 'cg')
|
||||
scipy_gmres = partial(scipy_solver, 'gmres')
|
||||
lax_cg = partial(solver, jax.scipy.sparse.linalg.cg)
|
||||
lax_gmres = partial(solver, jax.scipy.sparse.linalg.gmres)
|
||||
scipy_cg = partial(solver, scipy.sparse.linalg.cg)
|
||||
scipy_gmres = partial(solver, scipy.sparse.linalg.gmres)
|
||||
|
||||
|
||||
def rand_sym_pos_def(rng, shape, dtype):
|
||||
@ -74,8 +62,7 @@ def rand_sym_pos_def(rng, shape, dtype):
|
||||
|
||||
|
||||
class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
def _fetch_preconditioner(self, preconditioner, A, rng=None,
|
||||
return_function=False):
|
||||
def _fetch_preconditioner(self, preconditioner, A, rng=None):
|
||||
"""
|
||||
Returns one of various preconditioning matrices depending on the identifier
|
||||
`preconditioner' and the input matrix A whose inverse it supposedly
|
||||
@ -91,11 +78,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
M = np.linalg.inv(A)
|
||||
else:
|
||||
M = None
|
||||
|
||||
if M is None or not return_function:
|
||||
return M
|
||||
else:
|
||||
return lambda x: jnp.dot(M, x, precision=lax.Precision.HIGHEST)
|
||||
return M
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -186,6 +169,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "x0 and b must have matching shape"):
|
||||
jax.scipy.sparse.linalg.cg(A, b, b[:, np.newaxis])
|
||||
with self.assertRaisesRegex(ValueError, "must be a square matrix"):
|
||||
jax.scipy.sparse.linalg.cg(jnp.zeros((3, 2)), jnp.zeros((2,)))
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "linear operator must be either a function or ndarray"):
|
||||
jax.scipy.sparse.linalg.cg([[1]], jnp.zeros((1,)))
|
||||
|
||||
def test_cg_without_pytree_equality(self):
|
||||
|
||||
@ -273,16 +261,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
A = jnp.eye(shape[1], dtype=dtype)
|
||||
|
||||
solution = jnp.ones(shape[1], dtype=dtype)
|
||||
@jax.tree_util.Partial
|
||||
def A_mv(x):
|
||||
return matmul_high_precision(A, x)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
M = self._fetch_preconditioner(preconditioner, A, rng=rng,
|
||||
return_function=True)
|
||||
b = A_mv(solution)
|
||||
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
|
||||
b = matmul_high_precision(A, solution)
|
||||
restart = shape[-1]
|
||||
tol = shape[0] * jnp.finfo(dtype).eps
|
||||
x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol,
|
||||
x, info = jax.scipy.sparse.linalg.gmres(A, b, tol=tol, atol=tol,
|
||||
restart=restart,
|
||||
M=M, solve_method=solve_method)
|
||||
using_x64 = solution.dtype.kind in {np.float64, np.complex128}
|
||||
@ -308,15 +292,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
A = rng(shape, dtype)
|
||||
|
||||
solution = rng(shape[1:], dtype)
|
||||
@jax.tree_util.Partial
|
||||
def A_mv(x):
|
||||
return matmul_high_precision(A, x)
|
||||
M = self._fetch_preconditioner(preconditioner, A, rng=rng,
|
||||
return_function=True)
|
||||
b = A_mv(solution)
|
||||
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
|
||||
b = matmul_high_precision(A, solution)
|
||||
restart = shape[-1]
|
||||
tol = shape[0] * jnp.finfo(A.dtype).eps
|
||||
x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol,
|
||||
x, info = jax.scipy.sparse.linalg.gmres(A, b, tol=tol, atol=tol,
|
||||
restart=restart,
|
||||
M=M, solve_method=solve_method)
|
||||
using_x64 = solution.dtype.kind in {np.float64, np.complex128}
|
||||
@ -350,12 +330,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
A = rng(shape, dtype)
|
||||
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
|
||||
if preconditioner is None:
|
||||
M = lambda x: x
|
||||
else:
|
||||
M = self._fetch_preconditioner(preconditioner, A, rng=rng,
|
||||
return_function=True)
|
||||
|
||||
M = partial(matmul_high_precision, M)
|
||||
n = shape[0]
|
||||
x0 = rng(shape[:1], dtype)
|
||||
Q = np.zeros((n, n + 1), dtype=dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user