mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 23:16:05 +00:00
149 lines
5.7 KiB
Python
149 lines
5.7 KiB
Python
# Copyright 2020 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import scipy.signal as osp_signal
|
|
import warnings
|
|
|
|
import numpy as np
|
|
|
|
from jax import lax
|
|
from jax._src.numpy import lax_numpy as jnp
|
|
from jax._src.numpy import linalg
|
|
from jax._src.numpy.lax_numpy import _promote_dtypes_inexact
|
|
from jax._src.numpy.util import _wraps
|
|
|
|
|
|
# Note: we do not re-use the code from jax.numpy.convolve here, because the handling
|
|
# of padding differs slightly between the two implementations (particularly for
|
|
# mode='same').
|
|
def _convolve_nd(in1, in2, mode, *, precision):
|
|
if mode not in ["full", "same", "valid"]:
|
|
raise ValueError("mode must be one of ['full', 'same', 'valid']")
|
|
if in1.ndim != in2.ndim:
|
|
raise ValueError("in1 and in2 must have the same number of dimensions")
|
|
if in1.size == 0 or in2.size == 0:
|
|
raise ValueError(f"zero-size arrays not supported in convolutions, got shapes {in1.shape} and {in2.shape}.")
|
|
in1, in2 = _promote_dtypes_inexact(in1, in2)
|
|
|
|
no_swap = all(s1 >= s2 for s1, s2 in zip(in1.shape, in2.shape))
|
|
swap = all(s1 <= s2 for s1, s2 in zip(in1.shape, in2.shape))
|
|
if not (no_swap or swap):
|
|
raise ValueError("One input must be smaller than the other in every dimension.")
|
|
|
|
shape_o = in2.shape
|
|
if swap:
|
|
in1, in2 = in2, in1
|
|
shape = in2.shape
|
|
in2 = in2[tuple(slice(None, None, -1) for s in shape)]
|
|
|
|
if mode == 'valid':
|
|
padding = [(0, 0) for s in shape]
|
|
elif mode == 'same':
|
|
padding = [(s - 1 - (s_o - 1) // 2, s - s_o + (s_o - 1) // 2)
|
|
for (s, s_o) in zip(shape, shape_o)]
|
|
elif mode == 'full':
|
|
padding = [(s - 1, s - 1) for s in shape]
|
|
|
|
strides = tuple(1 for s in shape)
|
|
result = lax.conv_general_dilated(in1[None, None], in2[None, None], strides,
|
|
padding, precision=precision)
|
|
return result[0, 0]
|
|
|
|
|
|
@_wraps(osp_signal.convolve)
|
|
def convolve(in1, in2, mode='full', method='auto',
|
|
precision=None):
|
|
if method != 'auto':
|
|
warnings.warn("convolve() ignores method argument")
|
|
return _convolve_nd(in1, in2, mode, precision=precision)
|
|
|
|
|
|
@_wraps(osp_signal.convolve2d)
|
|
def convolve2d(in1, in2, mode='full', boundary='fill', fillvalue=0,
|
|
precision=None):
|
|
if boundary != 'fill' or fillvalue != 0:
|
|
raise NotImplementedError("convolve2d() only supports boundary='fill', fillvalue=0")
|
|
if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2:
|
|
raise ValueError("convolve2d() only supports 2-dimensional inputs.")
|
|
return _convolve_nd(in1, in2, mode, precision=precision)
|
|
|
|
|
|
@_wraps(osp_signal.correlate)
|
|
def correlate(in1, in2, mode='full', method='auto',
|
|
precision=None):
|
|
if method != 'auto':
|
|
warnings.warn("correlate() ignores method argument")
|
|
return _convolve_nd(in1, jnp.flip(in2.conj()), mode, precision=precision)
|
|
|
|
|
|
@_wraps(osp_signal.correlate2d)
|
|
def correlate2d(in1, in2, mode='full', boundary='fill', fillvalue=0,
|
|
precision=None):
|
|
if boundary != 'fill' or fillvalue != 0:
|
|
raise NotImplementedError("correlate2d() only supports boundary='fill', fillvalue=0")
|
|
if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2:
|
|
raise ValueError("correlate2d() only supports 2-dimensional inputs.")
|
|
|
|
swap = all(s1 <= s2 for s1, s2 in zip(in1.shape, in2.shape))
|
|
same_shape = all(s1 == s2 for s1, s2 in zip(in1.shape, in2.shape))
|
|
|
|
if mode == "same":
|
|
in1, in2 = in1[::-1, ::-1], in2.conj()
|
|
result = _convolve_nd(in1, in2, mode, precision=precision)[::-1, ::-1]
|
|
elif mode == "valid":
|
|
if swap and not same_shape:
|
|
in1, in2 = in2[::-1, ::-1], in1.conj()
|
|
result = _convolve_nd(in1, in2, mode, precision=precision)
|
|
else:
|
|
in1, in2 = in1[::-1, ::-1], in2.conj()
|
|
result = _convolve_nd(in1, in2, mode, precision=precision)[::-1, ::-1]
|
|
else:
|
|
if swap:
|
|
in1, in2 = in2[::-1, ::-1], in1.conj()
|
|
result = _convolve_nd(in1, in2, mode, precision=precision).conj()
|
|
else:
|
|
in1, in2 = in1[::-1, ::-1], in2.conj()
|
|
result = _convolve_nd(in1, in2, mode, precision=precision)[::-1, ::-1]
|
|
return result
|
|
|
|
|
|
@_wraps(osp_signal.detrend)
|
|
def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=None):
|
|
if overwrite_data is not None:
|
|
raise NotImplementedError("overwrite_data argument not implemented.")
|
|
if type not in ['constant', 'linear']:
|
|
raise ValueError("Trend type must be 'linear' or 'constant'.")
|
|
data, = _promote_dtypes_inexact(jnp.asarray(data))
|
|
if type == 'constant':
|
|
return data - data.mean(axis, keepdims=True)
|
|
else:
|
|
N = data.shape[axis]
|
|
# bp is static, so we use np operations to avoid pushing to device.
|
|
bp = np.sort(np.unique(np.r_[0, bp, N]))
|
|
if bp[0] < 0 or bp[-1] > N:
|
|
raise ValueError("Breakpoints must be non-negative and less than length of data along given axis.")
|
|
data = jnp.moveaxis(data, axis, 0)
|
|
shape = data.shape
|
|
data = data.reshape(N, -1)
|
|
for m in range(len(bp) - 1):
|
|
Npts = bp[m + 1] - bp[m]
|
|
A = jnp.vstack([
|
|
jnp.ones(Npts, dtype=data.dtype),
|
|
jnp.arange(1, Npts + 1, dtype=data.dtype) / Npts
|
|
]).T
|
|
sl = slice(bp[m], bp[m + 1])
|
|
coef, *_ = linalg.lstsq(A, data[sl])
|
|
data = data.at[sl].add(-jnp.matmul(A, coef, precision=lax.Precision.HIGHEST))
|
|
return jnp.moveaxis(data.reshape(shape), 0, axis)
|