mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00

In some environments this appears to import the config module rather than the config object.
247 lines
8.2 KiB
Python
247 lines
8.2 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import re
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
|
|
import numpy as np
|
|
|
|
import jax
|
|
from jax import lax
|
|
from jax._src import test_util as jtu
|
|
import jax.numpy as jnp # scan tests use numpy
|
|
import jax.scipy as jsp
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
|
|
|
|
def high_precision_dot(a, b):
|
|
return lax.dot(a, b, precision=lax.Precision.HIGHEST)
|
|
|
|
|
|
# Simple optimization routine for testing custom_root
|
|
def binary_search(func, x0, low=0.0, high=100.0):
|
|
del x0 # unused
|
|
|
|
def cond(state):
|
|
low, high = state
|
|
midpoint = 0.5 * (low + high)
|
|
return (low < midpoint) & (midpoint < high)
|
|
|
|
def body(state):
|
|
low, high = state
|
|
midpoint = 0.5 * (low + high)
|
|
update_upper = func(midpoint) > 0
|
|
low = jnp.where(update_upper, low, midpoint)
|
|
high = jnp.where(update_upper, midpoint, high)
|
|
return (low, high)
|
|
|
|
solution, _ = lax.while_loop(cond, body, (low, high))
|
|
return solution
|
|
|
|
# Optimization routine for testing custom_root.
|
|
def newton_raphson(func, x0):
|
|
tol = 1e-16
|
|
max_it = 20
|
|
|
|
fx0, dfx0 = func(x0), jax.jacobian(func)(x0)
|
|
initial_state = (0, x0, fx0, dfx0) # (iteration, x, f(x), grad(f)(x))
|
|
|
|
def cond(state):
|
|
it, _, fx, _ = state
|
|
return (jnp.max(jnp.abs(fx)) > tol) & (it < max_it)
|
|
|
|
def body(state):
|
|
it, x, fx, dfx = state
|
|
step = jnp.linalg.solve(
|
|
dfx.reshape((-1, fx.size)), fx.ravel()
|
|
).reshape(fx.shape)
|
|
x_next = x - step
|
|
fx, dfx = func(x_next), jax.jacobian(func)(x_next)
|
|
return (it + 1, x_next, fx, dfx)
|
|
|
|
_, x, _, _ = lax.while_loop(cond, body, initial_state)
|
|
|
|
return x
|
|
|
|
|
|
class CustomRootTest(jtu.JaxTestCase):
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "binary_search", "solve_method": binary_search},
|
|
{"testcase_name": "newton_raphson", "solve_method": newton_raphson},
|
|
)
|
|
def test_custom_root_scalar(self, solve_method):
|
|
|
|
def scalar_solve(f, y):
|
|
return y / f(1.0)
|
|
|
|
def sqrt_cubed(x, tangent_solve=scalar_solve):
|
|
f = lambda y: y ** 2 - x ** 3
|
|
# Note: Nonzero derivative at x0 required for newton_raphson
|
|
return lax.custom_root(f, 1.0, solve_method, tangent_solve)
|
|
|
|
value, grad = jax.value_and_grad(sqrt_cubed)(5.0)
|
|
self.assertAllClose(value, 5 ** 1.5, check_dtypes=False, rtol=1e-6)
|
|
rtol = 5e-6 if jtu.test_device_matches(["tpu"]) else 1e-7
|
|
self.assertAllClose(grad, jax.grad(pow)(5.0, 1.5), check_dtypes=False,
|
|
rtol=rtol)
|
|
jtu.check_grads(sqrt_cubed, (5.0,), order=2,
|
|
rtol={jnp.float32: 1e-2, jnp.float64: 1e-3})
|
|
|
|
inputs = jnp.array([4.0, 5.0])
|
|
results = jax.vmap(sqrt_cubed)(inputs)
|
|
self.assertAllClose(
|
|
results, inputs ** 1.5, check_dtypes=False,
|
|
atol={jnp.float32: 1e-3, jnp.float64: 1e-6},
|
|
rtol={jnp.float32: 1e-3, jnp.float64: 1e-6},
|
|
)
|
|
|
|
results = jax.jit(sqrt_cubed)(5.0)
|
|
self.assertAllClose(
|
|
results, 5.0**1.5, check_dtypes=False, rtol={np.float64: 1e-7})
|
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
|
def test_custom_root_vector_with_solve_closure(self):
|
|
|
|
def vector_solve(f, y):
|
|
return jnp.linalg.solve(jax.jacobian(f)(y), y)
|
|
|
|
def linear_solve(a, b):
|
|
f = lambda y: high_precision_dot(a, y) - b
|
|
x0 = jnp.zeros_like(b)
|
|
solution = jnp.linalg.solve(a, b)
|
|
oracle = lambda func, x0: solution
|
|
return lax.custom_root(f, x0, oracle, vector_solve)
|
|
|
|
rng = self.rng()
|
|
a = rng.randn(2, 2)
|
|
b = rng.randn(2)
|
|
jtu.check_grads(linear_solve, (a, b), order=2,
|
|
atol={np.float32: 1e-2, np.float64: 1e-11})
|
|
|
|
actual = jax.jit(linear_solve)(a, b)
|
|
expected = jnp.linalg.solve(a, b)
|
|
self.assertAllClose(expected, actual)
|
|
|
|
def test_custom_root_vector_nonlinear(self):
|
|
|
|
def nonlinear_func(x, y):
|
|
# func(x, y) == 0 if and only if x == y.
|
|
return (x - y) * (x**2 + y**2 + 1)
|
|
|
|
def tangent_solve(g, y):
|
|
return jnp.linalg.solve(
|
|
jax.jacobian(g)(y).reshape(-1, y.size),
|
|
y.ravel()
|
|
).reshape(y.shape)
|
|
|
|
def nonlinear_solve(y):
|
|
f = lambda x: nonlinear_func(x, y)
|
|
x0 = -jnp.ones_like(y)
|
|
return lax.custom_root(f, x0, newton_raphson, tangent_solve)
|
|
|
|
y = self.rng().randn(3, 1)
|
|
jtu.check_grads(nonlinear_solve, (y,), order=2,
|
|
rtol={jnp.float32: 1e-2, jnp.float64: 1e-3})
|
|
|
|
actual = jax.jit(nonlinear_solve)(y)
|
|
self.assertAllClose(y, actual, rtol=1e-5, atol=1e-5)
|
|
|
|
def test_custom_root_with_custom_linear_solve(self):
|
|
|
|
def linear_solve(a, b):
|
|
f = lambda x: high_precision_dot(a, x) - b
|
|
factors = jsp.linalg.cho_factor(a)
|
|
cho_solve = lambda f, b: jsp.linalg.cho_solve(factors, b)
|
|
def pos_def_solve(g, b):
|
|
return lax.custom_linear_solve(g, b, cho_solve, symmetric=True)
|
|
return lax.custom_root(f, b, cho_solve, pos_def_solve)
|
|
|
|
rng = self.rng()
|
|
a = rng.randn(2, 2)
|
|
b = rng.randn(2)
|
|
|
|
actual = linear_solve(high_precision_dot(a, a.T), b)
|
|
expected = jnp.linalg.solve(high_precision_dot(a, a.T), b)
|
|
self.assertAllClose(expected, actual)
|
|
|
|
actual = jax.jit(linear_solve)(high_precision_dot(a, a.T), b)
|
|
expected = jnp.linalg.solve(high_precision_dot(a, a.T), b)
|
|
self.assertAllClose(expected, actual)
|
|
|
|
jtu.check_grads(lambda x, y: linear_solve(high_precision_dot(x, x.T), y),
|
|
(a, b), order=2, rtol={jnp.float32: 1e-2})
|
|
|
|
def test_custom_root_with_aux(self):
|
|
def root_aux(a, b):
|
|
f = lambda x: high_precision_dot(a, x) - b
|
|
factors = jsp.linalg.cho_factor(a)
|
|
cho_solve = lambda f, b: (jsp.linalg.cho_solve(factors, b), orig_aux)
|
|
|
|
def pos_def_solve(g, b):
|
|
# prune aux to allow use as tangent_solve
|
|
cho_solve_noaux = lambda f, b: cho_solve(f, b)[0]
|
|
return lax.custom_linear_solve(g, b, cho_solve_noaux, symmetric=True)
|
|
|
|
return lax.custom_root(f, b, cho_solve, pos_def_solve, has_aux=True)
|
|
|
|
orig_aux = {"converged": np.array(1.), "nfev": np.array(12345.), "grad": np.array([1.0, 2.0, 3.0])}
|
|
|
|
rng = self.rng()
|
|
a = rng.randn(2, 2)
|
|
b = rng.randn(2)
|
|
|
|
actual, actual_aux = root_aux(high_precision_dot(a, a.T), b)
|
|
actual_jit, actual_jit_aux = jax.jit(root_aux)(high_precision_dot(a, a.T), b)
|
|
expected = jnp.linalg.solve(high_precision_dot(a, a.T), b)
|
|
|
|
self.assertAllClose(expected, actual)
|
|
self.assertAllClose(expected, actual_jit)
|
|
jtu.check_eq(actual_jit_aux, orig_aux)
|
|
|
|
# grad check with aux
|
|
jtu.check_grads(lambda x, y: root_aux(high_precision_dot(x, x.T), y),
|
|
(a, b), order=2, rtol={jnp.float32: 1e-2, np.float64: 3e-5})
|
|
|
|
# test vmap and jvp combined by jacfwd
|
|
fwd = jax.jacfwd(lambda x, y: root_aux(high_precision_dot(x, x.T), y), argnums=(0, 1))
|
|
expected_fwd = jax.jacfwd(lambda x, y: jnp.linalg.solve(high_precision_dot(x, x.T), y), argnums=(0, 1))
|
|
|
|
fwd_val, fwd_aux = fwd(a, b)
|
|
expected_fwd_val = expected_fwd(a, b)
|
|
self.assertAllClose(fwd_val, expected_fwd_val, rtol={np.float32: 5E-6, np.float64: 5E-12})
|
|
|
|
jtu.check_close(fwd_aux, jax.tree.map(jnp.zeros_like, fwd_aux))
|
|
|
|
def test_custom_root_errors(self):
|
|
with self.assertRaisesRegex(TypeError, re.escape("f() output pytree")):
|
|
lax.custom_root(lambda x: (x, x), 0.0, lambda f, x: x, lambda f, x: x)
|
|
with self.assertRaisesRegex(TypeError, re.escape("solve() output pytree")):
|
|
lax.custom_root(lambda x: x, 0.0, lambda f, x: (x, x), lambda f, x: x)
|
|
|
|
def dummy_root_usage(x):
|
|
f = lambda y: x - y
|
|
return lax.custom_root(f, 0.0, lambda f, x: x, lambda f, x: (x, x))
|
|
|
|
with self.assertRaisesRegex(
|
|
TypeError, re.escape("tangent_solve() output pytree")):
|
|
jax.jvp(dummy_root_usage, (0.0,), (0.0,))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|