better abs jvp

This commit is contained in:
James Bradbury 2019-09-18 23:55:31 -07:00
parent d11b17a2cd
commit b39179c887

View File

@ -1636,9 +1636,14 @@ ad.primitive_jvps[conj_p] = partial(ad.linear_jvp, conj_p)
ad.primitive_transposes[conj_p] = _conj_transpose_rule
abs_p = unop(_complex_basetype, _num, 'abs')
ad.defjvp2(abs_p,
lambda g, ans, x:
div(_maybe_real(mul(g, _maybe_conj(x))), _replace_zero(ans)))
def _abs_jvp_rule(g, ans, x):
if _iscomplex(x):
return _maybe_real(mul(g, div(_maybe_conj(x),
_replace_zero(convert_element_type(ans, _dtype(x))))))
else:
return select(ge(x, _zero(x)), g, neg(g))
ad.defjvp2(abs_p, _abs_jvp_rule)
_maybe_conj = lambda x: conj(x) if _iscomplex(x) else x
_maybe_real = lambda x: real(x) if _iscomplex(x) else x