mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #6807 from lgeiger:reuse-jvp-ans
PiperOrigin-RevId: 375366940
This commit is contained in:
commit
89d208b62b
@ -2383,7 +2383,7 @@ def tan_translation_rule(x):
|
||||
|
||||
tan_p = standard_unop(_float | _complex, 'tan',
|
||||
translation_rule=tan_translation_rule)
|
||||
ad.defjvp(tan_p, lambda g, x: mul(g, _const(x, 1) + square(tan(x))))
|
||||
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
|
||||
|
||||
|
||||
@partial(xla.lower_fun, multiple_results=False)
|
||||
@ -2527,8 +2527,8 @@ ad.defjvp(erf_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)),
|
||||
mul(g, exp(neg(square(x))))))
|
||||
|
||||
erfc_p = standard_unop(_float, 'erfc')
|
||||
ad.defjvp(erfc_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)),
|
||||
mul(neg(g), exp(neg(square(x))))))
|
||||
ad.defjvp(erfc_p, lambda g, x: mul(_const(x, -2. / np.sqrt(np.pi)),
|
||||
mul(g, exp(neg(square(x))))))
|
||||
|
||||
erf_inv_p = standard_unop(_float, 'erf_inv')
|
||||
ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, np.sqrt(np.pi) / 2.),
|
||||
@ -2595,7 +2595,7 @@ ad.defjvp2(sqrt_p, lambda g, ans, x: mul(g, div(_const(x, 0.5), ans)))
|
||||
rsqrt_p = standard_unop(_float | _complex, 'rsqrt')
|
||||
ad.defjvp2(rsqrt_p,
|
||||
lambda g, ans, x:
|
||||
mul(g, mul(_const(x, -0.5), pow(x, _const(x, -1.5)))))
|
||||
mul(g, mul(_const(x, -0.5), div(ans, x))))
|
||||
|
||||
pow_p = standard_naryop([_float | _complex, _float | _complex], 'pow')
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user