rocm_jax/tests/jet_test.py
Jake VanderPlas f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00

433 lines
16 KiB
Python

# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import reduce, partial
from absl.testing import absltest
import numpy as np
import unittest
import jax
from jax._src import test_util as jtu
import jax.numpy as jnp
import jax.scipy.special
from jax import random
from jax import jacfwd, jit
from jax.example_libraries import stax
from jax.experimental.jet import jet, fact, zero_series
from jax import lax
jax.config.parse_flags_with_absl()
def jvp_taylor(fun, primals, series):
# Computes the Taylor series the slow way, with nested jvp.
order, = set(map(len, series))
primals = tuple(jnp.asarray(p) for p in primals)
def composition(eps):
taylor_terms = [sum(eps ** (i+1) * terms[i] / fact(i + 1)
for i in range(len(terms))) for terms in series]
nudged_args = [(x + t).astype(x.dtype) for x, t in zip(primals, taylor_terms)]
return fun(*nudged_args)
primal_out = fun(*primals)
terms_out = [repeated(jacfwd, i+1)(composition)(0.) for i in range(order)]
return primal_out, terms_out
def repeated(f, n):
def rfun(p):
return reduce(lambda x, _: f(x), range(n), p)
return rfun
def transform(lims, x):
return x * (lims[1] - lims[0]) + lims[0]
class JetTest(jtu.JaxTestCase):
def check_jet(self, fun, primals, series, atol=1e-5, rtol=1e-5,
check_dtypes=True):
# Convert to jax arrays to ensure dtype canonicalization.
primals = jax.tree.map(jnp.asarray, primals)
series = jax.tree.map(jnp.asarray, series)
y, terms = jet(fun, primals, series)
expected_y, expected_terms = jvp_taylor(fun, primals, series)
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol,
check_dtypes=check_dtypes)
self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol,
check_dtypes=check_dtypes)
def check_jet_finite(self, fun, primals, series, atol=1e-5, rtol=1e-5,
check_dtypes=True):
# Convert to jax arrays to ensure dtype canonicalization.
primals = jax.tree.map(jnp.asarray, primals)
series = jax.tree.map(jnp.asarray, series)
y, terms = jet(fun, primals, series)
expected_y, expected_terms = jvp_taylor(fun, primals, series)
def _convert(x):
return jnp.where(jnp.isfinite(x), x, jnp.nan)
y = _convert(y)
expected_y = _convert(expected_y)
terms = _convert(jnp.asarray(terms))
expected_terms = _convert(jnp.asarray(expected_terms))
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol,
check_dtypes=check_dtypes)
self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol,
check_dtypes=check_dtypes)
@jtu.skip_on_devices("tpu")
def test_dot(self):
M, K, N = 2, 3, 4
order = 3
rng = self.rng()
x1 = rng.randn(M, K)
x2 = rng.randn(K, N)
primals = (x1, x2)
terms_in1 = [rng.randn(*x1.shape) for _ in range(order)]
terms_in2 = [rng.randn(*x2.shape) for _ in range(order)]
series_in = (terms_in1, terms_in2)
self.check_jet(jnp.dot, primals, series_in)
@jtu.skip_on_devices("tpu")
@jax.legacy_prng_key('allow')
def test_conv(self):
order = 3
input_shape = (1, 5, 5, 1)
key = random.PRNGKey(0)
# TODO(duvenaud): Check all types of padding
init_fun, apply_fun = stax.Conv(3, (2, 2), padding='VALID')
_, (W, b) = init_fun(key, input_shape)
rng = self.rng()
x = rng.randn(*input_shape).astype(W.dtype)
primals = (W, b, x)
series_in1 = [rng.randn(*W.shape).astype(W.dtype) for _ in range(order)]
series_in2 = [rng.randn(*b.shape).astype(W.dtype) for _ in range(order)]
series_in3 = [rng.randn(*x.shape).astype(W.dtype) for _ in range(order)]
series_in = (series_in1, series_in2, series_in3)
def f(W, b, x):
return apply_fun((W, b), x)
self.check_jet(f, primals, series_in, check_dtypes=False)
def unary_check(self, fun, lims=(-2, 2), order=3, dtype=None, atol=1e-3,
rtol=1e-3):
dims = 2, 3
rng = self.rng()
if dtype is None:
primal_in = transform(lims, rng.rand(*dims))
terms_in = [rng.randn(*dims) for _ in range(order)]
else:
rng = jtu.rand_uniform(rng, *lims)
primal_in = rng(dims, dtype)
terms_in = [rng(dims, dtype) for _ in range(order)]
self.check_jet(fun, (primal_in,), (terms_in,), atol, rtol)
def binary_check(self, fun, lims=None, order=3, finite=True, dtype=None):
lims = lims or [-2, 2]
dims = 2, 3
rng = self.rng()
if isinstance(lims, tuple):
x_lims, y_lims = lims
else:
x_lims, y_lims = lims, lims
if dtype is None:
primal_in = (transform(x_lims, rng.rand(*dims)),
transform(y_lims, rng.rand(*dims)))
series_in = ([rng.randn(*dims) for _ in range(order)],
[rng.randn(*dims) for _ in range(order)])
else:
rng = jtu.rand_uniform(rng, *lims)
primal_in = (rng(dims, dtype),
rng(dims, dtype))
series_in = ([rng(dims, dtype) for _ in range(order)],
[rng(dims, dtype) for _ in range(order)])
if finite:
self.check_jet(fun, primal_in, series_in, atol=1e-4, rtol=1e-4)
else:
self.check_jet_finite(fun, primal_in, series_in, atol=1e-4, rtol=1e-4)
def unary_check_float0(self, fun, lims=(-2, 2), order=3, dtype=None):
# like unary_check but for functions that output integers (so their tangent
# type is float0 arrays)
raise unittest.SkipTest("jet tests must be adapted for integer-output functions")
def binary_check_float0(self, fun, lims=(-2, 2), order=3, finite=True, dtype=None):
# like binary_check but for functions that output integers (so their tangent
# type is float0 arrays)
raise unittest.SkipTest("jet tests must be adapted for integer-output functions")
def expit_check(self, lims=(-2, 2), order=3):
dims = 2, 3
rng = self.rng()
primal_in = transform(lims, rng.rand(*dims))
terms_in = [rng.randn(*dims) for _ in range(order)]
primals = (primal_in, )
series = (terms_in, )
y, terms = jax.experimental.jet._logistic_taylor(primals, series)
expected_y, expected_terms = jvp_taylor(jax.scipy.special.expit, primals, series)
atol = 1e-4
rtol = 1e-4
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol)
self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol)
@jtu.skip_on_devices("tpu")
def test_int_pow(self):
for p in range(6):
self.unary_check(lambda x: x ** p, lims=[-2, 2])
self.unary_check(lambda x: x ** 10, lims=[0, 0])
@jtu.skip_on_devices("tpu")
def test_is_finite(self): self.unary_check_float0(lax.is_finite)
@jtu.skip_on_devices("tpu")
def test_and(self): self.binary_check_float0(lax.bitwise_and, dtype=np.bool_)
@jtu.skip_on_devices("tpu")
def test_or(self): self.binary_check_float0(lax.bitwise_or, dtype=np.bool_)
@jtu.skip_on_devices("tpu")
def test_xor(self): self.binary_check_float0(jnp.bitwise_xor, dtype=np.bool_)
@jtu.skip_on_devices("tpu")
def test_shift_left(self): self.binary_check_float0(lax.shift_left, dtype=np.int32)
@jtu.skip_on_devices("tpu")
def test_shift_right_a(self): self.binary_check_float0(lax.shift_right_arithmetic, dtype=np.int32)
@jtu.skip_on_devices("tpu")
def test_shift_right_l(self): self.binary_check_float0(lax.shift_right_logical, dtype=np.int32)
@jtu.skip_on_devices("tpu")
def test_le(self): self.binary_check_float0(lambda x, y: x <= y)
@jtu.skip_on_devices("tpu")
def test_gt(self): self.binary_check_float0(lambda x, y: x > y)
@jtu.skip_on_devices("tpu")
def test_lt(self): self.binary_check_float0(lambda x, y: x < y)
@jtu.skip_on_devices("tpu")
def test_ge(self): self.binary_check_float0(lambda x, y: x >= y)
@jtu.skip_on_devices("tpu")
def test_eq(self): self.binary_check_float0(lambda x, y: x == y)
@jtu.skip_on_devices("tpu")
def test_ne(self): self.binary_check_float0(lambda x, y: x != y)
@jtu.skip_on_devices("tpu")
def test_not(self): self.unary_check_float0(lax.bitwise_not, dtype=np.bool_)
@jtu.skip_on_devices("tpu")
def test_exp(self): self.unary_check(jnp.exp)
@jtu.skip_on_devices("tpu")
def test_neg(self): self.unary_check(jnp.negative)
@jtu.skip_on_devices("tpu")
def test_floor(self): self.unary_check(jnp.floor)
@jtu.skip_on_devices("tpu")
def test_ceil(self): self.unary_check(jnp.ceil)
@jtu.skip_on_devices("tpu")
def test_round(self): self.unary_check(lax.round)
@jtu.skip_on_devices("tpu")
def test_sign(self): self.unary_check(lax.sign)
@jtu.skip_on_devices("tpu")
def test_real(self): self.unary_check(lax.real, dtype=np.complex64)
@jtu.skip_on_devices("tpu")
def test_conj(self): self.unary_check(lax.conj, dtype=np.complex64)
@jtu.skip_on_devices("tpu")
def test_imag(self): self.unary_check(lax.imag, dtype=np.complex64)
@jtu.skip_on_devices("tpu")
def test_log(self): self.unary_check(jnp.log, lims=[0.8, 4.0])
@jtu.skip_on_devices("tpu")
def test_gather(self): self.unary_check(lambda x: x[1:])
@jtu.skip_on_devices("tpu")
def test_reduce_max(self): self.unary_check(lambda x: x.max(axis=1))
@jtu.skip_on_devices("tpu")
def test_reduce_min(self): self.unary_check(lambda x: x.min(axis=1))
@jtu.skip_on_devices("tpu")
def test_all_max(self): self.unary_check(jnp.max)
@jtu.skip_on_devices("tpu")
def test_all_min(self): self.unary_check(jnp.min)
@jtu.skip_on_devices("tpu")
def test_stopgrad(self): self.unary_check(lax.stop_gradient)
@jtu.skip_on_devices("tpu")
def test_abs(self): self.unary_check(jnp.abs)
@jtu.skip_on_devices("tpu")
def test_fft(self): self.unary_check(jnp.fft.fft)
@jtu.skip_on_devices("tpu")
def test_log1p(self): self.unary_check(jnp.log1p, lims=[0, 4.])
@jtu.skip_on_devices("tpu")
def test_expm1(self): self.unary_check(jnp.expm1)
@jtu.skip_on_devices("tpu")
def test_sin(self): self.unary_check(jnp.sin)
@jtu.skip_on_devices("tpu")
def test_cos(self): self.unary_check(jnp.cos)
@jtu.skip_on_devices("tpu")
def test_sinh(self): self.unary_check(jnp.sinh)
@jtu.skip_on_devices("tpu")
def test_cosh(self): self.unary_check(jnp.cosh)
@jtu.skip_on_devices("tpu")
def test_tanh(self): self.unary_check(jnp.tanh, lims=[-500, 500], order=5,
atol=5e-3)
@jtu.skip_on_devices("tpu")
def test_logistic(self): self.unary_check(lax.logistic, lims=[-100, 100], order=5)
@jtu.skip_on_devices("tpu")
def test_expit2(self): self.expit_check(lims=[-500, 500], order=5)
@jtu.skip_on_devices("tpu")
def test_sqrt(self): self.unary_check(jnp.sqrt, lims=[0, 5.])
@jtu.skip_on_devices("tpu")
def test_rsqrt(self): self.unary_check(lax.rsqrt, lims=[0, 5000.])
@jtu.skip_on_devices("tpu")
def test_asinh(self): self.unary_check(lax.asinh, lims=[-100, 100])
@jtu.skip_on_devices("tpu")
def test_acosh(self): self.unary_check(lax.acosh, lims=[-100, 100])
@jtu.skip_on_devices("tpu")
def test_atanh(self): self.unary_check(lax.atanh, lims=[-1, 1])
@jtu.skip_on_devices("tpu")
def test_erf(self): self.unary_check(lax.erf)
@jtu.skip_on_devices("tpu")
def test_erfc(self): self.unary_check(lax.erfc)
@jtu.skip_on_devices("tpu")
def test_erf_inv(self): self.unary_check(lax.erf_inv, lims=[-1, 1])
@jtu.skip_on_devices("tpu")
def test_cumsum(self): self.unary_check(jnp.cumsum)
@jtu.skip_on_devices("tpu")
def test_cumprod(self): self.unary_check(jnp.cumprod)
@jtu.skip_on_devices("tpu")
def test_cummax(self): self.unary_check(partial(lax.cummax, axis=0))
@jtu.skip_on_devices("tpu")
def test_cummin(self): self.unary_check(partial(lax.cummin, axis=0))
@jtu.skip_on_devices("tpu")
def test_dynamic_slice(self): self.unary_check(partial(lax.dynamic_slice, start_indices=(1,2), slice_sizes=(1,1)))
@jtu.skip_on_devices("tpu")
def test_dynamic_update_slice(self): self.unary_check(partial(lax.dynamic_update_slice, start_indices=(1,2), update=np.arange(6.0).reshape(2, 3)))
@jtu.skip_on_devices("tpu")
def test_div(self): self.binary_check(lambda x, y: x / y, lims=[0.8, 4.0])
@jtu.skip_on_devices("tpu")
def test_rem(self): self.binary_check(lax.rem, lims=[0.8, 4.0])
@jtu.skip_on_devices("tpu")
def test_complex(self): self.binary_check(lax.complex)
@jtu.skip_on_devices("tpu")
def test_sub(self): self.binary_check(lambda x, y: x - y)
@jtu.skip_on_devices("tpu")
def test_add(self): self.binary_check(lambda x, y: x + y)
@jtu.skip_on_devices("tpu")
def test_mul(self): self.binary_check(lambda x, y: x * y)
@jtu.skip_on_devices("tpu")
def test_max(self): self.binary_check(lax.max)
@jtu.skip_on_devices("tpu")
def test_min(self): self.binary_check(lax.min)
@jtu.skip_on_devices("tpu")
@jtu.ignore_warning(message="overflow encountered in power")
def test_pow(self): self.binary_check(lambda x, y: x ** y, lims=([0.2, 500], [-500, 500]), finite=False)
@jtu.skip_on_devices("tpu")
def test_atan2(self): self.binary_check(lax.atan2, lims=[-40, 40])
@jtu.skip_on_devices("tpu")
def test_clamp(self):
lims = [-1, 1]
order = 3
dims = 2, 3
# TODO(jakevdp): This test is very sensitive to the inputs, so we use a known
# working seed. We should instead use self.rng(), and make sure that the primal
# points lie outside an epsilon ball of the two critical points in the function.
rng = np.random.RandomState(0)
primal_in = (transform(lims, rng.rand(*dims)),
transform(lims, rng.rand(*dims)),
transform(lims, rng.rand(*dims)))
series_in = ([rng.randn(*dims) for _ in range(order)],
[rng.randn(*dims) for _ in range(order)],
[rng.randn(*dims) for _ in range(order)])
self.check_jet(lax.clamp, primal_in, series_in, atol=1e-4, rtol=1e-4)
def test_process_call(self):
def f(x):
return jit(lambda x: x * x)(x)
self.unary_check(f, rtol=2e-4)
def test_post_process_call(self):
def f(x):
return jit(lambda y: x * y)(2.)
self.unary_check(f, rtol=5e-4)
def test_select(self):
M, K = 2, 3
order = 3
rng = self.rng()
b = rng.rand(M, K) < 0.5
x = rng.randn(M, K)
y = rng.randn(M, K)
primals = (b, x, y)
terms_b = [rng.randn(*b.shape) for _ in range(order)]
terms_x = [rng.randn(*x.shape) for _ in range(order)]
terms_y = [rng.randn(*y.shape) for _ in range(order)]
series_in = (terms_b, terms_x, terms_y)
# Since this nudges bool inputs, we need to allow promotion to float.
with jax.numpy_dtype_promotion('standard'):
self.check_jet(jnp.where, primals, series_in, rtol=5e-4)
def test_inst_zero(self):
def f(x):
return jnp.full_like(x, 2.)
def g(x):
return 2. + 0 * x
x = jnp.ones(1)
order = 3
f_out_primals, f_out_series = jet(f, (x, ), ([jnp.ones_like(x) for _ in range(order)], ))
assert f_out_series is not zero_series
g_out_primals, g_out_series = jet(g, (x, ), ([jnp.ones_like(x) for _ in range(order)], ))
self.assertArraysEqual(g_out_primals, f_out_primals)
self.assertArraysEqual(g_out_series, f_out_series)
def test_add_any(self):
# https://github.com/google/jax/issues/5217
f = lambda x, eps: x * eps + eps + x
def g(eps):
x = jnp.array(1.)
return jax.grad(f)(x, eps)
jet(g, (1.,), ([1.],)) # doesn't crash
def test_scatter_add(self):
# very basic test from https://github.com/google/jax/issues/5365
def f(x):
x0 = x[0]
x1 = x[1]
return (x0**5 + x1**5).sum()
def h(eps):
from jax import jacfwd, grad
x = jnp.array([1., 1.])
μ = eps * x
def F(t):
return f(x + t * μ)
return grad(jacfwd(F))(0.)
self.check_jet(h, (0.,), ([1., 2., 3.],), rtol=1e-3)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())