Merge pull request #6807 from lgeiger:reuse-jvp-ans

PiperOrigin-RevId: 375366940
This commit is contained in:
jax authors 2021-05-23 11:13:57 -07:00
commit 89d208b62b

View File

@ -2383,7 +2383,7 @@ def tan_translation_rule(x):
tan_p = standard_unop(_float | _complex, 'tan',
translation_rule=tan_translation_rule)
ad.defjvp(tan_p, lambda g, x: mul(g, _const(x, 1) + square(tan(x))))
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
@partial(xla.lower_fun, multiple_results=False)
@ -2527,8 +2527,8 @@ ad.defjvp(erf_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)),
mul(g, exp(neg(square(x))))))
erfc_p = standard_unop(_float, 'erfc')
ad.defjvp(erfc_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)),
mul(neg(g), exp(neg(square(x))))))
ad.defjvp(erfc_p, lambda g, x: mul(_const(x, -2. / np.sqrt(np.pi)),
mul(g, exp(neg(square(x))))))
erf_inv_p = standard_unop(_float, 'erf_inv')
ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, np.sqrt(np.pi) / 2.),
@ -2595,7 +2595,7 @@ ad.defjvp2(sqrt_p, lambda g, ans, x: mul(g, div(_const(x, 0.5), ans)))
rsqrt_p = standard_unop(_float | _complex, 'rsqrt')
ad.defjvp2(rsqrt_p,
lambda g, ans, x:
mul(g, mul(_const(x, -0.5), pow(x, _const(x, -1.5)))))
mul(g, mul(_const(x, -0.5), div(ans, x))))
pow_p = standard_naryop([_float | _complex, _float | _complex], 'pow')