Roll-back https://github.com/google/jax/pull/14144 due to downstream test failures

PiperOrigin-RevId: 504628432
This commit is contained in:
jax authors 2023-01-25 12:15:00 -08:00
parent b58459c4ba
commit 78599e65d1
4 changed files with 25 additions and 40 deletions

View File

@ -16,7 +16,7 @@ from functools import partial
import operator
from jax import config
from jax.tree_util import tree_map, tree_reduce, tree_leaves
from jax.tree_util import tree_map, tree_reduce
from jax._src import api
from jax._src import dtypes as _dtypes
from jax._src.config import flags
@ -33,9 +33,10 @@ __all__ = ['check_grads', 'check_jvp', 'check_vjp']
FLAGS = flags.FLAGS
EPS = 1.0 / 2048
EPS = 1e-4
_fp8_enabled = xla_client._version >= 117
def _dtype(x):
if hasattr(x, 'dtype'):
return x.dtype
@ -196,20 +197,7 @@ def rand_like(rng, x):
return result.item() if is_python_scalar(x) else result
def numerical_jvp(f, primals, tangents, eps=None):
if eps is None:
t = _dtypes.result_type(*tree_leaves(primals))
# Assuming the roundoff error in the evaluation of the finite difference
# below is a few times eps_m*(|f_pos| + |f_neg|), where
# eps_m = np.finfo(t).eps, then the pareto optimal step size that roughly
# balances roundof error and truncation error is O(eps_m^1/3).
# The constant was determined heuristically to minimize the error
# tolerances in the testOpGrad unit test.
eps = (np.finfo(t).eps ** (1.0 / 3.0)) / 8
# Find the nearest power of 2 for eps. This makes the multiplications
# and divisions by eps below lossless in floating point, and improves
# the accuracy of the finite difference approximation in some cases.
eps = 2.0 ** np.floor(np.log2(eps))
def numerical_jvp(f, primals, tangents, eps=EPS):
delta = scalar_mul(tangents, eps)
f_pos = f(*add(primals, delta))
f_neg = f(*sub(primals, delta))
@ -227,7 +215,7 @@ def _merge_tolerance(tol, default):
return out
def check_jvp(f, f_jvp, args, atol=None, rtol=None, eps=None, err_msg=''):
def check_jvp(f, f_jvp, args, atol=None, rtol=None, eps=EPS, err_msg=''):
atol = _merge_tolerance(atol, default_gradient_tolerance)
rtol = _merge_tolerance(rtol, default_gradient_tolerance)
rng = np.random.RandomState(0)
@ -246,7 +234,7 @@ def check_jvp(f, f_jvp, args, atol=None, rtol=None, eps=None, err_msg=''):
err_msg=f'{err_msg} tangent' if err_msg else 'tangent')
def check_vjp(f, f_vjp, args, atol=None, rtol=None, eps=None, err_msg=''):
def check_vjp(f, f_vjp, args, atol=None, rtol=None, eps=EPS, err_msg=''):
atol = _merge_tolerance(atol, default_gradient_tolerance)
rtol = _merge_tolerance(rtol, default_gradient_tolerance)
_rand_like = partial(rand_like, np.random.RandomState(0))
@ -286,6 +274,7 @@ def check_grads(f, args, order,
AssertionError: if gradients do not match.
"""
args = tuple(args)
eps = eps or EPS
_check_jvp = partial(check_jvp, atol=atol, rtol=rtol, eps=eps)
_check_vjp = partial(check_vjp, atol=atol, rtol=rtol, eps=eps)

View File

@ -1645,7 +1645,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
atol = 1e-5
self.assertAllClose(ans, expected, check_dtypes=False, rtol=rtol, atol=atol)
rtol = 6e-3 if scan is not scan_with_new_checkpoint2 else 5e-2
rtol = 5e-3 if scan is not scan_with_new_checkpoint2 else 5e-2
atol = 5e-2 if "tpu" in jtu.device_under_test() else 1e-3
jtu.check_grads(partial(scan, f), (c, as_), order=2, modes=["rev"],
atol=atol, rtol=rtol)

View File

@ -62,6 +62,7 @@ def check_grads(f, args, order, atol=None, rtol=None, eps=None):
default_tol = 1e-6 if config.x64_enabled else 1e-2
atol = atol or default_tol
rtol = rtol or default_tol
eps = eps or default_tol
jtu.check_jvp(f, partial(jax.jvp, f), args, atol, rtol, eps)
jtu.check_vjp(f, partial(jax.vjp, f), args, atol, rtol, eps)

View File

@ -5075,16 +5075,13 @@ class NumpyGradTests(jtu.JaxTestCase):
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol):
rng = rng_factory(self.rng())
if jtu.device_under_test() == 'tpu':
# TODO(rmlarsen): These tolerances are dominated by the inaccurate
# implementation of float32 logarithms on TPUs. Remove this exception
# when TPU logarithms are improved.
tol = jtu.join_tolerance(tol, {np.float32: 5e-2, np.complex64: 5e-2})
else:
tol = jtu.join_tolerance(tol, {np.float32: 2e-3,np.float64: 1e-8,
np.complex64: 2e-3, np.complex128: 1e-8})
tol = jtu.join_tolerance(tol, {np.float32: 1e-1, np.float64: 1e-3,
np.complex64: 1e-1, np.complex128: 1e-3})
if jtu.device_under_test() == 'tpu' and op == jnp.arctanh:
tol = jtu.join_tolerance(tol, {np.float32: 2e-1})
args = tuple(rng(shape, dtype) for shape in shapes)
check_grads(op, args, order, ['fwd', 'rev'], tol, tol)
check_grads(op, args, order, ["fwd", "rev"], tol, tol)
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
@ -5093,10 +5090,8 @@ class NumpyGradTests(jtu.JaxTestCase):
)
for rec in GRAD_SPECIAL_VALUE_TEST_RECORDS))
def testOpGradSpecialValue(self, op, special_value, order):
tol = None
if jtu.device_under_test() == 'tpu' and op == jnp.arccosh:
tol = 4e-3
check_grads(op, (special_value,), order, ['fwd', 'rev'], tol, tol)
check_grads(op, (special_value,), order, ["fwd", "rev"],
atol={np.float32: 3e-3})
def testSincAtZero(self):
# Some manual tests for sinc at zero, since it doesn't have well-behaved
@ -5148,11 +5143,11 @@ class NumpyGradTests(jtu.JaxTestCase):
def testGradLogaddexpComplex(self, shapes, dtype):
rng = jtu.rand_default(self.rng())
args = tuple(jnp.array(rng(shape, dtype)) for shape in shapes)
if jtu.device_under_test() != 'tpu' and config.jax_enable_x64:
tol = 1e-5
if jtu.device_under_test() == "tpu":
tol = 5e-2
else:
tol = 2e-2
check_grads(jnp.logaddexp, args, 1, ['fwd', 'rev'], tol, tol)
tol = 3e-2
check_grads(jnp.logaddexp, args, 1, ["fwd", "rev"], tol, tol)
@jtu.sample_product(
shapes=filter(_shapes_are_broadcast_compatible,
@ -5163,11 +5158,11 @@ class NumpyGradTests(jtu.JaxTestCase):
def testGradLogaddexp2Complex(self, shapes, dtype):
rng = jtu.rand_default(self.rng())
args = tuple(jnp.array(rng(shape, dtype)) for shape in shapes)
if jtu.device_under_test() != 'tpu' and config.jax_enable_x64:
tol = 1e-5
if jtu.device_under_test() == "tpu":
tol = 5e-2
else:
tol = 2e-2
check_grads(jnp.logaddexp2, args, 1, ['fwd', 'rev'], tol, tol)
tol = 3e-2
check_grads(jnp.logaddexp2, args, 1, ["fwd", "rev"], tol, tol)
class NumpySignaturesTest(jtu.JaxTestCase):