Use more numerically stable formulation of tanh gradient.

This commit is contained in:
Peter Hawkins 2019-05-24 11:07:08 -04:00
parent 41c2e9d447
commit 9e68d9114e
2 changed files with 9 additions and 1 deletions

View File

@ -1437,7 +1437,7 @@ log1p_p = standard_unop(_float | _complex, 'log1p')
ad.defjvp(log1p_p, lambda g, x: div(g, add(x, _one(x))))
tanh_p = standard_unop(_float | _complex, 'tanh')
ad.defjvp(tanh_p, lambda g, x: div(g, pow(cosh(x), _two(x))))
ad.defjvp2(tanh_p, lambda g, ans, x: mul(g, sub(_one(x), mul(ans, ans))))
sin_p = standard_unop(_float | _complex, 'sin')
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))

View File

@ -1544,6 +1544,14 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testIssue746(self):
lnp.arange(12).reshape(3, 4) # doesn't crash
def testIssue764(self):
x = lnp.linspace(190, 200, 4)
f = api.grad(lambda x: lnp.sum(lnp.tanh(x)))
# Expected values computed with autograd in float64 precision.
expected = onp.array([3.71669453e-165, 4.72999108e-168, 6.01954653e-171,
7.66067839e-174], onp.float64)
self.assertAllClose(f(x), expected, check_dtypes=False)
if __name__ == "__main__":
absltest.main()