mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add comment in pow_jvp_lhs about calling _safe_mul
This commit is contained in:
parent
58749c0a13
commit
6a9b741ebc
@ -962,6 +962,9 @@ _maybe_real = lambda x: real(x) if _iscomplex(x) else x
|
||||
pow_p = standard_binop([_float | _complex, _float | _complex], 'pow')
|
||||
|
||||
def pow_jvp_lhs(g, x, y):
|
||||
# we call _safe_mul here so that we get the behavior 0*inf = 0, since when a
|
||||
# coefficient in `g` is zero we want to keep it at zero, not produce a nan.
|
||||
# see https://github.com/google/jax/pull/383
|
||||
jac = mul(y, pow(x, select(eq(y, _zeros(y)), _ones(y), sub(y, _ones(y)))))
|
||||
return _safe_mul(_brcast(g, y), jac)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user