mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix issue #1576
This commit is contained in:
parent
7cbd58b6c6
commit
6839f28c6a
@ -1713,17 +1713,17 @@ ad.defjvp2(rsqrt_p,
|
||||
|
||||
pow_p = standard_binop([_float | _complex, _float | _complex], 'pow')
|
||||
|
||||
def _pow_jvp_lhs(g, x, y):
|
||||
def _pow_jvp_lhs(g, ans, x, y):
|
||||
# we call _safe_mul here so that we get the behavior 0*inf = 0, since when a
|
||||
# coefficient in `g` is zero we want to keep it at zero, not produce a nan.
|
||||
# see https://github.com/google/jax/pull/383
|
||||
jac = mul(y, pow(x, select(eq(y, _zeros(y)), _ones(y), sub(y, _ones(y)))))
|
||||
return _safe_mul(_brcast(g, y), jac)
|
||||
|
||||
def _pow_jvp_rhs(g, x, y):
|
||||
return mul(_brcast(g, x), mul(log(_replace_zero(x)), pow(x, y)))
|
||||
def _pow_jvp_rhs(g, ans, x, y):
|
||||
return mul(_brcast(g, x), mul(log(_replace_zero(x)), ans))
|
||||
|
||||
ad.defjvp(pow_p, _pow_jvp_lhs, _pow_jvp_rhs)
|
||||
ad.defjvp2(pow_p, _pow_jvp_lhs, _pow_jvp_rhs)
|
||||
_replace_zero = lambda x: select(eq(x, _const(x, 0)), _ones(x), x)
|
||||
|
||||
not_p = standard_unop(_int | _bool, 'not')
|
||||
|
Loading…
x
Reference in New Issue
Block a user