mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
better abs jvp
This commit is contained in:
parent
d11b17a2cd
commit
b39179c887
@ -1636,9 +1636,14 @@ ad.primitive_jvps[conj_p] = partial(ad.linear_jvp, conj_p)
|
||||
ad.primitive_transposes[conj_p] = _conj_transpose_rule
|
||||
|
||||
abs_p = unop(_complex_basetype, _num, 'abs')
|
||||
ad.defjvp2(abs_p,
|
||||
lambda g, ans, x:
|
||||
div(_maybe_real(mul(g, _maybe_conj(x))), _replace_zero(ans)))
|
||||
|
||||
def _abs_jvp_rule(g, ans, x):
|
||||
if _iscomplex(x):
|
||||
return _maybe_real(mul(g, div(_maybe_conj(x),
|
||||
_replace_zero(convert_element_type(ans, _dtype(x))))))
|
||||
else:
|
||||
return select(ge(x, _zero(x)), g, neg(g))
|
||||
ad.defjvp2(abs_p, _abs_jvp_rule)
|
||||
_maybe_conj = lambda x: conj(x) if _iscomplex(x) else x
|
||||
_maybe_real = lambda x: real(x) if _iscomplex(x) else x
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user