Implement scipy.sparse.linalg.cg (second try) (#2566)

* super minimal starter code

* Update optimizers.py

* implement flip with axis = None

* Create sparse.py

* fix some imports

* Update sparse.py

* add partial function & test

* Update lax_scipy_sparse_test.py

* Update lax_scipy_sparse_test.py

* add a test case for sparse pd matrix & add bigger dim

* address comments

* fix info return & create matrix with rng_factory

* Update lax_scipy_sparse_test.py

* Update lax_scipy_sparse_test.py

* Update sparse.py

* Update sparse.py

* Update sparse.py

* Update lax_scipy_sparse_test.py

* Update lax_scipy_sparse_test.py

* cast jax arrays into numpy array for scipy compatibility

* Update sparse.py

* Update sparse.py

* fix None issue, but algo is not working

* fix return of build_and_solve and output of while_loop

* fix condition func of while loop

* clearer variable names

* mismatch error

* Update lax_scipy_sparse_test.py

* Fixes to jax.experimental.sparse.cg

* Fix tests for gradients

* Add support for preconditioners to cg

* Move cg into scipy, update docs

* doc tweak

Co-authored-by: Tuan Nguyen <anhtuan277@gmail.com>
This commit is contained in:
Stephan Hoyer 2020-04-03 13:37:11 -07:00 committed by GitHub
parent 2b3befff32
commit 1b93bb51a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 297 additions and 0 deletions

View File

@ -37,6 +37,16 @@ jax.scipy.ndimage
map_coordinates
jax.scipy.sparse.linalg
-----------------------
.. automodule:: jax.scipy.sparse.linalg
.. autosummary::
:toctree: _autosummary
cg
jax.scipy.special
-----------------

View File

@ -0,0 +1,15 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import linalg

133
jax/scipy/sparse/linalg.py Normal file
View File

@ -0,0 +1,133 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import textwrap
import scipy.sparse.linalg
import jax.numpy as jnp
import numpy as np
from jax.numpy.lax_numpy import _wraps
from jax import lax
def _vdot(x, y):
return jnp.vdot(x, y, precision=lax.Precision.HIGHEST)
def _identity(x):
return x
def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):
# tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg
bs = _vdot(b, b)
atol2 = jnp.maximum(tol ** 2 * bs, atol ** 2)
# https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method
def cond_fun(value):
x, r, gamma, p, k = value
rs = gamma if M is _identity else _vdot(r, r)
return (rs > atol2) & (k < maxiter)
def body_fun(value):
x, r, gamma, p, k = value
Ap = A(p)
alpha = gamma / _vdot(p, Ap)
x_ = x + alpha * p
r_ = r - alpha * Ap
z_ = M(r_)
gamma_ = _vdot(r_, z_)
beta_ = gamma_ / gamma
p_ = z_ + beta_ * p
return x_, r_, gamma_, p_, k + 1
r0 = b - A(x0)
p0 = z0 = M(r0)
gamma0 = _vdot(r0, z0)
initial_value = (x0, r0, gamma0, p0, 0)
x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
return x_final
def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
"""Use Conjugate Gradient iteration to solve ``Ax = b``.
The numerics of JAX's ``cg`` should exact match SciPy's ``cg`` (up to
numerical precision), but note that the interface is slightly different: you
need to supply the linear operator ``A`` as a function instead of a sparse
matrix or ``LinearOperator``.
Parameters
----------
A : function
Function that calculates the matrix-vector product ``Ax`` when called
like ``A(x)``. ``A`` must represent a hermitian, positive definite
matrix.
b : array
Right hand side of the linear system. Has shape (N,).
Returns
-------
x : array
The converged solution.
info : None
Placeholder for convergence information. In the future, JAX will report
the number of iterations when convergence is not achieved, like SciPy.
Other Parameters
----------------
x0 : array
Starting guess for the solution.
tol, atol : float, optional
Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
We do not implement SciPy's "legacy" behavior, so JAX's tolerance will
differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``.
maxiter : integer
Maximum number of iterations. Iteration will stop after maxiter
steps even if the specified tolerance has not been achieved.
M : function
Preconditioner for A. The preconditioner should approximate the
inverse of A. Effective preconditioning dramatically improves the
rate of convergence, which implies that fewer iterations are needed
to reach a given error tolerance.
See also
--------
scipy.sparse.linalg.cg
"""
if x0 is None:
x0 = jnp.zeros_like(b)
if maxiter is None:
maxiter = 10 * len(b) # copied from scipy
if M is None:
M = _identity
if x0.shape != b.shape:
raise ValueError(
f'x0 and b must have matching shape: {x0.shape} vs {b.shape}')
if b.ndim != 1:
raise ValueError(
f'b must be one-dimensional, but has shape {b.shape}')
cg_solve = partial(
_cg_solve, x0=x0, tol=tol, atol=atol, maxiter=maxiter, M=M)
x = lax.custom_linear_solve(A, b, cg_solve, symmetric=True)
info = None # TODO(shoyer): return the real iteration count here
return x, info

