mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
475 lines
17 KiB
Python
475 lines
17 KiB
Python
# Copyright 2020 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from functools import partial
|
|
import unittest
|
|
|
|
from absl.testing import parameterized
|
|
from absl.testing import absltest
|
|
import numpy as np
|
|
import scipy.sparse.linalg
|
|
|
|
from jax import jit
|
|
import jax.numpy as jnp
|
|
from jax import lax
|
|
from jax import test_util as jtu
|
|
from jax.tree_util import register_pytree_node_class
|
|
import jax.scipy.sparse.linalg
|
|
import jax._src.scipy.sparse.linalg
|
|
|
|
from jax.config import config
|
|
config.parse_flags_with_absl()
|
|
|
|
|
|
float_types = jtu.dtypes.floating
|
|
complex_types = jtu.dtypes.complex
|
|
|
|
|
|
def matmul_high_precision(a, b):
|
|
return jnp.matmul(a, b, precision=lax.Precision.HIGHEST)
|
|
|
|
|
|
@jit
|
|
def posify(matrix):
|
|
return matmul_high_precision(matrix, matrix.T.conj())
|
|
|
|
|
|
def solver(func, A, b, M=None, atol=0.0, **kwargs):
|
|
x, _ = func(A, b, atol=atol, M=M, **kwargs)
|
|
return x
|
|
|
|
|
|
lax_cg = partial(solver, jax.scipy.sparse.linalg.cg)
|
|
lax_gmres = partial(solver, jax.scipy.sparse.linalg.gmres)
|
|
lax_bicgstab = partial(solver, jax.scipy.sparse.linalg.bicgstab)
|
|
scipy_cg = partial(solver, scipy.sparse.linalg.cg)
|
|
scipy_gmres = partial(solver, scipy.sparse.linalg.gmres)
|
|
scipy_bicgstab = partial(solver, scipy.sparse.linalg.bicgstab)
|
|
|
|
|
|
def rand_sym_pos_def(rng, shape, dtype):
|
|
matrix = np.eye(N=shape[0], dtype=dtype) + rng(shape, dtype)
|
|
return matrix @ matrix.T.conj()
|
|
|
|
|
|
class LaxBackedScipyTests(jtu.JaxTestCase):
|
|
def _fetch_preconditioner(self, preconditioner, A, rng=None):
|
|
"""
|
|
Returns one of various preconditioning matrices depending on the identifier
|
|
`preconditioner' and the input matrix A whose inverse it supposedly
|
|
approximates.
|
|
"""
|
|
if preconditioner == 'identity':
|
|
M = np.eye(A.shape[0], dtype=A.dtype)
|
|
elif preconditioner == 'random':
|
|
if rng is None:
|
|
rng = jtu.rand_default(self.rng())
|
|
M = np.linalg.inv(rand_sym_pos_def(rng, A.shape, A.dtype))
|
|
elif preconditioner == 'exact':
|
|
M = np.linalg.inv(A)
|
|
else:
|
|
M = None
|
|
return M
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name":
|
|
"_shape={}_preconditioner={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
preconditioner),
|
|
"shape": shape, "dtype": dtype, "preconditioner": preconditioner}
|
|
for shape in [(4, 4), (7, 7)]
|
|
for dtype in [np.float64, np.complex128]
|
|
for preconditioner in [None, 'identity', 'exact', 'random']))
|
|
def test_cg_against_scipy(self, shape, dtype, preconditioner):
|
|
if not config.x64_enabled:
|
|
raise unittest.SkipTest("requires x64 mode")
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
A = rand_sym_pos_def(rng, shape, dtype)
|
|
b = rng(shape[:1], dtype)
|
|
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
|
|
|
|
def args_maker():
|
|
return A, b
|
|
|
|
self._CheckAgainstNumpy(
|
|
partial(scipy_cg, M=M, maxiter=1),
|
|
partial(lax_cg, M=M, maxiter=1),
|
|
args_maker,
|
|
tol=1e-12)
|
|
|
|
self._CheckAgainstNumpy(
|
|
partial(scipy_cg, M=M, maxiter=3),
|
|
partial(lax_cg, M=M, maxiter=3),
|
|
args_maker,
|
|
tol=1e-12)
|
|
|
|
self._CheckAgainstNumpy(
|
|
np.linalg.solve,
|
|
partial(lax_cg, M=M, atol=1e-10),
|
|
args_maker,
|
|
tol=1e-6)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name":
|
|
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
|
"shape": shape, "dtype": dtype}
|
|
for shape in [(2, 2)]
|
|
for dtype in float_types + complex_types))
|
|
def test_cg_as_solve(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
a = rng(shape, dtype)
|
|
b = rng(shape[:1], dtype)
|
|
|
|
expected = np.linalg.solve(posify(a), b)
|
|
actual = lax_cg(posify(a), b)
|
|
self.assertAllClose(expected, actual, atol=1e-5, rtol=1e-5)
|
|
|
|
actual = jit(lax_cg)(posify(a), b)
|
|
self.assertAllClose(expected, actual, atol=1e-5, rtol=1e-5)
|
|
|
|
# numerical gradients are only well defined if ``a`` is guaranteed to be
|
|
# positive definite.
|
|
jtu.check_grads(
|
|
lambda x, y: lax_cg(posify(x), y),
|
|
(a, b), order=2, rtol=2e-1)
|
|
|
|
def test_cg_ndarray(self):
|
|
A = lambda x: 2 * x
|
|
b = jnp.arange(9.0).reshape((3, 3))
|
|
expected = b / 2
|
|
actual, _ = jax.scipy.sparse.linalg.cg(A, b)
|
|
self.assertAllClose(expected, actual)
|
|
|
|
def test_cg_pytree(self):
|
|
A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
|
|
b = {"a": 1.0, "b": -4.0}
|
|
expected = {"a": 4.0, "b": -6.0}
|
|
actual, _ = jax.scipy.sparse.linalg.cg(A, b)
|
|
self.assertEqual(expected.keys(), actual.keys())
|
|
self.assertAlmostEqual(expected["a"], actual["a"], places=6)
|
|
self.assertAlmostEqual(expected["b"], actual["b"], places=6)
|
|
|
|
def test_cg_errors(self):
|
|
A = lambda x: x
|
|
b = jnp.zeros((2,))
|
|
with self.assertRaisesRegex(
|
|
ValueError, "x0 and b must have matching tree structure"):
|
|
jax.scipy.sparse.linalg.cg(A, {'x': b}, {'y': b})
|
|
with self.assertRaisesRegex(
|
|
ValueError, "x0 and b must have matching shape"):
|
|
jax.scipy.sparse.linalg.cg(A, b, b[:, np.newaxis])
|
|
with self.assertRaisesRegex(ValueError, "must be a square matrix"):
|
|
jax.scipy.sparse.linalg.cg(jnp.zeros((3, 2)), jnp.zeros((2,)))
|
|
with self.assertRaisesRegex(
|
|
TypeError, "linear operator must be either a function or ndarray"):
|
|
jax.scipy.sparse.linalg.cg([[1]], jnp.zeros((1,)))
|
|
|
|
def test_cg_without_pytree_equality(self):
|
|
|
|
@register_pytree_node_class
|
|
class MinimalPytree:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
def tree_flatten(self):
|
|
return [self.value], None
|
|
@classmethod
|
|
def tree_unflatten(cls, aux_data, children):
|
|
return cls(*children)
|
|
|
|
A = lambda x: MinimalPytree(2 * x.value)
|
|
b = MinimalPytree(jnp.arange(5.0))
|
|
expected = b.value / 2
|
|
actual, _ = jax.scipy.sparse.linalg.cg(A, b)
|
|
self.assertAllClose(expected, actual.value)
|
|
|
|
# BICGSTAB
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name":
|
|
"_shape={}_preconditioner={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
preconditioner),
|
|
"shape": shape, "dtype": dtype, "preconditioner": preconditioner}
|
|
for shape in [(5, 5)]
|
|
for dtype in [np.float64, np.complex128]
|
|
for preconditioner in [None, 'identity', 'exact', 'random']
|
|
))
|
|
def test_bicgstab_against_scipy(
|
|
self, shape, dtype, preconditioner):
|
|
if not config.jax_enable_x64:
|
|
raise unittest.SkipTest("requires x64 mode")
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
A = rng(shape, dtype)
|
|
b = rng(shape[:1], dtype)
|
|
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
|
|
|
|
def args_maker():
|
|
return A, b
|
|
|
|
self._CheckAgainstNumpy(
|
|
partial(scipy_bicgstab, M=M, maxiter=1),
|
|
partial(lax_bicgstab, M=M, maxiter=1),
|
|
args_maker,
|
|
tol=1e-5)
|
|
|
|
self._CheckAgainstNumpy(
|
|
partial(scipy_bicgstab, M=M, maxiter=2),
|
|
partial(lax_bicgstab, M=M, maxiter=2),
|
|
args_maker,
|
|
tol=1e-4)
|
|
|
|
self._CheckAgainstNumpy(
|
|
partial(scipy_bicgstab, M=M, maxiter=1),
|
|
partial(lax_bicgstab, M=M, maxiter=1),
|
|
args_maker,
|
|
tol=1e-4)
|
|
|
|
self._CheckAgainstNumpy(
|
|
np.linalg.solve,
|
|
partial(lax_bicgstab, M=M, atol=1e-6),
|
|
args_maker,
|
|
tol=1e-4)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name":
|
|
"_shape={}_preconditioner={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
preconditioner),
|
|
"shape": shape, "dtype": dtype, "preconditioner": preconditioner}
|
|
for shape in [(2, 2), (7, 7)]
|
|
for dtype in float_types + complex_types
|
|
for preconditioner in [None, 'identity', 'exact']
|
|
))
|
|
@jtu.skip_on_devices("gpu")
|
|
def test_bicgstab_on_identity_system(self, shape, dtype, preconditioner):
|
|
A = jnp.eye(shape[1], dtype=dtype)
|
|
solution = jnp.ones(shape[1], dtype=dtype)
|
|
rng = jtu.rand_default(self.rng())
|
|
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
|
|
b = matmul_high_precision(A, solution)
|
|
tol = shape[0] * jnp.finfo(dtype).eps
|
|
x, info = jax.scipy.sparse.linalg.bicgstab(A, b, tol=tol, atol=tol,
|
|
M=M)
|
|
using_x64 = solution.dtype.kind in {np.float64, np.complex128}
|
|
solution_tol = 1e-8 if using_x64 else 1e-4
|
|
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name":
|
|
"_shape={}_preconditioner={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
preconditioner),
|
|
"shape": shape, "dtype": dtype, "preconditioner": preconditioner
|
|
}
|
|
for shape in [(2, 2), (4, 4)]
|
|
for dtype in float_types + complex_types
|
|
for preconditioner in [None, 'identity', 'exact']
|
|
))
|
|
@jtu.skip_on_devices("gpu")
|
|
def test_bicgstab_on_random_system(self, shape, dtype, preconditioner):
|
|
rng = jtu.rand_default(self.rng())
|
|
A = rng(shape, dtype)
|
|
solution = rng(shape[1:], dtype)
|
|
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
|
|
b = matmul_high_precision(A, solution)
|
|
tol = shape[0] * jnp.finfo(A.dtype).eps
|
|
x, info = jax.scipy.sparse.linalg.bicgstab(A, b, tol=tol, atol=tol, M=M)
|
|
using_x64 = solution.dtype.kind in {np.float64, np.complex128}
|
|
solution_tol = 1e-8 if using_x64 else 1e-4
|
|
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
|
|
# solve = lambda A, b: jax.scipy.sparse.linalg.bicgstab(A, b)[0]
|
|
# jtu.check_grads(solve, (A, b), order=1, rtol=3e-1)
|
|
|
|
|
|
def test_bicgstab_pytree(self):
|
|
A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
|
|
b = {"a": 1.0, "b": -4.0}
|
|
expected = {"a": 4.0, "b": -6.0}
|
|
actual, _ = jax.scipy.sparse.linalg.bicgstab(A, b)
|
|
self.assertEqual(expected.keys(), actual.keys())
|
|
self.assertAlmostEqual(expected["a"], actual["a"], places=5)
|
|
self.assertAlmostEqual(expected["b"], actual["b"], places=5)
|
|
|
|
|
|
# GMRES
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name":
|
|
"_shape={}_preconditioner={}_solve_method={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
preconditioner,
|
|
solve_method),
|
|
"shape": shape, "dtype": dtype, "preconditioner": preconditioner,
|
|
"solve_method": solve_method}
|
|
for shape in [(3, 3)]
|
|
for dtype in [np.float64, np.complex128]
|
|
for preconditioner in [None, 'identity', 'exact', 'random']
|
|
for solve_method in ['incremental', 'batched']))
|
|
# TODO(b/186133663): test fails on CPU after LLVM change.
|
|
@jtu.skip_on_devices("cpu")
|
|
def test_gmres_against_scipy(
|
|
self, shape, dtype, preconditioner, solve_method):
|
|
if not config.x64_enabled:
|
|
raise unittest.SkipTest("requires x64 mode")
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
A = rng(shape, dtype)
|
|
b = rng(shape[:1], dtype)
|
|
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
|
|
|
|
def args_maker():
|
|
return A, b
|
|
|
|
self._CheckAgainstNumpy(
|
|
partial(scipy_gmres, M=M, restart=1, maxiter=1),
|
|
partial(lax_gmres, M=M, restart=1, maxiter=1, solve_method=solve_method),
|
|
args_maker,
|
|
tol=1e-10)
|
|
|
|
self._CheckAgainstNumpy(
|
|
partial(scipy_gmres, M=M, restart=1, maxiter=2),
|
|
partial(lax_gmres, M=M, restart=1, maxiter=2, solve_method=solve_method),
|
|
args_maker,
|
|
tol=1e-10)
|
|
|
|
self._CheckAgainstNumpy(
|
|
partial(scipy_gmres, M=M, restart=2, maxiter=1),
|
|
partial(lax_gmres, M=M, restart=2, maxiter=1, solve_method=solve_method),
|
|
args_maker,
|
|
tol=1e-10)
|
|
|
|
self._CheckAgainstNumpy(
|
|
np.linalg.solve,
|
|
partial(lax_gmres, M=M, atol=1e-6, solve_method=solve_method),
|
|
args_maker,
|
|
tol=1e-10)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name":
|
|
"_shape={}_preconditioner={}_solve_method={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
preconditioner,
|
|
solve_method),
|
|
"shape": shape, "dtype": dtype, "preconditioner": preconditioner,
|
|
"solve_method": solve_method}
|
|
for shape in [(2, 2), (7, 7)]
|
|
for dtype in float_types + complex_types
|
|
for preconditioner in [None, 'identity', 'exact']
|
|
for solve_method in ['batched', 'incremental']
|
|
))
|
|
@jtu.skip_on_devices("gpu")
|
|
def test_gmres_on_identity_system(self, shape, dtype, preconditioner,
|
|
solve_method):
|
|
A = jnp.eye(shape[1], dtype=dtype)
|
|
|
|
solution = jnp.ones(shape[1], dtype=dtype)
|
|
rng = jtu.rand_default(self.rng())
|
|
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
|
|
b = matmul_high_precision(A, solution)
|
|
restart = shape[-1]
|
|
tol = shape[0] * jnp.finfo(dtype).eps
|
|
x, info = jax.scipy.sparse.linalg.gmres(A, b, tol=tol, atol=tol,
|
|
restart=restart,
|
|
M=M, solve_method=solve_method)
|
|
using_x64 = solution.dtype.kind in {np.float64, np.complex128}
|
|
solution_tol = 1e-8 if using_x64 else 1e-4
|
|
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name":
|
|
"_shape={}_preconditioner={}_solve_method={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
preconditioner,
|
|
solve_method),
|
|
"shape": shape, "dtype": dtype, "preconditioner": preconditioner,
|
|
"solve_method": solve_method}
|
|
for shape in [(2, 2), (4, 4)]
|
|
for dtype in float_types + complex_types
|
|
for preconditioner in [None, 'identity', 'exact']
|
|
for solve_method in ['incremental', 'batched']
|
|
))
|
|
@jtu.skip_on_devices("gpu")
|
|
def test_gmres_on_random_system(self, shape, dtype, preconditioner,
|
|
solve_method):
|
|
rng = jtu.rand_default(self.rng())
|
|
A = rng(shape, dtype)
|
|
|
|
solution = rng(shape[1:], dtype)
|
|
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
|
|
b = matmul_high_precision(A, solution)
|
|
restart = shape[-1]
|
|
tol = shape[0] * jnp.finfo(A.dtype).eps
|
|
x, info = jax.scipy.sparse.linalg.gmres(A, b, tol=tol, atol=tol,
|
|
restart=restart,
|
|
M=M, solve_method=solve_method)
|
|
using_x64 = solution.dtype.kind in {np.float64, np.complex128}
|
|
solution_tol = 1e-8 if using_x64 else 1e-4
|
|
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
|
|
# solve = lambda A, b: jax.scipy.sparse.linalg.gmres(A, b)[0]
|
|
# jtu.check_grads(solve, (A, b), order=1, rtol=2e-1)
|
|
|
|
def test_gmres_pytree(self):
|
|
A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
|
|
b = {"a": 1.0, "b": -4.0}
|
|
expected = {"a": 4.0, "b": -6.0}
|
|
actual, _ = jax.scipy.sparse.linalg.gmres(A, b)
|
|
self.assertEqual(expected.keys(), actual.keys())
|
|
self.assertAlmostEqual(expected["a"], actual["a"], places=5)
|
|
self.assertAlmostEqual(expected["b"], actual["b"], places=5)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name":
|
|
"_shape={}_preconditioner={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
preconditioner),
|
|
"shape": shape, "dtype": dtype, "preconditioner": preconditioner}
|
|
for shape in [(2, 2), (3, 3)]
|
|
for dtype in float_types + complex_types
|
|
for preconditioner in [None, 'identity']))
|
|
def test_gmres_arnoldi_step(self, shape, dtype, preconditioner):
|
|
"""
|
|
The Arnoldi decomposition within GMRES is correct.
|
|
"""
|
|
if not config.x64_enabled:
|
|
raise unittest.SkipTest("requires x64 mode")
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
A = rng(shape, dtype)
|
|
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
|
|
if preconditioner is None:
|
|
M = lambda x: x
|
|
else:
|
|
M = partial(matmul_high_precision, M)
|
|
n = shape[0]
|
|
x0 = rng(shape[:1], dtype)
|
|
Q = np.zeros((n, n + 1), dtype=dtype)
|
|
Q[:, 0] = x0/jnp.linalg.norm(x0)
|
|
Q = jnp.array(Q)
|
|
H = jnp.eye(n, n + 1, dtype=dtype)
|
|
|
|
@jax.tree_util.Partial
|
|
def A_mv(x):
|
|
return matmul_high_precision(A, x)
|
|
for k in range(n):
|
|
Q, H, _ = jax._src.scipy.sparse.linalg._kth_arnoldi_iteration(
|
|
k, A_mv, M, Q, H)
|
|
QA = matmul_high_precision(Q[:, :n].conj().T, A)
|
|
QAQ = matmul_high_precision(QA, Q[:, :n])
|
|
self.assertAllClose(QAQ, H.T[:n, :], rtol=1e-5, atol=1e-5)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|