rocm_jax/tests/custom_root_test.py
Jake VanderPlas f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00

247 lines
8.2 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
from jax import lax
from jax._src import test_util as jtu
import jax.numpy as jnp # scan tests use numpy
import jax.scipy as jsp
jax.config.parse_flags_with_absl()
def high_precision_dot(a, b):
return lax.dot(a, b, precision=lax.Precision.HIGHEST)
# Simple optimization routine for testing custom_root
def binary_search(func, x0, low=0.0, high=100.0):
del x0 # unused
def cond(state):
low, high = state
midpoint = 0.5 * (low + high)
return (low < midpoint) & (midpoint < high)
def body(state):
low, high = state
midpoint = 0.5 * (low + high)
update_upper = func(midpoint) > 0
low = jnp.where(update_upper, low, midpoint)
high = jnp.where(update_upper, midpoint, high)
return (low, high)
solution, _ = lax.while_loop(cond, body, (low, high))
return solution
# Optimization routine for testing custom_root.
def newton_raphson(func, x0):
tol = 1e-16
max_it = 20
fx0, dfx0 = func(x0), jax.jacobian(func)(x0)
initial_state = (0, x0, fx0, dfx0) # (iteration, x, f(x), grad(f)(x))
def cond(state):
it, _, fx, _ = state
return (jnp.max(jnp.abs(fx)) > tol) & (it < max_it)
def body(state):
it, x, fx, dfx = state
step = jnp.linalg.solve(
dfx.reshape((-1, fx.size)), fx.ravel()
).reshape(fx.shape)
x_next = x - step
fx, dfx = func(x_next), jax.jacobian(func)(x_next)
return (it + 1, x_next, fx, dfx)
_, x, _, _ = lax.while_loop(cond, body, initial_state)
return x
class CustomRootTest(jtu.JaxTestCase):
@parameterized.named_parameters(
{"testcase_name": "binary_search", "solve_method": binary_search},
{"testcase_name": "newton_raphson", "solve_method": newton_raphson},
)
def test_custom_root_scalar(self, solve_method):
def scalar_solve(f, y):
return y / f(1.0)
def sqrt_cubed(x, tangent_solve=scalar_solve):
f = lambda y: y ** 2 - x ** 3
# Note: Nonzero derivative at x0 required for newton_raphson
return lax.custom_root(f, 1.0, solve_method, tangent_solve)
value, grad = jax.value_and_grad(sqrt_cubed)(5.0)
self.assertAllClose(value, 5 ** 1.5, check_dtypes=False, rtol=1e-6)
rtol = 5e-6 if jtu.test_device_matches(["tpu"]) else 1e-7
self.assertAllClose(grad, jax.grad(pow)(5.0, 1.5), check_dtypes=False,
rtol=rtol)
jtu.check_grads(sqrt_cubed, (5.0,), order=2,
rtol={jnp.float32: 1e-2, jnp.float64: 1e-3})
inputs = jnp.array([4.0, 5.0])
results = jax.vmap(sqrt_cubed)(inputs)
self.assertAllClose(
results, inputs ** 1.5, check_dtypes=False,
atol={jnp.float32: 1e-3, jnp.float64: 1e-6},
rtol={jnp.float32: 1e-3, jnp.float64: 1e-6},
)
results = jax.jit(sqrt_cubed)(5.0)
self.assertAllClose(
results, 5.0**1.5, check_dtypes=False, rtol={np.float64: 1e-7})
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_custom_root_vector_with_solve_closure(self):
def vector_solve(f, y):
return jnp.linalg.solve(jax.jacobian(f)(y), y)
def linear_solve(a, b):
f = lambda y: high_precision_dot(a, y) - b
x0 = jnp.zeros_like(b)
solution = jnp.linalg.solve(a, b)
oracle = lambda func, x0: solution
return lax.custom_root(f, x0, oracle, vector_solve)
rng = self.rng()
a = rng.randn(2, 2)
b = rng.randn(2)
jtu.check_grads(linear_solve, (a, b), order=2,
atol={np.float32: 1e-2, np.float64: 1e-11})
actual = jax.jit(linear_solve)(a, b)
expected = jnp.linalg.solve(a, b)
self.assertAllClose(expected, actual)
def test_custom_root_vector_nonlinear(self):
def nonlinear_func(x, y):
# func(x, y) == 0 if and only if x == y.
return (x - y) * (x**2 + y**2 + 1)
def tangent_solve(g, y):
return jnp.linalg.solve(
jax.jacobian(g)(y).reshape(-1, y.size),
y.ravel()
).reshape(y.shape)
def nonlinear_solve(y):
f = lambda x: nonlinear_func(x, y)
x0 = -jnp.ones_like(y)
return lax.custom_root(f, x0, newton_raphson, tangent_solve)
y = self.rng().randn(3, 1)
jtu.check_grads(nonlinear_solve, (y,), order=2,
rtol={jnp.float32: 1e-2, jnp.float64: 1e-3})
actual = jax.jit(nonlinear_solve)(y)
self.assertAllClose(y, actual, rtol=1e-5, atol=1e-5)
def test_custom_root_with_custom_linear_solve(self):
def linear_solve(a, b):
f = lambda x: high_precision_dot(a, x) - b
factors = jsp.linalg.cho_factor(a)
cho_solve = lambda f, b: jsp.linalg.cho_solve(factors, b)
def pos_def_solve(g, b):
return lax.custom_linear_solve(g, b, cho_solve, symmetric=True)
return lax.custom_root(f, b, cho_solve, pos_def_solve)
rng = self.rng()
a = rng.randn(2, 2)
b = rng.randn(2)
actual = linear_solve(high_precision_dot(a, a.T), b)
expected = jnp.linalg.solve(high_precision_dot(a, a.T), b)
self.assertAllClose(expected, actual)
actual = jax.jit(linear_solve)(high_precision_dot(a, a.T), b)
expected = jnp.linalg.solve(high_precision_dot(a, a.T), b)
self.assertAllClose(expected, actual)
jtu.check_grads(lambda x, y: linear_solve(high_precision_dot(x, x.T), y),
(a, b), order=2, rtol={jnp.float32: 1e-2})
def test_custom_root_with_aux(self):
def root_aux(a, b):
f = lambda x: high_precision_dot(a, x) - b
factors = jsp.linalg.cho_factor(a)
cho_solve = lambda f, b: (jsp.linalg.cho_solve(factors, b), orig_aux)
def pos_def_solve(g, b):
# prune aux to allow use as tangent_solve
cho_solve_noaux = lambda f, b: cho_solve(f, b)[0]
return lax.custom_linear_solve(g, b, cho_solve_noaux, symmetric=True)
return lax.custom_root(f, b, cho_solve, pos_def_solve, has_aux=True)
orig_aux = {"converged": np.array(1.), "nfev": np.array(12345.), "grad": np.array([1.0, 2.0, 3.0])}
rng = self.rng()
a = rng.randn(2, 2)
b = rng.randn(2)
actual, actual_aux = root_aux(high_precision_dot(a, a.T), b)
actual_jit, actual_jit_aux = jax.jit(root_aux)(high_precision_dot(a, a.T), b)
expected = jnp.linalg.solve(high_precision_dot(a, a.T), b)
self.assertAllClose(expected, actual)
self.assertAllClose(expected, actual_jit)
jtu.check_eq(actual_jit_aux, orig_aux)
# grad check with aux
jtu.check_grads(lambda x, y: root_aux(high_precision_dot(x, x.T), y),
(a, b), order=2, rtol={jnp.float32: 1e-2, np.float64: 3e-5})
# test vmap and jvp combined by jacfwd
fwd = jax.jacfwd(lambda x, y: root_aux(high_precision_dot(x, x.T), y), argnums=(0, 1))
expected_fwd = jax.jacfwd(lambda x, y: jnp.linalg.solve(high_precision_dot(x, x.T), y), argnums=(0, 1))
fwd_val, fwd_aux = fwd(a, b)
expected_fwd_val = expected_fwd(a, b)
self.assertAllClose(fwd_val, expected_fwd_val, rtol={np.float32: 5E-6, np.float64: 5E-12})
jtu.check_close(fwd_aux, jax.tree.map(jnp.zeros_like, fwd_aux))
def test_custom_root_errors(self):
with self.assertRaisesRegex(TypeError, re.escape("f() output pytree")):
lax.custom_root(lambda x: (x, x), 0.0, lambda f, x: x, lambda f, x: x)
with self.assertRaisesRegex(TypeError, re.escape("solve() output pytree")):
lax.custom_root(lambda x: x, 0.0, lambda f, x: (x, x), lambda f, x: x)
def dummy_root_usage(x):
f = lambda y: x - y
return lax.custom_root(f, 0.0, lambda f, x: x, lambda f, x: (x, x))
with self.assertRaisesRegex(
TypeError, re.escape("tangent_solve() output pytree")):
jax.jvp(dummy_root_usage, (0.0,), (0.0,))
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())