add comment in pow_jvp_lhs about calling _safe_mul

This commit is contained in:
Matthew Johnson 2019-02-16 08:08:04 -08:00
parent 58749c0a13
commit 6a9b741ebc

View File

@ -962,6 +962,9 @@ _maybe_real = lambda x: real(x) if _iscomplex(x) else x
pow_p = standard_binop([_float | _complex, _float | _complex], 'pow')
def pow_jvp_lhs(g, x, y):
# we call _safe_mul here so that we get the behavior 0*inf = 0, since when a
# coefficient in `g` is zero we want to keep it at zero, not produce a nan.
# see https://github.com/google/jax/pull/383
jac = mul(y, pow(x, select(eq(y, _zeros(y)), _ones(y), sub(y, _ones(y)))))
return _safe_mul(_brcast(g, y), jac)