add special value grad tests, sinh failing

This commit is contained in:
Matthew Johnson 2019-08-31 21:23:39 -07:00
parent 1147eac3b3
commit 4225857166

View File

@ -1399,6 +1399,18 @@ class LaxTest(jtu.JaxTestCase):
api.jit(f)(1.) # doesn't crash
def testReshapeWithUnusualShapes(self):
ans = lax.reshape(onp.ones((3,), onp.float32), (lax.add(1, 2), 1))
self.assertAllClose(ans, onp.ones((3, 1), onp.float32), check_dtypes=True)
jtu.check_raises_regexp(
lambda: lax.reshape(onp.ones(3,), (onp.array([3, 1]),)), TypeError,
"Shapes must be 1D sequences of concrete values of integer type.*")
jtu.check_raises_regexp(
lambda: lax.reshape(onp.ones(3,), (1.5, 2.0)), TypeError,
"Shapes must be 1D sequences of concrete values of integer type.*")
class DeviceConstantTest(jtu.JaxTestCase):
def _CheckDeviceConstant(self, make_const, expected):
@ -1502,12 +1514,28 @@ LAX_GRAD_OPS = [
dtypes=[onp.float64, onp.complex64]),
grad_test_spec(lax.log1p, nargs=1, order=2, rng=jtu.rand_positive(),
dtypes=[onp.float64, onp.complex64]),
grad_test_spec(lax.sinh, nargs=1, order=2, rng=jtu.rand_default(),
dtypes=[onp.float64, onp.complex64], tol=1e-5),
grad_test_spec(lax.cosh, nargs=1, order=2, rng=jtu.rand_default(),
dtypes=[onp.float64, onp.complex64], tol=1e-5),
grad_test_spec(lax.tanh, nargs=1, order=2, rng=jtu.rand_default(),
dtypes=[onp.float64, onp.complex64], tol=1e-5),
grad_test_spec(lax.asinh, nargs=1, order=2, rng=jtu.rand_positive(),
dtypes=[onp.float64, onp.complex64], tol=1e-5),
grad_test_spec(lax.acosh, nargs=1, order=2, rng=jtu.rand_positive(),
dtypes=[onp.float64, onp.complex64], tol=1e-5),
grad_test_spec(lax.atanh, nargs=1, order=2, rng=jtu.rand_uniform(-0.9, 0.9),
dtypes=[onp.float64, onp.complex64], tol=1e-5),
grad_test_spec(lax.sin, nargs=1, order=2, rng=jtu.rand_default(),
dtypes=[onp.float64, onp.complex64]),
grad_test_spec(lax.cos, nargs=1, order=2, rng=jtu.rand_default(),
dtypes=[onp.float64, onp.complex64]),
grad_test_spec(lax.tan, nargs=1, order=2, rng=jtu.rand_uniform(-1.3, 1.3),
dtypes=[onp.float64, onp.complex64]),
grad_test_spec(lax.asin, nargs=1, order=2, rng=jtu.rand_uniform(-1., 1.),
dtypes=[onp.float64]),
grad_test_spec(lax.acos, nargs=1, order=2, rng=jtu.rand_uniform(-1., 1.),
dtypes=[onp.float64]),
# TODO(proteneer): atan2 input is already a representation of a
# complex number. Need to think harder about what this even means
# if each input itself is a complex number.
@ -1556,6 +1584,26 @@ LAX_GRAD_OPS = [
# dtypes=[onp.float64], name="MinSomeEqual"),
]
GradSpecialValuesTestSpec = collections.namedtuple(
"GradSpecialValuesTestSpec", ["op", "values"])
LAX_GRAD_SPECIAL_VALUE_TESTS = [
GradSpecialValuesTestSpec(lax.sinh, [0.]),
GradSpecialValuesTestSpec(lax.cosh, [0.]),
GradSpecialValuesTestSpec(lax.tanh, [0., 1000.]),
GradSpecialValuesTestSpec(lax.asinh, [0., 1000.]),
GradSpecialValuesTestSpec(lax.acosh, [1000.]),
GradSpecialValuesTestSpec(lax.atanh, [0.]),
GradSpecialValuesTestSpec(lax.sin, [0., onp.pi, onp.pi/2., onp.pi/4.]),
GradSpecialValuesTestSpec(lax.cos, [0., onp.pi, onp.pi/2., onp.pi/4.]),
GradSpecialValuesTestSpec(lax.tan, [0.]),
GradSpecialValuesTestSpec(lax.asin, [0.]),
GradSpecialValuesTestSpec(lax.acos, [0.]),
GradSpecialValuesTestSpec(lax.atan, [0., 1000.]),
GradSpecialValuesTestSpec(lax.erf, [0., 10.]),
GradSpecialValuesTestSpec(lax.erfc, [0., 10.]),
]
def check_grads_bilinear(f, args, order,
modes=["fwd", "rev"], atol=None, rtol=None):
@ -1581,13 +1629,21 @@ class LaxAutodiffTest(jtu.JaxTestCase):
for dtype in rec.dtypes)
for rec in LAX_GRAD_OPS))
def testOpGrad(self, op, rng, shapes, dtype, order, tol):
if jtu.device_under_test() == "tpu":
if op is lax.pow:
raise SkipTest("pow grad imprecise on tpu")
if jtu.device_under_test() == "tpu" and op is lax.pow:
raise SkipTest("pow grad imprecise on tpu")
tol = 1e-1 if num_float_bits(dtype) == 32 else tol
args = tuple(rng(shape, dtype) for shape in shapes)
check_grads(op, args, order, ["fwd", "rev"], tol, tol)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "_{}_{}".format(rec.op.__name__, special_value),
"op": rec.op, "special_value": special_value}
for special_value in rec.values)
for rec in LAX_GRAD_SPECIAL_VALUE_TESTS))
def testOpGradSpecialValue(self, op, special_value):
check_grads(op, (special_value,), 2, ["fwd", "rev"])
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_from_dtype={}_to_dtype={}".format(
jtu.dtype_str(from_dtype), jtu.dtype_str(to_dtype)),
@ -2242,17 +2298,6 @@ class LaxAutodiffTest(jtu.JaxTestCase):
expected = onp.array(0.0)
self.assertAllClose(ans, expected, check_dtypes=False)
def testReshapeWithUnusualShapes(self):
ans = lax.reshape(onp.ones((3,), onp.float32), (lax.add(1, 2), 1))
self.assertAllClose(ans, onp.ones((3, 1), onp.float32), check_dtypes=True)
jtu.check_raises_regexp(
lambda: lax.reshape(onp.ones(3,), (onp.array([3, 1]),)), TypeError,
"Shapes must be 1D sequences of concrete values of integer type.*")
jtu.check_raises_regexp(
lambda: lax.reshape(onp.ones(3,), (1.5, 2.0)), TypeError,
"Shapes must be 1D sequences of concrete values of integer type.*")
def all_bdims(*shapes):
bdims = (itertools.chain([None], range(len(shape) + 1)) for shape in shapes)