mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 22:36:06 +00:00
363 lines
11 KiB
Python
363 lines
11 KiB
Python
# Copyright 2018 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import absolute_import
|
|
|
|
import functools
|
|
|
|
from absl import flags
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
|
|
import numpy as onp
|
|
import numpy.random as npr
|
|
|
|
from . import api
|
|
from .util import partial
|
|
from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce
|
|
|
|
FLAGS = flags.FLAGS
|
|
flags.DEFINE_enum(
|
|
'jax_test_dut',
|
|
None,
|
|
enum_values=['cpu', 'gpu', 'tpu'],
|
|
help=
|
|
'Describes the device under test in case special consideration is required.'
|
|
)
|
|
|
|
|
|
EPS = 1e-4
|
|
ATOL = 1e-4
|
|
RTOL = 1e-4
|
|
|
|
_dtype = lambda x: getattr(x, 'dtype', None) or onp.asarray(x).dtype
|
|
|
|
|
|
def numpy_eq(x, y):
|
|
testing_tpu = FLAGS.jax_test_dut and FLAGS.jax_test_dut.startswith("tpu")
|
|
testing_x32 = not FLAGS.jax_enable_x64
|
|
if testing_tpu or testing_x32:
|
|
return onp.allclose(x, y, 1e-3, 1e-3)
|
|
else:
|
|
return onp.allclose(x, y)
|
|
|
|
|
|
def numpy_close(a, b, atol=ATOL, rtol=RTOL, equal_nan=False):
|
|
testing_tpu = FLAGS.jax_test_dut and FLAGS.jax_test_dut.startswith("tpu")
|
|
testing_x32 = not FLAGS.jax_enable_x64
|
|
if testing_tpu or testing_x32:
|
|
atol = max(atol, 1e-1)
|
|
rtol = max(rtol, 1e-1)
|
|
return onp.allclose(a, b, atol=atol, rtol=rtol, equal_nan=equal_nan)
|
|
|
|
|
|
def check_eq(xs, ys):
|
|
assert tree_all(tree_multimap(numpy_eq, xs, ys)), \
|
|
'\n{} != \n{}'.format(xs, ys)
|
|
|
|
|
|
def check_close(xs, ys, atol=ATOL, rtol=RTOL):
|
|
close = partial(numpy_close, atol=atol, rtol=rtol)
|
|
assert tree_all(tree_multimap(close, xs, ys)), '\n{} != \n{}'.format(xs, ys)
|
|
|
|
|
|
def inner_prod(xs, ys):
|
|
contract = lambda x, y: onp.real(onp.vdot(x, y))
|
|
return tree_reduce(onp.add, tree_multimap(contract, xs, ys))
|
|
|
|
|
|
add = partial(tree_multimap, onp.add)
|
|
sub = partial(tree_multimap, onp.subtract)
|
|
conj = partial(tree_map, onp.conj)
|
|
|
|
|
|
def scalar_mul(xs, a):
|
|
return tree_map(lambda x: onp.multiply(x, a, dtype=_dtype(x)), xs)
|
|
|
|
|
|
def rand_like(rng, x):
|
|
shape = onp.shape(x)
|
|
dtype = _dtype(x)
|
|
randn = lambda: onp.asarray(rng.randn(*shape), dtype=dtype)
|
|
if onp.issubdtype(dtype, onp.complexfloating):
|
|
return randn() + 1.0j * randn()
|
|
else:
|
|
return randn()
|
|
|
|
|
|
def numerical_jvp(f, primals, tangents, eps=EPS):
|
|
delta = scalar_mul(tangents, EPS)
|
|
f_pos = f(*add(primals, delta))
|
|
f_neg = f(*sub(primals, delta))
|
|
return scalar_mul(sub(f_pos, f_neg), 0.5 / EPS)
|
|
|
|
|
|
def check_jvp(f, f_jvp, args, atol=ATOL, rtol=RTOL, eps=EPS):
|
|
rng = onp.random.RandomState(0)
|
|
tangent = tree_map(partial(rand_like, rng), args)
|
|
v_out, t_out = f_jvp(args, tangent)
|
|
v_out_expected = f(*args)
|
|
t_out_expected = numerical_jvp(f, args, tangent, eps=eps)
|
|
check_eq(v_out, v_out_expected)
|
|
check_close(t_out, t_out_expected, atol=atol, rtol=rtol)
|
|
|
|
|
|
def check_vjp(f, f_vjp, args, atol=ATOL, rtol=RTOL, eps=EPS):
|
|
_rand_like = partial(rand_like, onp.random.RandomState(0))
|
|
v_out, vjpfun = f_vjp(*args)
|
|
v_out_expected = f(*args)
|
|
check_eq(v_out, v_out_expected)
|
|
tangent = tree_map(_rand_like, args)
|
|
tangent_out = numerical_jvp(f, args, tangent, eps=EPS)
|
|
cotangent = tree_map(_rand_like, v_out)
|
|
cotangent_out = conj(vjpfun(conj(cotangent)))
|
|
ip = inner_prod(tangent, cotangent_out)
|
|
ip_expected = inner_prod(tangent_out, cotangent)
|
|
check_close(ip, ip_expected, atol=atol, rtol=rtol)
|
|
|
|
|
|
def skip_on_devices(*disabled_devices):
|
|
"""A decorator for test methods to skip the test on certain devices."""
|
|
def skip(test_method):
|
|
@functools.wraps(test_method)
|
|
def test_method_wrapper(self, *args, **kwargs):
|
|
device = FLAGS.jax_test_dut
|
|
if device in disabled_devices:
|
|
test_name = getattr(test_method, '__name__', '[unknown test]')
|
|
return absltest.unittest.skip(
|
|
'{} not supported on {}.'.format(test_name, device.upper()))
|
|
return test_method(self, *args, **kwargs)
|
|
return test_method_wrapper
|
|
return skip
|
|
|
|
|
|
def format_test_name_suffix(opname, shapes, dtypes):
|
|
arg_descriptions = (format_shape_dtype_string(shape, dtype)
|
|
for shape, dtype in zip(shapes, dtypes))
|
|
return '{}_{}'.format(opname.capitalize(), '_'.join(arg_descriptions))
|
|
|
|
|
|
def format_shape_dtype_string(shape, dtype):
|
|
if onp.isscalar(shape):
|
|
shapestr = str(shape) + ','
|
|
else:
|
|
shapestr = ','.join(str(dim) for dim in shape)
|
|
return '{}[{}]'.format(onp.dtype(dtype).name, shapestr)
|
|
|
|
|
|
def _rand_dtype(rand, shape, dtype, scale=1., post=lambda x: x):
|
|
"""Produce random values given shape, dtype, scale, and post-processor.
|
|
|
|
Args:
|
|
rand: a function for producing random values of a given shape, e.g. a
|
|
bound version of either onp.RandomState.randn or onp.RandomState.rand.
|
|
shape: a shape value as a tuple of positive integers.
|
|
dtype: a numpy dtype.
|
|
scale: optional, a multiplicative scale for the random values (default 1).
|
|
post: optional, a callable for post-processing the random values (default
|
|
identity).
|
|
|
|
Returns:
|
|
An ndarray of the given shape and dtype using random values based on a call
|
|
to rand but scaled, converted to the appropriate dtype, and post-processed.
|
|
"""
|
|
r = lambda: onp.asarray(scale * rand(*shape), dtype)
|
|
if onp.issubdtype(dtype, onp.complexfloating):
|
|
vals = r() + 1.0j * r()
|
|
else:
|
|
vals = r()
|
|
return onp.asarray(post(vals), dtype)
|
|
|
|
|
|
def rand_default():
|
|
randn = npr.RandomState(0).randn
|
|
return partial(_rand_dtype, randn, scale=3)
|
|
|
|
|
|
def rand_nonzero():
|
|
post = lambda x: onp.where(x == 0, 1, x)
|
|
randn = npr.RandomState(0).randn
|
|
return partial(_rand_dtype, randn, scale=3, post=post)
|
|
|
|
|
|
def rand_positive():
|
|
post = lambda x: x + 1
|
|
rand = npr.RandomState(0).rand
|
|
return partial(_rand_dtype, rand, scale=2, post=post)
|
|
|
|
|
|
def rand_small():
|
|
randn = npr.RandomState(0).randn
|
|
return partial(_rand_dtype, randn, scale=1e-3)
|
|
|
|
|
|
def rand_not_small():
|
|
post = lambda x: x + onp.where(x > 0, 10., -10.)
|
|
randn = npr.RandomState(0).randn
|
|
return partial(_rand_dtype, randn, scale=3., post=post)
|
|
|
|
|
|
def rand_small_positive():
|
|
rand = npr.RandomState(0).rand
|
|
return partial(_rand_dtype, rand, scale=2e-5)
|
|
|
|
|
|
def rand_some_equal():
|
|
randn = npr.RandomState(0).randn
|
|
rng = npr.RandomState(0)
|
|
|
|
def post(x):
|
|
flips = rng.rand(*onp.shape(x)) < 0.5
|
|
return onp.where(flips, x.ravel()[0], x)
|
|
|
|
return partial(_rand_dtype, randn, scale=100., post=post)
|
|
|
|
|
|
# TODO(mattjj): doesn't handle complex types
|
|
def rand_some_inf():
|
|
"""Return a random sampler that produces infinities in floating types."""
|
|
rng = npr.RandomState(1)
|
|
base_rand = rand_default()
|
|
|
|
def rand(shape, dtype):
|
|
"""The random sampler function."""
|
|
if not onp.issubdtype(dtype, onp.float):
|
|
# only float types have inf
|
|
return base_rand(shape, dtype)
|
|
|
|
posinf_flips = rng.rand(*shape) < 0.1
|
|
neginf_flips = rng.rand(*shape) < 0.1
|
|
|
|
vals = base_rand(shape, dtype)
|
|
vals = onp.where(posinf_flips, onp.inf, vals)
|
|
vals = onp.where(neginf_flips, -onp.inf, vals)
|
|
|
|
return onp.asarray(vals, dtype=dtype)
|
|
|
|
return rand
|
|
|
|
|
|
# TODO(mattjj): doesn't handle complex types
|
|
def rand_some_zero():
|
|
"""Return a random sampler that produces some zeros."""
|
|
rng = npr.RandomState(1)
|
|
base_rand = rand_default()
|
|
|
|
def rand(shape, dtype):
|
|
"""The random sampler function."""
|
|
zeros = rng.rand(*shape) < 0.5
|
|
|
|
vals = base_rand(shape, dtype)
|
|
vals = onp.where(zeros, 0, vals)
|
|
|
|
return onp.asarray(vals, dtype=dtype)
|
|
|
|
return rand
|
|
|
|
|
|
def rand_bool():
|
|
rng = npr.RandomState(0)
|
|
return lambda shape, dtype: rng.rand(*shape) < 0.5
|
|
|
|
def check_raises(thunk, err_type, msg):
|
|
try:
|
|
thunk()
|
|
assert False
|
|
except err_type as e:
|
|
assert str(e) == msg, "{}\n\n{}\n".format(e, msg)
|
|
|
|
|
|
class JaxTestCase(parameterized.TestCase):
|
|
"""Base class for JAX tests including numerical checks and boilerplate."""
|
|
|
|
def assertArraysAllClose(self, x, y, check_dtypes, atol=None, rtol=None):
|
|
"""Assert that x and y are close (up to numerical tolerances)."""
|
|
dtype = lambda x: str(onp.asarray(x).dtype)
|
|
tol = 1e-2 if str(onp.dtype(onp.float32)) in {dtype(x), dtype(y)} else 1e-5
|
|
atol = atol or tol
|
|
rtol = rtol or tol
|
|
|
|
if FLAGS.jax_test_dut == 'tpu':
|
|
atol = max(atol, 0.5)
|
|
rtol = max(rtol, 1e-1)
|
|
|
|
if not onp.allclose(x, y, atol=atol, rtol=rtol, equal_nan=True):
|
|
msg = ('Arguments x and y not equal to tolerance atol={}, rtol={}:\n'
|
|
'x:\n{}\n'
|
|
'y:\n{}\n').format(atol, rtol, x, y)
|
|
raise self.failureException(msg)
|
|
|
|
if check_dtypes:
|
|
self.assertDtypesMatch(x, y)
|
|
|
|
def assertDtypesMatch(self, x, y):
|
|
if FLAGS.jax_enable_x64:
|
|
self.assertEqual(onp.asarray(x).dtype, onp.asarray(y).dtype)
|
|
|
|
def assertAllClose(self, x, y, check_dtypes, atol=None, rtol=None):
|
|
"""Assert that x and y, either arrays or nested tuples/lists, are close."""
|
|
if isinstance(x, (tuple, list)):
|
|
self.assertIsInstance(y, (tuple, list))
|
|
self.assertEqual(len(x), len(y))
|
|
for x_elt, y_elt in zip(x, y):
|
|
self.assertAllClose(x_elt, y_elt, check_dtypes, atol=atol, rtol=rtol)
|
|
else:
|
|
is_array = lambda x: hasattr(x, '__array__') or onp.isscalar(x)
|
|
self.assertTrue(is_array(x))
|
|
self.assertTrue(is_array(y))
|
|
x = onp.asarray(x)
|
|
y = onp.asarray(y)
|
|
self.assertArraysAllClose(x, y, check_dtypes, atol=atol, rtol=rtol)
|
|
|
|
def _CompileAndCheck(self, fun, args_maker, check_dtypes,
|
|
rtol=None, atol=None):
|
|
"""Helper method for running JAX compilation and allclose assertions."""
|
|
args = args_maker()
|
|
|
|
def wrapped_fun(*args):
|
|
self.assertTrue(python_should_be_executing)
|
|
return fun(*args)
|
|
|
|
python_should_be_executing = True
|
|
python_ans = fun(*args)
|
|
|
|
cfun = api.jit(wrapped_fun)
|
|
python_should_be_executing = True
|
|
monitored_ans = cfun(*args)
|
|
|
|
python_should_be_executing = False
|
|
compiled_ans = cfun(*args)
|
|
|
|
self.assertAllClose(python_ans, monitored_ans, check_dtypes, rtol, atol)
|
|
self.assertAllClose(python_ans, compiled_ans, check_dtypes, rtol, atol)
|
|
|
|
args = args_maker()
|
|
|
|
python_should_be_executing = True
|
|
python_ans = fun(*args)
|
|
|
|
python_should_be_executing = False
|
|
compiled_ans = cfun(*args)
|
|
|
|
self.assertAllClose(python_ans, compiled_ans, check_dtypes, rtol, atol)
|
|
|
|
def _CheckAgainstNumpy(self, lax_op, numpy_reference_op, args_maker,
|
|
check_dtypes=False, tol=1e-5):
|
|
args = args_maker()
|
|
lax_ans = lax_op(*args)
|
|
numpy_ans = numpy_reference_op(*args)
|
|
self.assertAllClose(lax_ans, numpy_ans, check_dtypes=check_dtypes,
|
|
atol=tol, rtol=tol)
|
|
|