fix lax.imag jvp and enable test, fixes #979

This commit is contained in:
Matthew Johnson 2019-07-05 14:32:04 -07:00
parent 527fe14838
commit 93841df822
2 changed files with 3 additions and 3 deletions

View File

@ -1608,7 +1608,7 @@ real_p = unop(_complex_basetype, _complex, 'real')
ad.deflinear(real_p, lambda t: [complex(t, onp.zeros((), _dtype(t)))])
imag_p = unop(_complex_basetype, _complex, 'imag')
ad.deflinear(imag_p, lambda t: [complex(onp.zeros((), _dtype(t)), t)])
ad.defjvp(imag_p, lambda g, _: real(mul(_const(g, -1j), g)))
_complex_dtype = lambda dtype, *args: (onp.zeros((), dtype) + onp.zeros((), onp.complex64)).dtype
complex_p = binop(_complex_dtype, [_complex_elem_types, _complex_elem_types],

View File

@ -1523,8 +1523,8 @@ LAX_GRAD_OPS = [
grad_test_spec(lax.real, nargs=1, order=2, rng=jtu.rand_default(),
dtypes=[onp.complex64]),
# grad_test_spec(lax.imag, nargs=1, order=2, rng=jtu.rand_default(),
# dtypes=[onp.complex64]), # TODO(mattjj): enable
grad_test_spec(lax.imag, nargs=1, order=2, rng=jtu.rand_default(),
dtypes=[onp.complex64]),
# grad_test_spec(lax.complex, nargs=2, order=2, rng=jtu.rand_default(),
# dtypes=[onp.float32]), # TODO(mattjj): enable
grad_test_spec(lax.conj, nargs=1, order=2, rng=jtu.rand_default(),