Merge pull request #6748 from hawkinsp:complexconv

PiperOrigin-RevId: 374259292
This commit is contained in:
jax authors 2021-05-17 12:51:35 -07:00
commit d39261497c
2 changed files with 9 additions and 7 deletions

View File

@ -760,9 +760,8 @@ def trunc(x):
return where(lax.lt(x, lax._const(x, 0)), ceil(x), floor(x))
@partial(jit, static_argnums=(2, 3, 4))
def _conv(x, y, mode, op, precision):
if issubdtype(_dtype(x), complexfloating) or issubdtype(_dtype(y), complexfloating):
raise NotImplementedError(f"{op}() does not support complex inputs")
if ndim(x) != 1 or ndim(y) != 1:
raise ValueError(f"{op}() only support 1-dimensional inputs.")
x, y = _promote_dtypes_inexact(x, y)
@ -770,11 +769,14 @@ def _conv(x, y, mode, op, precision):
raise ValueError(f"{op}: inputs cannot be empty, got shapes {x.shape} and {y.shape}.")
out_order = slice(None)
if len(x) < len(y):
x, y = y, x
if op == "correlate":
if op == 'correlate':
y = conj(y)
if len(x) < len(y):
x, y = y, x
out_order = slice(None, None, -1)
if op == 'convolve':
elif op == 'convolve':
if len(x) < len(y):
x, y = y, x
y = y[::-1]
if mode == 'valid':

View File

@ -2221,7 +2221,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
"np_op": getattr(np, op)}
for mode in ['full', 'same', 'valid']
for op in ['convolve', 'correlate']
for dtype in default_dtypes
for dtype in number_dtypes
for xshape in one_dim_array_shapes
for yshape in one_dim_array_shapes))
def testConvolutions(self, xshape, yshape, dtype, mode, jnp_op, np_op):