Merge pull request #1295 from google/more-lax-autodiff-tests

add special value grad tests
This commit is contained in:
Matthew Johnson 2019-08-31 23:05:51 -07:00 committed by GitHub
commit 110634d50d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 149 additions and 61 deletions

View File

@ -26,11 +26,8 @@ Operators
abs
add
acos
acosh
asin
asinh
atan
atanh
atan2
batch_matmul
bitcast_convert_type

View File

@ -18,7 +18,8 @@ from .lax import (_reduce_sum, _reduce_max, _reduce_min, _reduce_or,
_reduce_and, _reduce_window_sum, _reduce_window_max,
_reduce_window_min, _reduce_window_prod, _float, _complex,
_input_dtype, _const, _eq_meet, _safe_mul,
_broadcasting_select, _check_user_dtype_supported)
_broadcasting_select, _check_user_dtype_supported,
_one, _const)
from .lax_control_flow import *
from .lax_fft import *
from .lax_parallel import *

View File

@ -1377,41 +1377,6 @@ def cosh(x):
# This formulation avoids overflow when e^x is inf but e^x/2 is not inf.
return add(exp(add(log_half, x)), exp(sub(log_half, x)))
def asinh(x):
r"""Elementwise arc hyperbolic sine: :math:`\mathrm{asinh}(x)`."""
# asinh(x) = log(x + sqrt(x**2 + 1))
result = log(add(x, sqrt(add(mul(x, x), _const(x, 1)))))
if onp.issubdtype(_dtype(result), onp.complexfloating):
return result
a = abs(x)
sqrt_max_value = onp.sqrt(onp.finfo(_dtype(x)).max)
return select(lt(a, _const(a, sqrt_max_value)),
result,
mul(sign(x), add(log(a), _const(a, onp.log(2.)))))
def acosh(x):
r"""Elementwise arc hyperbolic cosine: :math:`\mathrm{acosh}(x)`."""
# acosh(x) = log(x + sqrt((x + 1) * (x - 1))) if x < sqrt_max_value
# log(x) + log(2) otherwise
sqrt_max_value = onp.sqrt(onp.finfo(_dtype(x)).max)
result = log(add(x, mul(sqrt(add(x, _const(x, 1))),
sqrt(sub(x, _const(x, 1))))))
if onp.issubdtype(_dtype(result), onp.complexfloating):
return result
return select(
lt(x, _const(x, sqrt_max_value)),
result,
add(log(x), _const(x, onp.log(2.))))
def atanh(x):
r"""Elementwise arc hyperbolic tangent: :math:`\mathrm{atanh}(x)`."""
# atanh(x) = 0.5 * log((1 + x) / (1 - x))
result = mul(_const(x, 0.5), log(div(add(_const(x, 1), x),
sub(_const(x, 1), x))))
if onp.issubdtype(_dtype(result), onp.complexfloating):
return result
return select(le(abs(x), _one(x)), result, full_like(x, onp.nan))
# Add some methods to ShapedArray that rely on lax primitives

View File

@ -57,8 +57,6 @@ acos = onp.arccos
atan = onp.arctan
sinh = onp.sinh
cosh = onp.cosh
asinh = onp.arcsinh
acosh = onp.arccosh
lgamma = scipy.special.gammaln
digamma = scipy.special.digamma

View File

