mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
1257 lines
47 KiB
Python
1257 lines
47 KiB
Python
# Copyright 2021 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""See primitives_test docstring for how the Jax2TfLimitations are used."""
|
|
|
|
import itertools
|
|
from typing import Any, Callable, Optional, Sequence, Union
|
|
|
|
from jax import lax
|
|
from jax import numpy as jnp
|
|
from jax._src import test_util as jtu
|
|
from jax._src import dtypes
|
|
from jax.experimental.jax2tf.tests import primitive_harness
|
|
import numpy as np
|
|
|
|
DType = Any
|
|
|
|
|
|
class Jax2TfLimitation(primitive_harness.Limitation):
|
|
"""Specific primitive limitations for jax2tf.
|
|
|
|
See the primitive_test module docstring for details.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
description: str,
|
|
*,
|
|
devices: Union[str, Sequence[str]] = ("cpu", "gpu", "tpu"),
|
|
dtypes: Union[DType, Sequence[DType]] = (),
|
|
enabled: bool = True,
|
|
# jax2tf specific
|
|
modes=("eager", "graph", "compiled"),
|
|
skip_tf_run=False,
|
|
expect_tf_error: bool = True,
|
|
skip_comparison=False,
|
|
custom_assert: Optional[Callable] = None,
|
|
tol=None):
|
|
"""See the primitive_harness.Limitation common arguments.
|
|
|
|
Args :
|
|
modes: one of "eager", "graph", "compiled"
|
|
skip_tf_run: if set will skip the TF execution. Use this sparingly,
|
|
prefer `expect_tf_error`. Use only when the test cannot recover from
|
|
the TF error.
|
|
expect_tf_error: if set, then expect a TF error in the given mode when
|
|
executing the result of jax2tf conversion. If not set, then the
|
|
limitation must have a custom_assert or non-default tol.
|
|
skip_comparison: skips the numeric comparison.
|
|
tol: a tolerance to use for both atol and rtol. We will use the maximum
|
|
tolerance over all the applicable limitations, irrespective of their
|
|
order.
|
|
custom_assert: if given, then execute as
|
|
`custom_assert(tst, result_jax, result_tf, args=args, tol=tol, err_msg)`
|
|
, where `tst` is the current TestCase instance, and args are the input
|
|
arguments that the harness created. The `tol` is the maximum tolerance
|
|
based on the applicable limitations. `err_msg` is passed to NumPy
|
|
assert methods.
|
|
`result_tf` is already converted to NumPy arrays.
|
|
"""
|
|
super().__init__(
|
|
description, devices=devices, dtypes=dtypes, enabled=enabled)
|
|
if isinstance(modes, str):
|
|
modes = (modes,)
|
|
assert all(m in ["eager", "graph", "compiled"] for m in modes), "Invalid modes: {modes}"
|
|
self.modes = modes
|
|
self.expect_tf_error = expect_tf_error
|
|
self.skip_tf_run = skip_tf_run
|
|
self.custom_assert = custom_assert
|
|
self.tol = tol
|
|
self.skip_comparison = skip_comparison
|
|
|
|
def get_max_tolerance_limitation(
|
|
self, limitations: Sequence["Jax2TfLimitation"]
|
|
) -> Optional["Jax2TfLimitation"]:
|
|
"""Pick the tolerance limitation that establishes the maximum tolerance."""
|
|
# TODO: it would be best if the limitations with tolerance are mutually exclusive
|
|
# and we don't have to compute the maximum
|
|
# TODO: we made this an instance method only so that we don't have to import
|
|
# this module from tf_test.util.
|
|
max_tol_lim = None
|
|
for l in limitations:
|
|
if l.tol is not None:
|
|
if max_tol_lim is None or l.tol > max_tol_lim.tol:
|
|
max_tol_lim = l
|
|
return max_tol_lim
|
|
|
|
def filter( # type: ignore[override]
|
|
self,
|
|
dtype: Optional[DType] = None,
|
|
device: Optional[str] = None,
|
|
mode: Optional[str] = None) -> bool:
|
|
return ((mode is None or mode in self.modes) and
|
|
super().filter(device=device, dtype=dtype))
|
|
|
|
@classmethod
|
|
def limitations_for_harness(
|
|
cls, harness: primitive_harness.Harness) -> Sequence["Jax2TfLimitation"]:
|
|
group_method = getattr(cls, harness.group_name, None)
|
|
if harness.group_name in cls.harness_groups_no_limitations:
|
|
assert group_method is None, (
|
|
f"Harness group '{harness.group_name}' is both in "
|
|
f"'harness_groups_no_limitations' and has a custom "
|
|
f"Jax2TfLimitation.classmethod defined (see module docstring)")
|
|
return []
|
|
else:
|
|
assert group_method is not None, (
|
|
f"Harness group '{harness.group_name}' must be either part of "
|
|
f"'harness_groups_no_limitations' or must have a custom "
|
|
f"Jax2TfLimitation.classmethod defined (see module docstring)")
|
|
limitations = group_method(harness)
|
|
assert isinstance(limitations, (list, tuple))
|
|
return limitations
|
|
|
|
# We keep here the explicit set of groups for which we don't have limitations
|
|
harness_groups_no_limitations = {
|
|
"abs", "add", "add_any", "and", "atan2",
|
|
"bitcast_convert_type", "broadcast", "broadcast_in_dim", "cbrt", "ceil",
|
|
"clamp", "concatenate", "cos", "cosh", "complex", "conj",
|
|
"convert_element_type",
|
|
"cummax", "cummin", "device_put", "dynamic_slice",
|
|
"dynamic_update_slice", "exp", "eq", "floor", "gather", "ge", "gt",
|
|
"imag",
|
|
"iota", "is_finite", "le", "lt", "log", "mul", "ne", "neg", "not",
|
|
"or", "pad", "population_count",
|
|
"random_categorical", "random_split", "random_uniform", "random_randint",
|
|
"reduce",
|
|
"reduce_and", "reduce_prod", "reduce_or", "reduce_sum",
|
|
"reduce_window_add", "reduce_window_mul", "reduce_window_min",
|
|
"reduce_window_max",
|
|
"real", "reshape", "rev", "rsqrt", "scatter_max", "scatter_min",
|
|
"select_n", "select_and_scatter_add",
|
|
"shift_left", "shift_right_logical", "shift_right_arithmetic", "sign",
|
|
"sin", "sinh", "slice", "sqrt", "squeeze", "stop_gradient", "sub",
|
|
"tie_in", "transpose", "xor", "zeros_like"
|
|
}
|
|
|
|
@classmethod
|
|
def helper_get_trig_custom_limitation(cls, np_inverse):
|
|
|
|
def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg):
|
|
operand, = args
|
|
tst.assertAllClose(
|
|
operand, np_inverse(result_tf), atol=tol, rtol=tol, err_msg=err_msg)
|
|
|
|
return custom_numeric(
|
|
description="May return different but still correct results",
|
|
dtypes=[np.complex64, np.complex128],
|
|
custom_assert=custom_assert)
|
|
|
|
@classmethod
|
|
def acos(cls, harness: primitive_harness.Harness):
|
|
return [
|
|
custom_numeric(
|
|
dtypes=np.complex64,
|
|
devices=("cpu", "gpu"),
|
|
tol=1e-4,
|
|
modes=("eager", "graph", "compiled")),
|
|
custom_numeric(
|
|
dtypes=np.complex128,
|
|
devices=("cpu", "gpu"),
|
|
tol=1e-13,
|
|
modes=("eager", "graph", "compiled")),
|
|
]
|
|
|
|
@classmethod
|
|
def acosh(cls, harness: primitive_harness.Harness):
|
|
return [
|
|
custom_numeric(dtypes=np.complex64, devices=("cpu", "gpu"), tol=1e-3),
|
|
custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12),
|
|
cls.helper_get_trig_custom_limitation(np.cosh)
|
|
]
|
|
|
|
@classmethod
|
|
def approx_max_k(cls, harness: primitive_harness.Harness):
|
|
supported_dtypes = jtu.supported_dtypes()
|
|
return Jax2TfLimitation(
|
|
"eager is not supported in CPU or GPU.",
|
|
dtypes=[t for t in [jnp.bfloat16, np.float16, np.float32]
|
|
if t in supported_dtypes],
|
|
devices=("cpu", "gpu", "tpu"),
|
|
modes=("graph", "compiled"))
|
|
|
|
@classmethod
|
|
def argmax(cls, harness: primitive_harness.Harness):
|
|
return [
|
|
Jax2TfLimitation(
|
|
"different results when the input contains NaN and enable_xla=False",
|
|
dtypes=jtu.dtypes.all_inexact,
|
|
devices=("cpu", "gpu", "tpu"),
|
|
modes=("eager", "graph", "compiled"),
|
|
expect_tf_error=False,
|
|
skip_comparison=True,
|
|
enabled=("nan_" in harness.name and not harness.params["enable_xla"])),
|
|
]
|
|
|
|
@classmethod
|
|
def argmin(cls, harness: primitive_harness.Harness):
|
|
return cls.argmax(harness)
|
|
|
|
@classmethod
|
|
def asin(cls, harness: primitive_harness.Harness):
|
|
return [
|
|
custom_numeric(dtypes=np.complex64, devices=("cpu", "gpu"), tol=1e-4),
|
|
custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12),
|
|
cls.helper_get_trig_custom_limitation(np.sin)
|
|
]
|
|
|
|
@classmethod
|
|
def asinh(cls, harness: primitive_harness.Harness):
|
|
return [
|
|
custom_numeric(dtypes=np.complex64, devices=("cpu", "gpu"), tol=1e-3),
|
|
custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12),
|
|
cls.helper_get_trig_custom_limitation(np.sinh)
|
|
]
|
|
|
|
@classmethod
|
|
def atan(cls, harness: primitive_harness.Harness):
|
|
return [
|
|
custom_numeric(dtypes=np.complex64, devices=("cpu", "gpu"), tol=1e-5),
|
|
custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12),
|
|
cls.helper_get_trig_custom_limitation(np.tan)
|
|
]
|
|
|
|
@classmethod
|
|
def atanh(cls, harness: primitive_harness.Harness):
|
|
return [
|
|
custom_numeric(dtypes=np.float64, tol=1e-14),
|
|
custom_numeric(dtypes=np.complex64, tol=1e-3),
|
|
custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12),
|
|
cls.helper_get_trig_custom_limitation(np.tanh)
|
|
]
|
|
|
|
@classmethod
|
|
def bessel_i0e(cls, harness: primitive_harness.Harness):
|
|
return [
|
|
missing_tf_kernel(
|
|
dtypes=[dtypes.bfloat16],
|
|
devices=("cpu", "gpu"),
|
|
modes=("eager", "graph"))
|
|
]
|
|
|
|
@classmethod
|
|
def bessel_i1e(cls, harness: primitive_harness.Harness):
|
|
return cls.bessel_i0e(harness)
|
|
|
|
@classmethod
|
|
def cholesky(cls, harness: primitive_harness.Harness):
|
|
|
|
def custom_assert(tst, result_jax, result_tf, *, tol, err_msg, **_):
|
|
# cholesky_p returns garbage in the strictly upper triangular part of the
|
|
# result, so we can safely ignore that part.
|
|
tst.assertAllClose(
|
|
jnp.tril(result_jax), result_tf, atol=tol, err_msg=err_msg)
|
|
|
|
return [
|
|
# TODO: very high tolerance
|
|
custom_numeric(
|
|
dtypes=[np.float32, np.complex64],
|
|
tol=1e-2,
|
|
devices=("cpu", "gpu"),
|
|
modes=("eager", "graph", "compiled")),
|
|
custom_numeric(
|
|
dtypes=[np.float64, np.complex128],
|
|
tol=1e-6,
|
|
devices=("cpu", "gpu"),
|
|
modes=("eager", "graph", "compiled")),
|
|
custom_numeric(
|
|
dtypes=[dtypes.bfloat16, np.float16],
|
|
tol=5e-2,
|
|
devices=("cpu", "gpu"),
|
|
modes=("eager", "graph", "compiled")),
|
|
custom_numeric(
|
|
custom_assert=custom_assert,
|
|
description=(
|
|
"May return different values in the strictly upper triangular "
|
|
"part of the result. This does not matter for correctness, "
|
|
"because this part of the matrix is not considered in the result."
|
|
),
|
|
modes=("eager", "graph", "compiled"))
|
|
]
|
|
|
|
@classmethod
|
|
def conv_general_dilated(cls, harness: primitive_harness.Harness):
|
|
return [
|
|
# Even in compiled mode, for GPU we see a bit of discrepancy but
|
|
# very minor.
|
|
custom_numeric(dtypes=np.float32, devices="gpu",
|
|
modes=("eager", "graph", "compiled"),
|
|
tol=1e-5),
|
|
custom_numeric(dtypes=np.float32, devices="cpu",
|
|
modes=("eager", "graph", "compiled"),
|
|
tol=1e-4),
|
|
custom_numeric(description="higher numeric inaccuracy when `enable_xla=False`",
|
|
modes=("eager", "graph", "compiled"),
|
|
enabled=(not harness.params["enable_xla"]),
|
|
tol=5e-3)
|
|
]
|
|
|
|
@classmethod
|
|
def cumprod(cls, harness):
|
|
return [
|
|
# JAX uses a different lowering for CPU and GPU.
|
|
custom_numeric(
|
|
dtypes=(np.float16, jnp.bfloat16),
|
|
devices=("cpu", "gpu"),
|
|
modes=("eager", "graph", "compiled"),
|
|
tol=5e-1)
|
|
]
|
|
|
|
@classmethod
|
|
def cumsum(cls, harness):
|
|
return [
|
|
# JAX uses a different lowering for CPU and GPU.
|
|
custom_numeric(
|
|
dtypes=(np.float16, jnp.bfloat16),
|
|
devices=("cpu", "gpu"),
|
|
modes=("eager", "graph", "compiled"),
|
|
tol=5e-1)
|
|
]
|
|
|
|
@classmethod
|
|
def custom_linear_solve(cls, harness: primitive_harness.Harness):
|
|
return [
|
|
Jax2TfLimitation(
|
|
"TODO: large numerical discrepancy",
|
|
dtypes=np.float32,
|
|
devices="tpu",
|
|
expect_tf_error=False,
|
|
skip_comparison=True),
|
|
custom_numeric(dtypes=np.float32, devices="tpu", tol=0.01),
|
|
custom_numeric(tol=1e-3),
|
|
]
|
|
|
|
@classmethod
|
|
def digamma(cls, harness: primitive_harness.Harness):
|
|
dtype = harness.dtype
|
|
|
|
# In the bfloat16 case, TF and lax both return NaN in undefined cases.
|
|
# digamma is not defined at 0 and -1
|
|
def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg):
|
|
# lax.digamma returns NaN and tf.math.digamma returns inf
|
|
arg, = args
|
|
special_cases = (arg == 0.) | (arg == -1.)
|
|
nr_special_cases = np.count_nonzero(special_cases)
|
|
tst.assertAllClose(
|
|
np.full((nr_special_cases,), dtype(np.nan)),
|
|
result_jax[special_cases],
|
|
err_msg=err_msg)
|
|
tst.assertAllClose(
|
|
np.full((nr_special_cases,), dtype(np.inf)),
|
|
result_tf[special_cases],
|
|
err_msg=err_msg)
|
|
# non-special cases are equal
|
|
tst.assertAllClose(
|
|
result_jax[~special_cases],
|
|
result_tf[~special_cases],
|
|
atol=tol,
|
|
rtol=tol,
|
|
err_msg=err_msg)
|
|
|
|
return [
|
|
missing_tf_kernel(
|
|
dtypes=[dtypes.bfloat16],
|
|
devices=("cpu", "gpu"),
|
|
modes=("eager", "graph")),
|
|
custom_numeric(dtypes=np.float64, tol=1e-13),
|
|
custom_numeric(dtypes=np.float32, devices=["cpu", "gpu"], tol=1e-3),
|
|
custom_numeric(
|
|
dtypes=dtypes.bfloat16,
|
|
custom_assert=custom_assert,
|
|
description=(
|
|
"May return different results at singularity points 0 and -1."
|
|
"JAX returns nan and TF returns inf"))
|
|
]
|
|
|
|
@classmethod
|
|
def div(cls, harness: primitive_harness.Harness):
|
|
return [
|
|
Jax2TfLimitation(
|
|
"TF integer division fails if divisor contains 0; JAX returns NaN",
|
|
dtypes=[
|
|
np.uint8, np.int8, np.uint16, np.uint32, np.uint64, np.int8,
|
|
np.int16, np.int32, np.int64
|
|
],
|
|
# Only the harnesses with "singularity" will have divide by 0
|
|
enabled=("singularity" in harness.name))
|
|
]
|
|
|
|
@classmethod
|
|
def dot_general(cls, harness: primitive_harness.Harness):
|
|
return [
|
|
missing_tf_kernel(dtypes=[np.bool_],),
|
|
# TODO(b/189287598)
|
|
Jax2TfLimitation(
|
|
"Non-deterministic NaN for dot_general with preferred_element_type on GPU (b/189287598)",
|
|
dtypes=[
|
|
jnp.bfloat16, np.float16, np.float32, np.complex64
|
|
],
|
|
devices="gpu",
|
|
modes=("eager", "graph", "compiled"),
|
|
enabled=(harness.params["preferred_element_type"] is not None),
|
|
skip_comparison=True),
|
|
# JAX performs float16 matmuls in float32 on CPU, so the JAX result
|
|
# may be more precise.
|
|
custom_numeric(dtypes=[np.float16], devices=["cpu"], tol=1e-2,
|
|
modes=("eager", "graph", "compiled")),
|
|
]
|
|
|
|
@classmethod
|
|
def eig(cls, harness: primitive_harness.Harness):
|
|
compute_left_eigenvectors = harness.params["compute_left_eigenvectors"]
|
|
compute_right_eigenvectors = harness.params["compute_right_eigenvectors"]
|
|
dtype = harness.dtype
|
|
|
|
def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg):
|
|
operand, = args
|
|
inner_dimension = operand.shape[-1]
|
|
|
|
# Test ported from tests.linlag_test.testEig
|
|
# Norm, adjusted for dimension and type.
|
|
def norm(x):
|
|
norm = np.linalg.norm(x, axis=(-2, -1))
|
|
return norm / ((inner_dimension + 1) * jnp.finfo(dtype).eps)
|
|
|
|
def check_right_eigenvectors(a, w, vr):
|
|
tst.assertTrue(
|
|
np.all(norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100))
|
|
|
|
def check_left_eigenvectors(a, w, vl):
|
|
rank = len(a.shape)
|
|
aH = jnp.conj(a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2]))
|
|
wC = jnp.conj(w)
|
|
check_right_eigenvectors(aH, wC, vl)
|
|
|
|
def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array):
|
|
tol = None
|
|
# TODO(bchetioui): numerical discrepancies
|
|
if dtype in [np.float32, np.complex64]:
|
|
tol = 1e-4
|
|
elif dtype in [np.float64, np.complex128]:
|
|
tol = 1e-13
|
|
closest_diff = min(abs(eigenvalues_array - eigenvalue))
|
|
tst.assertAllClose(
|
|
closest_diff,
|
|
np.array(0., closest_diff.dtype),
|
|
atol=tol,
|
|
err_msg=err_msg)
|
|
|
|
all_w_jax, all_w_tf = result_jax[0], result_tf[0]
|
|
for idx in itertools.product(*map(range, operand.shape[:-2])):
|
|
w_jax, w_tf = all_w_jax[idx], all_w_tf[idx]
|
|
for i in range(inner_dimension):
|
|
check_eigenvalue_is_in_array(w_jax[i], w_tf)
|
|
check_eigenvalue_is_in_array(w_tf[i], w_jax)
|
|
|
|
if compute_left_eigenvectors:
|
|
check_left_eigenvectors(operand, all_w_tf, result_tf[1])
|
|
if compute_right_eigenvectors:
|
|
check_right_eigenvectors(operand, all_w_tf,
|
|
result_tf[1 + compute_left_eigenvectors])
|
|
|
|
return [
|
|
# Eig does not work in JAX on gpu or tpu
|
|
Jax2TfLimitation(
|
|
"function not compilable", modes="compiled", devices="cpu"),
|
|
Jax2TfLimitation(
|
|
"TF Conversion of eig is not implemented when both compute_left_eigenvectors and compute_right_eigenvectors are set to True",
|
|
enabled=(compute_left_eigenvectors and compute_right_eigenvectors)),
|
|
custom_numeric(
|
|
custom_assert=custom_assert,
|
|
description=("May return the eigenvalues and eigenvectors in a "
|
|
"potentially different order. The eigenvectors may "
|
|
"also be different, but equally valid."))
|
|
]
|
|
|
|
@classmethod
|
|
def eigh(cls, harness: primitive_harness.Harness):
|
|
dtype = harness.dtype
|
|
|
|
def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg):
|
|
operand, = args
|
|
inner_dimension = operand.shape[-1]
|
|
|
|
def check_right_eigenvectors(a, w, vr):
|
|
tol = 1e-16
|
|
# TODO(bchetioui): tolerance needs to be very high in compiled mode,
|
|
# specifically for eigenvectors.
|
|
if dtype == np.float64:
|
|
tol = 2e-5
|
|
elif dtype == np.float32:
|
|
tol = 1e-2
|
|
elif dtype in [dtypes.bfloat16, np.complex64]:
|
|
tol = 1e-3
|
|
elif dtype == np.complex128:
|
|
tol = 2e-5
|
|
tst.assertAllClose(
|
|
np.matmul(a, vr) - w[..., None, :] * vr,
|
|
np.zeros(a.shape, dtype=vr.dtype),
|
|
atol=tol,
|
|
# For bfloat16 the np.matmul returns float32 result.
|
|
check_dtypes=False,
|
|
err_msg=err_msg)
|
|
|
|
def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array):
|
|
tol = None
|
|
if dtype in [dtypes.bfloat16, np.float32, np.complex64]:
|
|
tol = 1e-3
|
|
elif dtype in [np.float64, np.complex128]:
|
|
tol = 1e-5
|
|
closest_diff = min(abs(eigenvalues_array - eigenvalue))
|
|
tst.assertAllClose(
|
|
closest_diff,
|
|
np.array(0., closest_diff.dtype),
|
|
atol=tol,
|
|
err_msg=err_msg)
|
|
|
|
_, all_w_jax = result_jax
|
|
all_vr_tf, all_w_tf = result_tf
|
|
|
|
for idx in itertools.product(*map(range, operand.shape[:-2])):
|
|
w_jax, w_tf = all_w_jax[idx], all_w_tf[idx]
|
|
for i in range(inner_dimension):
|
|
check_eigenvalue_is_in_array(w_jax[i], w_tf)
|
|
check_eigenvalue_is_in_array(w_tf[i], w_jax)
|
|
|
|
check_right_eigenvectors(operand, all_w_tf, all_vr_tf)
|
|
|
|
return [
|
|
missing_tf_kernel(
|
|
dtypes=dtypes.bfloat16,
|
|
devices="tpu",
|
|
enabled=(harness.params["shape"] != (0, 0)), # This actually works!
|
|
),
|
|
Jax2TfLimitation(
|
|
"TODO: numeric discrepancies",
|
|
dtypes=np.float16,
|
|
devices="tpu",
|
|
expect_tf_error=False,
|
|
skip_comparison=True),
|
|
custom_numeric(
|
|
custom_assert=custom_assert,
|
|
description=("May return the eigenvalues and eigenvectors in a "
|
|
"potentially different order. The eigenvectors may "
|
|
"also be different, but equally valid."),
|
|
modes=("eager", "graph", "compiled"))
|
|
]
|
|
|
|
@classmethod
|
|
def erf(cls, harness: primitive_harness.Harness):
|
|
return [
|
|
missing_tf_kernel(
|
|
dtypes=[dtypes.bfloat16],
|
|
devices=("cpu", "gpu"),
|
|
modes=("eager", "graph"))
|
|
]
|
|
|
|
@classmethod
|
|
def erfc(cls, harness: primitive_harness.Harness):
|
|
return [
|
|
missing_tf_kernel(
|
|
dtypes=[dtypes.bfloat16],
|
|
devices=("cpu", "gpu"),
|
|
modes=("eager", "graph"))
|
|
]
|
|
|
|
@classmethod
|
|
def erf_inv(cls, harness: primitive_harness.Harness):
|
|
# erf_inv is not defined for arg <= -1 or arg >= 1
|
|
def custom_assert(tst, result_jax, result_tf, *, args, tol,
|
|
err_msg): # noqa: F811
|
|
arg, = args
|
|
# for arg < -1 or arg > 1
|
|
# lax.erf_inv returns NaN; tf.math.erf_inv return +/- inf
|
|
special_cases = (arg < -1.) | (arg > 1.)
|
|
# non-special cases are equal
|
|
tst.assertAllClose(
|
|
result_jax[~special_cases],
|
|
result_tf[~special_cases],
|
|
atol=tol,
|
|
rtol=tol,
|
|
err_msg=err_msg)
|
|
|
|
return [
|
|
missing_tf_kernel(
|
|
dtypes=[dtypes.bfloat16, np.float16],
|
|
devices=("cpu", "gpu"),
|
|
modes=("eager", "graph")),
|
|
custom_numeric(dtypes=[np.float32, np.float64], tol=1e-4),
|
|
custom_numeric(
|
|
dtypes=[np.float32, np.float64],
|
|
custom_assert=custom_assert,
|
|
description=(
|
|
"May return different results at undefined points (< -1 or > 1):"
|
|
" JAX returns `NaN` and TF returns `+inf` or `-inf`."))
|
|
]
|
|
|
|
@classmethod
|
|
def expm1(cls, harness: primitive_harness.Harness):
|
|
return [custom_numeric(dtypes=np.float64, tol=1e-5)]
|
|
|
|
@classmethod
|
|
def fft(cls, harness):
|
|
return [
|
|
Jax2TfLimitation(
|
|
"TF function not compileable",
|
|
devices=("cpu", "gpu"),
|
|
dtypes=[np.float64, np.complex128],
|
|
modes="compiled"),
|
|
# TODO: very high tolerance
|
|
custom_numeric(tol=1e-3, modes=("eager", "graph", "compiled")),
|
|
]
|
|
|
|
@classmethod
|
|
def _pow_test_util(cls, harness: primitive_harness.Harness):
|
|
|
|
def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg):
|
|
# NaNs are mismatched, but assertAllClose will also behave weirdly for
|
|
# complex numbers containing np.inf as one of their components. See
|
|
# https://github.com/numpy/numpy/issues/15959 for more details.
|
|
mask = (
|
|
np.isnan(result_jax) + np.isnan(result_tf) + np.isinf(result_jax) +
|
|
np.isinf(result_tf))
|
|
tst.assertAllClose(
|
|
result_jax[~mask], result_tf[~mask], rtol=tol, err_msg=err_msg)
|
|
|
|
return [
|
|
custom_numeric(
|
|
dtypes=[np.float32, np.complex64], devices=("cpu", "gpu"),
|
|
tol=1e-3),
|
|
custom_numeric(
|
|
dtypes=[np.float64, np.complex128],
|
|
devices=("cpu", "gpu"),
|
|
tol=5e-5),
|
|
custom_numeric(
|
|
dtypes=[np.complex64, np.complex128],
|
|
custom_assert=custom_assert,
|
|
)
|
|
]
|
|
|
|
@classmethod
|
|
def igamma(cls, harness: primitive_harness.Harness):
|
|
dtype = harness.dtype
|
|
|
|
# igamma is not defined when the first argument is <=0
|
|
def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg):
|
|
arg1, arg2 = args
|
|
# lax.igamma returns NaN when arg1 == arg2 == 0; tf.math.igamma returns 0
|
|
special_cases = (arg1 == 0.) & (arg2 == 0.)
|
|
nr_special_cases = np.count_nonzero(special_cases)
|
|
tst.assertAllClose(
|
|
np.full((nr_special_cases,), np.nan, dtype=dtype),
|
|
result_jax[special_cases])
|
|
tst.assertAllClose(
|
|
np.full((nr_special_cases,), 0., dtype=dtype),
|
|
result_tf[special_cases])
|
|
# non-special cases are equal
|
|
tst.assertAllClose(
|
|
result_jax[~special_cases],
|
|
result_tf[~special_cases],
|
|
atol=tol,
|
|
rtol=tol,
|
|
err_msg=err_msg)
|
|
|
|
return [
|
|
missing_tf_kernel(
|
|
dtypes=[dtypes.bfloat16, np.float16],
|
|
devices=("cpu", "gpu"),
|
|
modes=("eager", "graph")),
|
|
custom_numeric(
|
|
custom_assert=custom_assert,
|
|
description=(
|
|
"May return different results at undefined points "
|
|
"(both arguments 0). JAX returns `NaN` and TF returns 0 or "
|
|
"JAX returns 1 and TF returns `NaN`"))
|
|
]
|
|
|
|
@classmethod
|
|
def igammac(cls, harness: primitive_harness.Harness):
|
|
dtype = harness.dtype
|
|
|
|
# igammac is not defined when the first argument is <=0
|
|
def custom_assert(tst, result_jax, result_tf, *, args, tol,
|
|
err_msg): # noqa: F811
|
|
arg1, arg2 = args
|
|
# lax.igammac returns 1. when arg1 <= 0; tf.math.igammac returns NaN
|
|
special_cases = (arg1 <= 0.) | (arg2 <= 0)
|
|
nr_special_cases = np.count_nonzero(special_cases)
|
|
tst.assertAllClose(
|
|
np.full((nr_special_cases,), 1., dtype=dtype),
|
|
result_jax[special_cases],
|
|
err_msg=err_msg)
|
|
tst.assertAllClose(
|
|
np.full((nr_special_cases,), np.nan, dtype=dtype),
|
|
result_tf[special_cases],
|
|
err_msg=err_msg)
|
|
# non-special cases are equal
|
|
tst.assertAllClose(
|
|
result_jax[~special_cases],
|
|
result_tf[~special_cases],
|
|
atol=tol,
|
|
rtol=tol,
|
|
err_msg=err_msg)
|
|
|
|
return [
|
|
missing_tf_kernel(
|
|
dtypes=[dtypes.bfloat16, np.float16],
|
|
devices=("cpu", "gpu"),
|
|
modes=("eager", "graph")),
|
|
custom_numeric(dtypes=np.float64, tol=1e-9),
|
|
custom_numeric(devices="gpu", tol=1e-3),
|
|
custom_numeric(
|
|
custom_assert=custom_assert,
|
|
devices=("cpu", "gpu"),
|
|
description=(
|
|
"May return different results at undefined points "
|
|
"(both arguments less or equal 0). JAX returns `NaN` and TF returns 0 or "
|
|
"JAX returns 1 and TF returns `NaN`")),
|
|
]
|
|
|
|
@classmethod
|
|
def integer_pow(cls, harness: primitive_harness.Harness):
|
|
y = harness.params["y"]
|
|
return [
|
|
missing_tf_kernel(
|
|
dtypes=[
|
|
np.int8, np.int16, np.uint8, np.uint16, np.uint32, np.uint64
|
|
],
|
|
modes="graph",
|
|
enabled=(y not in [0, 1]), # These are special-cased
|
|
devices=("cpu", "gpu")),
|
|
# TODO: on TPU, for f16, we get different results with eager mode
|
|
# than with compiled mode.
|
|
Jax2TfLimitation(
|
|
"Different overflow behavior. ",
|
|
dtypes=[np.float16, jnp.bfloat16],
|
|
devices="tpu",
|
|
expect_tf_error=False,
|
|
modes=("eager", "graph"),
|
|
skip_comparison=True),
|
|
Jax2TfLimitation(
|
|
"Different overflow behavior for large exponents. ",
|
|
dtypes=[
|
|
np.int8, np.int16, np.int32, np.int64, np.float16, jnp.bfloat16,
|
|
np.float32, np.complex64, np.complex128
|
|
],
|
|
enabled=(abs(y) > 10),
|
|
expect_tf_error=False,
|
|
modes=("eager", "graph"),
|
|
skip_comparison=True),
|
|
] + list(cls._pow_test_util(harness))
|
|
|
|
@classmethod
|
|
def pow(cls, harness: primitive_harness.Harness):
|
|
return cls._pow_test_util(harness)
|
|
|
|
@classmethod
|
|
def lgamma(cls, harness: primitive_harness.Harness):
|
|
return [
|
|
missing_tf_kernel(
|
|
dtypes=[dtypes.bfloat16],
|
|
devices=("cpu", "gpu"),
|
|
modes=("eager", "graph")),
|
|
custom_numeric(dtypes=np.float64, tol=1e-11),
|
|
custom_numeric(dtypes=np.float32, tol=1e-3)
|
|
]
|
|
|
|
@classmethod
|
|
def log1p(cls, harness: primitive_harness.Harness):
|
|
return [
|
|
custom_numeric(dtypes=np.complex128, tol=3e-14),
|
|
custom_numeric(dtypes=np.float64, tol=1e-10),
|
|
custom_numeric(dtypes=np.float32, tol=1e-3)
|
|
]
|
|
|
|
@classmethod
|
|
def lu(cls, harness: primitive_harness.Harness):
|
|
dtype = harness.dtype
|
|
|
|
def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg):
|
|
operand, = args
|
|
lu, pivots, perm = result_tf
|
|
batch_dims = operand.shape[:-2]
|
|
m, n = operand.shape[-2], operand.shape[-1]
|
|
|
|
def _make_permutation_matrix(perm):
|
|
result = []
|
|
for idx in itertools.product(*map(range, operand.shape[:-1])):
|
|
result += [0 if c != perm[idx] else 1 for c in range(m)]
|
|
result = np.reshape(np.array(result, dtype=dtype), [*batch_dims, m, m])
|
|
return result
|
|
|
|
k = min(m, n)
|
|
l = jnp.tril(lu, -1)[..., :, :k] + jnp.eye(m, k, dtype=dtype)
|
|
u = jnp.triu(lu)[..., :k, :]
|
|
p_mat = _make_permutation_matrix(perm)
|
|
|
|
tst.assertArraysEqual(
|
|
lax.linalg.lu_pivots_to_permutation(pivots, m), perm)
|
|
tst.assertAllClose(
|
|
jnp.matmul(p_mat, operand),
|
|
jnp.matmul(l, u),
|
|
atol=tol,
|
|
rtol=tol,
|
|
err_msg=err_msg)
|
|
|
|
return [
|
|
custom_numeric(
|
|
dtypes=[np.float32, np.complex64], devices="tpu", tol=0.1),
|
|
custom_numeric(
|
|
dtypes=[np.float32, np.complex64], devices=("cpu", "gpu"),
|
|
tol=1e-5),
|
|
custom_numeric(dtypes=[np.float64, np.complex128], tol=1e-13),
|
|
custom_numeric(
|
|
custom_assert=custom_assert,
|
|
description=("May return different, but also correct, results when "
|
|
"the decomposition is not unique"),
|
|
devices=("cpu", "gpu"),
|
|
modes=("eager", "graph", "compiled")),
|
|
]
|
|
|
|
@classmethod
|
|
def max(cls, harness: primitive_harness.Harness):
|
|
# TODO(bchetioui): discrepancies between TF & JAX when comparing with NaN;
|
|
# JAX always returns NaN, while TF returns the value NaN is compared with.
|
|
def custom_assert(tst, result_jax, result_tf, err_msg, **_):
|
|
mask = np.isnan(result_jax)
|
|
tst.assertAllClose(result_jax[~mask], result_tf[~mask], err_msg=err_msg)
|
|
|
|
return [
|
|
custom_numeric(
|
|
custom_assert=custom_assert,
|
|
description=(
|
|
"May return different values when one of the values is NaN. "
|
|
"JAX always returns NaN, while TF returns the value NaN is compared with."
|
|
),
|
|
modes=("eager", "graph", "compiled"))
|
|
]
|
|
|
|
@classmethod
|
|
def min(cls, harness: primitive_harness.Harness):
|
|
# TODO(bchetioui): discrepancies between TF & JAX when comparing with NaN;
|
|
# JAX always returns NaN, while TF returns the value NaN is compared with.
|
|
def custom_assert(tst, result_jax, result_tf, *, err_msg, **_):
|
|
mask = np.isnan(result_jax)
|
|
tst.assertAllClose(result_jax[~mask], result_tf[~mask], err_msg=err_msg)
|
|
|
|
return [
|
|
custom_numeric(
|
|
custom_assert=custom_assert,
|
|
description=(
|
|
"May return different values when one of the values is NaN. "
|
|
"JAX always returns NaN, while TF returns the value NaN is compared with."
|
|
),
|
|
modes=("eager", "graph", "compiled"))
|
|
]
|
|
|
|
@classmethod
|
|
def nextafter(cls, harness: primitive_harness.Harness):
|
|
return [missing_tf_kernel(dtypes=[np.float16, dtypes.bfloat16])]
|
|
|
|
@classmethod
|
|
def qr(cls, harness: primitive_harness.Harness):
|
|
# See https://github.com/google/jax/pull/3775#issuecomment-659407824;
|
|
# # jit_compile=True breaks for complex types.
|
|
# TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824.
|
|
# - for now, the performance of the HLO QR implementation called when
|
|
# compiling with TF is expected to have worse performance than the
|
|
# custom calls made in JAX.
|
|
return [
|
|
custom_numeric(
|
|
dtypes=[np.float64, np.complex128],
|
|
devices=("cpu", "gpu"),
|
|
modes=("eager", "graph", "compiled"),
|
|
tol=1e-13),
|
|
custom_numeric(
|
|
dtypes=[np.float32, np.complex64],
|
|
devices=("cpu", "gpu"),
|
|
modes=("eager", "graph", "compiled"),
|
|
tol=1e-4),
|
|
missing_tf_kernel(
|
|
dtypes=[dtypes.bfloat16],
|
|
devices="tpu",
|
|
)
|
|
]
|
|
|
|
@classmethod
|
|
def random_gamma(cls, harness: primitive_harness.Harness):
|
|
return [custom_numeric(devices="tpu", tol=1e-3)]
|
|
|
|
@classmethod
|
|
def reduce_max(cls, harness: primitive_harness.Harness):
|
|
# Unlike reduce_window_max, we use a native TF op: tf.reduce_max, which
|
|
# does not work for complex
|
|
return [missing_tf_kernel(dtypes=[np.complex64, np.complex128])]
|
|
|
|
@classmethod
|
|
def reduce_min(cls, harness: primitive_harness.Harness):
|
|
return cls.reduce_max(harness)
|
|
|
|
@classmethod
|
|
def regularized_incomplete_beta(cls, harness: primitive_harness.Harness):
|
|
return [
|
|
custom_numeric(dtypes=np.float64, tol=1e-14),
|
|
missing_tf_kernel(dtypes=[np.float16, dtypes.bfloat16])
|
|
]
|
|
|
|
@classmethod
|
|
def rem(cls, harness: primitive_harness.Harness):
|
|
return [
|
|
Jax2TfLimitation(
|
|
"TF integer division fails if divisor contains 0; JAX returns NaN",
|
|
dtypes=[
|
|
np.uint8, np.int8, np.uint16, np.uint32, np.uint64, np.int8,
|
|
np.int16, np.int32, np.int64
|
|
],
|
|
# Only the harnesses with "singularity" will have divide by 0
|
|
enabled=("singularity" in harness.name)),
|
|
]
|
|
|
|
@classmethod
|
|
def rng_bit_generator(cls, harness: primitive_harness.Harness):
|
|
return []
|
|
|
|
@classmethod
|
|
def round(cls, harness: primitive_harness.Harness):
|
|
return [
|
|
missing_tf_kernel(
|
|
dtypes=[dtypes.bfloat16],
|
|
devices=("cpu", "gpu"),
|
|
modes=("eager", "graph"))
|
|
]
|
|
|
|
@classmethod
|
|
def scatter_add(cls, harness):
|
|
return []
|
|
|
|
@classmethod
|
|
def scatter_mul(cls, harness):
|
|
return []
|
|
|
|
@classmethod
|
|
def select_and_gather_add(cls, harness):
|
|
return [
|
|
# This JAX primitives is not not exposed directly in the JAX API
|
|
# but arises from JVP of `lax.reduce_window` for reducers
|
|
# `lax.max` or `lax.min`. It also arises from second-order
|
|
# VJP of the same. Implemented using XlaReduceWindow.
|
|
Jax2TfLimitation((
|
|
"jax2tf unimplemented for 64-bit inputs because the current implementation "
|
|
"relies on packing two values into a single value. This can be "
|
|
"fixed by using a variadic XlaReduceWindow, when available"),
|
|
dtypes=[np.float64],
|
|
devices=("cpu", "gpu"))
|
|
]
|
|
|
|
@classmethod
|
|
def sort(cls, harness: primitive_harness.Harness):
|
|
return [
|
|
Jax2TfLimitation(
|
|
# I think that this is because TF is running on CPU even for GPU tests?
|
|
"TODO: TF non-stable multiple-array sort",
|
|
devices="gpu",
|
|
enabled=(harness.params["num_arrays"] > 1 and
|
|
not harness.params["is_stable"]),
|
|
expect_tf_error=False,
|
|
skip_comparison=True),
|
|
]
|
|
|
|
@classmethod
|
|
def svd(cls, harness: primitive_harness.Harness):
|
|
# TODO: slow test
|
|
compute_uv = harness.params["compute_uv"]
|
|
|
|
# Both `r_jax` and `r_tf` are 3-Tuples containing the SVD results:
|
|
# `S` (singular values), `U` (left singular vectors), and `Vh` (the
|
|
# adjoint of the right singular vectors). Note that the TF results are
|
|
# obtained through `_svd` in jax/experimental/jax2tf/jax2tf.py.
|
|
def custom_assert(tst, r_jax, r_tf, *, args, tol, err_msg):
|
|
|
|
def reconstruct_operand(result):
|
|
# Reconstructing operand as documented in numpy.linalg.svd (see
|
|
# https://numpy.org/doc/stable/reference/generated/numpy.linalg.svd.html)
|
|
s, u, v = result
|
|
U = u[..., :s.shape[-1]]
|
|
V = v[..., :s.shape[-1], :]
|
|
S = s[..., None, :]
|
|
return jnp.matmul(U * S, V, precision=lax.Precision.HIGHEST)
|
|
|
|
# Compares the shapes.
|
|
def compare_shapes(r_jax, r_tf):
|
|
shapes_jax = [result.shape for result in r_jax]
|
|
shapes_tf = [result.shape for result in r_tf]
|
|
tst.assertEqual(shapes_jax, shapes_tf)
|
|
|
|
# Compares reconstructed operand.
|
|
# Computes backward error https://www.netlib.org/lapack/lug/node97.html
|
|
# and uses the maximum backward error if there are batch dimensions.
|
|
# The backward error is bounded by some constant multiplying the machine
|
|
# precision.
|
|
# TODO: Compares the operand instead of the reconstructed operand.
|
|
def compare_reconstructed_operand(r_jax, r_tf, tol):
|
|
operand_jax = reconstruct_operand(r_jax)
|
|
operand_tf = reconstruct_operand(r_tf)
|
|
error_norm = jnp.linalg.norm(operand_jax - operand_tf,
|
|
axis=(-2, -1))
|
|
backward_error = (error_norm /
|
|
jnp.linalg.norm(operand_jax, axis=(-2, -1)))
|
|
max_backward_error = jnp.amax(backward_error)
|
|
tst.assertLess(max_backward_error, tol)
|
|
|
|
# Computes the absolute gap between singular value `\sigma_i` and the
|
|
# nearest other singular value and for all singular values. The absolute
|
|
# gap is used to approximate the upper bound of angular difference
|
|
# between the computed and the true singular vectors. If the matrix is
|
|
# rectangular `m != n`, the gap for the smallest nonzero singular value
|
|
# should also consider the gap between it and zero. Note that this code
|
|
# relies on the singular values being in descending order.
|
|
def compute_absolute_gap(s, m, n):
|
|
forward_appendant = np.Inf if m == n else 0
|
|
forward_diff = jnp.diff(s, axis=-1, append=forward_appendant)
|
|
backward_diff = jnp.diff(
|
|
s[..., ::-1], axis=-1, append=np.Inf)[..., ::-1]
|
|
absolute_gap = jnp.minimum(jnp.abs(forward_diff),
|
|
jnp.abs(backward_diff))
|
|
return absolute_gap
|
|
|
|
# See `CompareSingularVectors` in
|
|
# tensorflow/python/kernel_tests/linalg/svd_op_test.py
|
|
def compare_singular_vectors(x, y, *, error_bound):
|
|
# Singular vectors are only unique up to sign (complex phase factor for
|
|
# complex matrices), so we normalize the sign first.
|
|
sum_of_ratios = jnp.sum(jnp.divide(y, x), -2, keepdims=True)
|
|
phases = jnp.divide(sum_of_ratios, jnp.abs(sum_of_ratios))
|
|
x *= phases
|
|
|
|
# Note that in general `sqrt(sum(squares))` is not a stable way to
|
|
# compute l2 vector norms, but it should be OK for normalization
|
|
# factors of vectors with norm ~= 1 as here.
|
|
def dot_column_wise(a, b):
|
|
output = jnp.sum(jnp.einsum('...ij,...ij->...ij', a.conj(), b,
|
|
precision=lax.Precision.HIGHEST),
|
|
axis=-2)
|
|
return jnp.real(output)
|
|
|
|
cos_angular_diff = (
|
|
dot_column_wise(x, y) /
|
|
jnp.sqrt(dot_column_wise(x, x) * dot_column_wise(y, y)))
|
|
|
|
# Values of `\cos(angular_diff)` outside the interval [0, 1] are clipped
|
|
# to the interval edges. For example, `\cos(angular_diff)` could contain
|
|
# values like 1.0000001 on float32, which are clipped to 1.0. It is
|
|
# possible that anything other than `cos_angular_diff` can be outside
|
|
# the interval [0, 1] due to roundoff.
|
|
cos_angular_diff = jnp.clip(cos_angular_diff, a_min=0.0, a_max=1.0)
|
|
|
|
angular_diff = jnp.arccos(cos_angular_diff)
|
|
|
|
# TODO: removes the slack factor on the angular difference.
|
|
# It is possible that the singular vectors are not accurate to much more
|
|
# than O(\sqrt(eps)), which is likely a property of the SVD algorithms
|
|
# in question; revisit with better understanding of the SVD algorithms.
|
|
if x.dtype in [np.float32, np.complex64]:
|
|
slack_factor = 2E4
|
|
elif x.dtype in [np.float64, np.complex128]:
|
|
slack_factor = 2E9
|
|
|
|
np.testing.assert_array_less(angular_diff,
|
|
slack_factor * error_bound)
|
|
|
|
if compute_uv:
|
|
# Compares the shapes.
|
|
compare_shapes(r_jax, r_tf)
|
|
|
|
# Compares the singular values. Each computed singular value `\sigma_i`
|
|
# differs from the true `\sigma_i`* by at most
|
|
# `|\sigma_i - \sigma_i*| <= \epsilon \sigma_1`, where `\sigma_1` is the
|
|
# largest singular value and `\epsilon` denotes the machine precision.
|
|
s_jax, s_tf = r_jax[0], r_tf[0]
|
|
tst.assertAllClose(s_jax, s_tf, atol=tol, rtol=tol, err_msg=err_msg)
|
|
|
|
# Compares the reconstructed operand.
|
|
compare_reconstructed_operand(r_jax, r_tf, tol)
|
|
|
|
# Compares the singular vectors.
|
|
# We only compare the first `rank` singular vectors since the remainder
|
|
# forms an arbitrary orthonormal basis for the (row- or column-) null
|
|
# space, whose exact value depends on implementation details.
|
|
# TODO: A better estimation on the rank?
|
|
rank = r_jax[0].shape[-1]
|
|
|
|
# Computes the upper bound for angular difference of singular vectors.
|
|
# The upper bound has the shape of `[..., k]`, where `...` denotes the
|
|
# batch dimensions and `k` is the number of nonzero singular values.
|
|
m = r_jax[1].shape[-2]
|
|
n = r_jax[2].shape[-2]
|
|
absolute_gap = compute_absolute_gap(r_jax[0], m, n)
|
|
epsilon = jnp.finfo(r_jax[0].dtype).eps
|
|
sigma_largest = (r_jax[0][..., 0])[..., None]
|
|
upperbound_singular_vectors = epsilon * sigma_largest / absolute_gap
|
|
upperbound_singular_vectors = upperbound_singular_vectors[..., :rank]
|
|
|
|
# Left singular vectors.
|
|
u_jax = r_jax[1][..., :rank]
|
|
u_tf = r_tf[1][..., :rank]
|
|
compare_singular_vectors(u_jax, u_tf,
|
|
error_bound=upperbound_singular_vectors)
|
|
|
|
# Right singular vectors.
|
|
v_jax = jnp.swapaxes(r_jax[2][..., :rank, :], -2, -1).conj()
|
|
v_tf = jnp.swapaxes(r_tf[2][..., :rank, :], -2, -1).conj()
|
|
compare_singular_vectors(v_jax, v_tf,
|
|
error_bound=upperbound_singular_vectors)
|
|
else:
|
|
tst.assertAllClose(r_jax, r_tf, atol=tol, rtol=tol, err_msg=err_msg)
|
|
|
|
return [
|
|
# Works in JAX for complex due to custom calls on cpu and gpu
|
|
Jax2TfLimitation(
|
|
"function not compilable. Implemented using `tf.linalg.svd` and `tf.linalg.adjoint`",
|
|
dtypes=[np.complex64, np.complex128],
|
|
devices=("cpu", "gpu"),
|
|
modes=("compiled",)),
|
|
Jax2TfLimitation(
|
|
"Large numerical discrepancy",
|
|
dtypes=[np.float16],
|
|
devices=("tpu"),
|
|
modes=("eager", "graph", "compiled"),
|
|
skip_comparison=True),
|
|
missing_tf_kernel(dtypes=[dtypes.bfloat16], devices="tpu"),
|
|
custom_numeric(
|
|
tol=1e-4,
|
|
dtypes=[np.float32, np.complex64],
|
|
devices=("cpu", "gpu"),
|
|
modes=("eager", "graph", "compiled")),
|
|
# TODO: this is very low tolerance for f64
|
|
custom_numeric(
|
|
tol=1e-4,
|
|
dtypes=[np.float64, np.complex128],
|
|
devices=("cpu", "gpu"),
|
|
modes=("eager", "graph", "compiled")),
|
|
custom_numeric(
|
|
tol=1e-4,
|
|
description="custom numeric comparison when compute_uv on CPU/GPU",
|
|
custom_assert=custom_assert,
|
|
devices=("cpu", "gpu"),
|
|
modes=("eager", "graph", "compiled"),
|
|
enabled=(compute_uv == True)),
|
|
custom_numeric(
|
|
tol=1e-2,
|
|
description="custom numeric comparison when compute_uv on TPU",
|
|
dtypes=[np.float32, np.float64, np.complex64, np.complex128],
|
|
custom_assert=custom_assert,
|
|
devices=("tpu"),
|
|
modes=("eager", "graph", "compiled"),
|
|
enabled=(compute_uv == True)),
|
|
]
|
|
|
|
@classmethod
|
|
def tan(cls, harness):
|
|
return [
|
|
custom_numeric(dtypes=np.complex64, devices="tpu", tol=1e-4),
|
|
custom_numeric(dtypes=np.complex64, devices=("cpu", "gpu"), tol=1e-3),
|
|
custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12)
|
|
]
|
|
|
|
@classmethod
|
|
def tanh(cls, harness):
|
|
return [
|
|
custom_numeric(dtypes=np.complex128, tol=1e-7),
|
|
custom_numeric(dtypes=np.complex64, tol=1e-4)
|
|
]
|
|
|
|
@classmethod
|
|
def top_k(cls, harness):
|
|
|
|
def custom_assert(tst, result_jax, result_tf, *, err_msg, **_):
|
|
assert len(result_jax) == len(result_tf)
|
|
# TODO: TF and JAX sort [inf, nan] differently.
|
|
first_arr_jax, first_arr_tf = result_jax[0], result_tf[0]
|
|
if np.all(first_arr_jax == first_arr_tf):
|
|
for arr_jax, arr_tf in zip(result_jax, result_tf):
|
|
tst.assertArraysEqual(arr_jax, arr_tf, err_msg=err_msg)
|
|
else:
|
|
mask_jax, mask_tf = np.isnan(first_arr_jax), np.isnan(first_arr_tf)
|
|
tst.assertArraysEqual(
|
|
first_arr_jax[~mask_jax], first_arr_tf[~mask_tf], err_msg=err_msg)
|
|
|
|
return [
|
|
custom_numeric(
|
|
dtypes=[np.float16, dtypes.bfloat16, np.float32, np.float64],
|
|
custom_assert=custom_assert,
|
|
description=(
|
|
"Produces different results when the array contains `inf` and `NaN`"
|
|
" (they are sorted differently in TF vs. XLA)."))
|
|
]
|
|
|
|
@classmethod
|
|
def triangular_solve(cls, harness: primitive_harness.Harness):
|
|
return [
|
|
missing_tf_kernel(dtypes=[dtypes.bfloat16]),
|
|
missing_tf_kernel(
|
|
dtypes=[np.float16],
|
|
devices=("gpu", "cpu"),
|
|
modes=("eager", "graph")),
|
|
custom_numeric(dtypes=np.float32, tol=5e-3)
|
|
]
|
|
|
|
@classmethod
|
|
def tridiagonal_solve(cls, harness: primitive_harness.Harness):
|
|
return []
|
|
|
|
def custom_numeric(
|
|
*,
|
|
description="custom numeric comparison",
|
|
dtypes=(), # All
|
|
modes=(
|
|
"eager",
|
|
"graph",
|
|
), # By default we should not need tolerance for
|
|
# "compiled"
|
|
devices=("cpu", "gpu", "tpu"),
|
|
custom_assert=None,
|
|
enabled=True,
|
|
tol=None) -> Jax2TfLimitation:
|
|
|
|
return Jax2TfLimitation(
|
|
description,
|
|
expect_tf_error=False,
|
|
dtypes=dtypes,
|
|
devices=devices,
|
|
modes=modes,
|
|
custom_assert=custom_assert,
|
|
enabled=enabled,
|
|
tol=tol)
|
|
|
|
|
|
def missing_tf_kernel(*,
|
|
description="op not defined for dtype",
|
|
dtypes,
|
|
modes=("eager", "graph", "compiled"),
|
|
devices=("cpu", "gpu", "tpu"),
|
|
enabled=True) -> Jax2TfLimitation:
|
|
|
|
return Jax2TfLimitation(
|
|
description, dtypes=dtypes, devices=devices, modes=modes, enabled=enabled)
|