# Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """See primitives_test docstring for how the Jax2TfLimitations are used.""" import itertools from typing import Any, Callable, Optional, Sequence, Union from jax import lax from jax import numpy as jnp from jax._src import test_util as jtu from jax._src import dtypes from jax.experimental.jax2tf.tests import primitive_harness import numpy as np DType = Any class Jax2TfLimitation(primitive_harness.Limitation): """Specific primitive limitations for jax2tf. See the primitive_test module docstring for details. """ def __init__( self, description: str, *, devices: Union[str, Sequence[str]] = ("cpu", "gpu", "tpu"), dtypes: Union[DType, Sequence[DType]] = (), enabled: bool = True, # jax2tf specific modes=("eager", "graph", "compiled"), skip_tf_run=False, expect_tf_error: bool = True, skip_comparison=False, custom_assert: Optional[Callable] = None, tol=None): """See the primitive_harness.Limitation common arguments. Args : modes: one of "eager", "graph", "compiled" skip_tf_run: if set will skip the TF execution. Use this sparingly, prefer `expect_tf_error`. Use only when the test cannot recover from the TF error. expect_tf_error: if set, then expect a TF error in the given mode when executing the result of jax2tf conversion. If not set, then the limitation must have a custom_assert or non-default tol. skip_comparison: skips the numeric comparison. tol: a tolerance to use for both atol and rtol. We will use the maximum tolerance over all the applicable limitations, irrespective of their order. custom_assert: if given, then execute as `custom_assert(tst, result_jax, result_tf, args=args, tol=tol, err_msg)` , where `tst` is the current TestCase instance, and args are the input arguments that the harness created. The `tol` is the maximum tolerance based on the applicable limitations. `err_msg` is passed to NumPy assert methods. `result_tf` is already converted to NumPy arrays. """ super().__init__( description, devices=devices, dtypes=dtypes, enabled=enabled) if isinstance(modes, str): modes = (modes,) assert all(m in ["eager", "graph", "compiled"] for m in modes), "Invalid modes: {modes}" self.modes = modes self.expect_tf_error = expect_tf_error self.skip_tf_run = skip_tf_run self.custom_assert = custom_assert self.tol = tol self.skip_comparison = skip_comparison def get_max_tolerance_limitation( self, limitations: Sequence["Jax2TfLimitation"] ) -> Optional["Jax2TfLimitation"]: """Pick the tolerance limitation that establishes the maximum tolerance.""" # TODO: it would be best if the limitations with tolerance are mutually exclusive # and we don't have to compute the maximum # TODO: we made this an instance method only so that we don't have to import # this module from tf_test.util. max_tol_lim = None for l in limitations: if l.tol is not None: if max_tol_lim is None or l.tol > max_tol_lim.tol: max_tol_lim = l return max_tol_lim def filter( # type: ignore[override] self, dtype: Optional[DType] = None, device: Optional[str] = None, mode: Optional[str] = None) -> bool: return ((mode is None or mode in self.modes) and super().filter(device=device, dtype=dtype)) @classmethod def limitations_for_harness( cls, harness: primitive_harness.Harness) -> Sequence["Jax2TfLimitation"]: group_method = getattr(cls, harness.group_name, None) if harness.group_name in cls.harness_groups_no_limitations: assert group_method is None, ( f"Harness group '{harness.group_name}' is both in " f"'harness_groups_no_limitations' and has a custom " f"Jax2TfLimitation.classmethod defined (see module docstring)") return [] else: assert group_method is not None, ( f"Harness group '{harness.group_name}' must be either part of " f"'harness_groups_no_limitations' or must have a custom " f"Jax2TfLimitation.classmethod defined (see module docstring)") limitations = group_method(harness) assert isinstance(limitations, (list, tuple)) return limitations # We keep here the explicit set of groups for which we don't have limitations harness_groups_no_limitations = { "abs", "add", "add_any", "and", "atan2", "bitcast_convert_type", "broadcast", "broadcast_in_dim", "cbrt", "ceil", "clamp", "concatenate", "cos", "cosh", "complex", "conj", "convert_element_type", "cummax", "cummin", "device_put", "dynamic_slice", "dynamic_update_slice", "exp", "eq", "floor", "gather", "ge", "gt", "imag", "iota", "is_finite", "le", "lt", "log", "mul", "ne", "neg", "not", "or", "pad", "population_count", "random_categorical", "random_split", "random_uniform", "random_randint", "reduce", "reduce_and", "reduce_prod", "reduce_or", "reduce_sum", "reduce_window_add", "reduce_window_mul", "reduce_window_min", "reduce_window_max", "real", "reshape", "rev", "rsqrt", "scatter_max", "scatter_min", "select_n", "select_and_scatter_add", "shift_left", "shift_right_logical", "shift_right_arithmetic", "sign", "sin", "sinh", "slice", "sqrt", "squeeze", "stop_gradient", "sub", "tie_in", "transpose", "xor", "zeros_like" } @classmethod def helper_get_trig_custom_limitation(cls, np_inverse): def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg): operand, = args tst.assertAllClose( operand, np_inverse(result_tf), atol=tol, rtol=tol, err_msg=err_msg) return custom_numeric( description="May return different but still correct results", dtypes=[np.complex64, np.complex128], custom_assert=custom_assert) @classmethod def acos(cls, harness: primitive_harness.Harness): return [ custom_numeric( dtypes=np.complex64, devices=("cpu", "gpu"), tol=1e-4, modes=("eager", "graph", "compiled")), custom_numeric( dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-13, modes=("eager", "graph", "compiled")), ] @classmethod def acosh(cls, harness: primitive_harness.Harness): return [ custom_numeric(dtypes=np.complex64, devices=("cpu", "gpu"), tol=1e-3), custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12), cls.helper_get_trig_custom_limitation(np.cosh) ] @classmethod def approx_max_k(cls, harness: primitive_harness.Harness): supported_dtypes = jtu.supported_dtypes() return Jax2TfLimitation( "eager is not supported in CPU or GPU.", dtypes=[t for t in [jnp.bfloat16, np.float16, np.float32] if t in supported_dtypes], devices=("cpu", "gpu", "tpu"), modes=("graph", "compiled")) @classmethod def argmax(cls, harness: primitive_harness.Harness): return [ Jax2TfLimitation( "different results when the input contains NaN and enable_xla=False", dtypes=jtu.dtypes.all_inexact, devices=("cpu", "gpu", "tpu"), modes=("eager", "graph", "compiled"), expect_tf_error=False, skip_comparison=True, enabled=("nan_" in harness.name and not harness.params["enable_xla"])), ] @classmethod def argmin(cls, harness: primitive_harness.Harness): return cls.argmax(harness) @classmethod def asin(cls, harness: primitive_harness.Harness): return [ custom_numeric(dtypes=np.complex64, devices=("cpu", "gpu"), tol=1e-4), custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12), cls.helper_get_trig_custom_limitation(np.sin) ] @classmethod def asinh(cls, harness: primitive_harness.Harness): return [ custom_numeric(dtypes=np.complex64, devices=("cpu", "gpu"), tol=1e-3), custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12), cls.helper_get_trig_custom_limitation(np.sinh) ] @classmethod def atan(cls, harness: primitive_harness.Harness): return [ custom_numeric(dtypes=np.complex64, devices=("cpu", "gpu"), tol=1e-5), custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12), cls.helper_get_trig_custom_limitation(np.tan) ] @classmethod def atanh(cls, harness: primitive_harness.Harness): return [ custom_numeric(dtypes=np.float64, tol=1e-14), custom_numeric(dtypes=np.complex64, tol=1e-3), custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12), cls.helper_get_trig_custom_limitation(np.tanh) ] @classmethod def bessel_i0e(cls, harness: primitive_harness.Harness): return [ missing_tf_kernel( dtypes=[dtypes.bfloat16], devices=("cpu", "gpu"), modes=("eager", "graph")) ] @classmethod def bessel_i1e(cls, harness: primitive_harness.Harness): return cls.bessel_i0e(harness) @classmethod def cholesky(cls, harness: primitive_harness.Harness): def custom_assert(tst, result_jax, result_tf, *, tol, err_msg, **_): # cholesky_p returns garbage in the strictly upper triangular part of the # result, so we can safely ignore that part. tst.assertAllClose( jnp.tril(result_jax), result_tf, atol=tol, err_msg=err_msg) return [ # TODO: very high tolerance custom_numeric( dtypes=[np.float32, np.complex64], tol=1e-2, devices=("cpu", "gpu"), modes=("eager", "graph", "compiled")), custom_numeric( dtypes=[np.float64, np.complex128], tol=1e-6, devices=("cpu", "gpu"), modes=("eager", "graph", "compiled")), custom_numeric( dtypes=[dtypes.bfloat16, np.float16], tol=5e-2, devices=("cpu", "gpu"), modes=("eager", "graph", "compiled")), custom_numeric( custom_assert=custom_assert, description=( "May return different values in the strictly upper triangular " "part of the result. This does not matter for correctness, " "because this part of the matrix is not considered in the result." ), modes=("eager", "graph", "compiled")) ] @classmethod def conv_general_dilated(cls, harness: primitive_harness.Harness): return [ # Even in compiled mode, for GPU we see a bit of discrepancy but # very minor. custom_numeric(dtypes=np.float32, devices="gpu", modes=("eager", "graph", "compiled"), tol=1e-5), custom_numeric(dtypes=np.float32, devices="cpu", modes=("eager", "graph", "compiled"), tol=1e-4), custom_numeric(description="higher numeric inaccuracy when `enable_xla=False`", modes=("eager", "graph", "compiled"), enabled=(not harness.params["enable_xla"]), tol=5e-3) ] @classmethod def cumprod(cls, harness): return [ # JAX uses a different lowering for CPU and GPU. custom_numeric( dtypes=(np.float16, jnp.bfloat16), devices=("cpu", "gpu"), modes=("eager", "graph", "compiled"), tol=5e-1) ] @classmethod def cumsum(cls, harness): return [ # JAX uses a different lowering for CPU and GPU. custom_numeric( dtypes=(np.float16, jnp.bfloat16), devices=("cpu", "gpu"), modes=("eager", "graph", "compiled"), tol=5e-1) ] @classmethod def custom_linear_solve(cls, harness: primitive_harness.Harness): return [ Jax2TfLimitation( "TODO: large numerical discrepancy", dtypes=np.float32, devices="tpu", expect_tf_error=False, skip_comparison=True), custom_numeric(dtypes=np.float32, devices="tpu", tol=0.01), custom_numeric(tol=1e-3), ] @classmethod def digamma(cls, harness: primitive_harness.Harness): dtype = harness.dtype # In the bfloat16 case, TF and lax both return NaN in undefined cases. # digamma is not defined at 0 and -1 def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg): # lax.digamma returns NaN and tf.math.digamma returns inf arg, = args special_cases = (arg == 0.) | (arg == -1.) nr_special_cases = np.count_nonzero(special_cases) tst.assertAllClose( np.full((nr_special_cases,), dtype(np.nan)), result_jax[special_cases], err_msg=err_msg) tst.assertAllClose( np.full((nr_special_cases,), dtype(np.inf)), result_tf[special_cases], err_msg=err_msg) # non-special cases are equal tst.assertAllClose( result_jax[~special_cases], result_tf[~special_cases], atol=tol, rtol=tol, err_msg=err_msg) return [ missing_tf_kernel( dtypes=[dtypes.bfloat16], devices=("cpu", "gpu"), modes=("eager", "graph")), custom_numeric(dtypes=np.float64, tol=1e-13), custom_numeric(dtypes=np.float32, devices=["cpu", "gpu"], tol=1e-3), custom_numeric( dtypes=dtypes.bfloat16, custom_assert=custom_assert, description=( "May return different results at singularity points 0 and -1." "JAX returns nan and TF returns inf")) ] @classmethod def div(cls, harness: primitive_harness.Harness): return [ Jax2TfLimitation( "TF integer division fails if divisor contains 0; JAX returns NaN", dtypes=[ np.uint8, np.int8, np.uint16, np.uint32, np.uint64, np.int8, np.int16, np.int32, np.int64 ], # Only the harnesses with "singularity" will have divide by 0 enabled=("singularity" in harness.name)) ] @classmethod def dot_general(cls, harness: primitive_harness.Harness): return [ missing_tf_kernel(dtypes=[np.bool_],), # TODO(b/189287598) Jax2TfLimitation( "Non-deterministic NaN for dot_general with preferred_element_type on GPU (b/189287598)", dtypes=[ jnp.bfloat16, np.float16, np.float32, np.complex64 ], devices="gpu", modes=("eager", "graph", "compiled"), enabled=(harness.params["preferred_element_type"] is not None), skip_comparison=True), # JAX performs float16 matmuls in float32 on CPU, so the JAX result # may be more precise. custom_numeric(dtypes=[np.float16], devices=["cpu"], tol=1e-2, modes=("eager", "graph", "compiled")), ] @classmethod def eig(cls, harness: primitive_harness.Harness): compute_left_eigenvectors = harness.params["compute_left_eigenvectors"] compute_right_eigenvectors = harness.params["compute_right_eigenvectors"] dtype = harness.dtype def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg): operand, = args inner_dimension = operand.shape[-1] # Test ported from tests.linlag_test.testEig # Norm, adjusted for dimension and type. def norm(x): norm = np.linalg.norm(x, axis=(-2, -1)) return norm / ((inner_dimension + 1) * jnp.finfo(dtype).eps) def check_right_eigenvectors(a, w, vr): tst.assertTrue( np.all(norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100)) def check_left_eigenvectors(a, w, vl): rank = len(a.shape) aH = jnp.conj(a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2])) wC = jnp.conj(w) check_right_eigenvectors(aH, wC, vl) def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array): tol = None # TODO(bchetioui): numerical discrepancies if dtype in [np.float32, np.complex64]: tol = 1e-4 elif dtype in [np.float64, np.complex128]: tol = 1e-13 closest_diff = min(abs(eigenvalues_array - eigenvalue)) tst.assertAllClose( closest_diff, np.array(0., closest_diff.dtype), atol=tol, err_msg=err_msg) all_w_jax, all_w_tf = result_jax[0], result_tf[0] for idx in itertools.product(*map(range, operand.shape[:-2])): w_jax, w_tf = all_w_jax[idx], all_w_tf[idx] for i in range(inner_dimension): check_eigenvalue_is_in_array(w_jax[i], w_tf) check_eigenvalue_is_in_array(w_tf[i], w_jax) if compute_left_eigenvectors: check_left_eigenvectors(operand, all_w_tf, result_tf[1]) if compute_right_eigenvectors: check_right_eigenvectors(operand, all_w_tf, result_tf[1 + compute_left_eigenvectors]) return [ # Eig does not work in JAX on gpu or tpu Jax2TfLimitation( "function not compilable", modes="compiled", devices="cpu"), Jax2TfLimitation( "TF Conversion of eig is not implemented when both compute_left_eigenvectors and compute_right_eigenvectors are set to True", enabled=(compute_left_eigenvectors and compute_right_eigenvectors)), custom_numeric( custom_assert=custom_assert, description=("May return the eigenvalues and eigenvectors in a " "potentially different order. The eigenvectors may " "also be different, but equally valid.")) ] @classmethod def eigh(cls, harness: primitive_harness.Harness): dtype = harness.dtype def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg): operand, = args inner_dimension = operand.shape[-1] def check_right_eigenvectors(a, w, vr): tol = 1e-16 # TODO(bchetioui): tolerance needs to be very high in compiled mode, # specifically for eigenvectors. if dtype == np.float64: tol = 2e-5 elif dtype == np.float32: tol = 1e-2 elif dtype in [dtypes.bfloat16, np.complex64]: tol = 1e-3 elif dtype == np.complex128: tol = 2e-5 tst.assertAllClose( np.matmul(a, vr) - w[..., None, :] * vr, np.zeros(a.shape, dtype=vr.dtype), atol=tol, # For bfloat16 the np.matmul returns float32 result. check_dtypes=False, err_msg=err_msg) def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array): tol = None if dtype in [dtypes.bfloat16, np.float32, np.complex64]: tol = 1e-3 elif dtype in [np.float64, np.complex128]: tol = 1e-5 closest_diff = min(abs(eigenvalues_array - eigenvalue)) tst.assertAllClose( closest_diff, np.array(0., closest_diff.dtype), atol=tol, err_msg=err_msg) _, all_w_jax = result_jax all_vr_tf, all_w_tf = result_tf for idx in itertools.product(*map(range, operand.shape[:-2])): w_jax, w_tf = all_w_jax[idx], all_w_tf[idx] for i in range(inner_dimension): check_eigenvalue_is_in_array(w_jax[i], w_tf) check_eigenvalue_is_in_array(w_tf[i], w_jax) check_right_eigenvectors(operand, all_w_tf, all_vr_tf) return [ missing_tf_kernel( dtypes=dtypes.bfloat16, devices="tpu", enabled=(harness.params["shape"] != (0, 0)), # This actually works! ), Jax2TfLimitation( "TODO: numeric discrepancies", dtypes=np.float16, devices="tpu", expect_tf_error=False, skip_comparison=True), custom_numeric( custom_assert=custom_assert, description=("May return the eigenvalues and eigenvectors in a " "potentially different order. The eigenvectors may " "also be different, but equally valid."), modes=("eager", "graph", "compiled")) ] @classmethod def erf(cls, harness: primitive_harness.Harness): return [ missing_tf_kernel( dtypes=[dtypes.bfloat16], devices=("cpu", "gpu"), modes=("eager", "graph")) ] @classmethod def erfc(cls, harness: primitive_harness.Harness): return [ missing_tf_kernel( dtypes=[dtypes.bfloat16], devices=("cpu", "gpu"), modes=("eager", "graph")) ] @classmethod def erf_inv(cls, harness: primitive_harness.Harness): # erf_inv is not defined for arg <= -1 or arg >= 1 def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg): # noqa: F811 arg, = args # for arg < -1 or arg > 1 # lax.erf_inv returns NaN; tf.math.erf_inv return +/- inf special_cases = (arg < -1.) | (arg > 1.) # non-special cases are equal tst.assertAllClose( result_jax[~special_cases], result_tf[~special_cases], atol=tol, rtol=tol, err_msg=err_msg) return [ missing_tf_kernel( dtypes=[dtypes.bfloat16, np.float16], devices=("cpu", "gpu"), modes=("eager", "graph")), custom_numeric(dtypes=[np.float32, np.float64], tol=1e-4), custom_numeric( dtypes=[np.float32, np.float64], custom_assert=custom_assert, description=( "May return different results at undefined points (< -1 or > 1):" " JAX returns `NaN` and TF returns `+inf` or `-inf`.")) ] @classmethod def expm1(cls, harness: primitive_harness.Harness): return [custom_numeric(dtypes=np.float64, tol=1e-5)] @classmethod def fft(cls, harness): return [ Jax2TfLimitation( "TF function not compileable", devices=("cpu", "gpu"), dtypes=[np.float64, np.complex128], modes="compiled"), # TODO: very high tolerance custom_numeric(tol=1e-3, modes=("eager", "graph", "compiled")), ] @classmethod def _pow_test_util(cls, harness: primitive_harness.Harness): def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg): # NaNs are mismatched, but assertAllClose will also behave weirdly for # complex numbers containing np.inf as one of their components. See # https://github.com/numpy/numpy/issues/15959 for more details. mask = ( np.isnan(result_jax) + np.isnan(result_tf) + np.isinf(result_jax) + np.isinf(result_tf)) tst.assertAllClose( result_jax[~mask], result_tf[~mask], rtol=tol, err_msg=err_msg) return [ custom_numeric( dtypes=[np.float32, np.complex64], devices=("cpu", "gpu"), tol=1e-3), custom_numeric( dtypes=[np.float64, np.complex128], devices=("cpu", "gpu"), tol=5e-5), custom_numeric( dtypes=[np.complex64, np.complex128], custom_assert=custom_assert, ) ] @classmethod def igamma(cls, harness: primitive_harness.Harness): dtype = harness.dtype # igamma is not defined when the first argument is <=0 def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg): arg1, arg2 = args # lax.igamma returns NaN when arg1 == arg2 == 0; tf.math.igamma returns 0 special_cases = (arg1 == 0.) & (arg2 == 0.) nr_special_cases = np.count_nonzero(special_cases) tst.assertAllClose( np.full((nr_special_cases,), np.nan, dtype=dtype), result_jax[special_cases]) tst.assertAllClose( np.full((nr_special_cases,), 0., dtype=dtype), result_tf[special_cases]) # non-special cases are equal tst.assertAllClose( result_jax[~special_cases], result_tf[~special_cases], atol=tol, rtol=tol, err_msg=err_msg) return [ missing_tf_kernel( dtypes=[dtypes.bfloat16, np.float16], devices=("cpu", "gpu"), modes=("eager", "graph")), custom_numeric( custom_assert=custom_assert, description=( "May return different results at undefined points " "(both arguments 0). JAX returns `NaN` and TF returns 0 or " "JAX returns 1 and TF returns `NaN`")) ] @classmethod def igammac(cls, harness: primitive_harness.Harness): dtype = harness.dtype # igammac is not defined when the first argument is <=0 def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg): # noqa: F811 arg1, arg2 = args # lax.igammac returns 1. when arg1 <= 0; tf.math.igammac returns NaN special_cases = (arg1 <= 0.) | (arg2 <= 0) nr_special_cases = np.count_nonzero(special_cases) tst.assertAllClose( np.full((nr_special_cases,), 1., dtype=dtype), result_jax[special_cases], err_msg=err_msg) tst.assertAllClose( np.full((nr_special_cases,), np.nan, dtype=dtype), result_tf[special_cases], err_msg=err_msg) # non-special cases are equal tst.assertAllClose( result_jax[~special_cases], result_tf[~special_cases], atol=tol, rtol=tol, err_msg=err_msg) return [ missing_tf_kernel( dtypes=[dtypes.bfloat16, np.float16], devices=("cpu", "gpu"), modes=("eager", "graph")), custom_numeric(dtypes=np.float64, tol=1e-9), custom_numeric(devices="gpu", tol=1e-3), custom_numeric( custom_assert=custom_assert, devices=("cpu", "gpu"), description=( "May return different results at undefined points " "(both arguments less or equal 0). JAX returns `NaN` and TF returns 0 or " "JAX returns 1 and TF returns `NaN`")), ] @classmethod def integer_pow(cls, harness: primitive_harness.Harness): y = harness.params["y"] return [ missing_tf_kernel( dtypes=[ np.int8, np.int16, np.uint8, np.uint16, np.uint32, np.uint64 ], modes="graph", enabled=(y not in [0, 1]), # These are special-cased devices=("cpu", "gpu")), # TODO: on TPU, for f16, we get different results with eager mode # than with compiled mode. Jax2TfLimitation( "Different overflow behavior. ", dtypes=[np.float16, jnp.bfloat16], devices="tpu", expect_tf_error=False, modes=("eager", "graph"), skip_comparison=True), Jax2TfLimitation( "Different overflow behavior for large exponents. ", dtypes=[ np.int8, np.int16, np.int32, np.int64, np.float16, jnp.bfloat16, np.float32, np.complex64, np.complex128 ], enabled=(abs(y) > 10), expect_tf_error=False, modes=("eager", "graph"), skip_comparison=True), ] + list(cls._pow_test_util(harness)) @classmethod def pow(cls, harness: primitive_harness.Harness): return cls._pow_test_util(harness) @classmethod def lgamma(cls, harness: primitive_harness.Harness): return [ missing_tf_kernel( dtypes=[dtypes.bfloat16], devices=("cpu", "gpu"), modes=("eager", "graph")), custom_numeric(dtypes=np.float64, tol=1e-11), custom_numeric(dtypes=np.float32, tol=1e-3) ] @classmethod def log1p(cls, harness: primitive_harness.Harness): return [ custom_numeric(dtypes=np.complex128, tol=3e-14), custom_numeric(dtypes=np.float64, tol=1e-10), custom_numeric(dtypes=np.float32, tol=1e-3) ] @classmethod def lu(cls, harness: primitive_harness.Harness): dtype = harness.dtype def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg): operand, = args lu, pivots, perm = result_tf batch_dims = operand.shape[:-2] m, n = operand.shape[-2], operand.shape[-1] def _make_permutation_matrix(perm): result = [] for idx in itertools.product(*map(range, operand.shape[:-1])): result += [0 if c != perm[idx] else 1 for c in range(m)] result = np.reshape(np.array(result, dtype=dtype), [*batch_dims, m, m]) return result k = min(m, n) l = jnp.tril(lu, -1)[..., :, :k] + jnp.eye(m, k, dtype=dtype) u = jnp.triu(lu)[..., :k, :] p_mat = _make_permutation_matrix(perm) tst.assertArraysEqual( lax.linalg.lu_pivots_to_permutation(pivots, m), perm) tst.assertAllClose( jnp.matmul(p_mat, operand), jnp.matmul(l, u), atol=tol, rtol=tol, err_msg=err_msg) return [ custom_numeric( dtypes=[np.float32, np.complex64], devices="tpu", tol=0.1), custom_numeric( dtypes=[np.float32, np.complex64], devices=("cpu", "gpu"), tol=1e-5), custom_numeric(dtypes=[np.float64, np.complex128], tol=1e-13), custom_numeric( custom_assert=custom_assert, description=("May return different, but also correct, results when " "the decomposition is not unique"), devices=("cpu", "gpu"), modes=("eager", "graph", "compiled")), ] @classmethod def max(cls, harness: primitive_harness.Harness): # TODO(bchetioui): discrepancies between TF & JAX when comparing with NaN; # JAX always returns NaN, while TF returns the value NaN is compared with. def custom_assert(tst, result_jax, result_tf, err_msg, **_): mask = np.isnan(result_jax) tst.assertAllClose(result_jax[~mask], result_tf[~mask], err_msg=err_msg) return [ custom_numeric( custom_assert=custom_assert, description=( "May return different values when one of the values is NaN. " "JAX always returns NaN, while TF returns the value NaN is compared with." ), modes=("eager", "graph", "compiled")) ] @classmethod def min(cls, harness: primitive_harness.Harness): # TODO(bchetioui): discrepancies between TF & JAX when comparing with NaN; # JAX always returns NaN, while TF returns the value NaN is compared with. def custom_assert(tst, result_jax, result_tf, *, err_msg, **_): mask = np.isnan(result_jax) tst.assertAllClose(result_jax[~mask], result_tf[~mask], err_msg=err_msg) return [ custom_numeric( custom_assert=custom_assert, description=( "May return different values when one of the values is NaN. " "JAX always returns NaN, while TF returns the value NaN is compared with." ), modes=("eager", "graph", "compiled")) ] @classmethod def nextafter(cls, harness: primitive_harness.Harness): return [missing_tf_kernel(dtypes=[np.float16, dtypes.bfloat16])] @classmethod def qr(cls, harness: primitive_harness.Harness): # See https://github.com/google/jax/pull/3775#issuecomment-659407824; # # jit_compile=True breaks for complex types. # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824. # - for now, the performance of the HLO QR implementation called when # compiling with TF is expected to have worse performance than the # custom calls made in JAX. return [ custom_numeric( dtypes=[np.float64, np.complex128], devices=("cpu", "gpu"), modes=("eager", "graph", "compiled"), tol=1e-13), custom_numeric( dtypes=[np.float32, np.complex64], devices=("cpu", "gpu"), modes=("eager", "graph", "compiled"), tol=1e-4), missing_tf_kernel( dtypes=[dtypes.bfloat16], devices="tpu", ) ] @classmethod def random_gamma(cls, harness: primitive_harness.Harness): return [custom_numeric(devices="tpu", tol=1e-3)] @classmethod def reduce_max(cls, harness: primitive_harness.Harness): # Unlike reduce_window_max, we use a native TF op: tf.reduce_max, which # does not work for complex return [missing_tf_kernel(dtypes=[np.complex64, np.complex128])] @classmethod def reduce_min(cls, harness: primitive_harness.Harness): return cls.reduce_max(harness) @classmethod def regularized_incomplete_beta(cls, harness: primitive_harness.Harness): return [ custom_numeric(dtypes=np.float64, tol=1e-14), missing_tf_kernel(dtypes=[np.float16, dtypes.bfloat16]) ] @classmethod def rem(cls, harness: primitive_harness.Harness): return [ Jax2TfLimitation( "TF integer division fails if divisor contains 0; JAX returns NaN", dtypes=[ np.uint8, np.int8, np.uint16, np.uint32, np.uint64, np.int8, np.int16, np.int32, np.int64 ], # Only the harnesses with "singularity" will have divide by 0 enabled=("singularity" in harness.name)), ] @classmethod def rng_bit_generator(cls, harness: primitive_harness.Harness): return [] @classmethod def round(cls, harness: primitive_harness.Harness): return [ missing_tf_kernel( dtypes=[dtypes.bfloat16], devices=("cpu", "gpu"), modes=("eager", "graph")) ] @classmethod def scatter_add(cls, harness): return [] @classmethod def scatter_mul(cls, harness): return [] @classmethod def select_and_gather_add(cls, harness): return [ # This JAX primitives is not not exposed directly in the JAX API # but arises from JVP of `lax.reduce_window` for reducers # `lax.max` or `lax.min`. It also arises from second-order # VJP of the same. Implemented using XlaReduceWindow. Jax2TfLimitation(( "jax2tf unimplemented for 64-bit inputs because the current implementation " "relies on packing two values into a single value. This can be " "fixed by using a variadic XlaReduceWindow, when available"), dtypes=[np.float64], devices=("cpu", "gpu")) ] @classmethod def sort(cls, harness: primitive_harness.Harness): return [ Jax2TfLimitation( # I think that this is because TF is running on CPU even for GPU tests? "TODO: TF non-stable multiple-array sort", devices="gpu", enabled=(harness.params["num_arrays"] > 1 and not harness.params["is_stable"]), expect_tf_error=False, skip_comparison=True), ] @classmethod def svd(cls, harness: primitive_harness.Harness): # TODO: slow test compute_uv = harness.params["compute_uv"] # Both `r_jax` and `r_tf` are 3-Tuples containing the SVD results: # `S` (singular values), `U` (left singular vectors), and `Vh` (the # adjoint of the right singular vectors). Note that the TF results are # obtained through `_svd` in jax/experimental/jax2tf/jax2tf.py. def custom_assert(tst, r_jax, r_tf, *, args, tol, err_msg): def reconstruct_operand(result): # Reconstructing operand as documented in numpy.linalg.svd (see # https://numpy.org/doc/stable/reference/generated/numpy.linalg.svd.html) s, u, v = result U = u[..., :s.shape[-1]] V = v[..., :s.shape[-1], :] S = s[..., None, :] return jnp.matmul(U * S, V, precision=lax.Precision.HIGHEST) # Compares the shapes. def compare_shapes(r_jax, r_tf): shapes_jax = [result.shape for result in r_jax] shapes_tf = [result.shape for result in r_tf] tst.assertEqual(shapes_jax, shapes_tf) # Compares reconstructed operand. # Computes backward error https://www.netlib.org/lapack/lug/node97.html # and uses the maximum backward error if there are batch dimensions. # The backward error is bounded by some constant multiplying the machine # precision. # TODO: Compares the operand instead of the reconstructed operand. def compare_reconstructed_operand(r_jax, r_tf, tol): operand_jax = reconstruct_operand(r_jax) operand_tf = reconstruct_operand(r_tf) error_norm = jnp.linalg.norm(operand_jax - operand_tf, axis=(-2, -1)) backward_error = (error_norm / jnp.linalg.norm(operand_jax, axis=(-2, -1))) max_backward_error = jnp.amax(backward_error) tst.assertLess(max_backward_error, tol) # Computes the absolute gap between singular value `\sigma_i` and the # nearest other singular value and for all singular values. The absolute # gap is used to approximate the upper bound of angular difference # between the computed and the true singular vectors. If the matrix is # rectangular `m != n`, the gap for the smallest nonzero singular value # should also consider the gap between it and zero. Note that this code # relies on the singular values being in descending order. def compute_absolute_gap(s, m, n): forward_appendant = np.Inf if m == n else 0 forward_diff = jnp.diff(s, axis=-1, append=forward_appendant) backward_diff = jnp.diff( s[..., ::-1], axis=-1, append=np.Inf)[..., ::-1] absolute_gap = jnp.minimum(jnp.abs(forward_diff), jnp.abs(backward_diff)) return absolute_gap # See `CompareSingularVectors` in # tensorflow/python/kernel_tests/linalg/svd_op_test.py def compare_singular_vectors(x, y, *, error_bound): # Singular vectors are only unique up to sign (complex phase factor for # complex matrices), so we normalize the sign first. sum_of_ratios = jnp.sum(jnp.divide(y, x), -2, keepdims=True) phases = jnp.divide(sum_of_ratios, jnp.abs(sum_of_ratios)) x *= phases # Note that in general `sqrt(sum(squares))` is not a stable way to # compute l2 vector norms, but it should be OK for normalization # factors of vectors with norm ~= 1 as here. def dot_column_wise(a, b): output = jnp.sum(jnp.einsum('...ij,...ij->...ij', a.conj(), b, precision=lax.Precision.HIGHEST), axis=-2) return jnp.real(output) cos_angular_diff = ( dot_column_wise(x, y) / jnp.sqrt(dot_column_wise(x, x) * dot_column_wise(y, y))) # Values of `\cos(angular_diff)` outside the interval [0, 1] are clipped # to the interval edges. For example, `\cos(angular_diff)` could contain # values like 1.0000001 on float32, which are clipped to 1.0. It is # possible that anything other than `cos_angular_diff` can be outside # the interval [0, 1] due to roundoff. cos_angular_diff = jnp.clip(cos_angular_diff, a_min=0.0, a_max=1.0) angular_diff = jnp.arccos(cos_angular_diff) # TODO: removes the slack factor on the angular difference. # It is possible that the singular vectors are not accurate to much more # than O(\sqrt(eps)), which is likely a property of the SVD algorithms # in question; revisit with better understanding of the SVD algorithms. if x.dtype in [np.float32, np.complex64]: slack_factor = 2E4 elif x.dtype in [np.float64, np.complex128]: slack_factor = 2E9 np.testing.assert_array_less(angular_diff, slack_factor * error_bound) if compute_uv: # Compares the shapes. compare_shapes(r_jax, r_tf) # Compares the singular values. Each computed singular value `\sigma_i` # differs from the true `\sigma_i`* by at most # `|\sigma_i - \sigma_i*| <= \epsilon \sigma_1`, where `\sigma_1` is the # largest singular value and `\epsilon` denotes the machine precision. s_jax, s_tf = r_jax[0], r_tf[0] tst.assertAllClose(s_jax, s_tf, atol=tol, rtol=tol, err_msg=err_msg) # Compares the reconstructed operand. compare_reconstructed_operand(r_jax, r_tf, tol) # Compares the singular vectors. # We only compare the first `rank` singular vectors since the remainder # forms an arbitrary orthonormal basis for the (row- or column-) null # space, whose exact value depends on implementation details. # TODO: A better estimation on the rank? rank = r_jax[0].shape[-1] # Computes the upper bound for angular difference of singular vectors. # The upper bound has the shape of `[..., k]`, where `...` denotes the # batch dimensions and `k` is the number of nonzero singular values. m = r_jax[1].shape[-2] n = r_jax[2].shape[-2] absolute_gap = compute_absolute_gap(r_jax[0], m, n) epsilon = jnp.finfo(r_jax[0].dtype).eps sigma_largest = (r_jax[0][..., 0])[..., None] upperbound_singular_vectors = epsilon * sigma_largest / absolute_gap upperbound_singular_vectors = upperbound_singular_vectors[..., :rank] # Left singular vectors. u_jax = r_jax[1][..., :rank] u_tf = r_tf[1][..., :rank] compare_singular_vectors(u_jax, u_tf, error_bound=upperbound_singular_vectors) # Right singular vectors. v_jax = jnp.swapaxes(r_jax[2][..., :rank, :], -2, -1).conj() v_tf = jnp.swapaxes(r_tf[2][..., :rank, :], -2, -1).conj() compare_singular_vectors(v_jax, v_tf, error_bound=upperbound_singular_vectors) else: tst.assertAllClose(r_jax, r_tf, atol=tol, rtol=tol, err_msg=err_msg) return [ # Works in JAX for complex due to custom calls on cpu and gpu Jax2TfLimitation( "function not compilable. Implemented using `tf.linalg.svd` and `tf.linalg.adjoint`", dtypes=[np.complex64, np.complex128], devices=("cpu", "gpu"), modes=("compiled",)), Jax2TfLimitation( "Large numerical discrepancy", dtypes=[np.float16], devices=("tpu"), modes=("eager", "graph", "compiled"), skip_comparison=True), missing_tf_kernel(dtypes=[dtypes.bfloat16], devices="tpu"), custom_numeric( tol=1e-4, dtypes=[np.float32, np.complex64], devices=("cpu", "gpu"), modes=("eager", "graph", "compiled")), # TODO: this is very low tolerance for f64 custom_numeric( tol=1e-4, dtypes=[np.float64, np.complex128], devices=("cpu", "gpu"), modes=("eager", "graph", "compiled")), custom_numeric( tol=1e-4, description="custom numeric comparison when compute_uv on CPU/GPU", custom_assert=custom_assert, devices=("cpu", "gpu"), modes=("eager", "graph", "compiled"), enabled=(compute_uv == True)), custom_numeric( tol=1e-2, description="custom numeric comparison when compute_uv on TPU", dtypes=[np.float32, np.float64, np.complex64, np.complex128], custom_assert=custom_assert, devices=("tpu"), modes=("eager", "graph", "compiled"), enabled=(compute_uv == True)), ] @classmethod def tan(cls, harness): return [ custom_numeric(dtypes=np.complex64, devices="tpu", tol=1e-4), custom_numeric(dtypes=np.complex64, devices=("cpu", "gpu"), tol=1e-3), custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12) ] @classmethod def tanh(cls, harness): return [ custom_numeric(dtypes=np.complex128, tol=1e-7), custom_numeric(dtypes=np.complex64, tol=1e-4) ] @classmethod def top_k(cls, harness): def custom_assert(tst, result_jax, result_tf, *, err_msg, **_): assert len(result_jax) == len(result_tf) # TODO: TF and JAX sort [inf, nan] differently. first_arr_jax, first_arr_tf = result_jax[0], result_tf[0] if np.all(first_arr_jax == first_arr_tf): for arr_jax, arr_tf in zip(result_jax, result_tf): tst.assertArraysEqual(arr_jax, arr_tf, err_msg=err_msg) else: mask_jax, mask_tf = np.isnan(first_arr_jax), np.isnan(first_arr_tf) tst.assertArraysEqual( first_arr_jax[~mask_jax], first_arr_tf[~mask_tf], err_msg=err_msg) return [ custom_numeric( dtypes=[np.float16, dtypes.bfloat16, np.float32, np.float64], custom_assert=custom_assert, description=( "Produces different results when the array contains `inf` and `NaN`" " (they are sorted differently in TF vs. XLA).")) ] @classmethod def triangular_solve(cls, harness: primitive_harness.Harness): return [ missing_tf_kernel(dtypes=[dtypes.bfloat16]), missing_tf_kernel( dtypes=[np.float16], devices=("gpu", "cpu"), modes=("eager", "graph")), custom_numeric(dtypes=np.float32, tol=5e-3) ] @classmethod def tridiagonal_solve(cls, harness: primitive_harness.Harness): return [] def custom_numeric( *, description="custom numeric comparison", dtypes=(), # All modes=( "eager", "graph", ), # By default we should not need tolerance for # "compiled" devices=("cpu", "gpu", "tpu"), custom_assert=None, enabled=True, tol=None) -> Jax2TfLimitation: return Jax2TfLimitation( description, expect_tf_error=False, dtypes=dtypes, devices=devices, modes=modes, custom_assert=custom_assert, enabled=enabled, tol=tol) def missing_tf_kernel(*, description="op not defined for dtype", dtypes, modes=("eager", "graph", "compiled"), devices=("cpu", "gpu", "tpu"), enabled=True) -> Jax2TfLimitation: return Jax2TfLimitation( description, dtypes=dtypes, devices=devices, modes=modes, enabled=enabled)