mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #6748 from hawkinsp:complexconv
PiperOrigin-RevId: 374259292
This commit is contained in:
commit
d39261497c
@ -760,9 +760,8 @@ def trunc(x):
|
||||
return where(lax.lt(x, lax._const(x, 0)), ceil(x), floor(x))
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(2, 3, 4))
|
||||
def _conv(x, y, mode, op, precision):
|
||||
if issubdtype(_dtype(x), complexfloating) or issubdtype(_dtype(y), complexfloating):
|
||||
raise NotImplementedError(f"{op}() does not support complex inputs")
|
||||
if ndim(x) != 1 or ndim(y) != 1:
|
||||
raise ValueError(f"{op}() only support 1-dimensional inputs.")
|
||||
x, y = _promote_dtypes_inexact(x, y)
|
||||
@ -770,11 +769,14 @@ def _conv(x, y, mode, op, precision):
|
||||
raise ValueError(f"{op}: inputs cannot be empty, got shapes {x.shape} and {y.shape}.")
|
||||
|
||||
out_order = slice(None)
|
||||
if len(x) < len(y):
|
||||
x, y = y, x
|
||||
if op == "correlate":
|
||||
if op == 'correlate':
|
||||
y = conj(y)
|
||||
if len(x) < len(y):
|
||||
x, y = y, x
|
||||
out_order = slice(None, None, -1)
|
||||
if op == 'convolve':
|
||||
elif op == 'convolve':
|
||||
if len(x) < len(y):
|
||||
x, y = y, x
|
||||
y = y[::-1]
|
||||
|
||||
if mode == 'valid':
|
||||
|
@ -2221,7 +2221,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
"np_op": getattr(np, op)}
|
||||
for mode in ['full', 'same', 'valid']
|
||||
for op in ['convolve', 'correlate']
|
||||
for dtype in default_dtypes
|
||||
for dtype in number_dtypes
|
||||
for xshape in one_dim_array_shapes
|
||||
for yshape in one_dim_array_shapes))
|
||||
def testConvolutions(self, xshape, yshape, dtype, mode, jnp_op, np_op):
|
||||
|
Loading…
x
Reference in New Issue
Block a user