rocm_jax/jax/test_util.py

1085 lines
35 KiB
Python
Raw Normal View History

2018-11-17 18:03:33 -08:00
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
2018-11-17 18:03:33 -08:00
import functools
import re
import os
import textwrap
from typing import Dict, Sequence, Union
import unittest
import warnings
import zlib
2018-11-17 18:03:33 -08:00
from absl.testing import absltest
2018-11-17 18:03:33 -08:00
from absl.testing import parameterized
import numpy as np
2018-11-17 18:03:33 -08:00
import numpy.random as npr
from ._src import api
from . import core
from ._src import dtypes as _dtypes
from . import lax
from ._src.config import flags, bool_env, config
from ._src.util import partial, prod
2018-11-17 18:03:33 -08:00
from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce
from .lib import xla_bridge
from .interpreters import xla
2018-11-17 18:03:33 -08:00
2018-11-17 18:03:33 -08:00
FLAGS = flags.FLAGS
flags.DEFINE_enum(
'jax_test_dut', '',
enum_values=['', 'cpu', 'gpu', 'tpu'],
2018-11-17 18:03:33 -08:00
help=
'Describes the device under test in case special consideration is required.'
)
flags.DEFINE_integer(
'num_generated_cases',
int(os.getenv('JAX_NUM_GENERATED_CASES', 10)),
help='Number of generated cases to test')
2018-11-17 18:03:33 -08:00
flags.DEFINE_integer(
'max_cases_sampling_retries',
int(os.getenv('JAX_MAX_CASES_SAMPLING_RETRIES', 100)),
'Number of times a failed test sample should be retried. '
'When an unseen case cannot be generated in this many trials, the '
'sampling process is terminated.'
)
flags.DEFINE_bool(
'jax_skip_slow_tests',
bool_env('JAX_SKIP_SLOW_TESTS', False),
help='Skip tests marked as slow (> 5 sec).'
)
flags.DEFINE_string(
'test_targets', '',
'Regular expression specifying which tests to run, called via re.match on '
'the test name. If empty or unspecified, run all tests.'
)
flags.DEFINE_string(
'exclude_test_targets', '',
'Regular expression specifying which tests NOT to run, called via re.match '
'on the test name. If empty or unspecified, run all tests.'
)
2018-11-17 18:03:33 -08:00
EPS = 1e-4
def _dtype(x):
return (getattr(x, 'dtype', None) or
np.dtype(_dtypes.python_scalar_dtypes.get(type(x), None)) or
np.asarray(x).dtype)
2018-11-17 18:03:33 -08:00
def num_float_bits(dtype):
return _dtypes.finfo(_dtypes.canonicalize_dtype(dtype)).bits
def is_sequence(x):
try:
iter(x)
except TypeError:
return False
else:
return True
_default_tolerance = {
np.dtype(np.bool_): 0,
np.dtype(np.int8): 0,
np.dtype(np.int16): 0,
np.dtype(np.int32): 0,
np.dtype(np.int64): 0,
np.dtype(np.uint8): 0,
np.dtype(np.uint16): 0,
np.dtype(np.uint32): 0,
np.dtype(np.uint64): 0,
np.dtype(_dtypes.bfloat16): 1e-2,
np.dtype(np.float16): 1e-3,
np.dtype(np.float32): 1e-6,
np.dtype(np.float64): 1e-15,
np.dtype(np.complex64): 1e-6,
np.dtype(np.complex128): 1e-15,
}
def default_tolerance():
if device_under_test() != "tpu":
return _default_tolerance
tol = _default_tolerance.copy()
tol[np.dtype(np.float32)] = 1e-3
tol[np.dtype(np.complex64)] = 1e-3
return tol
default_gradient_tolerance = {
np.dtype(_dtypes.bfloat16): 1e-1,
np.dtype(np.float16): 1e-2,
np.dtype(np.float32): 2e-3,
np.dtype(np.float64): 1e-5,
np.dtype(np.complex64): 1e-3,
np.dtype(np.complex128): 1e-5,
}
def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
a = a.astype(np.float32) if a.dtype == _dtypes.bfloat16 else a
b = b.astype(np.float32) if b.dtype == _dtypes.bfloat16 else b
kw = {}
if atol: kw["atol"] = atol
if rtol: kw["rtol"] = rtol
with np.errstate(invalid='ignore'):
# TODO(phawkins): surprisingly, assert_allclose sometimes reports invalid
# value errors. It should not do that.
np.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
def tolerance(dtype, tol=None):
tol = {} if tol is None else tol
if not isinstance(tol, dict):
return tol
tol = {np.dtype(key): value for key, value in tol.items()}
dtype = _dtypes.canonicalize_dtype(np.dtype(dtype))
return tol.get(dtype, default_tolerance()[dtype])
def _normalize_tolerance(tol):
tol = tol or 0
if isinstance(tol, dict):
return {np.dtype(k): v for k, v in tol.items()}
else:
return {k: tol for k in _default_tolerance}
def join_tolerance(tol1, tol2):
tol1 = _normalize_tolerance(tol1)
tol2 = _normalize_tolerance(tol2)
out = tol1
for k, v in tol2.items():
out[k] = max(v, tol1.get(k, 0))
return out
def _assert_numpy_close(a, b, atol=None, rtol=None, err_msg=''):
assert a.shape == b.shape
atol = max(tolerance(a.dtype, atol), tolerance(b.dtype, atol))
rtol = max(tolerance(a.dtype, rtol), tolerance(b.dtype, rtol))
_assert_numpy_allclose(a, b, atol=atol * a.size, rtol=rtol * b.size,
err_msg=err_msg)
2018-11-17 18:03:33 -08:00
def check_eq(xs, ys, err_msg=''):
assert_close = partial(_assert_numpy_allclose, err_msg=err_msg)
tree_all(tree_multimap(assert_close, xs, ys))
2018-11-17 18:03:33 -08:00
def check_close(xs, ys, atol=None, rtol=None, err_msg=''):
assert_close = partial(_assert_numpy_close, atol=atol, rtol=rtol,
err_msg=err_msg)
tree_all(tree_multimap(assert_close, xs, ys))
2018-11-17 18:03:33 -08:00
def _check_dtypes_match(xs, ys):
def _assert_dtypes_match(x, y):
if config.x64_enabled:
assert _dtype(x) == _dtype(y)
else:
assert (_dtypes.canonicalize_dtype(_dtype(x)) ==
_dtypes.canonicalize_dtype(_dtype(y)))
tree_all(tree_multimap(_assert_dtypes_match, xs, ys))
2018-11-17 18:03:33 -08:00
def inner_prod(xs, ys):
def contract(x, y):
return np.real(np.dot(np.conj(x).reshape(-1), y.reshape(-1)))
return tree_reduce(np.add, tree_multimap(contract, xs, ys))
2018-11-17 18:03:33 -08:00
def _safe_subtract(x, y, *, dtype):
"""Subtraction that with `inf - inf == 0` semantics."""
with np.errstate(invalid='ignore'):
return np.where(np.equal(x, y), np.array(0, dtype),
np.subtract(x, y, dtype=dtype))
add = partial(tree_multimap, lambda x, y: np.add(x, y, dtype=_dtype(x)))
sub = partial(tree_multimap, lambda x, y: np.subtract(x, y, dtype=_dtype(x)))
safe_sub = partial(tree_multimap,
lambda x, y: _safe_subtract(x, y, dtype=_dtype(x)))
conj = partial(tree_map, lambda x: np.conj(x, dtype=_dtype(x)))
2018-11-17 18:03:33 -08:00
def scalar_mul(xs, a):
return tree_map(lambda x: np.multiply(x, a, dtype=_dtype(x)), xs)
2018-11-17 18:03:33 -08:00
def rand_like(rng, x):
shape = np.shape(x)
2018-11-17 18:03:33 -08:00
dtype = _dtype(x)
randn = lambda: np.asarray(rng.randn(*shape), dtype=dtype)
if _dtypes.issubdtype(dtype, np.complexfloating):
return randn() + dtype.type(1.0j) * randn()
2018-11-17 18:03:33 -08:00
else:
return randn()
def numerical_jvp(f, primals, tangents, eps=EPS):
delta = scalar_mul(tangents, eps)
2018-11-17 18:03:33 -08:00
f_pos = f(*add(primals, delta))
f_neg = f(*sub(primals, delta))
return scalar_mul(safe_sub(f_pos, f_neg), 0.5 / eps)
2018-11-17 18:03:33 -08:00
def _merge_tolerance(tol, default):
if tol is None:
return default
if not isinstance(tol, dict):
return tol
out = default.copy()
for k, v in tol.items():
out[np.dtype(k)] = v
return out
def check_jvp(f, f_jvp, args, atol=None, rtol=None, eps=EPS, err_msg=''):
atol = _merge_tolerance(atol, default_gradient_tolerance)
rtol = _merge_tolerance(rtol, default_gradient_tolerance)
rng = np.random.RandomState(0)
2018-11-17 18:03:33 -08:00
tangent = tree_map(partial(rand_like, rng), args)
v_out, t_out = f_jvp(args, tangent)
_check_dtypes_match(v_out, t_out)
2018-11-17 18:03:33 -08:00
v_out_expected = f(*args)
_check_dtypes_match(v_out, v_out_expected)
2018-11-17 18:03:33 -08:00
t_out_expected = numerical_jvp(f, args, tangent, eps=eps)
# In principle we should expect exact equality of v_out and v_out_expected,
# but due to nondeterminism especially on GPU (e.g., due to convolution
# autotuning) we only require "close".
check_close(v_out, v_out_expected, atol=atol, rtol=rtol,
err_msg=f'{err_msg} primal' if err_msg else 'primal')
check_close(t_out, t_out_expected, atol=atol, rtol=rtol,
err_msg=f'{err_msg} tangent' if err_msg else 'tangent')
2018-11-17 18:03:33 -08:00
def check_vjp(f, f_vjp, args, atol=None, rtol=None, eps=EPS, err_msg=''):
atol = _merge_tolerance(atol, default_gradient_tolerance)
rtol = _merge_tolerance(rtol, default_gradient_tolerance)
_rand_like = partial(rand_like, np.random.RandomState(0))
2018-11-17 18:03:33 -08:00
v_out, vjpfun = f_vjp(*args)
v_out_expected = f(*args)
check_close(v_out, v_out_expected, atol=atol, rtol=rtol,
err_msg=f'{err_msg} primal' if err_msg else 'primal')
2018-11-17 18:03:33 -08:00
tangent = tree_map(_rand_like, args)
tangent_out = numerical_jvp(f, args, tangent, eps=eps)
2018-11-17 18:03:33 -08:00
cotangent = tree_map(_rand_like, v_out)
cotangent_out = conj(vjpfun(conj(cotangent)))
ip = inner_prod(tangent, cotangent_out)
ip_expected = inner_prod(tangent_out, cotangent)
check_close(ip, ip_expected, atol=atol, rtol=rtol,
err_msg=(f'{err_msg} cotangent projection'
if err_msg else 'cotangent projection'))
2018-11-17 18:03:33 -08:00
def check_grads(f, args, order,
modes=["fwd", "rev"], atol=None, rtol=None, eps=None):
"""Check gradients from automatic differentiation against finite differences.
Gradients are only checked in a single randomly chosen direction, which
ensures that the finite difference calculation does not become prohibitively
expensive even for large input/output spaces.
Args:
f: function to check at ``f(*args)``.
args: tuple of argument values.
order: forward and backwards gradients up to this order are checked.
modes: lists of gradient modes to check ('fwd' and/or 'rev').
atol: absolute tolerance for gradient equality.
rtol: relative tolerance for gradient equality.
eps: step size used for finite differences.
Raises:
AssertionError: if gradients do not match.
"""
args = tuple(args)
eps = eps or EPS
_check_jvp = partial(check_jvp, atol=atol, rtol=rtol, eps=eps)
_check_vjp = partial(check_vjp, atol=atol, rtol=rtol, eps=eps)
def _check_grads(f, args, order, err_msg=''):
if "fwd" in modes:
fwd_msg = f'JVP of {err_msg}' if err_msg else 'JVP'
_check_jvp(f, partial(api.jvp, f), args, err_msg=fwd_msg)
if order > 1:
_check_grads(partial(api.jvp, f), (args, args), order - 1, fwd_msg)
if "rev" in modes:
rev_msg = f'VJP of {err_msg}' if err_msg else 'VJP'
_check_vjp(f, partial(api.vjp, f), args, err_msg=rev_msg)
if order > 1:
def f_vjp(*args):
out_primal_py, vjp_py = api.vjp(f, *args)
return vjp_py(out_primal_py)
_check_grads(f_vjp, args, order - 1, rev_msg)
_check_grads(f, args, order)
@contextmanager
def count_device_put():
device_put = xla.device_put
count = [0]
def device_put_and_count(*args, **kwargs):
count[0] += 1
return device_put(*args, **kwargs)
xla.device_put = device_put_and_count
try:
yield count
finally:
xla.device_put = device_put
@contextmanager
def count_primitive_compiles():
xla.xla_primitive_callable.cache_clear()
# We count how many times we call primitive_computation (which is called
# inside xla_primitive_callable) instead of xla_primitive_callable so we don't
# count cache hits.
primitive_computation = xla.primitive_computation
count = [0]
def primitive_computation_and_count(*args, **kwargs):
count[0] += 1
return primitive_computation(*args, **kwargs)
xla.primitive_computation = primitive_computation_and_count
try:
yield count
finally:
xla.primitive_computation = primitive_computation
@contextmanager
def count_jit_and_pmap_compiles():
# No need to clear any caches since we generally jit and pmap fresh callables
# in tests.
jaxpr_subcomp = xla.jaxpr_subcomp
count = [0]
def jaxpr_subcomp_and_count(*args, **kwargs):
count[0] += 1
return jaxpr_subcomp(*args, **kwargs)
xla.jaxpr_subcomp = jaxpr_subcomp_and_count
try:
yield count
finally:
xla.jaxpr_subcomp = jaxpr_subcomp
@contextmanager
def assert_num_jit_and_pmap_compilations(times):
with count_jit_and_pmap_compiles() as count:
yield
if count[0] != times:
raise AssertionError(f"Expected exactly {times} XLA compilations, "
f"but executed {count[0]}")
def device_under_test():
return FLAGS.jax_test_dut or xla_bridge.get_backend().platform
def if_device_under_test(device_type: Union[str, Sequence[str]],
if_true, if_false):
"""Chooses `if_true` of `if_false` based on device_under_test."""
if device_under_test() in ([device_type] if isinstance(device_type, str)
else device_type):
return if_true
else:
return if_false
def supported_dtypes():
if device_under_test() == "tpu":
types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16,
np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64}
else:
types = {np.bool_, np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64,
_dtypes.bfloat16, np.float16, np.float32, np.float64,
np.complex64, np.complex128}
if not config.x64_enabled:
types -= {np.uint64, np.int64, np.float64, np.complex128}
return types
2020-04-02 11:13:40 +02:00
def skip_if_unsupported_type(dtype):
dtype = np.dtype(dtype)
if dtype.type not in supported_dtypes():
raise unittest.SkipTest(
f"Type {dtype.name} not supported on {device_under_test()}")
2020-04-02 11:13:40 +02:00
2018-11-17 18:03:33 -08:00
def skip_on_devices(*disabled_devices):
"""A decorator for test methods to skip the test on certain devices."""
def skip(test_method):
@functools.wraps(test_method)
def test_method_wrapper(self, *args, **kwargs):
device = device_under_test()
2018-11-17 18:03:33 -08:00
if device in disabled_devices:
test_name = getattr(test_method, '__name__', '[unknown test]')
raise unittest.SkipTest(
f"{test_name} not supported on {device.upper()}.")
2018-11-17 18:03:33 -08:00
return test_method(self, *args, **kwargs)
return test_method_wrapper
return skip
def set_host_platform_device_count(nr_devices: int):
"""Returns a closure that undoes the operation."""
prev_xla_flags = os.getenv("XLA_FLAGS")
flags_str = prev_xla_flags or ""
# Don't override user-specified device count, or other XLA flags.
if "xla_force_host_platform_device_count" not in flags_str:
os.environ["XLA_FLAGS"] = (flags_str +
f" --xla_force_host_platform_device_count={nr_devices}")
# Clear any cached backends so new CPU backend will pick up the env var.
xla_bridge.get_backend.cache_clear()
def undo():
if prev_xla_flags is None:
del os.environ["XLA_FLAGS"]
else:
os.environ["XLA_FLAGS"] = prev_xla_flags
xla_bridge.get_backend.cache_clear()
return undo
2018-11-17 18:03:33 -08:00
def skip_on_flag(flag_name, skip_value):
"""A decorator for test methods to skip the test when flags are set."""
def skip(test_method): # pylint: disable=missing-docstring
@functools.wraps(test_method)
def test_method_wrapper(self, *args, **kwargs):
2021-03-19 13:49:38 -07:00
flag_value = config._read(flag_name)
if flag_value == skip_value:
test_name = getattr(test_method, '__name__', '[unknown test]')
raise unittest.SkipTest(
f"{test_name} not supported when FLAGS.{flag_name} is {flag_value}")
return test_method(self, *args, **kwargs)
return test_method_wrapper
return skip
2018-11-17 18:03:33 -08:00
def format_test_name_suffix(opname, shapes, dtypes):
arg_descriptions = (format_shape_dtype_string(shape, dtype)
for shape, dtype in zip(shapes, dtypes))
return '{}_{}'.format(opname.capitalize(), '_'.join(arg_descriptions))
# We use special symbols, represented as singleton objects, to distinguish
# between NumPy scalars, Python scalars, and 0-D arrays.
class ScalarShape(object):
def __len__(self): return 0
class _NumpyScalar(ScalarShape): pass
class _PythonScalar(ScalarShape): pass
NUMPY_SCALAR_SHAPE = _NumpyScalar()
PYTHON_SCALAR_SHAPE = _PythonScalar()
def _dims_of_shape(shape):
"""Converts `shape` to a tuple of dimensions."""
if type(shape) in (list, tuple):
return shape
elif isinstance(shape, ScalarShape):
return ()
elif np.ndim(shape) == 0:
return (shape,)
else:
raise TypeError(type(shape))
def _cast_to_shape(value, shape, dtype):
"""Casts `value` to the correct Python type for `shape` and `dtype`."""
if shape is NUMPY_SCALAR_SHAPE:
# explicitly cast to NumPy scalar in case `value` is a Python scalar.
return np.dtype(dtype).type(value)
elif shape is PYTHON_SCALAR_SHAPE:
# explicitly cast to Python scalar via https://stackoverflow.com/a/11389998
return np.asarray(value).item()
elif type(shape) in (list, tuple):
assert np.shape(value) == tuple(shape)
return value
elif np.ndim(shape) == 0:
assert np.shape(value) == (shape,)
return value
else:
raise TypeError(type(shape))
def dtype_str(dtype):
return np.dtype(dtype).name
2018-11-17 18:03:33 -08:00
def format_shape_dtype_string(shape, dtype):
if isinstance(shape, np.ndarray):
return f'{dtype_str(dtype)}[{shape}]'
elif isinstance(shape, list):
shape = tuple(shape)
return _format_shape_dtype_string(shape, dtype)
@functools.lru_cache(maxsize=64)
def _format_shape_dtype_string(shape, dtype):
if shape is NUMPY_SCALAR_SHAPE:
return dtype_str(dtype)
elif shape is PYTHON_SCALAR_SHAPE:
return 'py' + dtype_str(dtype)
elif type(shape) is tuple:
2018-11-17 18:03:33 -08:00
shapestr = ','.join(str(dim) for dim in shape)
return '{}[{}]'.format(dtype_str(dtype), shapestr)
elif type(shape) is int:
return '{}[{},]'.format(dtype_str(dtype), shape)
else:
raise TypeError(type(shape))
2018-11-17 18:03:33 -08:00
def _rand_dtype(rand, shape, dtype, scale=1., post=lambda x: x):
"""Produce random values given shape, dtype, scale, and post-processor.
Args:
rand: a function for producing random values of a given shape, e.g. a
bound version of either np.RandomState.randn or np.RandomState.rand.
2018-11-17 18:03:33 -08:00
shape: a shape value as a tuple of positive integers.
dtype: a numpy dtype.
scale: optional, a multiplicative scale for the random values (default 1).
post: optional, a callable for post-processing the random values (default
identity).
Returns:
An ndarray of the given shape and dtype using random values based on a call
to rand but scaled, converted to the appropriate dtype, and post-processed.
"""
r = lambda: np.asarray(scale * rand(*_dims_of_shape(shape)), dtype)
if _dtypes.issubdtype(dtype, np.complexfloating):
2018-11-17 18:03:33 -08:00
vals = r() + 1.0j * r()
else:
vals = r()
return _cast_to_shape(np.asarray(post(vals), dtype), shape, dtype)
2018-11-17 18:03:33 -08:00
def rand_fullrange(rng, standardize_nans=False):
"""Random numbers that span the full range of available bits."""
def gen(shape, dtype, post=lambda x: x):
dtype = np.dtype(dtype)
size = dtype.itemsize * np.prod(_dims_of_shape(shape))
vals = rng.randint(0, np.iinfo(np.uint8).max, size=size, dtype=np.uint8)
vals = post(vals).view(dtype).reshape(shape)
# Non-standard NaNs cause errors in numpy equality assertions.
if standardize_nans and np.issubdtype(dtype, np.floating):
vals[np.isnan(vals)] = np.nan
return _cast_to_shape(vals, shape, dtype)
return gen
def rand_default(rng, scale=3):
return partial(_rand_dtype, rng.randn, scale=scale)
2018-11-17 18:03:33 -08:00
def rand_nonzero(rng):
post = lambda x: np.where(x == 0, np.array(1, dtype=x.dtype), x)
return partial(_rand_dtype, rng.randn, scale=3, post=post)
2018-11-17 18:03:33 -08:00
def rand_positive(rng):
2018-11-17 18:03:33 -08:00
post = lambda x: x + 1
return partial(_rand_dtype, rng.rand, scale=2, post=post)
2018-11-17 18:03:33 -08:00
def rand_small(rng):
return partial(_rand_dtype, rng.randn, scale=1e-3)
2018-11-17 18:03:33 -08:00
def rand_not_small(rng, offset=10.):
post = lambda x: x + np.where(x > 0, offset, -offset)
return partial(_rand_dtype, rng.randn, scale=3., post=post)
2018-11-17 18:03:33 -08:00
def rand_small_positive(rng):
return partial(_rand_dtype, rng.rand, scale=2e-5)
2018-11-17 18:03:33 -08:00
def rand_uniform(rng, low=0.0, high=1.0):
assert low < high
2019-01-14 16:39:30 -05:00
post = lambda x: x * (high - low) + low
return partial(_rand_dtype, rng.rand, post=post)
2018-11-17 18:03:33 -08:00
def rand_some_equal(rng):
2018-11-17 18:03:33 -08:00
def post(x):
x_ravel = x.ravel()
if len(x_ravel) == 0:
return x
flips = rng.rand(*np.shape(x)) < 0.5
return np.where(flips, x_ravel[0], x)
2018-11-17 18:03:33 -08:00
return partial(_rand_dtype, rng.randn, scale=100., post=post)
2018-11-17 18:03:33 -08:00
def rand_some_inf(rng):
2018-11-17 18:03:33 -08:00
"""Return a random sampler that produces infinities in floating types."""
base_rand = rand_default(rng)
2018-11-17 18:03:33 -08:00
2019-09-23 10:04:31 -07:00
"""
TODO: Complex numbers are not correctly tested
If blocks should be switched in order, and relevant tests should be fixed
"""
2018-11-17 18:03:33 -08:00
def rand(shape, dtype):
"""The random sampler function."""
if not _dtypes.issubdtype(dtype, np.floating):
2018-11-17 18:03:33 -08:00
# only float types have inf
return base_rand(shape, dtype)
if _dtypes.issubdtype(dtype, np.complexfloating):
base_dtype = np.real(np.array(0, dtype=dtype)).dtype
out = (rand(shape, base_dtype) +
np.array(1j, dtype) * rand(shape, base_dtype))
return _cast_to_shape(out, shape, dtype)
dims = _dims_of_shape(shape)
posinf_flips = rng.rand(*dims) < 0.1
neginf_flips = rng.rand(*dims) < 0.1
2018-11-17 18:03:33 -08:00
vals = base_rand(shape, dtype)
vals = np.where(posinf_flips, np.array(np.inf, dtype=dtype), vals)
vals = np.where(neginf_flips, np.array(-np.inf, dtype=dtype), vals)
2018-11-17 18:03:33 -08:00
return _cast_to_shape(np.asarray(vals, dtype=dtype), shape, dtype)
2018-11-17 18:03:33 -08:00
return rand
def rand_some_nan(rng):
"""Return a random sampler that produces nans in floating types."""
base_rand = rand_default(rng)
def rand(shape, dtype):
"""The random sampler function."""
if _dtypes.issubdtype(dtype, np.complexfloating):
base_dtype = np.real(np.array(0, dtype=dtype)).dtype
out = (rand(shape, base_dtype) +
np.array(1j, dtype) * rand(shape, base_dtype))
return _cast_to_shape(out, shape, dtype)
if not _dtypes.issubdtype(dtype, np.floating):
2019-09-23 08:53:49 -07:00
# only float types have inf
return base_rand(shape, dtype)
dims = _dims_of_shape(shape)
r = rng.rand(*dims)
nan_flips = r < 0.1
neg_nan_flips = r < 0.05
vals = base_rand(shape, dtype)
vals = np.where(nan_flips, np.array(np.nan, dtype=dtype), vals)
vals = np.where(neg_nan_flips, np.array(-np.nan, dtype=dtype), vals)
return _cast_to_shape(np.asarray(vals, dtype=dtype), shape, dtype)
2018-11-17 18:03:33 -08:00
return rand
def rand_some_inf_and_nan(rng):
"""Return a random sampler that produces infinities in floating types."""
base_rand = rand_default(rng)
2019-09-23 10:04:31 -07:00
"""
TODO: Complex numbers are not correctly tested
If blocks should be switched in order, and relevant tests should be fixed
"""
def rand(shape, dtype):
"""The random sampler function."""
if not _dtypes.issubdtype(dtype, np.floating):
# only float types have inf
return base_rand(shape, dtype)
if _dtypes.issubdtype(dtype, np.complexfloating):
base_dtype = np.real(np.array(0, dtype=dtype)).dtype
out = (rand(shape, base_dtype) +
np.array(1j, dtype) * rand(shape, base_dtype))
return _cast_to_shape(out, shape, dtype)
dims = _dims_of_shape(shape)
posinf_flips = rng.rand(*dims) < 0.1
neginf_flips = rng.rand(*dims) < 0.1
nan_flips = rng.rand(*dims) < 0.1
vals = base_rand(shape, dtype)
vals = np.where(posinf_flips, np.array(np.inf, dtype=dtype), vals)
vals = np.where(neginf_flips, np.array(-np.inf, dtype=dtype), vals)
vals = np.where(nan_flips, np.array(np.nan, dtype=dtype), vals)
return _cast_to_shape(np.asarray(vals, dtype=dtype), shape, dtype)
return rand
2018-11-17 18:03:33 -08:00
# TODO(mattjj): doesn't handle complex types
def rand_some_zero(rng):
2018-11-17 18:03:33 -08:00
"""Return a random sampler that produces some zeros."""
base_rand = rand_default(rng)
2018-11-17 18:03:33 -08:00
def rand(shape, dtype):
"""The random sampler function."""
dims = _dims_of_shape(shape)
zeros = rng.rand(*dims) < 0.5
2018-11-17 18:03:33 -08:00
vals = base_rand(shape, dtype)
vals = np.where(zeros, np.array(0, dtype=dtype), vals)
2018-11-17 18:03:33 -08:00
return _cast_to_shape(np.asarray(vals, dtype=dtype), shape, dtype)
2018-11-17 18:03:33 -08:00
return rand
def rand_int(rng, low=0, high=None):
def fn(shape, dtype):
nonlocal high
if low == 0 and high is None:
if np.issubdtype(dtype, np.integer):
high = np.iinfo(dtype).max
else:
raise ValueError("rand_int requires an explicit `high` value for "
"non-integer types.")
return rng.randint(low, high=high, size=shape, dtype=dtype)
return fn
def rand_unique_int(rng, high=None):
2020-04-19 11:49:15 -07:00
def fn(shape, dtype):
return rng.choice(np.arange(high or prod(shape), dtype=dtype),
2020-04-19 11:49:15 -07:00
size=shape, replace=False)
return fn
def rand_bool(rng):
def generator(shape, dtype):
return _cast_to_shape(rng.rand(*_dims_of_shape(shape)) < 0.5, shape, dtype)
return generator
2018-11-17 18:03:33 -08:00
def check_raises(thunk, err_type, msg):
try:
thunk()
assert False
except err_type as e:
2018-12-06 21:47:47 -05:00
assert str(e).startswith(msg), "\n{}\n\n{}\n".format(e, msg)
2018-11-17 18:03:33 -08:00
def check_raises_regexp(thunk, err_type, pattern):
try:
thunk()
assert False
except err_type as e:
assert re.match(pattern, str(e)), "{}\n\n{}\n".format(e, pattern)
def iter_eqns(jaxpr):
# TODO(necula): why doesn't this search in params?
for eqn in jaxpr.eqns:
yield eqn
for subjaxpr in core.subjaxprs(jaxpr):
yield from iter_eqns(subjaxpr)
def assert_dot_precision(expected_precision, fun, *args):
jaxpr = api.make_jaxpr(fun)(*args)
precisions = [eqn.params['precision'] for eqn in iter_eqns(jaxpr.jaxpr)
if eqn.primitive == lax.dot_general_p]
for precision in precisions:
msg = "Unexpected precision: {} != {}".format(expected_precision, precision)
if isinstance(precision, tuple):
assert precision[0] == expected_precision, msg
assert precision[1] == expected_precision, msg
else:
assert precision == expected_precision, msg
_CACHED_INDICES: Dict[int, Sequence[int]] = {}
def cases_from_list(xs):
xs = list(xs)
n = len(xs)
k = min(n, FLAGS.num_generated_cases)
# Random sampling for every parameterized test is expensive. Do it once and
# cache the result.
indices = _CACHED_INDICES.get(n)
if indices is None:
rng = npr.RandomState(42)
_CACHED_INDICES[n] = indices = rng.permutation(n)
return [xs[i] for i in indices[:k]]
def cases_from_gens(*gens):
sizes = [1, 3, 10]
cases_per_size = int(FLAGS.num_generated_cases / len(sizes)) + 1
for size in sizes:
for i in range(cases_per_size):
yield ('_{}_{}'.format(size, i),) + tuple(gen(size) for gen in gens)
def named_cases_from_sampler(gen):
seen = set()
retries = 0
rng = npr.RandomState(42)
def choose_one(x):
if not isinstance(x, (list, tuple)):
x = list(x)
return [x[rng.randint(len(x))]]
while (len(seen) < FLAGS.num_generated_cases and
retries < FLAGS.max_cases_sampling_retries):
retries += 1
cases = list(gen(choose_one))
if not cases:
continue
if len(cases) > 1:
raise RuntimeError("Generator is expected to only return a single case when sampling")
case = cases[0]
if case["testcase_name"] in seen:
continue
retries = 0
seen.add(case["testcase_name"])
yield case
class JaxTestLoader(absltest.TestLoader):
def getTestCaseNames(self, testCaseClass):
names = super().getTestCaseNames(testCaseClass)
if FLAGS.test_targets:
pattern = re.compile(FLAGS.test_targets)
names = [name for name in names
if pattern.search(f"{testCaseClass.__name__}.{name}")]
if FLAGS.exclude_test_targets:
pattern = re.compile(FLAGS.exclude_test_targets)
names = [name for name in names
if not pattern.search(f"{testCaseClass.__name__}.{name}")]
return names
2018-11-17 18:03:33 -08:00
class JaxTestCase(parameterized.TestCase):
"""Base class for JAX tests including numerical checks and boilerplate."""
# TODO(mattjj): this obscures the error messages from failures, figure out how
# to re-enable it
# def tearDown(self) -> None:
# assert core.reset_trace_state()
def setUp(self):
super(JaxTestCase, self).setUp()
2021-03-19 13:49:38 -07:00
config.update('jax_enable_checks', True)
# We use the adler32 hash for two reasons.
# a) it is deterministic run to run, unlike hash() which is randomized.
# b) it returns values in int32 range, which RandomState requires.
self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode()))
def rng(self):
return self._rng
def assertArraysEqual(self, x, y, *, check_dtypes=True):
"""Assert that x and y arrays are exactly equal."""
if check_dtypes:
self.assertDtypesMatch(x, y)
np.testing.assert_array_equal(x, y)
def assertArraysAllClose(self, x, y, *, check_dtypes=True, atol=None,
rtol=None):
2018-11-17 18:03:33 -08:00
"""Assert that x and y are close (up to numerical tolerances)."""
self.assertEqual(x.shape, y.shape)
atol = max(tolerance(_dtype(x), atol), tolerance(_dtype(y), atol))
rtol = max(tolerance(_dtype(x), rtol), tolerance(_dtype(y), rtol))
2018-11-17 18:03:33 -08:00
_assert_numpy_allclose(x, y, atol=atol, rtol=rtol)
2018-11-17 18:03:33 -08:00
if check_dtypes:
self.assertDtypesMatch(x, y)
def assertDtypesMatch(self, x, y, *, canonicalize_dtypes=True):
if not config.x64_enabled and canonicalize_dtypes:
self.assertEqual(_dtypes.canonicalize_dtype(_dtype(x)),
_dtypes.canonicalize_dtype(_dtype(y)))
else:
self.assertEqual(_dtype(x), _dtype(y))
2018-11-17 18:03:33 -08:00
def assertAllClose(self, x, y, *, check_dtypes=True, atol=None, rtol=None,
canonicalize_dtypes=True):
2018-11-17 18:03:33 -08:00
"""Assert that x and y, either arrays or nested tuples/lists, are close."""
if isinstance(x, dict):
self.assertIsInstance(y, dict)
self.assertEqual(set(x.keys()), set(y.keys()))
for k in x.keys():
self.assertAllClose(x[k], y[k], check_dtypes=check_dtypes, atol=atol,
rtol=rtol, canonicalize_dtypes=canonicalize_dtypes)
elif is_sequence(x) and not hasattr(x, '__array__'):
self.assertTrue(is_sequence(y) and not hasattr(y, '__array__'))
self.assertEqual(len(x), len(y))
for x_elt, y_elt in zip(x, y):
self.assertAllClose(x_elt, y_elt, check_dtypes=check_dtypes, atol=atol,
rtol=rtol, canonicalize_dtypes=canonicalize_dtypes)
elif hasattr(x, '__array__') or np.isscalar(x):
self.assertTrue(hasattr(y, '__array__') or np.isscalar(y))
if check_dtypes:
self.assertDtypesMatch(x, y, canonicalize_dtypes=canonicalize_dtypes)
x = np.asarray(x)
y = np.asarray(y)
self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol)
elif x == y:
return
else:
raise TypeError((type(x), type(y)))
2018-11-17 18:03:33 -08:00
2019-11-26 13:56:58 +01:00
def assertMultiLineStrippedEqual(self, expected, what):
"""Asserts two strings are equal, after dedenting and stripping each line."""
expected = textwrap.dedent(expected)
what = textwrap.dedent(what)
2019-11-26 13:56:58 +01:00
ignore_space_re = re.compile(r'\s*\n\s*')
expected_clean = re.sub(ignore_space_re, '\n', expected.strip())
what_clean = re.sub(ignore_space_re, '\n', what.strip())
self.assertMultiLineEqual(expected_clean, what_clean,
msg="Found\n{}\nExpecting\n{}".format(what, expected))
def _CompileAndCheck(self, fun, args_maker, *, check_dtypes=True,
2021-03-09 13:25:38 -08:00
rtol=None, atol=None, check_cache_misses=True):
2018-11-17 18:03:33 -08:00
"""Helper method for running JAX compilation and allclose assertions."""
args = args_maker()
def wrapped_fun(*args):
self.assertTrue(python_should_be_executing)
return fun(*args)
python_should_be_executing = True
python_ans = fun(*args)
python_shapes = tree_map(lambda x: np.shape(x), python_ans)
np_shapes = tree_map(lambda x: np.shape(np.asarray(x)), python_ans)
self.assertEqual(python_shapes, np_shapes)
cache_misses = xla.xla_primitive_callable.cache_info().misses
python_ans = fun(*args)
2021-03-09 13:25:38 -08:00
if check_cache_misses:
self.assertEqual(
cache_misses, xla.xla_primitive_callable.cache_info().misses,
"Compilation detected during second call of {} in op-by-op "
"mode.".format(fun))
2018-11-17 18:03:33 -08:00
cfun = api.jit(wrapped_fun)
python_should_be_executing = True
monitored_ans = cfun(*args)
python_should_be_executing = False
compiled_ans = cfun(*args)
self.assertAllClose(python_ans, monitored_ans, check_dtypes=check_dtypes,
atol=atol, rtol=rtol)
self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes,
atol=atol, rtol=rtol)
2018-11-17 18:03:33 -08:00
args = args_maker()
python_should_be_executing = True
python_ans = fun(*args)
python_should_be_executing = False
compiled_ans = cfun(*args)
self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes,
atol=atol, rtol=rtol)
2018-11-17 18:03:33 -08:00
def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker,
check_dtypes=True, tol=None,
canonicalize_dtypes=True):
2018-11-17 18:03:33 -08:00
args = args_maker()
lax_ans = lax_op(*args)
numpy_ans = numpy_reference_op(*args)
self.assertAllClose(numpy_ans, lax_ans, check_dtypes=check_dtypes,
atol=tol, rtol=tol,
canonicalize_dtypes=canonicalize_dtypes)
class BufferDonationTestCase(JaxTestCase):
assertDeleted = lambda self, x: self._assertDeleted(x, True)
assertNotDeleted = lambda self, x: self._assertDeleted(x, False)
def _assertDeleted(self, x, deleted):
if hasattr(x, "device_buffer"):
self.assertEqual(x.device_buffer.is_deleted(), deleted)
else:
for buffer in x.device_buffers:
self.assertEqual(buffer.is_deleted(), deleted)
@contextmanager
def ignore_warning(**kw):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", **kw)
yield
class _cached_property:
null = object()
def __init__(self, method):
self._method = method
self._value = self.null
def __get__(self, obj, cls):
if self._value is self.null:
self._value = self._method(obj)
return self._value
class _LazyDtypes:
"""A class that unifies lists of supported dtypes.
These could be module-level constants, but device_under_test() is not always
known at import time, so we need to define these lists lazily.
"""
def supported(self, dtypes):
supported = supported_dtypes()
return type(dtypes)(d for d in dtypes if d in supported)
@_cached_property
def floating(self):
return self.supported([np.float32, np.float64])
@_cached_property
def all_floating(self):
return self.supported([_dtypes.bfloat16, np.float16, np.float32, np.float64])
@_cached_property
def integer(self):
return self.supported([np.int32, np.int64])
@_cached_property
def all_integer(self):
return self.supported([np.int8, np.int16, np.int32, np.int64])
@_cached_property
def unsigned(self):
return self.supported([np.uint32, np.uint64])
@_cached_property
def all_unsigned(self):
return self.supported([np.uint8, np.uint16, np.uint32, np.uint64])
@_cached_property
def complex(self):
return self.supported([np.complex64, np.complex128])
@_cached_property
def boolean(self):
return self.supported([np.bool_])
@_cached_property
def inexact(self):
return self.floating + self.complex
@_cached_property
def all_inexact(self):
return self.all_floating + self.complex
@_cached_property
def numeric(self):
return self.floating + self.integer + self.unsigned + self.complex
@_cached_property
def all(self):
return (self.all_floating + self.all_integer + self.all_unsigned +
self.complex + self.boolean)
dtypes = _LazyDtypes()