This commit is contained in:
chenyee 2019-10-28 22:37:01 +08:00
parent 7cbd58b6c6
commit 6839f28c6a

View File

@ -1713,17 +1713,17 @@ ad.defjvp2(rsqrt_p,
pow_p = standard_binop([_float | _complex, _float | _complex], 'pow')
def _pow_jvp_lhs(g, x, y):
def _pow_jvp_lhs(g, ans, x, y):
# we call _safe_mul here so that we get the behavior 0*inf = 0, since when a
# coefficient in `g` is zero we want to keep it at zero, not produce a nan.
# see https://github.com/google/jax/pull/383
jac = mul(y, pow(x, select(eq(y, _zeros(y)), _ones(y), sub(y, _ones(y)))))
return _safe_mul(_brcast(g, y), jac)
def _pow_jvp_rhs(g, x, y):
return mul(_brcast(g, x), mul(log(_replace_zero(x)), pow(x, y)))
def _pow_jvp_rhs(g, ans, x, y):
return mul(_brcast(g, x), mul(log(_replace_zero(x)), ans))
ad.defjvp(pow_p, _pow_jvp_lhs, _pow_jvp_rhs)
ad.defjvp2(pow_p, _pow_jvp_lhs, _pow_jvp_rhs)
_replace_zero = lambda x: select(eq(x, _const(x, 0)), _ones(x), x)
not_p = standard_unop(_int | _bool, 'not')