mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
487 lines
16 KiB
Python
487 lines
16 KiB
Python
# Copyright 2020 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from functools import partial
|
|
import unittest
|
|
|
|
from absl.testing import absltest
|
|
import numpy as np
|
|
import scipy.sparse.linalg
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import jax.scipy.sparse.linalg
|
|
from jax import jit
|
|
from jax import lax
|
|
from jax.tree_util import register_pytree_node_class
|
|
|
|
import jax._src.scipy.sparse.linalg as sp_linalg
|
|
from jax._src import config
|
|
from jax._src import dtypes
|
|
from jax._src import test_util as jtu
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
|
|
float_types = jtu.dtypes.floating
|
|
complex_types = jtu.dtypes.complex
|
|
|
|
|
|
def matmul_high_precision(a, b):
|
|
return jnp.matmul(a, b, precision=lax.Precision.HIGHEST)
|
|
|
|
|
|
@jit
|
|
def posify(matrix):
|
|
return matmul_high_precision(matrix, matrix.T.conj())
|
|
|
|
|
|
def solver(func, A, b, M=None, atol=0.0, **kwargs):
|
|
x, _ = func(A, b, atol=atol, M=M, **kwargs)
|
|
return x
|
|
|
|
|
|
lax_cg = partial(solver, jax.scipy.sparse.linalg.cg)
|
|
lax_gmres = partial(solver, jax.scipy.sparse.linalg.gmres)
|
|
lax_bicgstab = partial(solver, jax.scipy.sparse.linalg.bicgstab)
|
|
scipy_cg = partial(solver, scipy.sparse.linalg.cg)
|
|
scipy_gmres = partial(solver, scipy.sparse.linalg.gmres)
|
|
scipy_bicgstab = partial(solver, scipy.sparse.linalg.bicgstab)
|
|
|
|
|
|
def rand_sym_pos_def(rng, shape, dtype):
|
|
matrix = np.eye(N=shape[0], dtype=dtype) + rng(shape, dtype)
|
|
return matrix @ matrix.T.conj()
|
|
|
|
|
|
class CustomOperator:
|
|
def __init__(self, A):
|
|
self.A = A
|
|
self.shape = self.A.shape
|
|
|
|
def __matmul__(self, x):
|
|
return self.A @ x
|
|
|
|
|
|
class LaxBackedScipyTests(jtu.JaxTestCase):
|
|
|
|
def _fetch_preconditioner(self, preconditioner, A, rng=None):
|
|
"""
|
|
Returns one of various preconditioning matrices depending on the identifier
|
|
`preconditioner' and the input matrix A whose inverse it supposedly
|
|
approximates.
|
|
"""
|
|
if preconditioner == 'identity':
|
|
M = np.eye(A.shape[0], dtype=A.dtype)
|
|
elif preconditioner == 'random':
|
|
if rng is None:
|
|
rng = jtu.rand_default(self.rng())
|
|
M = np.linalg.inv(rand_sym_pos_def(rng, A.shape, A.dtype))
|
|
elif preconditioner == 'exact':
|
|
M = np.linalg.inv(A)
|
|
else:
|
|
M = None
|
|
return M
|
|
|
|
@jtu.sample_product(
|
|
shape=[(4, 4), (7, 7)],
|
|
dtype=[np.float64, np.complex128],
|
|
preconditioner=[None, 'identity', 'exact', 'random'],
|
|
)
|
|
def test_cg_against_scipy(self, shape, dtype, preconditioner):
|
|
if not config.enable_x64.value:
|
|
raise unittest.SkipTest("requires x64 mode")
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
A = rand_sym_pos_def(rng, shape, dtype)
|
|
b = rng(shape[:1], dtype)
|
|
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
|
|
|
|
def args_maker():
|
|
return A, b
|
|
|
|
self._CheckAgainstNumpy(
|
|
partial(scipy_cg, M=M, maxiter=1),
|
|
partial(lax_cg, M=M, maxiter=1),
|
|
args_maker,
|
|
tol=1e-12)
|
|
|
|
self._CheckAgainstNumpy(
|
|
partial(scipy_cg, M=M, maxiter=3),
|
|
partial(lax_cg, M=M, maxiter=3),
|
|
args_maker,
|
|
tol=1e-12)
|
|
|
|
self._CheckAgainstNumpy(
|
|
np.linalg.solve,
|
|
partial(lax_cg, M=M, atol=1e-10),
|
|
args_maker,
|
|
tol=1e-6)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(2, 2)],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
def test_cg_as_solve(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
a = rng(shape, dtype)
|
|
b = rng(shape[:1], dtype)
|
|
|
|
expected = np.linalg.solve(posify(a), b)
|
|
actual = lax_cg(posify(a), b)
|
|
self.assertAllClose(expected, actual, atol=1e-5, rtol=1e-5)
|
|
|
|
actual = jit(lax_cg)(posify(a), b)
|
|
self.assertAllClose(expected, actual, atol=1e-5, rtol=1e-5)
|
|
|
|
# numerical gradients are only well defined if ``a`` is guaranteed to be
|
|
# positive definite.
|
|
jtu.check_grads(
|
|
lambda x, y: lax_cg(posify(x), y),
|
|
(a, b), order=2, rtol=2e-1)
|
|
|
|
def test_cg_ndarray(self):
|
|
A = lambda x: 2 * x
|
|
b = jnp.arange(9.0).reshape((3, 3))
|
|
expected = b / 2
|
|
actual, _ = jax.scipy.sparse.linalg.cg(A, b)
|
|
self.assertAllClose(expected, actual)
|
|
|
|
def test_cg_pytree(self):
|
|
A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
|
|
b = {"a": 1.0, "b": -4.0}
|
|
expected = {"a": 4.0, "b": -6.0}
|
|
actual, _ = jax.scipy.sparse.linalg.cg(A, b)
|
|
self.assertEqual(expected.keys(), actual.keys())
|
|
self.assertAlmostEqual(expected["a"], actual["a"], places=6)
|
|
self.assertAlmostEqual(expected["b"], actual["b"], places=6)
|
|
|
|
@jtu.skip_on_devices('tpu')
|
|
def test_cg_matmul(self):
|
|
A = CustomOperator(2 * jnp.eye(3))
|
|
b = jnp.arange(9.0).reshape(3, 3)
|
|
expected = b / 2
|
|
actual, _ = jax.scipy.sparse.linalg.cg(A, b)
|
|
self.assertAllClose(expected, actual)
|
|
|
|
def test_cg_errors(self):
|
|
A = lambda x: x
|
|
b = jnp.zeros((2,))
|
|
with self.assertRaisesRegex(
|
|
ValueError, "x0 and b must have matching tree structure"):
|
|
jax.scipy.sparse.linalg.cg(A, {'x': b}, {'y': b})
|
|
with self.assertRaisesRegex(
|
|
ValueError, "x0 and b must have matching shape"):
|
|
jax.scipy.sparse.linalg.cg(A, b, b[:, np.newaxis])
|
|
with self.assertRaisesRegex(ValueError, "must be a square matrix"):
|
|
jax.scipy.sparse.linalg.cg(jnp.zeros((3, 2)), jnp.zeros((2,)))
|
|
with self.assertRaisesRegex(
|
|
TypeError, "linear operator must be either a function or ndarray"):
|
|
jax.scipy.sparse.linalg.cg([[1]], jnp.zeros((1,)))
|
|
|
|
def test_cg_without_pytree_equality(self):
|
|
|
|
@register_pytree_node_class
|
|
class MinimalPytree:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
def tree_flatten(self):
|
|
return [self.value], None
|
|
@classmethod
|
|
def tree_unflatten(cls, aux_data, children):
|
|
return cls(*children)
|
|
|
|
A = lambda x: MinimalPytree(2 * x.value)
|
|
b = MinimalPytree(jnp.arange(5.0))
|
|
expected = b.value / 2
|
|
actual, _ = jax.scipy.sparse.linalg.cg(A, b)
|
|
self.assertAllClose(expected, actual.value)
|
|
|
|
def test_cg_weak_types(self):
|
|
x, _ = jax.scipy.sparse.linalg.bicgstab(lambda x: x, 1.0)
|
|
self.assertTrue(dtypes.is_weakly_typed(x))
|
|
|
|
# BICGSTAB
|
|
@jtu.sample_product(
|
|
shape=[(5, 5)],
|
|
dtype=[np.float64, np.complex128],
|
|
preconditioner=[None, 'identity', 'exact', 'random'],
|
|
)
|
|
def test_bicgstab_against_scipy(
|
|
self, shape, dtype, preconditioner):
|
|
if not config.enable_x64.value:
|
|
raise unittest.SkipTest("requires x64 mode")
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
A = rng(shape, dtype)
|
|
b = rng(shape[:1], dtype)
|
|
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
|
|
|
|
def args_maker():
|
|
return A, b
|
|
|
|
self._CheckAgainstNumpy(
|
|
partial(scipy_bicgstab, M=M, maxiter=1),
|
|
partial(lax_bicgstab, M=M, maxiter=1),
|
|
args_maker,
|
|
tol=1e-5)
|
|
|
|
self._CheckAgainstNumpy(
|
|
partial(scipy_bicgstab, M=M, maxiter=2),
|
|
partial(lax_bicgstab, M=M, maxiter=2),
|
|
args_maker,
|
|
tol=1e-4)
|
|
|
|
self._CheckAgainstNumpy(
|
|
partial(scipy_bicgstab, M=M, maxiter=1),
|
|
partial(lax_bicgstab, M=M, maxiter=1),
|
|
args_maker,
|
|
tol=1e-4)
|
|
|
|
self._CheckAgainstNumpy(
|
|
np.linalg.solve,
|
|
partial(lax_bicgstab, M=M, atol=1e-6),
|
|
args_maker,
|
|
tol=1e-4)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(2, 2), (7, 7)],
|
|
dtype=float_types + complex_types,
|
|
preconditioner=[None, 'identity', 'exact'],
|
|
)
|
|
@jtu.skip_on_devices("gpu")
|
|
def test_bicgstab_on_identity_system(self, shape, dtype, preconditioner):
|
|
A = jnp.eye(shape[1], dtype=dtype)
|
|
solution = jnp.ones(shape[1], dtype=dtype)
|
|
rng = jtu.rand_default(self.rng())
|
|
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
|
|
b = matmul_high_precision(A, solution)
|
|
tol = shape[0] * float(jnp.finfo(dtype).eps)
|
|
x, info = jax.scipy.sparse.linalg.bicgstab(A, b, tol=tol, atol=tol,
|
|
M=M)
|
|
using_x64 = solution.dtype.kind in {np.float64, np.complex128}
|
|
solution_tol = 1e-8 if using_x64 else 1e-4
|
|
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(2, 2), (4, 4)],
|
|
dtype=float_types + complex_types,
|
|
preconditioner=[None, 'identity', 'exact'],
|
|
)
|
|
@jtu.skip_on_devices("gpu")
|
|
def test_bicgstab_on_random_system(self, shape, dtype, preconditioner):
|
|
rng = jtu.rand_default(self.rng())
|
|
A = rng(shape, dtype)
|
|
solution = rng(shape[1:], dtype)
|
|
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
|
|
b = matmul_high_precision(A, solution)
|
|
tol = shape[0] * float(jnp.finfo(A.dtype).eps)
|
|
x, info = jax.scipy.sparse.linalg.bicgstab(A, b, tol=tol, atol=tol, M=M)
|
|
using_x64 = solution.dtype.kind in {np.float64, np.complex128}
|
|
solution_tol = 1e-8 if using_x64 else 1e-4
|
|
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
|
|
# solve = lambda A, b: jax.scipy.sparse.linalg.bicgstab(A, b)[0]
|
|
# jtu.check_grads(solve, (A, b), order=1, rtol=3e-1)
|
|
|
|
|
|
def test_bicgstab_pytree(self):
|
|
A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
|
|
b = {"a": 1.0, "b": -4.0}
|
|
expected = {"a": 4.0, "b": -6.0}
|
|
actual, _ = jax.scipy.sparse.linalg.bicgstab(A, b)
|
|
self.assertEqual(expected.keys(), actual.keys())
|
|
self.assertAlmostEqual(expected["a"], actual["a"], places=5)
|
|
self.assertAlmostEqual(expected["b"], actual["b"], places=5)
|
|
|
|
def test_bicgstab_weak_types(self):
|
|
x, _ = jax.scipy.sparse.linalg.bicgstab(lambda x: x, 1.0)
|
|
self.assertTrue(dtypes.is_weakly_typed(x))
|
|
|
|
@jtu.skip_on_devices('tpu')
|
|
def test_bicgstab_matmul(self):
|
|
A = CustomOperator(2 * jnp.eye(3))
|
|
b = jnp.arange(9.0).reshape(3, 3)
|
|
expected = b / 2
|
|
actual, _ = jax.scipy.sparse.linalg.bicgstab(A, b)
|
|
self.assertAllClose(expected, actual)
|
|
|
|
# GMRES
|
|
@jtu.sample_product(
|
|
shape=[(3, 3)],
|
|
dtype=[np.float64, np.complex128],
|
|
preconditioner=[None, 'identity', 'exact', 'random'],
|
|
solve_method=['incremental', 'batched'],
|
|
)
|
|
def test_gmres_against_scipy(
|
|
self, shape, dtype, preconditioner, solve_method):
|
|
if not config.enable_x64.value:
|
|
raise unittest.SkipTest("requires x64 mode")
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
A = rng(shape, dtype)
|
|
b = rng(shape[:1], dtype)
|
|
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
|
|
|
|
def args_maker():
|
|
return A, b
|
|
|
|
self._CheckAgainstNumpy(
|
|
partial(scipy_gmres, M=M, restart=1, maxiter=1),
|
|
partial(lax_gmres, M=M, restart=1, maxiter=1, solve_method=solve_method),
|
|
args_maker,
|
|
tol=1e-10)
|
|
|
|
self._CheckAgainstNumpy(
|
|
partial(scipy_gmres, M=M, restart=1, maxiter=2),
|
|
partial(lax_gmres, M=M, restart=1, maxiter=2, solve_method=solve_method),
|
|
args_maker,
|
|
tol=1e-10)
|
|
|
|
self._CheckAgainstNumpy(
|
|
partial(scipy_gmres, M=M, restart=2, maxiter=1),
|
|
partial(lax_gmres, M=M, restart=2, maxiter=1, solve_method=solve_method),
|
|
args_maker,
|
|
tol=1e-9)
|
|
|
|
self._CheckAgainstNumpy(
|
|
np.linalg.solve,
|
|
partial(lax_gmres, M=M, atol=1e-6, solve_method=solve_method),
|
|
args_maker,
|
|
tol=1e-8)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(2, 2), (7, 7)],
|
|
dtype=float_types + complex_types,
|
|
preconditioner=[None, 'identity', 'exact'],
|
|
solve_method=['batched', 'incremental'],
|
|
)
|
|
@jtu.skip_on_devices("gpu")
|
|
def test_gmres_on_identity_system(self, shape, dtype, preconditioner,
|
|
solve_method):
|
|
A = jnp.eye(shape[1], dtype=dtype)
|
|
|
|
solution = jnp.ones(shape[1], dtype=dtype)
|
|
rng = jtu.rand_default(self.rng())
|
|
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
|
|
b = matmul_high_precision(A, solution)
|
|
restart = shape[-1]
|
|
tol = shape[0] * float(jnp.finfo(dtype).eps)
|
|
x, info = jax.scipy.sparse.linalg.gmres(A, b, tol=tol, atol=tol,
|
|
restart=restart,
|
|
M=M, solve_method=solve_method)
|
|
using_x64 = solution.dtype.kind in {np.float64, np.complex128}
|
|
solution_tol = 1e-8 if using_x64 else 1e-4
|
|
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(2, 2), (4, 4)],
|
|
dtype=float_types + complex_types,
|
|
preconditioner=[None, 'identity', 'exact'],
|
|
solve_method=['incremental', 'batched'],
|
|
)
|
|
@jtu.skip_on_devices("gpu")
|
|
def test_gmres_on_random_system(self, shape, dtype, preconditioner,
|
|
solve_method):
|
|
rng = jtu.rand_default(self.rng())
|
|
A = rng(shape, dtype)
|
|
|
|
solution = rng(shape[1:], dtype)
|
|
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
|
|
b = matmul_high_precision(A, solution)
|
|
restart = shape[-1]
|
|
tol = shape[0] * float(jnp.finfo(A.dtype).eps)
|
|
x, info = jax.scipy.sparse.linalg.gmres(A, b, tol=tol, atol=tol,
|
|
restart=restart,
|
|
M=M, solve_method=solve_method)
|
|
using_x64 = solution.dtype.kind in {np.float64, np.complex128}
|
|
solution_tol = 1e-8 if using_x64 else 1e-4
|
|
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
|
|
# solve = lambda A, b: jax.scipy.sparse.linalg.gmres(A, b)[0]
|
|
# jtu.check_grads(solve, (A, b), order=1, rtol=2e-1)
|
|
|
|
def test_gmres_pytree(self):
|
|
A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
|
|
b = {"a": 1.0, "b": -4.0}
|
|
expected = {"a": 4.0, "b": -6.0}
|
|
actual, _ = jax.scipy.sparse.linalg.gmres(A, b)
|
|
self.assertEqual(expected.keys(), actual.keys())
|
|
self.assertAlmostEqual(expected["a"], actual["a"], places=5)
|
|
self.assertAlmostEqual(expected["b"], actual["b"], places=5)
|
|
|
|
@jax.default_matmul_precision("float32")
|
|
def test_gmres_matmul(self):
|
|
A = CustomOperator(2 * jnp.eye(3))
|
|
b = jnp.arange(9.0).reshape(3, 3)
|
|
expected = b / 2
|
|
actual, _ = jax.scipy.sparse.linalg.gmres(A, b)
|
|
self.assertAllClose(expected, actual)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(2, 2), (3, 3)],
|
|
dtype=float_types + complex_types,
|
|
preconditioner=[None, 'identity'],
|
|
)
|
|
def test_gmres_arnoldi_step(self, shape, dtype, preconditioner):
|
|
"""
|
|
The Arnoldi decomposition within GMRES is correct.
|
|
"""
|
|
if not config.enable_x64.value:
|
|
raise unittest.SkipTest("requires x64 mode")
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
A = rng(shape, dtype)
|
|
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
|
|
if preconditioner is None:
|
|
M = lambda x: x
|
|
else:
|
|
M = partial(matmul_high_precision, M)
|
|
n = shape[0]
|
|
x0 = rng(shape[:1], dtype)
|
|
Q = np.zeros((n, n + 1), dtype=dtype)
|
|
Q[:, 0] = x0 / jnp.linalg.norm(x0).astype(dtype)
|
|
Q = jnp.array(Q)
|
|
H = jnp.eye(n, n + 1, dtype=dtype)
|
|
|
|
@jax.tree_util.Partial
|
|
def A_mv(x):
|
|
return matmul_high_precision(A, x)
|
|
for k in range(n):
|
|
Q, H, _ = sp_linalg._kth_arnoldi_iteration(
|
|
k, A_mv, M, Q, H)
|
|
QA = matmul_high_precision(Q[:, :n].conj().T, A)
|
|
QAQ = matmul_high_precision(QA, Q[:, :n])
|
|
self.assertAllClose(QAQ, H.T[:n, :], rtol=2e-5, atol=2e-5)
|
|
|
|
def test_gmres_weak_types(self):
|
|
x, _ = jax.scipy.sparse.linalg.gmres(lambda x: x, 1.0)
|
|
self.assertTrue(dtypes.is_weakly_typed(x))
|
|
|
|
def test_linear_solve_batching_via_jacrev(self):
|
|
# See https://github.com/jax-ml/jax/issues/14249
|
|
rng = np.random.RandomState(0)
|
|
M = rng.randn(5, 5)
|
|
A = np.dot(M, M.T)
|
|
matvec = lambda x: (jnp.dot(A, x[0]), jnp.dot(A, x[1]))
|
|
|
|
def f(b):
|
|
return jax.scipy.sparse.linalg.cg(matvec, (b, b))[0]
|
|
|
|
b = rng.randn(5)
|
|
jax.jacrev(f)(b) # doesn't crash
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|