mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Implement scipy.sparse.linalg.cg (second try) (#2566)
* super minimal starter code * Update optimizers.py * implement flip with axis = None * Create sparse.py * fix some imports * Update sparse.py * add partial function & test * Update lax_scipy_sparse_test.py * Update lax_scipy_sparse_test.py * add a test case for sparse pd matrix & add bigger dim * address comments * fix info return & create matrix with rng_factory * Update lax_scipy_sparse_test.py * Update lax_scipy_sparse_test.py * Update sparse.py * Update sparse.py * Update sparse.py * Update lax_scipy_sparse_test.py * Update lax_scipy_sparse_test.py * cast jax arrays into numpy array for scipy compatibility * Update sparse.py * Update sparse.py * fix None issue, but algo is not working * fix return of build_and_solve and output of while_loop * fix condition func of while loop * clearer variable names * mismatch error * Update lax_scipy_sparse_test.py * Fixes to jax.experimental.sparse.cg * Fix tests for gradients * Add support for preconditioners to cg * Move cg into scipy, update docs * doc tweak Co-authored-by: Tuan Nguyen <anhtuan277@gmail.com>
This commit is contained in:
parent
2b3befff32
commit
1b93bb51a8
@ -37,6 +37,16 @@ jax.scipy.ndimage
|
||||
|
||||
map_coordinates
|
||||
|
||||
jax.scipy.sparse.linalg
|
||||
-----------------------
|
||||
|
||||
.. automodule:: jax.scipy.sparse.linalg
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
cg
|
||||
|
||||
jax.scipy.special
|
||||
-----------------
|
||||
|
||||
|
15
jax/scipy/sparse/__init__.py
Normal file
15
jax/scipy/sparse/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import linalg
|
133
jax/scipy/sparse/linalg.py
Normal file
133
jax/scipy/sparse/linalg.py
Normal file
@ -0,0 +1,133 @@
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import textwrap
|
||||
|
||||
import scipy.sparse.linalg
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.numpy.lax_numpy import _wraps
|
||||
from jax import lax
|
||||
|
||||
|
||||
def _vdot(x, y):
|
||||
return jnp.vdot(x, y, precision=lax.Precision.HIGHEST)
|
||||
|
||||
|
||||
def _identity(x):
|
||||
return x
|
||||
|
||||
|
||||
def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):
|
||||
# tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg
|
||||
bs = _vdot(b, b)
|
||||
atol2 = jnp.maximum(tol ** 2 * bs, atol ** 2)
|
||||
|
||||
# https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method
|
||||
|
||||
def cond_fun(value):
|
||||
x, r, gamma, p, k = value
|
||||
rs = gamma if M is _identity else _vdot(r, r)
|
||||
return (rs > atol2) & (k < maxiter)
|
||||
|
||||
def body_fun(value):
|
||||
x, r, gamma, p, k = value
|
||||
Ap = A(p)
|
||||
alpha = gamma / _vdot(p, Ap)
|
||||
x_ = x + alpha * p
|
||||
r_ = r - alpha * Ap
|
||||
z_ = M(r_)
|
||||
gamma_ = _vdot(r_, z_)
|
||||
beta_ = gamma_ / gamma
|
||||
p_ = z_ + beta_ * p
|
||||
return x_, r_, gamma_, p_, k + 1
|
||||
|
||||
r0 = b - A(x0)
|
||||
p0 = z0 = M(r0)
|
||||
gamma0 = _vdot(r0, z0)
|
||||
initial_value = (x0, r0, gamma0, p0, 0)
|
||||
|
||||
x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
|
||||
|
||||
return x_final
|
||||
|
||||
|
||||
def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
|
||||
"""Use Conjugate Gradient iteration to solve ``Ax = b``.
|
||||
|
||||
The numerics of JAX's ``cg`` should exact match SciPy's ``cg`` (up to
|
||||
numerical precision), but note that the interface is slightly different: you
|
||||
need to supply the linear operator ``A`` as a function instead of a sparse
|
||||
matrix or ``LinearOperator``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
A : function
|
||||
Function that calculates the matrix-vector product ``Ax`` when called
|
||||
like ``A(x)``. ``A`` must represent a hermitian, positive definite
|
||||
matrix.
|
||||
b : array
|
||||
Right hand side of the linear system. Has shape (N,).
|
||||
|
||||
Returns
|
||||
-------
|
||||
x : array
|
||||
The converged solution.
|
||||
info : None
|
||||
Placeholder for convergence information. In the future, JAX will report
|
||||
the number of iterations when convergence is not achieved, like SciPy.
|
||||
|
||||
Other Parameters
|
||||
----------------
|
||||
x0 : array
|
||||
Starting guess for the solution.
|
||||
tol, atol : float, optional
|
||||
Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
|
||||
We do not implement SciPy's "legacy" behavior, so JAX's tolerance will
|
||||
differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``.
|
||||
maxiter : integer
|
||||
Maximum number of iterations. Iteration will stop after maxiter
|
||||
steps even if the specified tolerance has not been achieved.
|
||||
M : function
|
||||
Preconditioner for A. The preconditioner should approximate the
|
||||
inverse of A. Effective preconditioning dramatically improves the
|
||||
rate of convergence, which implies that fewer iterations are needed
|
||||
to reach a given error tolerance.
|
||||
|
||||
See also
|
||||
--------
|
||||
scipy.sparse.linalg.cg
|
||||
"""
|
||||
if x0 is None:
|
||||
x0 = jnp.zeros_like(b)
|
||||
|
||||
if maxiter is None:
|
||||
maxiter = 10 * len(b) # copied from scipy
|
||||
|
||||
if M is None:
|
||||
M = _identity
|
||||
|
||||
if x0.shape != b.shape:
|
||||
raise ValueError(
|
||||
f'x0 and b must have matching shape: {x0.shape} vs {b.shape}')
|
||||
if b.ndim != 1:
|
||||
raise ValueError(
|
||||
f'b must be one-dimensional, but has shape {b.shape}')
|
||||
|
||||
cg_solve = partial(
|
||||
_cg_solve, x0=x0, tol=tol, atol=atol, maxiter=maxiter, M=M)
|
||||
x = lax.custom_linear_solve(A, b, cg_solve, symmetric=True)
|
||||
info = None # TODO(shoyer): return the real iteration count here
|
||||
return x, info
|
139
tests/lax_scipy_sparse_test.py
Normal file
139
tests/lax_scipy_sparse_test.py
Normal file
@ -0,0 +1,139 @@
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from absl.testing import parameterized
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax import jit
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import scipy.sparse.linalg
|
||||
from jax import lax
|
||||
from jax import test_util as jtu
|
||||
import jax.scipy.sparse.linalg
|
||||
|
||||
|
||||
float_types = [np.float32, np.float64]
|
||||
complex_types = [np.complex64, np.complex128]
|
||||
|
||||
|
||||
def matmul_high_precision(a, b):
|
||||
return jnp.matmul(a, b, precision=lax.Precision.HIGHEST)
|
||||
|
||||
|
||||
@jit
|
||||
def posify(matrix):
|
||||
return matmul_high_precision(matrix, matrix.T.conj())
|
||||
|
||||
|
||||
def lax_cg(A, b, M=None, tol=0.0, atol=0.0, **kwargs):
|
||||
A = partial(matmul_high_precision, A)
|
||||
if M is not None:
|
||||
M = partial(matmul_high_precision, M)
|
||||
x, _ = jax.scipy.sparse.linalg.cg(A, b, tol=tol, atol=atol, M=M, **kwargs)
|
||||
return x
|
||||
|
||||
|
||||
def scipy_cg(A, b, tol=0.0, atol=0.0, **kwargs):
|
||||
x, _ = scipy.sparse.linalg.cg(A, b, tol=tol, atol=atol, **kwargs)
|
||||
return x
|
||||
|
||||
|
||||
def rand_sym_pos_def(rng, shape, dtype):
|
||||
matrix = np.eye(N=shape[0], dtype=dtype) + rng(shape, dtype)
|
||||
return matrix @ matrix.T.conj()
|
||||
|
||||
|
||||
class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_preconditioner={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
preconditioner),
|
||||
"shape": shape, "dtype": dtype, "rng_factory": rng_factory,
|
||||
"preconditioner": preconditioner}
|
||||
for shape in [(4, 4), (7, 7), (32, 32)]
|
||||
for dtype in float_types + complex_types
|
||||
for rng_factory in [jtu.rand_default]
|
||||
for rng_factory in [jtu.rand_default]
|
||||
for preconditioner in [None, 'random', 'identity', 'exact']))
|
||||
def test_cg_against_scipy(self, shape, dtype, rng_factory, preconditioner):
|
||||
|
||||
rng = rng_factory()
|
||||
A = rand_sym_pos_def(rng, shape, dtype)
|
||||
b = rng(shape[:1], dtype)
|
||||
|
||||
if preconditioner == 'identity':
|
||||
M = np.eye(shape[0], dtype=dtype)
|
||||
elif preconditioner == 'random':
|
||||
M = np.linalg.inv(rand_sym_pos_def(rng, shape, dtype))
|
||||
elif preconditioner == 'exact':
|
||||
M = np.linalg.inv(A)
|
||||
else:
|
||||
M = None
|
||||
|
||||
def args_maker():
|
||||
return A, b
|
||||
|
||||
self._CheckAgainstNumpy(
|
||||
partial(scipy_cg, M=M, maxiter=1),
|
||||
partial(lax_cg, M=M, maxiter=1),
|
||||
args_maker,
|
||||
check_dtypes=True,
|
||||
tol=3e-5)
|
||||
|
||||
self._CheckAgainstNumpy(
|
||||
partial(scipy_cg, M=M, maxiter=3),
|
||||
partial(lax_cg, M=M, maxiter=3),
|
||||
args_maker,
|
||||
check_dtypes=True,
|
||||
tol=1e-4)
|
||||
|
||||
self._CheckAgainstNumpy(
|
||||
np.linalg.solve,
|
||||
partial(lax_cg, M=M, atol=1e-6),
|
||||
args_maker,
|
||||
check_dtypes=True,
|
||||
tol=2e-4)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype, "rng_factory": rng_factory}
|
||||
for shape in [(2, 2)]
|
||||
for dtype in float_types
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def test_cg_as_solve(self, shape, dtype, rng_factory):
|
||||
|
||||
rng = rng_factory()
|
||||
a = rng(shape, dtype)
|
||||
b = rng(shape[:1], dtype)
|
||||
|
||||
expected = np.linalg.solve(posify(a), b)
|
||||
actual = lax_cg(posify(a), b)
|
||||
self.assertAllClose(expected, actual, check_dtypes=True)
|
||||
|
||||
actual = jit(lax_cg)(posify(a), b)
|
||||
self.assertAllClose(expected, actual, check_dtypes=True)
|
||||
|
||||
# numerical gradients are only well defined if ``a`` is guaranteed to be
|
||||
# positive definite.
|
||||
jtu.check_grads(
|
||||
lambda x, y: lax_cg(posify(x), y),
|
||||
(a, b), order=2, rtol=1e-2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
Loading…
x
Reference in New Issue
Block a user