mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #1295 from google/more-lax-autodiff-tests
add special value grad tests
This commit is contained in:
commit
110634d50d
@ -26,11 +26,8 @@ Operators
|
||||
abs
|
||||
add
|
||||
acos
|
||||
acosh
|
||||
asin
|
||||
asinh
|
||||
atan
|
||||
atanh
|
||||
atan2
|
||||
batch_matmul
|
||||
bitcast_convert_type
|
||||
|
@ -18,7 +18,8 @@ from .lax import (_reduce_sum, _reduce_max, _reduce_min, _reduce_or,
|
||||
_reduce_and, _reduce_window_sum, _reduce_window_max,
|
||||
_reduce_window_min, _reduce_window_prod, _float, _complex,
|
||||
_input_dtype, _const, _eq_meet, _safe_mul,
|
||||
_broadcasting_select, _check_user_dtype_supported)
|
||||
_broadcasting_select, _check_user_dtype_supported,
|
||||
_one, _const)
|
||||
from .lax_control_flow import *
|
||||
from .lax_fft import *
|
||||
from .lax_parallel import *
|
||||
|
@ -1377,41 +1377,6 @@ def cosh(x):
|
||||
# This formulation avoids overflow when e^x is inf but e^x/2 is not inf.
|
||||
return add(exp(add(log_half, x)), exp(sub(log_half, x)))
|
||||
|
||||
def asinh(x):
|
||||
r"""Elementwise arc hyperbolic sine: :math:`\mathrm{asinh}(x)`."""
|
||||
# asinh(x) = log(x + sqrt(x**2 + 1))
|
||||
result = log(add(x, sqrt(add(mul(x, x), _const(x, 1)))))
|
||||
if onp.issubdtype(_dtype(result), onp.complexfloating):
|
||||
return result
|
||||
a = abs(x)
|
||||
sqrt_max_value = onp.sqrt(onp.finfo(_dtype(x)).max)
|
||||
return select(lt(a, _const(a, sqrt_max_value)),
|
||||
result,
|
||||
mul(sign(x), add(log(a), _const(a, onp.log(2.)))))
|
||||
|
||||
def acosh(x):
|
||||
r"""Elementwise arc hyperbolic cosine: :math:`\mathrm{acosh}(x)`."""
|
||||
# acosh(x) = log(x + sqrt((x + 1) * (x - 1))) if x < sqrt_max_value
|
||||
# log(x) + log(2) otherwise
|
||||
sqrt_max_value = onp.sqrt(onp.finfo(_dtype(x)).max)
|
||||
result = log(add(x, mul(sqrt(add(x, _const(x, 1))),
|
||||
sqrt(sub(x, _const(x, 1))))))
|
||||
if onp.issubdtype(_dtype(result), onp.complexfloating):
|
||||
return result
|
||||
return select(
|
||||
lt(x, _const(x, sqrt_max_value)),
|
||||
result,
|
||||
add(log(x), _const(x, onp.log(2.))))
|
||||
|
||||
|
||||
def atanh(x):
|
||||
r"""Elementwise arc hyperbolic tangent: :math:`\mathrm{atanh}(x)`."""
|
||||
# atanh(x) = 0.5 * log((1 + x) / (1 - x))
|
||||
result = mul(_const(x, 0.5), log(div(add(_const(x, 1), x),
|
||||
sub(_const(x, 1), x))))
|
||||
if onp.issubdtype(_dtype(result), onp.complexfloating):
|
||||
return result
|
||||
return select(le(abs(x), _one(x)), result, full_like(x, onp.nan))
|
||||
|
||||
# Add some methods to ShapedArray that rely on lax primitives
|
||||
|
||||
|
@ -57,8 +57,6 @@ acos = onp.arccos
|
||||
atan = onp.arctan
|
||||
sinh = onp.sinh
|
||||
cosh = onp.cosh
|
||||
asinh = onp.arcsinh
|
||||
acosh = onp.arccosh
|
||||
|
||||
lgamma = scipy.special.gammaln
|
||||
digamma = scipy.special.digamma
|
||||
|
@ -41,7 +41,7 @@ import opt_einsum
|
||||
import six
|
||||
from six.moves import builtins, xrange
|
||||
|
||||
from jax import jit, device_put
|
||||
from jax import jit, device_put, custom_transforms, defjvp
|
||||
from .. import core
|
||||
from ..abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray
|
||||
from ..config import flags
|
||||
@ -327,9 +327,6 @@ arctan = _one_to_one_unop(onp.arctan, lax.atan, True)
|
||||
sinh = _one_to_one_unop(onp.sinh, lax.sinh, True)
|
||||
cosh = _one_to_one_unop(onp.cosh, lax.cosh, True)
|
||||
tanh = _one_to_one_unop(onp.tanh, lax.tanh, True)
|
||||
arcsinh = _one_to_one_unop(onp.arcsinh, lax.asinh, True)
|
||||
arccosh = _one_to_one_unop(onp.arccosh, lax.acosh, True)
|
||||
arctanh = _one_to_one_unop(onp.arctanh, lax.atanh, True)
|
||||
sqrt = _one_to_one_unop(onp.sqrt, lax.sqrt, True)
|
||||
|
||||
|
||||
@ -573,6 +570,47 @@ def sinc(x):
|
||||
lax._const(x, 1), lax.div(lax.sin(pi_x), pi_x))
|
||||
|
||||
|
||||
@_wraps(onp.arcsinh)
|
||||
@custom_transforms
|
||||
def arcsinh(x):
|
||||
# asinh(x) = log(x + sqrt(x**2 + 1))
|
||||
x, = _promote_to_result_dtype(onp.arcsinh, x)
|
||||
one = lax._const(x, 1)
|
||||
result = lax.log(x + lax.sqrt(x * x + one))
|
||||
if onp.issubdtype(_dtype(result), onp.complexfloating):
|
||||
return result
|
||||
a = abs(x)
|
||||
sqrt_max_value = onp.sqrt(onp.finfo(_dtype(x)).max)
|
||||
log2 = lax._const(a, onp.log(2))
|
||||
return lax.select(a < sqrt_max_value, result, lax.sign(x) * (lax.log(a) + log2))
|
||||
defjvp(arcsinh, lambda g, ans, x: g / lax.sqrt(lax._const(x, 1) + square(x)))
|
||||
|
||||
|
||||
@_wraps(onp.arccosh)
|
||||
def arccosh(x):
|
||||
# acosh(x) = log(x + sqrt((x + 1) * (x - 1))) if x < sqrt_max_value
|
||||
# log(x) + log(2) otherwise
|
||||
x, = _promote_to_result_dtype(onp.arccosh, x)
|
||||
one = lax._const(x, 1)
|
||||
result = lax.log(x + lax.sqrt((x + one) * (x - one)))
|
||||
if onp.issubdtype(_dtype(result), onp.complexfloating):
|
||||
return result
|
||||
sqrt_max_value = onp.sqrt(onp.finfo(_dtype(x)).max)
|
||||
log2 = lax._const(x, onp.log(2))
|
||||
return lax.select(x < sqrt_max_value, result, lax.log(x) + log2)
|
||||
|
||||
|
||||
@_wraps(onp.arctanh)
|
||||
def arctanh(x):
|
||||
# atanh(x) = 0.5 * log((1 + x) / (1 - x))
|
||||
x, = _promote_to_result_dtype(onp.arctanh, x)
|
||||
one = lax._const(x, 1)
|
||||
result = lax._const(x, 0.5) * lax.log((one + x) / (one - x))
|
||||
if onp.issubdtype(_dtype(result), onp.complexfloating):
|
||||
return result
|
||||
return lax.select(abs(x) <= 1, result, lax.full_like(x, onp.nan))
|
||||
|
||||
|
||||
@_wraps(onp.transpose)
|
||||
def transpose(x, axes=None):
|
||||
axes = onp.arange(ndim(x))[::-1] if axes is None else axes
|
||||
|
@ -36,6 +36,7 @@ from jax import api
|
||||
from jax import lax
|
||||
from jax import numpy as lnp
|
||||
from jax import test_util as jtu
|
||||
from jax.test_util import check_grads
|
||||
from jax.lib import xla_bridge
|
||||
|
||||
from jax.config import config
|
||||
@ -1864,5 +1865,59 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
FLAGS.jax_numpy_rank_promotion = prev_flag
|
||||
|
||||
|
||||
# Most grad tests are at the lax level (see lax_test.py), but we add some here
|
||||
# as needed for e.g. particular compound ops of interest.
|
||||
|
||||
GradTestSpec = collections.namedtuple(
|
||||
"GradTestSpec", ["op", "nargs", "order", "rng", "dtypes", "name", "tol"])
|
||||
def grad_test_spec(op, nargs, order, rng, dtypes, name=None, tol=None):
|
||||
return GradTestSpec(op, nargs, order, rng, dtypes, name or op.__name__, tol)
|
||||
|
||||
GRAD_TEST_RECORDS = [
|
||||
grad_test_spec(lnp.arcsinh, nargs=1, order=2, rng=jtu.rand_positive(),
|
||||
dtypes=[onp.float64, onp.complex64], tol=1e-4),
|
||||
grad_test_spec(lnp.arccosh, nargs=1, order=2, rng=jtu.rand_positive(),
|
||||
dtypes=[onp.float64, onp.complex64], tol=1e-4),
|
||||
grad_test_spec(lnp.arctanh, nargs=1, order=2, rng=jtu.rand_uniform(-0.9, 0.9),
|
||||
dtypes=[onp.float64, onp.complex64], tol=1e-4),
|
||||
]
|
||||
|
||||
GradSpecialValuesTestSpec = collections.namedtuple(
|
||||
"GradSpecialValuesTestSpec", ["op", "values"])
|
||||
|
||||
GRAD_SPECIAL_VALUE_TEST_RECORDS = [
|
||||
GradSpecialValuesTestSpec(lnp.arcsinh, [0., 1000.]),
|
||||
GradSpecialValuesTestSpec(lnp.arccosh, [1000.]),
|
||||
GradSpecialValuesTestSpec(lnp.arctanh, [0.]),
|
||||
]
|
||||
|
||||
def num_float_bits(dtype):
|
||||
return onp.finfo(xla_bridge.canonicalize_dtype(dtype)).bits
|
||||
|
||||
class NumpyGradTests(jtu.JaxTestCase):
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(
|
||||
rec.name, shapes, itertools.repeat(dtype)),
|
||||
"op": rec.op, "rng": rec.rng, "shapes": shapes, "dtype": dtype,
|
||||
"order": rec.order, "tol": rec.tol}
|
||||
for shapes in CombosWithReplacement(nonempty_shapes, rec.nargs)
|
||||
for dtype in rec.dtypes)
|
||||
for rec in GRAD_TEST_RECORDS))
|
||||
def testOpGrad(self, op, rng, shapes, dtype, order, tol):
|
||||
tol = 1e-1 if num_float_bits(dtype) == 32 else tol
|
||||
args = tuple(rng(shape, dtype) for shape in shapes)
|
||||
check_grads(op, args, order, ["fwd", "rev"], tol, tol)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_{}".format(rec.op.__name__, special_value),
|
||||
"op": rec.op, "special_value": special_value}
|
||||
for special_value in rec.values)
|
||||
for rec in GRAD_SPECIAL_VALUE_TEST_RECORDS))
|
||||
def testOpGradSpecialValue(self, op, special_value):
|
||||
check_grads(op, (special_value,), 2, ["fwd", "rev"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
@ -98,8 +98,6 @@ LAX_OPS = [
|
||||
op_record(lax.atan, 1, float_dtypes, jtu.rand_small()),
|
||||
op_record(lax.sinh, 1, float_dtypes + complex_dtypes, jtu.rand_default()),
|
||||
op_record(lax.cosh, 1, float_dtypes + complex_dtypes, jtu.rand_default()),
|
||||
op_record(lax.asinh, 1, float_dtypes + complex_dtypes, jtu.rand_positive()),
|
||||
op_record(lax.acosh, 1, float_dtypes + complex_dtypes, jtu.rand_positive()),
|
||||
|
||||
op_record(lax.lgamma, 1, float_dtypes, jtu.rand_positive()),
|
||||
op_record(lax.digamma, 1, float_dtypes, jtu.rand_positive()),
|
||||
@ -1399,6 +1397,18 @@ class LaxTest(jtu.JaxTestCase):
|
||||
|
||||
api.jit(f)(1.) # doesn't crash
|
||||
|
||||
def testReshapeWithUnusualShapes(self):
|
||||
ans = lax.reshape(onp.ones((3,), onp.float32), (lax.add(1, 2), 1))
|
||||
self.assertAllClose(ans, onp.ones((3, 1), onp.float32), check_dtypes=True)
|
||||
|
||||
jtu.check_raises_regexp(
|
||||
lambda: lax.reshape(onp.ones(3,), (onp.array([3, 1]),)), TypeError,
|
||||
"Shapes must be 1D sequences of concrete values of integer type.*")
|
||||
|
||||
jtu.check_raises_regexp(
|
||||
lambda: lax.reshape(onp.ones(3,), (1.5, 2.0)), TypeError,
|
||||
"Shapes must be 1D sequences of concrete values of integer type.*")
|
||||
|
||||
|
||||
class DeviceConstantTest(jtu.JaxTestCase):
|
||||
def _CheckDeviceConstant(self, make_const, expected):
|
||||
@ -1502,12 +1512,22 @@ LAX_GRAD_OPS = [
|
||||
dtypes=[onp.float64, onp.complex64]),
|
||||
grad_test_spec(lax.log1p, nargs=1, order=2, rng=jtu.rand_positive(),
|
||||
dtypes=[onp.float64, onp.complex64]),
|
||||
grad_test_spec(lax.sinh, nargs=1, order=2, rng=jtu.rand_default(),
|
||||
dtypes=[onp.float64, onp.complex64], tol=1e-5),
|
||||
grad_test_spec(lax.cosh, nargs=1, order=2, rng=jtu.rand_default(),
|
||||
dtypes=[onp.float64, onp.complex64], tol=1e-5),
|
||||
grad_test_spec(lax.tanh, nargs=1, order=2, rng=jtu.rand_default(),
|
||||
dtypes=[onp.float64, onp.complex64], tol=1e-5),
|
||||
grad_test_spec(lax.sin, nargs=1, order=2, rng=jtu.rand_default(),
|
||||
dtypes=[onp.float64, onp.complex64]),
|
||||
grad_test_spec(lax.cos, nargs=1, order=2, rng=jtu.rand_default(),
|
||||
dtypes=[onp.float64, onp.complex64]),
|
||||
grad_test_spec(lax.tan, nargs=1, order=2, rng=jtu.rand_uniform(-1.3, 1.3),
|
||||
dtypes=[onp.float64, onp.complex64], tol=1e-3),
|
||||
grad_test_spec(lax.asin, nargs=1, order=2, rng=jtu.rand_uniform(-1., 1.),
|
||||
dtypes=[onp.float64], tol=1e-3),
|
||||
grad_test_spec(lax.acos, nargs=1, order=2, rng=jtu.rand_uniform(-1., 1.),
|
||||
dtypes=[onp.float64], tol=1e-3),
|
||||
# TODO(proteneer): atan2 input is already a representation of a
|
||||
# complex number. Need to think harder about what this even means
|
||||
# if each input itself is a complex number.
|
||||
@ -1556,6 +1576,23 @@ LAX_GRAD_OPS = [
|
||||
# dtypes=[onp.float64], name="MinSomeEqual"),
|
||||
]
|
||||
|
||||
GradSpecialValuesTestSpec = collections.namedtuple(
|
||||
"GradSpecialValuesTestSpec", ["op", "values"])
|
||||
|
||||
LAX_GRAD_SPECIAL_VALUE_TESTS = [
|
||||
GradSpecialValuesTestSpec(lax.sinh, [0.]),
|
||||
GradSpecialValuesTestSpec(lax.cosh, [0.]),
|
||||
GradSpecialValuesTestSpec(lax.tanh, [0., 1000.]),
|
||||
GradSpecialValuesTestSpec(lax.sin, [0., onp.pi, onp.pi/2., onp.pi/4.]),
|
||||
GradSpecialValuesTestSpec(lax.cos, [0., onp.pi, onp.pi/2., onp.pi/4.]),
|
||||
GradSpecialValuesTestSpec(lax.tan, [0.]),
|
||||
GradSpecialValuesTestSpec(lax.asin, [0.]),
|
||||
GradSpecialValuesTestSpec(lax.acos, [0.]),
|
||||
GradSpecialValuesTestSpec(lax.atan, [0., 1000.]),
|
||||
GradSpecialValuesTestSpec(lax.erf, [0., 10.]),
|
||||
GradSpecialValuesTestSpec(lax.erfc, [0., 10.]),
|
||||
]
|
||||
|
||||
|
||||
def check_grads_bilinear(f, args, order,
|
||||
modes=["fwd", "rev"], atol=None, rtol=None):
|
||||
@ -1581,13 +1618,21 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
for dtype in rec.dtypes)
|
||||
for rec in LAX_GRAD_OPS))
|
||||
def testOpGrad(self, op, rng, shapes, dtype, order, tol):
|
||||
if jtu.device_under_test() == "tpu":
|
||||
if op is lax.pow:
|
||||
raise SkipTest("pow grad imprecise on tpu")
|
||||
if jtu.device_under_test() == "tpu" and op is lax.pow:
|
||||
raise SkipTest("pow grad imprecise on tpu")
|
||||
tol = 1e-1 if num_float_bits(dtype) == 32 else tol
|
||||
args = tuple(rng(shape, dtype) for shape in shapes)
|
||||
check_grads(op, args, order, ["fwd", "rev"], tol, tol)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_{}".format(rec.op.__name__, special_value),
|
||||
"op": rec.op, "special_value": special_value}
|
||||
for special_value in rec.values)
|
||||
for rec in LAX_GRAD_SPECIAL_VALUE_TESTS))
|
||||
def testOpGradSpecialValue(self, op, special_value):
|
||||
check_grads(op, (special_value,), 2, ["fwd", "rev"])
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_from_dtype={}_to_dtype={}".format(
|
||||
jtu.dtype_str(from_dtype), jtu.dtype_str(to_dtype)),
|
||||
@ -2242,17 +2287,6 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
expected = onp.array(0.0)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testReshapeWithUnusualShapes(self):
|
||||
ans = lax.reshape(onp.ones((3,), onp.float32), (lax.add(1, 2), 1))
|
||||
self.assertAllClose(ans, onp.ones((3, 1), onp.float32), check_dtypes=True)
|
||||
|
||||
jtu.check_raises_regexp(
|
||||
lambda: lax.reshape(onp.ones(3,), (onp.array([3, 1]),)), TypeError,
|
||||
"Shapes must be 1D sequences of concrete values of integer type.*")
|
||||
|
||||
jtu.check_raises_regexp(
|
||||
lambda: lax.reshape(onp.ones(3,), (1.5, 2.0)), TypeError,
|
||||
"Shapes must be 1D sequences of concrete values of integer type.*")
|
||||
|
||||
def all_bdims(*shapes):
|
||||
bdims = (itertools.chain([None], range(len(shape) + 1)) for shape in shapes)
|
||||
|
Loading…
x
Reference in New Issue
Block a user