mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix lax.imag jvp and enable test, fixes #979
This commit is contained in:
parent
527fe14838
commit
93841df822
@ -1608,7 +1608,7 @@ real_p = unop(_complex_basetype, _complex, 'real')
|
||||
ad.deflinear(real_p, lambda t: [complex(t, onp.zeros((), _dtype(t)))])
|
||||
|
||||
imag_p = unop(_complex_basetype, _complex, 'imag')
|
||||
ad.deflinear(imag_p, lambda t: [complex(onp.zeros((), _dtype(t)), t)])
|
||||
ad.defjvp(imag_p, lambda g, _: real(mul(_const(g, -1j), g)))
|
||||
|
||||
_complex_dtype = lambda dtype, *args: (onp.zeros((), dtype) + onp.zeros((), onp.complex64)).dtype
|
||||
complex_p = binop(_complex_dtype, [_complex_elem_types, _complex_elem_types],
|
||||
|
@ -1523,8 +1523,8 @@ LAX_GRAD_OPS = [
|
||||
|
||||
grad_test_spec(lax.real, nargs=1, order=2, rng=jtu.rand_default(),
|
||||
dtypes=[onp.complex64]),
|
||||
# grad_test_spec(lax.imag, nargs=1, order=2, rng=jtu.rand_default(),
|
||||
# dtypes=[onp.complex64]), # TODO(mattjj): enable
|
||||
grad_test_spec(lax.imag, nargs=1, order=2, rng=jtu.rand_default(),
|
||||
dtypes=[onp.complex64]),
|
||||
# grad_test_spec(lax.complex, nargs=2, order=2, rng=jtu.rand_default(),
|
||||
# dtypes=[onp.float32]), # TODO(mattjj): enable
|
||||
grad_test_spec(lax.conj, nargs=1, order=2, rng=jtu.rand_default(),
|
||||
|
Loading…
x
Reference in New Issue
Block a user