@ -41,7 +41,7 @@ import opt_einsum
import six
from six.moves import builtins, xrange
from jax import jit, device_put
from jax import jit, device_put, custom_transforms, defjvp
from .. import core
from ..abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray
from ..config import flags
@ -327,9 +327,6 @@ arctan = _one_to_one_unop(onp.arctan, lax.atan, True)
sinh = _one_to_one_unop(onp.sinh, lax.sinh, True)
cosh = _one_to_one_unop(onp.cosh, lax.cosh, True)
tanh = _one_to_one_unop(onp.tanh, lax.tanh, True)
arcsinh = _one_to_one_unop(onp.arcsinh, lax.asinh, True)
arccosh = _one_to_one_unop(onp.arccosh, lax.acosh, True)
arctanh = _one_to_one_unop(onp.arctanh, lax.atanh, True)
sqrt = _one_to_one_unop(onp.sqrt, lax.sqrt, True)
@ -573,6 +570,47 @@ def sinc(x):
lax._const(x, 1), lax.div(lax.sin(pi_x), pi_x))
@_wraps(onp.arcsinh)
@custom_transforms
def arcsinh(x):
# asinh(x) = log(x + sqrt(x**2 + 1))
x, = _promote_to_result_dtype(onp.arcsinh, x)
one = lax._const(x, 1)
result = lax.log(x + lax.sqrt(x * x + one))
if onp.issubdtype(_dtype(result), onp.complexfloating):
return result
a = abs(x)
sqrt_max_value = onp.sqrt(onp.finfo(_dtype(x)).max)
log2 = lax._const(a, onp.log(2))
return lax.select(a < sqrt_max_value, result, lax.sign(x) * (lax.log(a) + log2))
defjvp(arcsinh, lambda g, ans, x: g / lax.sqrt(lax._const(x, 1) + square(x)))
@_wraps(onp.arccosh)
def arccosh(x):
# acosh(x) = log(x + sqrt((x + 1) * (x - 1))) if x < sqrt_max_value
# log(x) + log(2) otherwise
x, = _promote_to_result_dtype(onp.arccosh, x)
one = lax._const(x, 1)
result = lax.log(x + lax.sqrt((x + one) * (x - one)))
if onp.issubdtype(_dtype(result), onp.complexfloating):
return result
sqrt_max_value = onp.sqrt(onp.finfo(_dtype(x)).max)
log2 = lax._const(x, onp.log(2))
return lax.select(x < sqrt_max_value, result, lax.log(x) + log2)
@_wraps(onp.arctanh)
def arctanh(x):
# atanh(x) = 0.5 * log((1 + x) / (1 - x))
x, = _promote_to_result_dtype(onp.arctanh, x)
one = lax._const(x, 1)
result = lax._const(x, 0.5) * lax.log((one + x) / (one - x))
if onp.issubdtype(_dtype(result), onp.complexfloating):
return result
return lax.select(abs(x) <= 1, result, lax.full_like(x, onp.nan))
@_wraps(onp.transpose)
def transpose(x, axes=None):
axes = onp.arange(ndim(x))[::-1] if axes is None else axes

View File

@ -36,6 +36,7 @@ from jax import api
from jax import lax
from jax import numpy as lnp
from jax import test_util as jtu
from jax.test_util import check_grads
from jax.lib import xla_bridge
from jax.config import config
@ -1864,5 +1865,59 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
FLAGS.jax_numpy_rank_promotion = prev_flag
# Most grad tests are at the lax level (see lax_test.py), but we add some here
# as needed for e.g. particular compound ops of interest.
GradTestSpec = collections.namedtuple(
"GradTestSpec", ["op", "nargs", "order", "rng", "dtypes", "name", "tol"])
def grad_test_spec(op, nargs, order, rng, dtypes, name=None, tol=None):
return GradTestSpec(op, nargs, order, rng, dtypes, name or op.__name__, tol)
GRAD_TEST_RECORDS = [
grad_test_spec(lnp.arcsinh, nargs=1, order=2, rng=jtu.rand_positive(),
dtypes=[onp.float64, onp.complex64], tol=1e-4),
grad_test_spec(lnp.arccosh, nargs=1, order=2, rng=jtu.rand_positive(),
dtypes=[onp.float64, onp.complex64], tol=1e-4),
grad_test_spec(lnp.arctanh, nargs=1, order=2, rng=jtu.rand_uniform(-0.9, 0.9),
dtypes=[onp.float64, onp.complex64], tol=1e-4),
]
GradSpecialValuesTestSpec = collections.namedtuple(
"GradSpecialValuesTestSpec", ["op", "values"])
GRAD_SPECIAL_VALUE_TEST_RECORDS = [
GradSpecialValuesTestSpec(lnp.arcsinh, [0., 1000.]),
GradSpecialValuesTestSpec(lnp.arccosh, [1000.]),
GradSpecialValuesTestSpec(lnp.arctanh, [0.]),
]
def num_float_bits(dtype):
return onp.finfo(xla_bridge.canonicalize_dtype(dtype)).bits
class NumpyGradTests(jtu.JaxTestCase):
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(
rec.name, shapes, itertools.repeat(dtype)),
"op": rec.op, "rng": rec.rng, "shapes": shapes, "dtype": dtype,
"order": rec.order, "tol": rec.tol}
for shapes in CombosWithReplacement(nonempty_shapes, rec.nargs)
for dtype in rec.dtypes)
for rec in GRAD_TEST_RECORDS))
def testOpGrad(self, op, rng, shapes, dtype, order, tol):
tol = 1e-1 if num_float_bits(dtype) == 32 else tol
args = tuple(rng(shape, dtype) for shape in shapes)
check_grads(op, args, order, ["fwd", "rev"], tol, tol)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "_{}_{}".format(rec.op.__name__, special_value),
"op": rec.op, "special_value": special_value}
for special_value in rec.values)
for rec in GRAD_SPECIAL_VALUE_TEST_RECORDS))
def testOpGradSpecialValue(self, op, special_value):
check_grads(op, (special_value,), 2, ["fwd", "rev"])
if __name__ == "__main__":
absltest.main()

