mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Use more numerically stable formulation of tanh gradient.
This commit is contained in:
parent
41c2e9d447
commit
9e68d9114e
@ -1437,7 +1437,7 @@ log1p_p = standard_unop(_float | _complex, 'log1p')
|
||||
ad.defjvp(log1p_p, lambda g, x: div(g, add(x, _one(x))))
|
||||
|
||||
tanh_p = standard_unop(_float | _complex, 'tanh')
|
||||
ad.defjvp(tanh_p, lambda g, x: div(g, pow(cosh(x), _two(x))))
|
||||
ad.defjvp2(tanh_p, lambda g, ans, x: mul(g, sub(_one(x), mul(ans, ans))))
|
||||
|
||||
sin_p = standard_unop(_float | _complex, 'sin')
|
||||
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
|
||||
|
@ -1544,6 +1544,14 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def testIssue746(self):
|
||||
lnp.arange(12).reshape(3, 4) # doesn't crash
|
||||
|
||||
def testIssue764(self):
|
||||
x = lnp.linspace(190, 200, 4)
|
||||
f = api.grad(lambda x: lnp.sum(lnp.tanh(x)))
|
||||
# Expected values computed with autograd in float64 precision.
|
||||
expected = onp.array([3.71669453e-165, 4.72999108e-168, 6.01954653e-171,
|
||||
7.66067839e-174], onp.float64)
|
||||
self.assertAllClose(f(x), expected, check_dtypes=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user