2021-09-24 07:02:08 -07:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
import inspect
|
|
|
|
import functools
|
|
|
|
from functools import partial
|
|
|
|
import re
|
|
|
|
import os
|
|
|
|
import textwrap
|
2022-04-21 13:44:12 -07:00
|
|
|
from typing import Dict, List, Generator, Sequence, Tuple, Union
|
2021-09-24 07:02:08 -07:00
|
|
|
import unittest
|
|
|
|
import warnings
|
|
|
|
import zlib
|
|
|
|
|
|
|
|
from absl.testing import absltest
|
|
|
|
from absl.testing import parameterized
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import numpy.random as npr
|
|
|
|
|
|
|
|
from jax._src import api
|
|
|
|
from jax import core
|
|
|
|
from jax._src import dtypes as _dtypes
|
|
|
|
from jax import lax
|
|
|
|
from jax._src.config import flags, bool_env, config
|
|
|
|
from jax._src.util import prod, unzip2
|
2022-04-04 14:39:43 -07:00
|
|
|
from jax.tree_util import tree_map, tree_all
|
2021-09-24 07:02:08 -07:00
|
|
|
from jax._src.lib import xla_bridge
|
2021-11-22 08:22:10 -08:00
|
|
|
from jax._src import dispatch
|
2022-04-21 13:44:12 -07:00
|
|
|
from jax._src.public_test_util import ( # noqa: F401
|
2022-04-04 14:39:43 -07:00
|
|
|
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
|
|
|
|
check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, device_under_test, tolerance)
|
2021-11-30 06:08:26 -08:00
|
|
|
from jax.interpreters import mlir
|
2022-02-23 10:46:27 -08:00
|
|
|
from jax.experimental.maps import Mesh
|
2021-09-24 07:02:08 -07:00
|
|
|
|
2022-04-04 14:39:43 -07:00
|
|
|
# This submodule includes private test utilities that are not exported to
|
|
|
|
# jax.test_util. Functionality appearing here is for internal use only, and
|
|
|
|
# may be changed or removed at any time and without any deprecation cycle.
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
FLAGS = flags.FLAGS
|
2021-11-15 07:56:34 -08:00
|
|
|
flags.DEFINE_string(
|
2021-09-24 07:02:08 -07:00
|
|
|
'jax_test_dut', '',
|
|
|
|
help=
|
|
|
|
'Describes the device under test in case special consideration is required.'
|
|
|
|
)
|
|
|
|
|
|
|
|
flags.DEFINE_integer(
|
|
|
|
'num_generated_cases',
|
2021-10-04 17:54:18 -07:00
|
|
|
int(os.getenv('JAX_NUM_GENERATED_CASES', '10')),
|
2021-09-24 07:02:08 -07:00
|
|
|
help='Number of generated cases to test')
|
|
|
|
|
|
|
|
flags.DEFINE_integer(
|
|
|
|
'max_cases_sampling_retries',
|
2021-10-04 17:54:18 -07:00
|
|
|
int(os.getenv('JAX_MAX_CASES_SAMPLING_RETRIES', '100')),
|
2021-09-24 07:02:08 -07:00
|
|
|
'Number of times a failed test sample should be retried. '
|
|
|
|
'When an unseen case cannot be generated in this many trials, the '
|
|
|
|
'sampling process is terminated.'
|
|
|
|
)
|
|
|
|
|
|
|
|
flags.DEFINE_bool(
|
|
|
|
'jax_skip_slow_tests',
|
|
|
|
bool_env('JAX_SKIP_SLOW_TESTS', False),
|
|
|
|
help='Skip tests marked as slow (> 5 sec).'
|
|
|
|
)
|
|
|
|
|
|
|
|
flags.DEFINE_string(
|
|
|
|
'test_targets', '',
|
2021-09-28 18:42:44 +02:00
|
|
|
'Regular expression specifying which tests to run, called via re.search on '
|
2021-09-24 07:02:08 -07:00
|
|
|
'the test name. If empty or unspecified, run all tests.'
|
|
|
|
)
|
|
|
|
flags.DEFINE_string(
|
|
|
|
'exclude_test_targets', '',
|
2021-09-28 18:42:44 +02:00
|
|
|
'Regular expression specifying which tests NOT to run, called via re.search '
|
2021-09-24 07:02:08 -07:00
|
|
|
'on the test name. If empty or unspecified, run all tests.'
|
|
|
|
)
|
|
|
|
|
|
|
|
def num_float_bits(dtype):
|
|
|
|
return _dtypes.finfo(_dtypes.canonicalize_dtype(dtype)).bits
|
|
|
|
|
2021-12-09 16:57:29 -08:00
|
|
|
def to_default_dtype(arr):
|
|
|
|
"""Convert a value to an array with JAX's default dtype.
|
|
|
|
|
|
|
|
This is generally used for type conversions of values returned by numpy functions,
|
|
|
|
to make their dtypes take into account the state of the ``jax_enable_x64`` and
|
|
|
|
``jax_default_dtype_bits`` flags.
|
|
|
|
"""
|
|
|
|
arr = np.asarray(arr)
|
2021-12-09 09:47:21 -08:00
|
|
|
dtype = _dtypes._default_types.get(arr.dtype.kind)
|
|
|
|
return arr.astype(_dtypes.canonicalize_dtype(dtype)) if dtype else arr
|
|
|
|
|
|
|
|
def with_jax_dtype_defaults(func, use_defaults=True):
|
|
|
|
"""Return a version of a function with outputs that match JAX's default dtypes.
|
|
|
|
|
|
|
|
This is generally used to wrap numpy functions within tests, in order to make
|
|
|
|
their default output dtypes match those of corresponding JAX functions, taking
|
|
|
|
into account the state of the ``jax_enable_x64`` and ``jax_default_dtype_bits``
|
|
|
|
flags.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
use_defaults : whether to convert any given output to the default dtype. May be
|
|
|
|
a single boolean, in which case it specifies the conversion for all outputs,
|
|
|
|
or may be a a pytree with the same structure as the function output.
|
|
|
|
"""
|
|
|
|
@functools.wraps(func)
|
|
|
|
def wrapped(*args, **kwargs):
|
|
|
|
result = func(*args, **kwargs)
|
|
|
|
if isinstance(use_defaults, bool):
|
2021-12-09 16:57:29 -08:00
|
|
|
return tree_map(to_default_dtype, result) if use_defaults else result
|
2021-12-09 09:47:21 -08:00
|
|
|
else:
|
2021-12-09 16:57:29 -08:00
|
|
|
f = lambda arr, use_default: to_default_dtype(arr) if use_default else arr
|
2021-12-09 09:47:21 -08:00
|
|
|
return tree_map(f, result, use_defaults)
|
|
|
|
return wrapped
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
def is_sequence(x):
|
|
|
|
try:
|
|
|
|
iter(x)
|
|
|
|
except TypeError:
|
|
|
|
return False
|
|
|
|
else:
|
|
|
|
return True
|
|
|
|
|
|
|
|
def _normalize_tolerance(tol):
|
|
|
|
tol = tol or 0
|
|
|
|
if isinstance(tol, dict):
|
|
|
|
return {np.dtype(k): v for k, v in tol.items()}
|
|
|
|
else:
|
|
|
|
return {k: tol for k in _default_tolerance}
|
|
|
|
|
|
|
|
def join_tolerance(tol1, tol2):
|
|
|
|
tol1 = _normalize_tolerance(tol1)
|
|
|
|
tol2 = _normalize_tolerance(tol2)
|
|
|
|
out = tol1
|
|
|
|
for k, v in tol2.items():
|
|
|
|
out[k] = max(v, tol1.get(k, 0))
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
def check_eq(xs, ys, err_msg=''):
|
|
|
|
assert_close = partial(_assert_numpy_allclose, err_msg=err_msg)
|
2022-04-01 14:51:54 -07:00
|
|
|
tree_all(tree_map(assert_close, xs, ys))
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def count_device_put():
|
2021-11-22 08:22:10 -08:00
|
|
|
device_put = dispatch.device_put
|
2021-09-24 07:02:08 -07:00
|
|
|
count = [0]
|
|
|
|
|
|
|
|
def device_put_and_count(*args, **kwargs):
|
|
|
|
count[0] += 1
|
|
|
|
return device_put(*args, **kwargs)
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
dispatch.device_put = device_put_and_count
|
2021-09-24 07:02:08 -07:00
|
|
|
try:
|
|
|
|
yield count
|
|
|
|
finally:
|
2021-11-22 08:22:10 -08:00
|
|
|
dispatch.device_put = device_put
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def count_primitive_compiles():
|
2021-11-22 08:22:10 -08:00
|
|
|
dispatch.xla_primitive_callable.cache_clear()
|
2021-09-24 07:02:08 -07:00
|
|
|
|
2021-10-13 10:56:21 -04:00
|
|
|
count = [-1]
|
2021-09-24 07:02:08 -07:00
|
|
|
try:
|
|
|
|
yield count
|
|
|
|
finally:
|
2021-11-22 08:22:10 -08:00
|
|
|
count[0] = dispatch.xla_primitive_callable.cache_info().misses
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def count_jit_and_pmap_compiles():
|
|
|
|
# No need to clear any caches since we generally jit and pmap fresh callables
|
|
|
|
# in tests.
|
|
|
|
|
2021-11-30 06:08:26 -08:00
|
|
|
mlir_jaxpr_subcomp = mlir.jaxpr_subcomp
|
2021-09-24 07:02:08 -07:00
|
|
|
count = [0]
|
|
|
|
|
2021-11-30 06:08:26 -08:00
|
|
|
def mlir_jaxpr_subcomp_and_count(*args, **kwargs):
|
|
|
|
count[0] += 1
|
|
|
|
return mlir_jaxpr_subcomp(*args, **kwargs)
|
|
|
|
|
|
|
|
mlir.jaxpr_subcomp = mlir_jaxpr_subcomp_and_count
|
2021-09-24 07:02:08 -07:00
|
|
|
try:
|
|
|
|
yield count
|
|
|
|
finally:
|
2021-11-30 06:08:26 -08:00
|
|
|
mlir.jaxpr_subcomp = mlir_jaxpr_subcomp
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def assert_num_jit_and_pmap_compilations(times):
|
|
|
|
with count_jit_and_pmap_compiles() as count:
|
|
|
|
yield
|
|
|
|
if count[0] != times:
|
|
|
|
raise AssertionError(f"Expected exactly {times} XLA compilations, "
|
|
|
|
f"but executed {count[0]}")
|
|
|
|
|
|
|
|
def if_device_under_test(device_type: Union[str, Sequence[str]],
|
|
|
|
if_true, if_false):
|
|
|
|
"""Chooses `if_true` of `if_false` based on device_under_test."""
|
|
|
|
if device_under_test() in ([device_type] if isinstance(device_type, str)
|
|
|
|
else device_type):
|
|
|
|
return if_true
|
|
|
|
else:
|
|
|
|
return if_false
|
|
|
|
|
|
|
|
def supported_dtypes():
|
|
|
|
if device_under_test() == "tpu":
|
|
|
|
types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16,
|
|
|
|
np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64}
|
2021-11-15 07:56:34 -08:00
|
|
|
elif device_under_test() == "iree":
|
|
|
|
types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16,
|
|
|
|
np.uint32, np.float32}
|
2021-09-24 07:02:08 -07:00
|
|
|
else:
|
|
|
|
types = {np.bool_, np.int8, np.int16, np.int32, np.int64,
|
|
|
|
np.uint8, np.uint16, np.uint32, np.uint64,
|
|
|
|
_dtypes.bfloat16, np.float16, np.float32, np.float64,
|
|
|
|
np.complex64, np.complex128}
|
|
|
|
if not config.x64_enabled:
|
|
|
|
types -= {np.uint64, np.int64, np.float64, np.complex128}
|
|
|
|
return types
|
|
|
|
|
|
|
|
def is_device_rocm():
|
|
|
|
return xla_bridge.get_backend().platform_version.startswith('rocm')
|
|
|
|
|
|
|
|
def is_device_cuda():
|
|
|
|
return xla_bridge.get_backend().platform_version.startswith('cuda')
|
|
|
|
|
|
|
|
def _get_device_tags():
|
|
|
|
"""returns a set of tags definded for the device under test"""
|
|
|
|
if is_device_rocm():
|
2022-05-12 19:13:00 +01:00
|
|
|
device_tags = {device_under_test(), "rocm"}
|
2021-09-24 07:02:08 -07:00
|
|
|
elif is_device_cuda():
|
2022-05-12 19:13:00 +01:00
|
|
|
device_tags = {device_under_test(), "cuda"}
|
2021-09-24 07:02:08 -07:00
|
|
|
else:
|
2022-05-12 19:13:00 +01:00
|
|
|
device_tags = {device_under_test()}
|
2021-09-24 07:02:08 -07:00
|
|
|
return device_tags
|
|
|
|
|
|
|
|
def skip_on_devices(*disabled_devices):
|
|
|
|
"""A decorator for test methods to skip the test on certain devices."""
|
|
|
|
def skip(test_method):
|
|
|
|
@functools.wraps(test_method)
|
|
|
|
def test_method_wrapper(self, *args, **kwargs):
|
|
|
|
device_tags = _get_device_tags()
|
|
|
|
if device_tags & set(disabled_devices):
|
|
|
|
test_name = getattr(test_method, '__name__', '[unknown test]')
|
|
|
|
raise unittest.SkipTest(
|
|
|
|
f"{test_name} not supported on device with tags {device_tags}.")
|
|
|
|
return test_method(self, *args, **kwargs)
|
|
|
|
return test_method_wrapper
|
|
|
|
return skip
|
|
|
|
|
|
|
|
def set_host_platform_device_count(nr_devices: int):
|
|
|
|
"""Returns a closure that undoes the operation."""
|
|
|
|
prev_xla_flags = os.getenv("XLA_FLAGS")
|
|
|
|
flags_str = prev_xla_flags or ""
|
|
|
|
# Don't override user-specified device count, or other XLA flags.
|
|
|
|
if "xla_force_host_platform_device_count" not in flags_str:
|
|
|
|
os.environ["XLA_FLAGS"] = (flags_str +
|
|
|
|
f" --xla_force_host_platform_device_count={nr_devices}")
|
|
|
|
# Clear any cached backends so new CPU backend will pick up the env var.
|
|
|
|
xla_bridge.get_backend.cache_clear()
|
|
|
|
def undo():
|
|
|
|
if prev_xla_flags is None:
|
|
|
|
del os.environ["XLA_FLAGS"]
|
|
|
|
else:
|
|
|
|
os.environ["XLA_FLAGS"] = prev_xla_flags
|
|
|
|
xla_bridge.get_backend.cache_clear()
|
|
|
|
return undo
|
|
|
|
|
|
|
|
def skip_on_flag(flag_name, skip_value):
|
|
|
|
"""A decorator for test methods to skip the test when flags are set."""
|
|
|
|
def skip(test_method): # pylint: disable=missing-docstring
|
|
|
|
@functools.wraps(test_method)
|
|
|
|
def test_method_wrapper(self, *args, **kwargs):
|
|
|
|
flag_value = config._read(flag_name)
|
|
|
|
if flag_value == skip_value:
|
|
|
|
test_name = getattr(test_method, '__name__', '[unknown test]')
|
|
|
|
raise unittest.SkipTest(
|
|
|
|
f"{test_name} not supported when FLAGS.{flag_name} is {flag_value}")
|
|
|
|
return test_method(self, *args, **kwargs)
|
|
|
|
return test_method_wrapper
|
|
|
|
return skip
|
|
|
|
|
|
|
|
|
|
|
|
def format_test_name_suffix(opname, shapes, dtypes):
|
|
|
|
arg_descriptions = (format_shape_dtype_string(shape, dtype)
|
|
|
|
for shape, dtype in zip(shapes, dtypes))
|
|
|
|
return '{}_{}'.format(opname.capitalize(), '_'.join(arg_descriptions))
|
|
|
|
|
|
|
|
|
|
|
|
# We use special symbols, represented as singleton objects, to distinguish
|
|
|
|
# between NumPy scalars, Python scalars, and 0-D arrays.
|
2022-05-12 19:13:00 +01:00
|
|
|
class ScalarShape:
|
2021-09-24 07:02:08 -07:00
|
|
|
def __len__(self): return 0
|
|
|
|
class _NumpyScalar(ScalarShape): pass
|
|
|
|
class _PythonScalar(ScalarShape): pass
|
|
|
|
NUMPY_SCALAR_SHAPE = _NumpyScalar()
|
|
|
|
PYTHON_SCALAR_SHAPE = _PythonScalar()
|
|
|
|
|
|
|
|
|
2022-06-06 21:28:13 -07:00
|
|
|
# Some shape combinations don't make sense.
|
|
|
|
def is_valid_shape(shape, dtype):
|
|
|
|
if shape == PYTHON_SCALAR_SHAPE:
|
|
|
|
return dtype == np.dtype(type(np.array(0, dtype=dtype).item()))
|
|
|
|
|
|
|
|
|
2021-09-24 07:02:08 -07:00
|
|
|
def _dims_of_shape(shape):
|
|
|
|
"""Converts `shape` to a tuple of dimensions."""
|
|
|
|
if type(shape) in (list, tuple):
|
|
|
|
return shape
|
|
|
|
elif isinstance(shape, ScalarShape):
|
|
|
|
return ()
|
|
|
|
elif np.ndim(shape) == 0:
|
|
|
|
return (shape,)
|
|
|
|
else:
|
|
|
|
raise TypeError(type(shape))
|
|
|
|
|
|
|
|
|
|
|
|
def _cast_to_shape(value, shape, dtype):
|
|
|
|
"""Casts `value` to the correct Python type for `shape` and `dtype`."""
|
|
|
|
if shape is NUMPY_SCALAR_SHAPE:
|
|
|
|
# explicitly cast to NumPy scalar in case `value` is a Python scalar.
|
|
|
|
return np.dtype(dtype).type(value)
|
|
|
|
elif shape is PYTHON_SCALAR_SHAPE:
|
|
|
|
# explicitly cast to Python scalar via https://stackoverflow.com/a/11389998
|
|
|
|
return np.asarray(value).item()
|
|
|
|
elif type(shape) in (list, tuple):
|
|
|
|
assert np.shape(value) == tuple(shape)
|
|
|
|
return value
|
|
|
|
elif np.ndim(shape) == 0:
|
|
|
|
assert np.shape(value) == (shape,)
|
|
|
|
return value
|
|
|
|
else:
|
|
|
|
raise TypeError(type(shape))
|
|
|
|
|
|
|
|
|
|
|
|
def dtype_str(dtype):
|
|
|
|
return np.dtype(dtype).name
|
|
|
|
|
|
|
|
|
|
|
|
def format_shape_dtype_string(shape, dtype):
|
|
|
|
if isinstance(shape, np.ndarray):
|
|
|
|
return f'{dtype_str(dtype)}[{shape}]'
|
|
|
|
elif isinstance(shape, list):
|
|
|
|
shape = tuple(shape)
|
|
|
|
return _format_shape_dtype_string(shape, dtype)
|
|
|
|
|
|
|
|
@functools.lru_cache(maxsize=64)
|
|
|
|
def _format_shape_dtype_string(shape, dtype):
|
|
|
|
if shape is NUMPY_SCALAR_SHAPE:
|
|
|
|
return dtype_str(dtype)
|
|
|
|
elif shape is PYTHON_SCALAR_SHAPE:
|
|
|
|
return 'py' + dtype_str(dtype)
|
|
|
|
elif type(shape) is tuple:
|
|
|
|
shapestr = ','.join(str(dim) for dim in shape)
|
2022-05-12 19:13:00 +01:00
|
|
|
return f'{dtype_str(dtype)}[{shapestr}]'
|
2021-09-24 07:02:08 -07:00
|
|
|
elif type(shape) is int:
|
2022-05-12 19:13:00 +01:00
|
|
|
return f'{dtype_str(dtype)}[{shape},]'
|
2021-09-24 07:02:08 -07:00
|
|
|
else:
|
|
|
|
raise TypeError(type(shape))
|
|
|
|
|
|
|
|
|
|
|
|
def _rand_dtype(rand, shape, dtype, scale=1., post=lambda x: x):
|
|
|
|
"""Produce random values given shape, dtype, scale, and post-processor.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
rand: a function for producing random values of a given shape, e.g. a
|
|
|
|
bound version of either np.RandomState.randn or np.RandomState.rand.
|
|
|
|
shape: a shape value as a tuple of positive integers.
|
|
|
|
dtype: a numpy dtype.
|
|
|
|
scale: optional, a multiplicative scale for the random values (default 1).
|
|
|
|
post: optional, a callable for post-processing the random values (default
|
|
|
|
identity).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An ndarray of the given shape and dtype using random values based on a call
|
|
|
|
to rand but scaled, converted to the appropriate dtype, and post-processed.
|
|
|
|
"""
|
2022-02-07 08:59:44 -08:00
|
|
|
if _dtypes.issubdtype(dtype, np.unsignedinteger):
|
|
|
|
r = lambda: np.asarray(scale * abs(rand(*_dims_of_shape(shape))), dtype)
|
|
|
|
else:
|
|
|
|
r = lambda: np.asarray(scale * rand(*_dims_of_shape(shape)), dtype)
|
2021-09-24 07:02:08 -07:00
|
|
|
if _dtypes.issubdtype(dtype, np.complexfloating):
|
|
|
|
vals = r() + 1.0j * r()
|
|
|
|
else:
|
|
|
|
vals = r()
|
|
|
|
return _cast_to_shape(np.asarray(post(vals), dtype), shape, dtype)
|
|
|
|
|
|
|
|
|
|
|
|
def rand_fullrange(rng, standardize_nans=False):
|
|
|
|
"""Random numbers that span the full range of available bits."""
|
|
|
|
def gen(shape, dtype, post=lambda x: x):
|
|
|
|
dtype = np.dtype(dtype)
|
2022-06-09 07:50:59 -07:00
|
|
|
size = dtype.itemsize * np.prod(_dims_of_shape(shape))
|
2021-09-24 07:02:08 -07:00
|
|
|
vals = rng.randint(0, np.iinfo(np.uint8).max, size=size, dtype=np.uint8)
|
2022-06-09 07:50:59 -07:00
|
|
|
vals = post(vals).view(dtype).reshape(shape)
|
2021-09-24 07:02:08 -07:00
|
|
|
# Non-standard NaNs cause errors in numpy equality assertions.
|
|
|
|
if standardize_nans and np.issubdtype(dtype, np.floating):
|
|
|
|
vals[np.isnan(vals)] = np.nan
|
|
|
|
return _cast_to_shape(vals, shape, dtype)
|
|
|
|
return gen
|
|
|
|
|
|
|
|
|
|
|
|
def rand_default(rng, scale=3):
|
|
|
|
return partial(_rand_dtype, rng.randn, scale=scale)
|
|
|
|
|
|
|
|
|
|
|
|
def rand_nonzero(rng):
|
|
|
|
post = lambda x: np.where(x == 0, np.array(1, dtype=x.dtype), x)
|
|
|
|
return partial(_rand_dtype, rng.randn, scale=3, post=post)
|
|
|
|
|
|
|
|
|
|
|
|
def rand_positive(rng):
|
|
|
|
post = lambda x: x + 1
|
|
|
|
return partial(_rand_dtype, rng.rand, scale=2, post=post)
|
|
|
|
|
|
|
|
|
|
|
|
def rand_small(rng):
|
|
|
|
return partial(_rand_dtype, rng.randn, scale=1e-3)
|
|
|
|
|
|
|
|
|
|
|
|
def rand_not_small(rng, offset=10.):
|
|
|
|
post = lambda x: x + np.where(x > 0, offset, -offset)
|
|
|
|
return partial(_rand_dtype, rng.randn, scale=3., post=post)
|
|
|
|
|
|
|
|
|
|
|
|
def rand_small_positive(rng):
|
|
|
|
return partial(_rand_dtype, rng.rand, scale=2e-5)
|
|
|
|
|
|
|
|
def rand_uniform(rng, low=0.0, high=1.0):
|
|
|
|
assert low < high
|
|
|
|
post = lambda x: x * (high - low) + low
|
|
|
|
return partial(_rand_dtype, rng.rand, post=post)
|
|
|
|
|
|
|
|
|
|
|
|
def rand_some_equal(rng):
|
|
|
|
|
|
|
|
def post(x):
|
|
|
|
x_ravel = x.ravel()
|
|
|
|
if len(x_ravel) == 0:
|
|
|
|
return x
|
|
|
|
flips = rng.rand(*np.shape(x)) < 0.5
|
|
|
|
return np.where(flips, x_ravel[0], x)
|
|
|
|
|
|
|
|
return partial(_rand_dtype, rng.randn, scale=100., post=post)
|
|
|
|
|
|
|
|
|
|
|
|
def rand_some_inf(rng):
|
|
|
|
"""Return a random sampler that produces infinities in floating types."""
|
|
|
|
base_rand = rand_default(rng)
|
|
|
|
|
2021-10-04 17:54:18 -07:00
|
|
|
# TODO: Complex numbers are not correctly tested
|
|
|
|
# If blocks should be switched in order, and relevant tests should be fixed
|
2021-09-24 07:02:08 -07:00
|
|
|
def rand(shape, dtype):
|
|
|
|
"""The random sampler function."""
|
|
|
|
if not _dtypes.issubdtype(dtype, np.floating):
|
|
|
|
# only float types have inf
|
|
|
|
return base_rand(shape, dtype)
|
|
|
|
|
|
|
|
if _dtypes.issubdtype(dtype, np.complexfloating):
|
|
|
|
base_dtype = np.real(np.array(0, dtype=dtype)).dtype
|
|
|
|
out = (rand(shape, base_dtype) +
|
|
|
|
np.array(1j, dtype) * rand(shape, base_dtype))
|
|
|
|
return _cast_to_shape(out, shape, dtype)
|
|
|
|
|
|
|
|
dims = _dims_of_shape(shape)
|
|
|
|
posinf_flips = rng.rand(*dims) < 0.1
|
|
|
|
neginf_flips = rng.rand(*dims) < 0.1
|
|
|
|
|
|
|
|
vals = base_rand(shape, dtype)
|
|
|
|
vals = np.where(posinf_flips, np.array(np.inf, dtype=dtype), vals)
|
|
|
|
vals = np.where(neginf_flips, np.array(-np.inf, dtype=dtype), vals)
|
|
|
|
|
|
|
|
return _cast_to_shape(np.asarray(vals, dtype=dtype), shape, dtype)
|
|
|
|
|
|
|
|
return rand
|
|
|
|
|
|
|
|
def rand_some_nan(rng):
|
|
|
|
"""Return a random sampler that produces nans in floating types."""
|
|
|
|
base_rand = rand_default(rng)
|
|
|
|
|
|
|
|
def rand(shape, dtype):
|
|
|
|
"""The random sampler function."""
|
|
|
|
if _dtypes.issubdtype(dtype, np.complexfloating):
|
|
|
|
base_dtype = np.real(np.array(0, dtype=dtype)).dtype
|
|
|
|
out = (rand(shape, base_dtype) +
|
|
|
|
np.array(1j, dtype) * rand(shape, base_dtype))
|
|
|
|
return _cast_to_shape(out, shape, dtype)
|
|
|
|
|
|
|
|
if not _dtypes.issubdtype(dtype, np.floating):
|
|
|
|
# only float types have inf
|
|
|
|
return base_rand(shape, dtype)
|
|
|
|
|
|
|
|
dims = _dims_of_shape(shape)
|
|
|
|
r = rng.rand(*dims)
|
|
|
|
nan_flips = r < 0.1
|
|
|
|
neg_nan_flips = r < 0.05
|
|
|
|
|
|
|
|
vals = base_rand(shape, dtype)
|
|
|
|
vals = np.where(nan_flips, np.array(np.nan, dtype=dtype), vals)
|
|
|
|
vals = np.where(neg_nan_flips, np.array(-np.nan, dtype=dtype), vals)
|
|
|
|
|
|
|
|
return _cast_to_shape(np.asarray(vals, dtype=dtype), shape, dtype)
|
|
|
|
|
|
|
|
return rand
|
|
|
|
|
|
|
|
def rand_some_inf_and_nan(rng):
|
|
|
|
"""Return a random sampler that produces infinities in floating types."""
|
|
|
|
base_rand = rand_default(rng)
|
|
|
|
|
2021-10-04 17:54:18 -07:00
|
|
|
# TODO: Complex numbers are not correctly tested
|
|
|
|
# If blocks should be switched in order, and relevant tests should be fixed
|
2021-09-24 07:02:08 -07:00
|
|
|
def rand(shape, dtype):
|
|
|
|
"""The random sampler function."""
|
|
|
|
if not _dtypes.issubdtype(dtype, np.floating):
|
|
|
|
# only float types have inf
|
|
|
|
return base_rand(shape, dtype)
|
|
|
|
|
|
|
|
if _dtypes.issubdtype(dtype, np.complexfloating):
|
|
|
|
base_dtype = np.real(np.array(0, dtype=dtype)).dtype
|
|
|
|
out = (rand(shape, base_dtype) +
|
|
|
|
np.array(1j, dtype) * rand(shape, base_dtype))
|
|
|
|
return _cast_to_shape(out, shape, dtype)
|
|
|
|
|
|
|
|
dims = _dims_of_shape(shape)
|
|
|
|
posinf_flips = rng.rand(*dims) < 0.1
|
|
|
|
neginf_flips = rng.rand(*dims) < 0.1
|
|
|
|
nan_flips = rng.rand(*dims) < 0.1
|
|
|
|
|
|
|
|
vals = base_rand(shape, dtype)
|
|
|
|
vals = np.where(posinf_flips, np.array(np.inf, dtype=dtype), vals)
|
|
|
|
vals = np.where(neginf_flips, np.array(-np.inf, dtype=dtype), vals)
|
|
|
|
vals = np.where(nan_flips, np.array(np.nan, dtype=dtype), vals)
|
|
|
|
|
|
|
|
return _cast_to_shape(np.asarray(vals, dtype=dtype), shape, dtype)
|
|
|
|
|
|
|
|
return rand
|
|
|
|
|
|
|
|
# TODO(mattjj): doesn't handle complex types
|
|
|
|
def rand_some_zero(rng):
|
|
|
|
"""Return a random sampler that produces some zeros."""
|
|
|
|
base_rand = rand_default(rng)
|
|
|
|
|
|
|
|
def rand(shape, dtype):
|
|
|
|
"""The random sampler function."""
|
|
|
|
dims = _dims_of_shape(shape)
|
|
|
|
zeros = rng.rand(*dims) < 0.5
|
|
|
|
|
|
|
|
vals = base_rand(shape, dtype)
|
|
|
|
vals = np.where(zeros, np.array(0, dtype=dtype), vals)
|
|
|
|
|
|
|
|
return _cast_to_shape(np.asarray(vals, dtype=dtype), shape, dtype)
|
|
|
|
|
|
|
|
return rand
|
|
|
|
|
|
|
|
|
|
|
|
def rand_int(rng, low=0, high=None):
|
|
|
|
def fn(shape, dtype):
|
|
|
|
nonlocal high
|
|
|
|
if low == 0 and high is None:
|
|
|
|
if np.issubdtype(dtype, np.integer):
|
|
|
|
high = np.iinfo(dtype).max
|
|
|
|
else:
|
|
|
|
raise ValueError("rand_int requires an explicit `high` value for "
|
|
|
|
"non-integer types.")
|
|
|
|
return rng.randint(low, high=high, size=shape, dtype=dtype)
|
|
|
|
return fn
|
|
|
|
|
|
|
|
def rand_unique_int(rng, high=None):
|
|
|
|
def fn(shape, dtype):
|
|
|
|
return rng.choice(np.arange(high or prod(shape), dtype=dtype),
|
|
|
|
size=shape, replace=False)
|
|
|
|
return fn
|
|
|
|
|
|
|
|
def rand_bool(rng):
|
|
|
|
def generator(shape, dtype):
|
|
|
|
return _cast_to_shape(rng.rand(*_dims_of_shape(shape)) < 0.5, shape, dtype)
|
|
|
|
return generator
|
|
|
|
|
|
|
|
def check_raises(thunk, err_type, msg):
|
|
|
|
try:
|
|
|
|
thunk()
|
|
|
|
assert False
|
|
|
|
except err_type as e:
|
2022-05-12 19:13:00 +01:00
|
|
|
assert str(e).startswith(msg), f"\n{e}\n\n{msg}\n"
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
def check_raises_regexp(thunk, err_type, pattern):
|
|
|
|
try:
|
|
|
|
thunk()
|
|
|
|
assert False
|
|
|
|
except err_type as e:
|
2022-05-12 19:13:00 +01:00
|
|
|
assert re.match(pattern, str(e)), f"{e}\n\n{pattern}\n"
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
|
|
|
|
def iter_eqns(jaxpr):
|
|
|
|
# TODO(necula): why doesn't this search in params?
|
2022-05-12 19:13:00 +01:00
|
|
|
yield from jaxpr.eqns
|
2021-09-24 07:02:08 -07:00
|
|
|
for subjaxpr in core.subjaxprs(jaxpr):
|
|
|
|
yield from iter_eqns(subjaxpr)
|
|
|
|
|
|
|
|
def assert_dot_precision(expected_precision, fun, *args):
|
|
|
|
jaxpr = api.make_jaxpr(fun)(*args)
|
|
|
|
precisions = [eqn.params['precision'] for eqn in iter_eqns(jaxpr.jaxpr)
|
|
|
|
if eqn.primitive == lax.dot_general_p]
|
|
|
|
for precision in precisions:
|
2022-05-12 19:13:00 +01:00
|
|
|
msg = f"Unexpected precision: {expected_precision} != {precision}"
|
2021-09-24 07:02:08 -07:00
|
|
|
if isinstance(precision, tuple):
|
|
|
|
assert precision[0] == expected_precision, msg
|
|
|
|
assert precision[1] == expected_precision, msg
|
|
|
|
else:
|
|
|
|
assert precision == expected_precision, msg
|
|
|
|
|
|
|
|
|
|
|
|
_CACHED_INDICES: Dict[int, Sequence[int]] = {}
|
|
|
|
|
|
|
|
def cases_from_list(xs):
|
|
|
|
xs = list(xs)
|
|
|
|
n = len(xs)
|
|
|
|
k = min(n, FLAGS.num_generated_cases)
|
|
|
|
# Random sampling for every parameterized test is expensive. Do it once and
|
|
|
|
# cache the result.
|
|
|
|
indices = _CACHED_INDICES.get(n)
|
|
|
|
if indices is None:
|
|
|
|
rng = npr.RandomState(42)
|
|
|
|
_CACHED_INDICES[n] = indices = rng.permutation(n)
|
|
|
|
return [xs[i] for i in indices[:k]]
|
|
|
|
|
|
|
|
def cases_from_gens(*gens):
|
|
|
|
sizes = [1, 3, 10]
|
|
|
|
cases_per_size = int(FLAGS.num_generated_cases / len(sizes)) + 1
|
|
|
|
for size in sizes:
|
|
|
|
for i in range(cases_per_size):
|
2022-05-12 19:13:00 +01:00
|
|
|
yield (f'_{size}_{i}',) + tuple(gen(size) for gen in gens)
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
def named_cases_from_sampler(gen):
|
|
|
|
seen = set()
|
|
|
|
retries = 0
|
|
|
|
rng = npr.RandomState(42)
|
|
|
|
def choose_one(x):
|
|
|
|
if not isinstance(x, (list, tuple)):
|
|
|
|
x = list(x)
|
|
|
|
return [x[rng.randint(len(x))]]
|
|
|
|
while (len(seen) < FLAGS.num_generated_cases and
|
|
|
|
retries < FLAGS.max_cases_sampling_retries):
|
|
|
|
retries += 1
|
|
|
|
cases = list(gen(choose_one))
|
|
|
|
if not cases:
|
|
|
|
continue
|
|
|
|
if len(cases) > 1:
|
|
|
|
raise RuntimeError("Generator is expected to only return a single case when sampling")
|
|
|
|
case = cases[0]
|
|
|
|
if case["testcase_name"] in seen:
|
|
|
|
continue
|
|
|
|
retries = 0
|
|
|
|
seen.add(case["testcase_name"])
|
|
|
|
yield case
|
|
|
|
|
|
|
|
|
|
|
|
class JaxTestLoader(absltest.TestLoader):
|
|
|
|
def getTestCaseNames(self, testCaseClass):
|
|
|
|
names = super().getTestCaseNames(testCaseClass)
|
|
|
|
if FLAGS.test_targets:
|
|
|
|
pattern = re.compile(FLAGS.test_targets)
|
|
|
|
names = [name for name in names
|
|
|
|
if pattern.search(f"{testCaseClass.__name__}.{name}")]
|
|
|
|
if FLAGS.exclude_test_targets:
|
|
|
|
pattern = re.compile(FLAGS.exclude_test_targets)
|
|
|
|
names = [name for name in names
|
|
|
|
if not pattern.search(f"{testCaseClass.__name__}.{name}")]
|
|
|
|
return names
|
|
|
|
|
|
|
|
|
|
|
|
def with_config(**kwds):
|
|
|
|
"""Test case decorator for subclasses of JaxTestCase"""
|
|
|
|
def decorator(cls):
|
|
|
|
assert inspect.isclass(cls) and issubclass(cls, JaxTestCase), "@with_config can only wrap JaxTestCase class definitions."
|
|
|
|
cls._default_config = {**JaxTestCase._default_config, **kwds}
|
|
|
|
return cls
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
|
|
|
class JaxTestCase(parameterized.TestCase):
|
|
|
|
"""Base class for JAX tests including numerical checks and boilerplate."""
|
2022-02-14 09:22:05 -08:00
|
|
|
_default_config = {
|
|
|
|
'jax_enable_checks': True,
|
2022-02-15 02:42:30 -08:00
|
|
|
'jax_numpy_rank_promotion': 'raise',
|
|
|
|
'jax_traceback_filtering': 'off',
|
2022-02-14 09:22:05 -08:00
|
|
|
}
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
# TODO(mattjj): this obscures the error messages from failures, figure out how
|
|
|
|
# to re-enable it
|
|
|
|
# def tearDown(self) -> None:
|
|
|
|
# assert core.reset_trace_state()
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
super().setUp()
|
|
|
|
self._original_config = {}
|
|
|
|
for key, value in self._default_config.items():
|
2022-01-26 15:49:59 -08:00
|
|
|
self._original_config[key] = config._read(key)
|
2021-09-24 07:02:08 -07:00
|
|
|
config.update(key, value)
|
|
|
|
|
|
|
|
# We use the adler32 hash for two reasons.
|
|
|
|
# a) it is deterministic run to run, unlike hash() which is randomized.
|
|
|
|
# b) it returns values in int32 range, which RandomState requires.
|
|
|
|
self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode()))
|
|
|
|
|
|
|
|
def tearDown(self):
|
|
|
|
for key, value in self._original_config.items():
|
|
|
|
config.update(key, value)
|
|
|
|
super().tearDown()
|
|
|
|
|
|
|
|
def rng(self):
|
|
|
|
return self._rng
|
|
|
|
|
|
|
|
def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg=''):
|
|
|
|
"""Assert that x and y arrays are exactly equal."""
|
|
|
|
if check_dtypes:
|
|
|
|
self.assertDtypesMatch(x, y)
|
|
|
|
# Work around https://github.com/numpy/numpy/issues/18992
|
|
|
|
with np.errstate(over='ignore'):
|
|
|
|
np.testing.assert_array_equal(x, y, err_msg=err_msg)
|
|
|
|
|
|
|
|
def assertArraysAllClose(self, x, y, *, check_dtypes=True, atol=None,
|
|
|
|
rtol=None, err_msg=''):
|
|
|
|
"""Assert that x and y are close (up to numerical tolerances)."""
|
|
|
|
self.assertEqual(x.shape, y.shape)
|
|
|
|
atol = max(tolerance(_dtype(x), atol), tolerance(_dtype(y), atol))
|
|
|
|
rtol = max(tolerance(_dtype(x), rtol), tolerance(_dtype(y), rtol))
|
|
|
|
|
|
|
|
_assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
|
|
|
|
|
|
|
|
if check_dtypes:
|
|
|
|
self.assertDtypesMatch(x, y)
|
|
|
|
|
|
|
|
def assertDtypesMatch(self, x, y, *, canonicalize_dtypes=True):
|
|
|
|
if not config.x64_enabled and canonicalize_dtypes:
|
|
|
|
self.assertEqual(_dtypes.canonicalize_dtype(_dtype(x)),
|
|
|
|
_dtypes.canonicalize_dtype(_dtype(y)))
|
|
|
|
else:
|
|
|
|
self.assertEqual(_dtype(x), _dtype(y))
|
|
|
|
|
|
|
|
def assertAllClose(self, x, y, *, check_dtypes=True, atol=None, rtol=None,
|
|
|
|
canonicalize_dtypes=True, err_msg=''):
|
|
|
|
"""Assert that x and y, either arrays or nested tuples/lists, are close."""
|
|
|
|
if isinstance(x, dict):
|
|
|
|
self.assertIsInstance(y, dict)
|
|
|
|
self.assertEqual(set(x.keys()), set(y.keys()))
|
|
|
|
for k in x.keys():
|
|
|
|
self.assertAllClose(x[k], y[k], check_dtypes=check_dtypes, atol=atol,
|
|
|
|
rtol=rtol, canonicalize_dtypes=canonicalize_dtypes,
|
|
|
|
err_msg=err_msg)
|
|
|
|
elif is_sequence(x) and not hasattr(x, '__array__'):
|
|
|
|
self.assertTrue(is_sequence(y) and not hasattr(y, '__array__'))
|
|
|
|
self.assertEqual(len(x), len(y))
|
|
|
|
for x_elt, y_elt in zip(x, y):
|
|
|
|
self.assertAllClose(x_elt, y_elt, check_dtypes=check_dtypes, atol=atol,
|
|
|
|
rtol=rtol, canonicalize_dtypes=canonicalize_dtypes,
|
|
|
|
err_msg=err_msg)
|
|
|
|
elif hasattr(x, '__array__') or np.isscalar(x):
|
|
|
|
self.assertTrue(hasattr(y, '__array__') or np.isscalar(y))
|
|
|
|
if check_dtypes:
|
|
|
|
self.assertDtypesMatch(x, y, canonicalize_dtypes=canonicalize_dtypes)
|
|
|
|
x = np.asarray(x)
|
|
|
|
y = np.asarray(y)
|
|
|
|
self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol,
|
|
|
|
err_msg=err_msg)
|
|
|
|
elif x == y:
|
|
|
|
return
|
|
|
|
else:
|
|
|
|
raise TypeError((type(x), type(y)))
|
|
|
|
|
|
|
|
def assertMultiLineStrippedEqual(self, expected, what):
|
|
|
|
"""Asserts two strings are equal, after dedenting and stripping each line."""
|
|
|
|
expected = textwrap.dedent(expected)
|
|
|
|
what = textwrap.dedent(what)
|
|
|
|
ignore_space_re = re.compile(r'\s*\n\s*')
|
|
|
|
expected_clean = re.sub(ignore_space_re, '\n', expected.strip())
|
|
|
|
what_clean = re.sub(ignore_space_re, '\n', what.strip())
|
2021-11-24 12:58:16 +02:00
|
|
|
if what_clean != expected_clean:
|
|
|
|
# Print it so we can copy-and-paste it into the test
|
|
|
|
print(f"Found\n{what}\n")
|
2021-09-24 07:02:08 -07:00
|
|
|
self.assertMultiLineEqual(expected_clean, what_clean,
|
2022-05-12 19:13:00 +01:00
|
|
|
msg=f"Found\n{what}\nExpecting\n{expected}")
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
def _CompileAndCheck(self, fun, args_maker, *, check_dtypes=True,
|
|
|
|
rtol=None, atol=None, check_cache_misses=True):
|
|
|
|
"""Helper method for running JAX compilation and allclose assertions."""
|
|
|
|
args = args_maker()
|
|
|
|
|
|
|
|
def wrapped_fun(*args):
|
|
|
|
self.assertTrue(python_should_be_executing)
|
|
|
|
return fun(*args)
|
|
|
|
|
|
|
|
python_should_be_executing = True
|
|
|
|
python_ans = fun(*args)
|
|
|
|
|
|
|
|
python_shapes = tree_map(lambda x: np.shape(x), python_ans)
|
|
|
|
np_shapes = tree_map(lambda x: np.shape(np.asarray(x)), python_ans)
|
|
|
|
self.assertEqual(python_shapes, np_shapes)
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
cache_misses = dispatch.xla_primitive_callable.cache_info().misses
|
2021-09-24 07:02:08 -07:00
|
|
|
python_ans = fun(*args)
|
|
|
|
if check_cache_misses:
|
|
|
|
self.assertEqual(
|
2021-11-22 08:22:10 -08:00
|
|
|
cache_misses, dispatch.xla_primitive_callable.cache_info().misses,
|
2021-09-24 07:02:08 -07:00
|
|
|
"Compilation detected during second call of {} in op-by-op "
|
|
|
|
"mode.".format(fun))
|
|
|
|
|
|
|
|
cfun = api.jit(wrapped_fun)
|
|
|
|
python_should_be_executing = True
|
|
|
|
monitored_ans = cfun(*args)
|
|
|
|
|
|
|
|
python_should_be_executing = False
|
|
|
|
compiled_ans = cfun(*args)
|
|
|
|
|
|
|
|
self.assertAllClose(python_ans, monitored_ans, check_dtypes=check_dtypes,
|
|
|
|
atol=atol, rtol=rtol)
|
|
|
|
self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes,
|
|
|
|
atol=atol, rtol=rtol)
|
|
|
|
|
|
|
|
args = args_maker()
|
|
|
|
|
|
|
|
python_should_be_executing = True
|
|
|
|
python_ans = fun(*args)
|
|
|
|
|
|
|
|
python_should_be_executing = False
|
|
|
|
compiled_ans = cfun(*args)
|
|
|
|
|
|
|
|
self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes,
|
|
|
|
atol=atol, rtol=rtol)
|
|
|
|
|
|
|
|
def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker,
|
|
|
|
check_dtypes=True, tol=None, atol=None, rtol=None,
|
|
|
|
canonicalize_dtypes=True):
|
|
|
|
args = args_maker()
|
|
|
|
lax_ans = lax_op(*args)
|
|
|
|
numpy_ans = numpy_reference_op(*args)
|
|
|
|
self.assertAllClose(numpy_ans, lax_ans, check_dtypes=check_dtypes,
|
|
|
|
atol=atol or tol, rtol=rtol or tol,
|
|
|
|
canonicalize_dtypes=canonicalize_dtypes)
|
|
|
|
|
2022-05-06 16:28:24 +01:00
|
|
|
_CPP_JIT_IMPLEMENTATION = functools.partial(api._jit, True)
|
|
|
|
_CPP_JIT_IMPLEMENTATION._name = "cpp"
|
|
|
|
_PYTHON_JIT_IMPLEMENTATION = functools.partial(api._jit, False)
|
|
|
|
_PYTHON_JIT_IMPLEMENTATION._name = "python"
|
|
|
|
_NOOP_JIT_IMPLEMENTATION = lambda x, *args, **kwargs: x
|
|
|
|
_NOOP_JIT_IMPLEMENTATION._name = "noop"
|
|
|
|
|
|
|
|
JIT_IMPLEMENTATION = (
|
|
|
|
_CPP_JIT_IMPLEMENTATION,
|
|
|
|
_PYTHON_JIT_IMPLEMENTATION,
|
|
|
|
_NOOP_JIT_IMPLEMENTATION,
|
|
|
|
)
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
class BufferDonationTestCase(JaxTestCase):
|
|
|
|
assertDeleted = lambda self, x: self._assertDeleted(x, True)
|
|
|
|
assertNotDeleted = lambda self, x: self._assertDeleted(x, False)
|
|
|
|
|
|
|
|
def _assertDeleted(self, x, deleted):
|
|
|
|
if hasattr(x, "device_buffer"):
|
|
|
|
self.assertEqual(x.device_buffer.is_deleted(), deleted)
|
|
|
|
else:
|
|
|
|
for buffer in x.device_buffers:
|
|
|
|
self.assertEqual(buffer.is_deleted(), deleted)
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def ignore_warning(**kw):
|
|
|
|
with warnings.catch_warnings():
|
|
|
|
warnings.filterwarnings("ignore", **kw)
|
|
|
|
yield
|
|
|
|
|
|
|
|
# -------------------- Mesh parametrization helpers --------------------
|
|
|
|
|
|
|
|
MeshSpec = List[Tuple[str, int]]
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
|
|
|
|
"""Test utility for setting up meshes given mesh data from `schedules`."""
|
|
|
|
# This is similar to the `with_mesh` function above, but isn't a decorator.
|
|
|
|
axis_names, shape = unzip2(named_shape)
|
|
|
|
size = prod(shape)
|
|
|
|
local_devices = list(api.local_devices())
|
|
|
|
if len(local_devices) < size:
|
|
|
|
raise unittest.SkipTest(f"Test requires {size} local devices")
|
|
|
|
mesh_devices = np.array(local_devices[:size]).reshape(shape)
|
2022-02-23 10:46:27 -08:00
|
|
|
with Mesh(mesh_devices, axis_names):
|
2021-09-24 07:02:08 -07:00
|
|
|
yield
|
|
|
|
|
|
|
|
def with_mesh_from_kwargs(f):
|
|
|
|
return lambda *args, **kwargs: with_mesh(kwargs['mesh'])(f)(*args, **kwargs)
|
|
|
|
|
|
|
|
def with_and_without_mesh(f):
|
|
|
|
return parameterized.named_parameters(
|
|
|
|
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
|
|
|
|
for name, mesh, axis_resources in (
|
|
|
|
('', (), ()),
|
|
|
|
('Mesh', (('x', 2),), (('i', 'x'),))
|
|
|
|
))(with_mesh_from_kwargs(f))
|
|
|
|
|
Add reverse-mode AD support for pjit
This is a somewhat big patch, because the transposition process turns out to be
quite difficult. The biggest issue appears when we do partial evaluation and we have
to add a whole bunch of intermediate values as outputs of the primal computation,
but we don't have any partition specs for them!
A simple workaround would be to mark all of them as replicated, but that would
likely tank performance which is why we didn't go with that option. Instead, we use
a newly added XLA option called `allow_spmd_sharding_propagation_to_output` to compile
a throwaway executable that lets us query output sharding that XLA considers convenient
for the computation.
However, there's one more difficulty: XLA's `OpSharding` is much less constrained
than our `PartitionSpec`s. In particular, while `PartitionSpec`s can only represent
"block permutations" of devices (with blocks deliniated by mesh axes), `OpSharding`
allows arbitrary assignment (permutation) of tensor chunks to devices. This means that
not every `OpSharding` has a corresponding `PartitionSpec`, but I did implement a
(somewhat involved) procedure that should recover one whenever it exists.
Unfortunately this makes our support for reverse-mode AD partial, because we might
be unable to handle `OpSharding` returned by XLA. But this will only happen if XLA
actually comes up with sharding specifications on its own. If it merely propagates
the sharding obtained from `PartitionSpec`s into the middle of the computation, then
we should be good. In any case, if we end up seeing failures in this path, we should
consider relaxing `PartitionSpec`s, but that would be a pretty large change, so I decided
to avoid it unless there's no other way.
PiperOrigin-RevId: 399680306
2021-09-29 07:19:28 -07:00
|
|
|
old_spmd_lowering_flag = None
|
2021-09-24 07:02:08 -07:00
|
|
|
def set_spmd_lowering_flag(val: bool):
|
|
|
|
global old_spmd_lowering_flag
|
|
|
|
old_spmd_lowering_flag = config.experimental_xmap_spmd_lowering
|
|
|
|
config.update('experimental_xmap_spmd_lowering', val)
|
|
|
|
|
|
|
|
def restore_spmd_lowering_flag():
|
Add reverse-mode AD support for pjit
This is a somewhat big patch, because the transposition process turns out to be
quite difficult. The biggest issue appears when we do partial evaluation and we have
to add a whole bunch of intermediate values as outputs of the primal computation,
but we don't have any partition specs for them!
A simple workaround would be to mark all of them as replicated, but that would
likely tank performance which is why we didn't go with that option. Instead, we use
a newly added XLA option called `allow_spmd_sharding_propagation_to_output` to compile
a throwaway executable that lets us query output sharding that XLA considers convenient
for the computation.
However, there's one more difficulty: XLA's `OpSharding` is much less constrained
than our `PartitionSpec`s. In particular, while `PartitionSpec`s can only represent
"block permutations" of devices (with blocks deliniated by mesh axes), `OpSharding`
allows arbitrary assignment (permutation) of tensor chunks to devices. This means that
not every `OpSharding` has a corresponding `PartitionSpec`, but I did implement a
(somewhat involved) procedure that should recover one whenever it exists.
Unfortunately this makes our support for reverse-mode AD partial, because we might
be unable to handle `OpSharding` returned by XLA. But this will only happen if XLA
actually comes up with sharding specifications on its own. If it merely propagates
the sharding obtained from `PartitionSpec`s into the middle of the computation, then
we should be good. In any case, if we end up seeing failures in this path, we should
consider relaxing `PartitionSpec`s, but that would be a pretty large change, so I decided
to avoid it unless there's no other way.
PiperOrigin-RevId: 399680306
2021-09-29 07:19:28 -07:00
|
|
|
if old_spmd_lowering_flag is None: return
|
2021-09-24 07:02:08 -07:00
|
|
|
config.update('experimental_xmap_spmd_lowering', old_spmd_lowering_flag)
|
|
|
|
|
2022-02-04 11:16:54 -08:00
|
|
|
old_spmd_manual_lowering_flag = None
|
|
|
|
def set_spmd_manual_lowering_flag(val: bool):
|
|
|
|
global old_spmd_manual_lowering_flag
|
|
|
|
old_spmd_manual_lowering_flag = config.experimental_xmap_spmd_lowering_manual
|
|
|
|
config.update('experimental_xmap_spmd_lowering_manual', val)
|
|
|
|
|
|
|
|
def restore_spmd_manual_lowering_flag():
|
|
|
|
if old_spmd_manual_lowering_flag is None: return
|
|
|
|
config.update('experimental_xmap_spmd_lowering_manual', old_spmd_manual_lowering_flag)
|
|
|
|
|
2022-01-11 15:42:31 -08:00
|
|
|
def create_global_mesh(mesh_shape, axis_names):
|
|
|
|
size = prod(mesh_shape)
|
|
|
|
if len(api.devices()) < size:
|
|
|
|
raise unittest.SkipTest(f"Test requires {size} global devices.")
|
|
|
|
devices = sorted(api.devices(), key=lambda d: d.id)
|
|
|
|
mesh_devices = np.array(devices[:size]).reshape(mesh_shape)
|
|
|
|
global_mesh = Mesh(mesh_devices, axis_names)
|
|
|
|
return global_mesh
|
|
|
|
|
|
|
|
|
2021-09-24 07:02:08 -07:00
|
|
|
class _cached_property:
|
|
|
|
null = object()
|
|
|
|
|
|
|
|
def __init__(self, method):
|
|
|
|
self._method = method
|
|
|
|
self._value = self.null
|
|
|
|
|
|
|
|
def __get__(self, obj, cls):
|
|
|
|
if self._value is self.null:
|
|
|
|
self._value = self._method(obj)
|
|
|
|
return self._value
|
|
|
|
|
|
|
|
|
|
|
|
class _LazyDtypes:
|
|
|
|
"""A class that unifies lists of supported dtypes.
|
|
|
|
|
|
|
|
These could be module-level constants, but device_under_test() is not always
|
|
|
|
known at import time, so we need to define these lists lazily.
|
|
|
|
"""
|
|
|
|
def supported(self, dtypes):
|
|
|
|
supported = supported_dtypes()
|
|
|
|
return type(dtypes)(d for d in dtypes if d in supported)
|
|
|
|
|
|
|
|
@_cached_property
|
|
|
|
def floating(self):
|
|
|
|
return self.supported([np.float32, np.float64])
|
|
|
|
|
|
|
|
@_cached_property
|
|
|
|
def all_floating(self):
|
|
|
|
return self.supported([_dtypes.bfloat16, np.float16, np.float32, np.float64])
|
|
|
|
|
|
|
|
@_cached_property
|
|
|
|
def integer(self):
|
|
|
|
return self.supported([np.int32, np.int64])
|
|
|
|
|
|
|
|
@_cached_property
|
|
|
|
def all_integer(self):
|
|
|
|
return self.supported([np.int8, np.int16, np.int32, np.int64])
|
|
|
|
|
|
|
|
@_cached_property
|
|
|
|
def unsigned(self):
|
|
|
|
return self.supported([np.uint32, np.uint64])
|
|
|
|
|
|
|
|
@_cached_property
|
|
|
|
def all_unsigned(self):
|
|
|
|
return self.supported([np.uint8, np.uint16, np.uint32, np.uint64])
|
|
|
|
|
|
|
|
@_cached_property
|
|
|
|
def complex(self):
|
|
|
|
return self.supported([np.complex64, np.complex128])
|
|
|
|
|
|
|
|
@_cached_property
|
|
|
|
def boolean(self):
|
|
|
|
return self.supported([np.bool_])
|
|
|
|
|
|
|
|
@_cached_property
|
|
|
|
def inexact(self):
|
|
|
|
return self.floating + self.complex
|
|
|
|
|
|
|
|
@_cached_property
|
|
|
|
def all_inexact(self):
|
|
|
|
return self.all_floating + self.complex
|
|
|
|
|
|
|
|
@_cached_property
|
|
|
|
def numeric(self):
|
|
|
|
return self.floating + self.integer + self.unsigned + self.complex
|
|
|
|
|
|
|
|
@_cached_property
|
|
|
|
def all(self):
|
|
|
|
return (self.all_floating + self.all_integer + self.all_unsigned +
|
|
|
|
self.complex + self.boolean)
|
|
|
|
|
|
|
|
|
|
|
|
dtypes = _LazyDtypes()
|
2022-04-04 14:39:43 -07:00
|
|
|
|
|
|
|
|
|
|
|
class DeprecatedJaxTestCase(JaxTestCase):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
warnings.warn(textwrap.dedent("""\
|
|
|
|
jax.test_util.JaxTestCase is deprecated as of jax version 0.3.1:
|
|
|
|
The suggested replacement is to use parametrized.TestCase directly.
|
|
|
|
For tests that rely on custom asserts such as JaxTestCase.assertAllClose(),
|
|
|
|
the suggested replacement is to use standard numpy testing utilities such
|
|
|
|
as np.testing.assert_allclose(), which work directly with JAX arrays."""),
|
|
|
|
category=DeprecationWarning)
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
class DeprecatedJaxTestLoader(JaxTestLoader):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
warnings.warn(
|
|
|
|
"jax.test_util.JaxTestLoader is deprecated as of jax version 0.3.1. Use absltest.TestLoader directly.",
|
|
|
|
category=DeprecationWarning)
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
class DeprecatedBufferDonationTestCase(BufferDonationTestCase):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
warnings.warn(textwrap.dedent("""\
|
|
|
|
jax.test_util.JaxTestCase is deprecated as of jax version 0.3.1:
|
|
|
|
The suggested replacement is to use parametrized.TestCase directly.
|
|
|
|
For tests that rely on custom asserts such as JaxTestCase.assertAllClose(),
|
|
|
|
the suggested replacement is to use standard numpy testing utilities such
|
|
|
|
as np.testing.assert_allclose(), which work directly with JAX arrays."""),
|
|
|
|
category=DeprecationWarning)
|
|
|
|
super().__init__(*args, **kwargs)
|