mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Roll-back https://github.com/google/jax/pull/14144 due to downstream test failures
PiperOrigin-RevId: 504628432
This commit is contained in:
parent
b58459c4ba
commit
78599e65d1
@ -16,7 +16,7 @@ from functools import partial
|
||||
import operator
|
||||
|
||||
from jax import config
|
||||
from jax.tree_util import tree_map, tree_reduce, tree_leaves
|
||||
from jax.tree_util import tree_map, tree_reduce
|
||||
from jax._src import api
|
||||
from jax._src import dtypes as _dtypes
|
||||
from jax._src.config import flags
|
||||
@ -33,9 +33,10 @@ __all__ = ['check_grads', 'check_jvp', 'check_vjp']
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
EPS = 1.0 / 2048
|
||||
EPS = 1e-4
|
||||
_fp8_enabled = xla_client._version >= 117
|
||||
|
||||
|
||||
def _dtype(x):
|
||||
if hasattr(x, 'dtype'):
|
||||
return x.dtype
|
||||
@ -196,20 +197,7 @@ def rand_like(rng, x):
|
||||
return result.item() if is_python_scalar(x) else result
|
||||
|
||||
|
||||
def numerical_jvp(f, primals, tangents, eps=None):
|
||||
if eps is None:
|
||||
t = _dtypes.result_type(*tree_leaves(primals))
|
||||
# Assuming the roundoff error in the evaluation of the finite difference
|
||||
# below is a few times eps_m*(|f_pos| + |f_neg|), where
|
||||
# eps_m = np.finfo(t).eps, then the pareto optimal step size that roughly
|
||||
# balances roundof error and truncation error is O(eps_m^1/3).
|
||||
# The constant was determined heuristically to minimize the error
|
||||
# tolerances in the testOpGrad unit test.
|
||||
eps = (np.finfo(t).eps ** (1.0 / 3.0)) / 8
|
||||
# Find the nearest power of 2 for eps. This makes the multiplications
|
||||
# and divisions by eps below lossless in floating point, and improves
|
||||
# the accuracy of the finite difference approximation in some cases.
|
||||
eps = 2.0 ** np.floor(np.log2(eps))
|
||||
def numerical_jvp(f, primals, tangents, eps=EPS):
|
||||
delta = scalar_mul(tangents, eps)
|
||||
f_pos = f(*add(primals, delta))
|
||||
f_neg = f(*sub(primals, delta))
|
||||
@ -227,7 +215,7 @@ def _merge_tolerance(tol, default):
|
||||
return out
|
||||
|
||||
|
||||
def check_jvp(f, f_jvp, args, atol=None, rtol=None, eps=None, err_msg=''):
|
||||
def check_jvp(f, f_jvp, args, atol=None, rtol=None, eps=EPS, err_msg=''):
|
||||
atol = _merge_tolerance(atol, default_gradient_tolerance)
|
||||
rtol = _merge_tolerance(rtol, default_gradient_tolerance)
|
||||
rng = np.random.RandomState(0)
|
||||
@ -246,7 +234,7 @@ def check_jvp(f, f_jvp, args, atol=None, rtol=None, eps=None, err_msg=''):
|
||||
err_msg=f'{err_msg} tangent' if err_msg else 'tangent')
|
||||
|
||||
|
||||
def check_vjp(f, f_vjp, args, atol=None, rtol=None, eps=None, err_msg=''):
|
||||
def check_vjp(f, f_vjp, args, atol=None, rtol=None, eps=EPS, err_msg=''):
|
||||
atol = _merge_tolerance(atol, default_gradient_tolerance)
|
||||
rtol = _merge_tolerance(rtol, default_gradient_tolerance)
|
||||
_rand_like = partial(rand_like, np.random.RandomState(0))
|
||||
@ -286,6 +274,7 @@ def check_grads(f, args, order,
|
||||
AssertionError: if gradients do not match.
|
||||
"""
|
||||
args = tuple(args)
|
||||
eps = eps or EPS
|
||||
|
||||
_check_jvp = partial(check_jvp, atol=atol, rtol=rtol, eps=eps)
|
||||
_check_vjp = partial(check_vjp, atol=atol, rtol=rtol, eps=eps)
|
||||
|
@ -1645,7 +1645,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
atol = 1e-5
|
||||
self.assertAllClose(ans, expected, check_dtypes=False, rtol=rtol, atol=atol)
|
||||
|
||||
rtol = 6e-3 if scan is not scan_with_new_checkpoint2 else 5e-2
|
||||
rtol = 5e-3 if scan is not scan_with_new_checkpoint2 else 5e-2
|
||||
atol = 5e-2 if "tpu" in jtu.device_under_test() else 1e-3
|
||||
jtu.check_grads(partial(scan, f), (c, as_), order=2, modes=["rev"],
|
||||
atol=atol, rtol=rtol)
|
||||
|
@ -62,6 +62,7 @@ def check_grads(f, args, order, atol=None, rtol=None, eps=None):
|
||||
default_tol = 1e-6 if config.x64_enabled else 1e-2
|
||||
atol = atol or default_tol
|
||||
rtol = rtol or default_tol
|
||||
eps = eps or default_tol
|
||||
jtu.check_jvp(f, partial(jax.jvp, f), args, atol, rtol, eps)
|
||||
jtu.check_vjp(f, partial(jax.vjp, f), args, atol, rtol, eps)
|
||||
|
||||
|
@ -5075,16 +5075,13 @@ class NumpyGradTests(jtu.JaxTestCase):
|
||||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
|
||||
def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol):
|
||||
rng = rng_factory(self.rng())
|
||||
if jtu.device_under_test() == 'tpu':
|
||||
# TODO(rmlarsen): These tolerances are dominated by the inaccurate
|
||||
# implementation of float32 logarithms on TPUs. Remove this exception
|
||||
# when TPU logarithms are improved.
|
||||
tol = jtu.join_tolerance(tol, {np.float32: 5e-2, np.complex64: 5e-2})
|
||||
else:
|
||||
tol = jtu.join_tolerance(tol, {np.float32: 2e-3,np.float64: 1e-8,
|
||||
np.complex64: 2e-3, np.complex128: 1e-8})
|
||||
tol = jtu.join_tolerance(tol, {np.float32: 1e-1, np.float64: 1e-3,
|
||||
np.complex64: 1e-1, np.complex128: 1e-3})
|
||||
if jtu.device_under_test() == 'tpu' and op == jnp.arctanh:
|
||||
tol = jtu.join_tolerance(tol, {np.float32: 2e-1})
|
||||
|
||||
args = tuple(rng(shape, dtype) for shape in shapes)
|
||||
check_grads(op, args, order, ['fwd', 'rev'], tol, tol)
|
||||
check_grads(op, args, order, ["fwd", "rev"], tol, tol)
|
||||
|
||||
@parameterized.parameters(itertools.chain.from_iterable(
|
||||
jtu.sample_product_testcases(
|
||||
@ -5093,10 +5090,8 @@ class NumpyGradTests(jtu.JaxTestCase):
|
||||
)
|
||||
for rec in GRAD_SPECIAL_VALUE_TEST_RECORDS))
|
||||
def testOpGradSpecialValue(self, op, special_value, order):
|
||||
tol = None
|
||||
if jtu.device_under_test() == 'tpu' and op == jnp.arccosh:
|
||||
tol = 4e-3
|
||||
check_grads(op, (special_value,), order, ['fwd', 'rev'], tol, tol)
|
||||
check_grads(op, (special_value,), order, ["fwd", "rev"],
|
||||
atol={np.float32: 3e-3})
|
||||
|
||||
def testSincAtZero(self):
|
||||
# Some manual tests for sinc at zero, since it doesn't have well-behaved
|
||||
@ -5148,11 +5143,11 @@ class NumpyGradTests(jtu.JaxTestCase):
|
||||
def testGradLogaddexpComplex(self, shapes, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args = tuple(jnp.array(rng(shape, dtype)) for shape in shapes)
|
||||
if jtu.device_under_test() != 'tpu' and config.jax_enable_x64:
|
||||
tol = 1e-5
|
||||
if jtu.device_under_test() == "tpu":
|
||||
tol = 5e-2
|
||||
else:
|
||||
tol = 2e-2
|
||||
check_grads(jnp.logaddexp, args, 1, ['fwd', 'rev'], tol, tol)
|
||||
tol = 3e-2
|
||||
check_grads(jnp.logaddexp, args, 1, ["fwd", "rev"], tol, tol)
|
||||
|
||||
@jtu.sample_product(
|
||||
shapes=filter(_shapes_are_broadcast_compatible,
|
||||
@ -5163,11 +5158,11 @@ class NumpyGradTests(jtu.JaxTestCase):
|
||||
def testGradLogaddexp2Complex(self, shapes, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args = tuple(jnp.array(rng(shape, dtype)) for shape in shapes)
|
||||
if jtu.device_under_test() != 'tpu' and config.jax_enable_x64:
|
||||
tol = 1e-5
|
||||
if jtu.device_under_test() == "tpu":
|
||||
tol = 5e-2
|
||||
else:
|
||||
tol = 2e-2
|
||||
check_grads(jnp.logaddexp2, args, 1, ['fwd', 'rev'], tol, tol)
|
||||
tol = 3e-2
|
||||
check_grads(jnp.logaddexp2, args, 1, ["fwd", "rev"], tol, tol)
|
||||
|
||||
|
||||
class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user