mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
508 lines
17 KiB
Python
508 lines
17 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from functools import partial
|
|
import re
|
|
import unittest
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
|
|
import numpy as np
|
|
|
|
import jax
|
|
from jax import lax
|
|
from jax.ad_checkpoint import checkpoint
|
|
from jax._src import test_util as jtu
|
|
import jax.numpy as jnp # scan tests use numpy
|
|
import jax.scipy as jsp
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
|
|
|
|
def high_precision_dot(a, b):
|
|
return lax.dot(a, b, precision=lax.Precision.HIGHEST)
|
|
|
|
|
|
def posify(matrix):
|
|
return high_precision_dot(matrix, matrix.T.conj())
|
|
|
|
|
|
# Simple optimization routine for testing custom_root
|
|
def binary_search(func, x0, low=0.0, high=100.0):
|
|
del x0 # unused
|
|
|
|
def cond(state):
|
|
low, high = state
|
|
midpoint = 0.5 * (low + high)
|
|
return (low < midpoint) & (midpoint < high)
|
|
|
|
def body(state):
|
|
low, high = state
|
|
midpoint = 0.5 * (low + high)
|
|
update_upper = func(midpoint) > 0
|
|
low = jnp.where(update_upper, low, midpoint)
|
|
high = jnp.where(update_upper, midpoint, high)
|
|
return (low, high)
|
|
|
|
solution, _ = lax.while_loop(cond, body, (low, high))
|
|
return solution
|
|
|
|
# Optimization routine for testing custom_root.
|
|
def newton_raphson(func, x0):
|
|
tol = 1e-16
|
|
max_it = 20
|
|
|
|
fx0, dfx0 = func(x0), jax.jacobian(func)(x0)
|
|
initial_state = (0, x0, fx0, dfx0) # (iteration, x, f(x), grad(f)(x))
|
|
|
|
def cond(state):
|
|
it, _, fx, _ = state
|
|
return (jnp.max(jnp.abs(fx)) > tol) & (it < max_it)
|
|
|
|
def body(state):
|
|
it, x, fx, dfx = state
|
|
step = jnp.linalg.solve(
|
|
dfx.reshape((-1, fx.size)), fx.ravel()
|
|
).reshape(fx.shape)
|
|
x_next = x - step
|
|
fx, dfx = func(x_next), jax.jacobian(func)(x_next)
|
|
return (it + 1, x_next, fx, dfx)
|
|
|
|
_, x, _, _ = lax.while_loop(cond, body, initial_state)
|
|
|
|
return x
|
|
|
|
|
|
class CustomLinearSolveTest(jtu.JaxTestCase):
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "nonsymmetric", "symmetric": False},
|
|
{"testcase_name": "symmetric", "symmetric": True},
|
|
)
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
|
def test_custom_linear_solve(self, symmetric):
|
|
|
|
def explicit_jacobian_solve(matvec, b):
|
|
return lax.stop_gradient(jnp.linalg.solve(jax.jacobian(matvec)(b), b))
|
|
|
|
def matrix_free_solve(matvec, b):
|
|
return lax.custom_linear_solve(
|
|
matvec, b, explicit_jacobian_solve, explicit_jacobian_solve,
|
|
symmetric=symmetric)
|
|
|
|
def linear_solve(a, b):
|
|
return matrix_free_solve(partial(high_precision_dot, a), b)
|
|
|
|
rng = self.rng()
|
|
a = rng.randn(3, 3)
|
|
if symmetric:
|
|
a = a + a.T
|
|
b = rng.randn(3)
|
|
jtu.check_grads(linear_solve, (a, b), order=2, rtol=3e-3)
|
|
|
|
expected = jnp.linalg.solve(a, b)
|
|
actual = jax.jit(linear_solve)(a, b)
|
|
self.assertAllClose(expected, actual)
|
|
|
|
c = rng.randn(3, 2)
|
|
expected = jnp.linalg.solve(a, c)
|
|
actual = jax.vmap(linear_solve, (None, 1), 1)(a, c)
|
|
self.assertAllClose(expected, actual)
|
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
|
def test_custom_linear_solve_aux(self):
|
|
def explicit_jacobian_solve_aux(matvec, b):
|
|
x = lax.stop_gradient(jnp.linalg.solve(jax.jacobian(matvec)(b), b))
|
|
return x, array_aux
|
|
|
|
def matrix_free_solve_aux(matvec, b):
|
|
return lax.custom_linear_solve(
|
|
matvec, b, explicit_jacobian_solve_aux, explicit_jacobian_solve_aux,
|
|
symmetric=True, has_aux=True)
|
|
|
|
def linear_solve_aux(a, b):
|
|
return matrix_free_solve_aux(partial(high_precision_dot, a), b)
|
|
|
|
# array aux values, to be able to use jtu.check_grads
|
|
array_aux = {"converged": np.array(1.), "nfev": np.array(12345.)}
|
|
rng = self.rng()
|
|
a = rng.randn(3, 3)
|
|
a = a + a.T
|
|
b = rng.randn(3)
|
|
|
|
expected = jnp.linalg.solve(a, b)
|
|
actual_nojit, nojit_aux = linear_solve_aux(a, b)
|
|
actual_jit, jit_aux = jax.jit(linear_solve_aux)(a, b)
|
|
|
|
self.assertAllClose(expected, actual_nojit)
|
|
self.assertAllClose(expected, actual_jit)
|
|
# scalar dict equality check
|
|
self.assertDictEqual(nojit_aux, array_aux)
|
|
self.assertDictEqual(jit_aux, array_aux)
|
|
|
|
# jvp / vjp test
|
|
jtu.check_grads(linear_solve_aux, (a, b), order=2, rtol=4e-3)
|
|
|
|
# vmap test
|
|
c = rng.randn(3, 2)
|
|
expected = jnp.linalg.solve(a, c)
|
|
expected_aux = jax.tree.map(partial(np.repeat, repeats=2), array_aux)
|
|
actual_vmap, vmap_aux = jax.vmap(linear_solve_aux, (None, 1), -1)(a, c)
|
|
|
|
self.assertAllClose(expected, actual_vmap)
|
|
jtu.check_eq(expected_aux, vmap_aux)
|
|
|
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
|
@unittest.skip("Test is too slow (> 1 minute at time of writing)")
|
|
def test_custom_linear_solve_zeros(self):
|
|
def explicit_jacobian_solve(matvec, b):
|
|
return lax.stop_gradient(jnp.linalg.solve(jax.jacobian(matvec)(b), b))
|
|
|
|
def matrix_free_solve(matvec, b):
|
|
return lax.custom_linear_solve(matvec, b, explicit_jacobian_solve,
|
|
explicit_jacobian_solve)
|
|
|
|
def linear_solve(a, b):
|
|
return matrix_free_solve(partial(high_precision_dot, a), b)
|
|
|
|
rng = self.rng()
|
|
a = rng.randn(3, 3)
|
|
b = rng.randn(3)
|
|
jtu.check_grads(lambda x: linear_solve(x, b), (a,), order=2,
|
|
rtol={np.float32: 5e-3})
|
|
jtu.check_grads(lambda x: linear_solve(a, x), (b,), order=2,
|
|
rtol={np.float32: 5e-3})
|
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
|
def test_custom_linear_solve_iterative(self):
|
|
|
|
def richardson_iteration(matvec, b, omega=0.1, tolerance=1e-6):
|
|
# Equivalent to vanilla gradient descent:
|
|
# https://en.wikipedia.org/wiki/Modified_Richardson_iteration
|
|
def cond(x):
|
|
return jnp.linalg.norm(matvec(x) - b) > tolerance
|
|
def body(x):
|
|
return x + omega * (b - matvec(x))
|
|
return lax.while_loop(cond, body, b)
|
|
|
|
def matrix_free_solve(matvec, b):
|
|
return lax.custom_linear_solve(matvec, b, richardson_iteration,
|
|
richardson_iteration)
|
|
|
|
def build_and_solve(a, b):
|
|
# intentionally non-linear in a and b
|
|
matvec = partial(high_precision_dot, jnp.exp(a))
|
|
return matrix_free_solve(matvec, jnp.cos(b))
|
|
|
|
# rng = self.rng()
|
|
# This test is very sensitive to the inputs, so we use a known working seed.
|
|
rng = np.random.RandomState(0)
|
|
a = rng.randn(2, 2)
|
|
b = rng.randn(2)
|
|
expected = jnp.linalg.solve(jnp.exp(a), jnp.cos(b))
|
|
actual = build_and_solve(a, b)
|
|
self.assertAllClose(expected, actual, atol=1e-5)
|
|
jtu.check_grads(build_and_solve, (a, b), atol=1e-5, order=2,
|
|
rtol={jnp.float32: 6e-2, jnp.float64: 2e-3})
|
|
|
|
# vmap across an empty dimension
|
|
jtu.check_grads(
|
|
jax.vmap(build_and_solve), (a[None, :, :], b[None, :]),
|
|
atol=1e-5,
|
|
order=2,
|
|
rtol={jnp.float32: 6e-2, jnp.float64: 2e-3})
|
|
|
|
def test_custom_linear_solve_cholesky(self):
|
|
|
|
def positive_definite_solve(a, b):
|
|
factors = jsp.linalg.cho_factor(a)
|
|
def solve(matvec, x):
|
|
return jsp.linalg.cho_solve(factors, x)
|
|
matvec = partial(high_precision_dot, a)
|
|
return lax.custom_linear_solve(matvec, b, solve, symmetric=True)
|
|
|
|
rng = self.rng()
|
|
a = rng.randn(2, 2)
|
|
b = rng.randn(2)
|
|
|
|
tol = {np.float32: 1E-3 if jtu.test_device_matches(["tpu"]) else 2E-4,
|
|
np.float64: 1E-12}
|
|
expected = jnp.linalg.solve(np.asarray(posify(a)), b)
|
|
actual = positive_definite_solve(posify(a), b)
|
|
self.assertAllClose(expected, actual, rtol=tol, atol=tol)
|
|
|
|
actual = jax.jit(positive_definite_solve)(posify(a), b)
|
|
self.assertAllClose(expected, actual, rtol=tol, atol=tol)
|
|
|
|
# numerical gradients are only well defined if ``a`` is guaranteed to be
|
|
# positive definite.
|
|
jtu.check_grads(
|
|
lambda x, y: positive_definite_solve(posify(x), y),
|
|
(a, b), order=2, rtol=0.3)
|
|
|
|
def test_custom_linear_solve_complex(self):
|
|
|
|
def solve(a, b):
|
|
def solve(matvec, x):
|
|
return jsp.linalg.solve(a, x)
|
|
def tr_solve(matvec, x):
|
|
return jsp.linalg.solve(a.T, x)
|
|
matvec = partial(high_precision_dot, a)
|
|
return lax.custom_linear_solve(matvec, b, solve, tr_solve)
|
|
|
|
rng = self.rng()
|
|
a = 0.5 * rng.randn(2, 2) + 0.5j * rng.randn(2, 2)
|
|
b = 0.5 * rng.randn(2) + 0.5j * rng.randn(2)
|
|
jtu.check_grads(solve, (a, b), order=2, rtol=1e-2)
|
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
|
def test_custom_linear_solve_lu(self):
|
|
|
|
def linear_solve(a, b):
|
|
a_factors = jsp.linalg.lu_factor(a)
|
|
at_factors = jsp.linalg.lu_factor(a.T)
|
|
def solve(matvec, x):
|
|
return jsp.linalg.lu_solve(a_factors, x)
|
|
def transpose_solve(vecmat, x):
|
|
return jsp.linalg.lu_solve(at_factors, x)
|
|
return lax.custom_linear_solve(
|
|
partial(high_precision_dot, a), b, solve, transpose_solve)
|
|
|
|
rng = self.rng()
|
|
a = rng.randn(3, 3)
|
|
b = rng.randn(3)
|
|
|
|
expected = jnp.linalg.solve(a, b)
|
|
actual = linear_solve(a, b)
|
|
self.assertAllClose(expected, actual)
|
|
|
|
jtu.check_grads(linear_solve, (a, b), order=2, rtol=2e-3)
|
|
|
|
# regression test for https://github.com/jax-ml/jax/issues/1536
|
|
jtu.check_grads(jax.jit(linear_solve), (a, b), order=2,
|
|
rtol={np.float32: 2e-3})
|
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
|
def test_custom_linear_solve_without_transpose_solve(self):
|
|
|
|
def explicit_jacobian_solve(matvec, b):
|
|
return lax.stop_gradient(jnp.linalg.solve(jax.jacobian(matvec)(b), b))
|
|
|
|
def loss(a, b):
|
|
matvec = partial(high_precision_dot, a)
|
|
x = lax.custom_linear_solve(matvec, b, explicit_jacobian_solve)
|
|
return jnp.sum(x)
|
|
|
|
rng = self.rng()
|
|
a = rng.randn(2, 2)
|
|
b = rng.randn(2)
|
|
|
|
jtu.check_grads(loss, (a, b), order=2, modes=['fwd'],
|
|
atol={np.float32: 2e-3, np.float64: 1e-11})
|
|
jtu.check_grads(jax.vmap(loss), (a[None,:,:], b[None,:]), order=2,
|
|
modes=['fwd'], atol={np.float32: 2e-3, np.float64: 1e-11})
|
|
|
|
with self.assertRaisesRegex(TypeError, "transpose_solve required"):
|
|
jax.grad(loss)(a, b)
|
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
|
@unittest.skip("Test is too slow (> 2 minutes at time of writing)")
|
|
def test_custom_linear_solve_pytree(self):
|
|
"""Test custom linear solve with inputs and outputs that are pytrees."""
|
|
|
|
def unrolled_matvec(mat, x):
|
|
"""Apply a Python list of lists of scalars to a list of scalars."""
|
|
result = []
|
|
for i in range(len(mat)):
|
|
v = 0
|
|
for j in range(len(x)):
|
|
if mat[i][j] is not None:
|
|
v += mat[i][j] * x[j]
|
|
result.append(v)
|
|
return result
|
|
|
|
def unrolled_substitution_solve(matvec, b, lower_tri):
|
|
"""Solve a triangular unrolled system with fwd/back substitution."""
|
|
zero = jnp.zeros(())
|
|
one = jnp.ones(())
|
|
x = [zero for _ in b]
|
|
ordering = range(len(b)) if lower_tri else range(len(b) - 1, -1, -1)
|
|
for i in ordering:
|
|
residual = b[i] - matvec(x)[i]
|
|
diagonal = matvec([one if i == j else zero for j in range(len(b))])[i]
|
|
x[i] = residual / diagonal
|
|
return x
|
|
|
|
def custom_unrolled_lower_tri_solve(mat, b):
|
|
return lax.custom_linear_solve(
|
|
partial(unrolled_matvec, mat), b,
|
|
partial(unrolled_substitution_solve, lower_tri=True),
|
|
partial(unrolled_substitution_solve, lower_tri=False))
|
|
|
|
mat = [[1.0, None, None, None, None, None, None],
|
|
[1.0, 1.0, None, None, None, None, None],
|
|
[None, 1.0, 1.0, None, None, None, None],
|
|
[None, None, 1.0, 1.0, None, None, None],
|
|
[None, None, None, 1.0, 1.0, None, None],
|
|
[None, None, None, None, None, 2.0, None],
|
|
[None, None, None, None, None, 4.0, 3.0]]
|
|
|
|
rng = self.rng()
|
|
b = list(rng.randn(7))
|
|
|
|
# Non-batched
|
|
jtu.check_grads(custom_unrolled_lower_tri_solve, (mat, b), order=2,
|
|
rtol={jnp.float32: 2e-2})
|
|
|
|
# Batch one element of b (which, because of unrolling, should only affect
|
|
# the first block of outputs)
|
|
b_bat = list(b)
|
|
b_bat[3] = rng.randn(3)
|
|
jtu.check_grads(
|
|
jax.vmap(
|
|
custom_unrolled_lower_tri_solve,
|
|
in_axes=(None, [None, None, None, 0, None, None, None]),
|
|
out_axes=[0, 0, 0, 0, 0, None, None]), (mat, b_bat),
|
|
order=2,
|
|
rtol={jnp.float32: 1e-2})
|
|
|
|
# Batch one element of mat (again only affecting first block)
|
|
mat[2][1] = rng.randn(3)
|
|
mat_axis_tree = [
|
|
[0 if i == 2 and j == 1 else None for j in range(7)] for i in range(7)
|
|
]
|
|
jtu.check_grads(
|
|
jax.vmap(
|
|
custom_unrolled_lower_tri_solve,
|
|
in_axes=(mat_axis_tree, None),
|
|
out_axes=[0, 0, 0, 0, 0, None, None]), (mat, b),
|
|
order=2)
|
|
|
|
|
|
|
|
def test_custom_linear_solve_pytree_with_aux(self):
|
|
# Check that lax.custom_linear_solve handles
|
|
# pytree inputs + has_aux=True
|
|
# https://github.com/jax-ml/jax/pull/13093
|
|
|
|
aux_orig = {'a': 1, 'b': 2}
|
|
b = {'c': jnp.ones(2), 'd': jnp.ones(3)}
|
|
|
|
def solve_with_aux(matvec, b):
|
|
return b, aux_orig
|
|
|
|
sol, aux = lax.custom_linear_solve(
|
|
lambda x:x,
|
|
b,
|
|
solve_with_aux,
|
|
solve_with_aux,
|
|
has_aux=True)
|
|
|
|
assert len(aux.keys()) == 2
|
|
assert 'a' in aux
|
|
assert 'b' in aux
|
|
self.assertAllClose(aux['a'], aux_orig['a'], check_dtypes=False)
|
|
self.assertAllClose(aux['b'], aux_orig['b'], check_dtypes=False)
|
|
|
|
|
|
def test_custom_linear_solve_errors(self):
|
|
|
|
solve = lambda f, x: x
|
|
|
|
with self.assertRaisesRegex(TypeError, re.escape("matvec() output pytree")):
|
|
lax.custom_linear_solve(lambda x: [x], 1.0, solve, solve)
|
|
with self.assertRaisesRegex(TypeError, re.escape("solve() output pytree")):
|
|
lax.custom_linear_solve(lambda x: x, 1.0, lambda f, x: [x], solve)
|
|
with self.assertRaisesRegex(
|
|
TypeError, re.escape("transpose_solve() output pytree")):
|
|
lax.custom_linear_solve(lambda x: x, 1.0, solve, lambda f, x: [x])
|
|
|
|
with self.assertRaisesRegex(ValueError, re.escape("solve() output shapes")):
|
|
lax.custom_linear_solve(lambda x: x, 1.0, lambda f, x: jnp.ones(2), solve)
|
|
|
|
def bad_matvec_usage(a):
|
|
return lax.custom_linear_solve(
|
|
lambda x: a * jnp.ones(2), 1.0, solve, solve)
|
|
with self.assertRaisesRegex(ValueError, re.escape("matvec() output shapes")):
|
|
jax.jvp(bad_matvec_usage, (1.0,), (1.0,))
|
|
|
|
def test_custom_linear_solve_new_remat(self):
|
|
|
|
def explicit_jacobian_solve(matvec, b):
|
|
return lax.stop_gradient(jnp.linalg.solve(jax.jacobian(matvec)(b), b))
|
|
|
|
def matrix_free_solve(matvec, b):
|
|
return lax.custom_linear_solve(
|
|
matvec, b, explicit_jacobian_solve, explicit_jacobian_solve,
|
|
symmetric=True)
|
|
|
|
@checkpoint
|
|
def linear_solve(a, b):
|
|
return matrix_free_solve(partial(high_precision_dot, a), b)
|
|
|
|
rng = self.rng()
|
|
a = rng.randn(3, 3)
|
|
if True:
|
|
a = a + a.T
|
|
b = rng.randn(3)
|
|
jtu.check_grads(linear_solve, (a, b), order=1, rtol=3e-3, modes=['rev'])
|
|
|
|
@partial(checkpoint, policy=lambda *_, **__: True)
|
|
def linear_solve(a, b):
|
|
return matrix_free_solve(partial(high_precision_dot, a), b)
|
|
jtu.check_grads(linear_solve, (a, b), order=1, rtol=3e-3, modes=['rev'])
|
|
|
|
def test_custom_linear_solve_batching_with_aux(self):
|
|
def solve(mv, b):
|
|
aux = (np.array(1.), True, 0)
|
|
return mv(b), aux
|
|
|
|
def solve_aux(x):
|
|
matvec = lambda y: jax.tree.map(partial(jnp.dot, A), y)
|
|
return lax.custom_linear_solve(matvec, (x, x), solve, solve, symmetric=True, has_aux=True)
|
|
|
|
rng = self.rng()
|
|
A = rng.randn(3, 3)
|
|
A = A + A.T
|
|
b = rng.randn(3, 3)
|
|
|
|
# doesn't crash
|
|
jax.vmap(solve_aux)(b)
|
|
|
|
def test_custom_linear_solve_ordered_effects(self):
|
|
# See https://github.com/jax-ml/jax/issues/26087
|
|
def mat_vec(v):
|
|
jax.debug.callback(lambda: print("mat_vec"), ordered=True)
|
|
return v
|
|
|
|
def solve(b):
|
|
return lax.custom_linear_solve(mat_vec, b, lambda matvec, x: matvec(x))
|
|
|
|
b = self.rng().randn(24)
|
|
with jtu.capture_stdout() as output:
|
|
expected = solve(b)
|
|
jax.effects_barrier()
|
|
self.assertEqual(output(), "mat_vec\n")
|
|
with jtu.capture_stdout() as output:
|
|
computed = jax.jit(solve)(b)
|
|
jax.effects_barrier()
|
|
self.assertEqual(output(), "mat_vec\n")
|
|
self.assertAllClose(computed, expected)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|