View File

@ -98,8 +98,6 @@ LAX_OPS = [
op_record(lax.atan, 1, float_dtypes, jtu.rand_small()),
op_record(lax.sinh, 1, float_dtypes + complex_dtypes, jtu.rand_default()),
op_record(lax.cosh, 1, float_dtypes + complex_dtypes, jtu.rand_default()),
op_record(lax.asinh, 1, float_dtypes + complex_dtypes, jtu.rand_positive()),
op_record(lax.acosh, 1, float_dtypes + complex_dtypes, jtu.rand_positive()),
op_record(lax.lgamma, 1, float_dtypes, jtu.rand_positive()),
op_record(lax.digamma, 1, float_dtypes, jtu.rand_positive()),
@ -1399,6 +1397,18 @@ class LaxTest(jtu.JaxTestCase):
api.jit(f)(1.) # doesn't crash
def testReshapeWithUnusualShapes(self):
ans = lax.reshape(onp.ones((3,), onp.float32), (lax.add(1, 2), 1))
self.assertAllClose(ans, onp.ones((3, 1), onp.float32), check_dtypes=True)
jtu.check_raises_regexp(
lambda: lax.reshape(onp.ones(3,), (onp.array([3, 1]),)), TypeError,
"Shapes must be 1D sequences of concrete values of integer type.*")
jtu.check_raises_regexp(
lambda: lax.reshape(onp.ones(3,), (1.5, 2.0)), TypeError,
"Shapes must be 1D sequences of concrete values of integer type.*")
class DeviceConstantTest(jtu.JaxTestCase):
def _CheckDeviceConstant(self, make_const, expected):
@ -1502,12 +1512,22 @@ LAX_GRAD_OPS = [
dtypes=[onp.float64, onp.complex64]),
grad_test_spec(lax.log1p, nargs=1, order=2, rng=jtu.rand_positive(),
dtypes=[onp.float64, onp.complex64]),
grad_test_spec(lax.sinh, nargs=1, order=2, rng=jtu.rand_default(),
dtypes=[onp.float64, onp.complex64], tol=1e-5),
grad_test_spec(lax.cosh, nargs=1, order=2, rng=jtu.rand_default(),
dtypes=[onp.float64, onp.complex64], tol=1e-5),
grad_test_spec(lax.tanh, nargs=1, order=2, rng=jtu.rand_default(),
dtypes=[onp.float64, onp.complex64], tol=1e-5),
grad_test_spec(lax.sin, nargs=1, order=2, rng=jtu.rand_default(),
dtypes=[onp.float64, onp.complex64]),
grad_test_spec(lax.cos, nargs=1, order=2, rng=jtu.rand_default(),
dtypes=[onp.float64, onp.complex64]),
grad_test_spec(lax.tan, nargs=1, order=2, rng=jtu.rand_uniform(-1.3, 1.3),
dtypes=[onp.float64, onp.complex64], tol=1e-3),
grad_test_spec(lax.asin, nargs=1, order=2, rng=jtu.rand_uniform(-1., 1.),
dtypes=[onp.float64], tol=1e-3),
grad_test_spec(lax.acos, nargs=1, order=2, rng=jtu.rand_uniform(-1., 1.),
dtypes=[onp.float64], tol=1e-3),
# TODO(proteneer): atan2 input is already a representation of a
# complex number. Need to think harder about what this even means
# if each input itself is a complex number.
@ -1556,6 +1576,23 @@ LAX_GRAD_OPS = [
# dtypes=[onp.float64], name="MinSomeEqual"),
]
GradSpecialValuesTestSpec = collections.namedtuple(
"GradSpecialValuesTestSpec", ["op", "values"])
LAX_GRAD_SPECIAL_VALUE_TESTS = [
GradSpecialValuesTestSpec(lax.sinh, [0.]),
GradSpecialValuesTestSpec(lax.cosh, [0.]),
GradSpecialValuesTestSpec(lax.tanh, [0., 1000.]),
GradSpecialValuesTestSpec(lax.sin, [0., onp.pi, onp.pi/2., onp.pi/4.]),
GradSpecialValuesTestSpec(lax.cos, [0., onp.pi, onp.pi/2., onp.pi/4.]),
GradSpecialValuesTestSpec(lax.tan, [0.]),
GradSpecialValuesTestSpec(lax.asin, [0.]),
GradSpecialValuesTestSpec(lax.acos, [0.]),
GradSpecialValuesTestSpec(lax.atan, [0., 1000.]),
GradSpecialValuesTestSpec(lax.erf, [0., 10.]),
GradSpecialValuesTestSpec(lax.erfc, [0., 10.]),
]
def check_grads_bilinear(f, args, order,
modes=["fwd", "rev"], atol=None, rtol=None):
@ -1581,13 +1618,21 @@ class LaxAutodiffTest(jtu.JaxTestCase):
for dtype in rec.dtypes)
for rec in LAX_GRAD_OPS))
def testOpGrad(self, op, rng, shapes, dtype, order, tol):
if jtu.device_under_test() == "tpu":
if op is lax.pow:
raise SkipTest("pow grad imprecise on tpu")
if jtu.device_under_test() == "tpu" and op is lax.pow:
raise SkipTest("pow grad imprecise on tpu")
tol = 1e-1 if num_float_bits(dtype) == 32 else tol
args = tuple(rng(shape, dtype) for shape in shapes)
check_grads(op, args, order, ["fwd", "rev"], tol, tol)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "_{}_{}".format(rec.op.__name__, special_value),
"op": rec.op, "special_value": special_value}
for special_value in rec.values)
for rec in LAX_GRAD_SPECIAL_VALUE_TESTS))
def testOpGradSpecialValue(self, op, special_value):
check_grads(op, (special_value,), 2, ["fwd", "rev"])
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_from_dtype={}_to_dtype={}".format(
jtu.dtype_str(from_dtype), jtu.dtype_str(to_dtype)),
@ -2242,17 +2287,6 @@ class LaxAutodiffTest(jtu.JaxTestCase):
expected = onp.array(0.0)
self.assertAllClose(ans, expected, check_dtypes=False)
def testReshapeWithUnusualShapes(self):
ans = lax.reshape(onp.ones((3,), onp.float32), (lax.add(1, 2), 1))
self.assertAllClose(ans, onp.ones((3, 1), onp.float32), check_dtypes=True)
jtu.check_raises_regexp(
lambda: lax.reshape(onp.ones(3,), (onp.array([3, 1]),)), TypeError,
"Shapes must be 1D sequences of concrete values of integer type.*")
jtu.check_raises_regexp(
lambda: lax.reshape(onp.ones(3,), (1.5, 2.0)), TypeError,
"Shapes must be 1D sequences of concrete values of integer type.*")
def all_bdims(*shapes):
bdims = (itertools.chain([None], range(len(shape) + 1)) for shape in shapes)