View File

@ -0,0 +1,139 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from absl.testing import parameterized
from absl.testing import absltest
from jax import jit
import jax.numpy as jnp
import numpy as np
import scipy.sparse.linalg
from jax import lax
from jax import test_util as jtu
import jax.scipy.sparse.linalg
float_types = [np.float32, np.float64]
complex_types = [np.complex64, np.complex128]
def matmul_high_precision(a, b):
return jnp.matmul(a, b, precision=lax.Precision.HIGHEST)
@jit
def posify(matrix):
return matmul_high_precision(matrix, matrix.T.conj())
def lax_cg(A, b, M=None, tol=0.0, atol=0.0, **kwargs):
A = partial(matmul_high_precision, A)
if M is not None:
M = partial(matmul_high_precision, M)
x, _ = jax.scipy.sparse.linalg.cg(A, b, tol=tol, atol=atol, M=M, **kwargs)
return x
def scipy_cg(A, b, tol=0.0, atol=0.0, **kwargs):
x, _ = scipy.sparse.linalg.cg(A, b, tol=tol, atol=atol, **kwargs)
return x
def rand_sym_pos_def(rng, shape, dtype):
matrix = np.eye(N=shape[0], dtype=dtype) + rng(shape, dtype)
return matrix @ matrix.T.conj()
class LaxBackedScipyTests(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}_preconditioner={}".format(
jtu.format_shape_dtype_string(shape, dtype),
preconditioner),
"shape": shape, "dtype": dtype, "rng_factory": rng_factory,
"preconditioner": preconditioner}
for shape in [(4, 4), (7, 7), (32, 32)]
for dtype in float_types + complex_types
for rng_factory in [jtu.rand_default]
for rng_factory in [jtu.rand_default]
for preconditioner in [None, 'random', 'identity', 'exact']))
def test_cg_against_scipy(self, shape, dtype, rng_factory, preconditioner):
rng = rng_factory()
A = rand_sym_pos_def(rng, shape, dtype)
b = rng(shape[:1], dtype)
if preconditioner == 'identity':
M = np.eye(shape[0], dtype=dtype)
elif preconditioner == 'random':
M = np.linalg.inv(rand_sym_pos_def(rng, shape, dtype))
elif preconditioner == 'exact':
M = np.linalg.inv(A)
else:
M = None
def args_maker():
return A, b
self._CheckAgainstNumpy(
partial(scipy_cg, M=M, maxiter=1),
partial(lax_cg, M=M, maxiter=1),
args_maker,
check_dtypes=True,
tol=3e-5)
self._CheckAgainstNumpy(
partial(scipy_cg, M=M, maxiter=3),
partial(lax_cg, M=M, maxiter=3),
args_maker,
check_dtypes=True,
tol=1e-4)
self._CheckAgainstNumpy(
np.linalg.solve,
partial(lax_cg, M=M, atol=1e-6),
args_maker,
check_dtypes=True,
tol=2e-4)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype, "rng_factory": rng_factory}
for shape in [(2, 2)]
for dtype in float_types
for rng_factory in [jtu.rand_default]))
def test_cg_as_solve(self, shape, dtype, rng_factory):
rng = rng_factory()
a = rng(shape, dtype)
b = rng(shape[:1], dtype)
expected = np.linalg.solve(posify(a), b)
actual = lax_cg(posify(a), b)
self.assertAllClose(expected, actual, check_dtypes=True)
actual = jit(lax_cg)(posify(a), b)
self.assertAllClose(expected, actual, check_dtypes=True)
# numerical gradients are only well defined if ``a`` is guaranteed to be
# positive definite.
jtu.check_grads(
lambda x, y: lax_cg(posify(x), y),
(a, b), order=2, rtol=1e-2)
if __name__ == "__main__":
absltest.main()