mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #16487 from jakevdp:convolve-dtype
PiperOrigin-RevId: 542929304
This commit is contained in:
commit
01a16f5914
@ -349,11 +349,24 @@ def trunc(x: ArrayLike) -> Array:
|
||||
return where(lax.lt(x, _lax_const(x, 0)), ufuncs.ceil(x), ufuncs.floor(x))
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(2, 3, 4))
|
||||
def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike) -> Array:
|
||||
_PREFERRED_ELEMENT_TYPE_DESCRIPTION = """
|
||||
preferred_element_type : dtype, optional
|
||||
If specified, accumulate results and return a result of the given data type.
|
||||
If not specified, the function instead follows the numpy convention of always
|
||||
accumulating results and returning an inexact dtype.
|
||||
"""
|
||||
|
||||
@partial(jit, static_argnames=['mode', 'op', 'precision', 'preferred_element_type'])
|
||||
def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike,
|
||||
preferred_element_type: Optional[DTypeLike] = None) -> Array:
|
||||
if ndim(x) != 1 or ndim(y) != 1:
|
||||
raise ValueError(f"{op}() only support 1-dimensional inputs.")
|
||||
x, y = util.promote_dtypes_inexact(x, y)
|
||||
if preferred_element_type is None:
|
||||
# if unspecified, promote to inexact following NumPy's default for convolutions.
|
||||
x, y = util.promote_dtypes_inexact(x, y)
|
||||
else:
|
||||
# otherwise cast to same type but otherwise preserve input dtypes
|
||||
x, y = util.promote_dtypes(x, y)
|
||||
if len(x) == 0 or len(y) == 0:
|
||||
raise ValueError(f"{op}: inputs cannot be empty, got shapes {x.shape} and {y.shape}.")
|
||||
|
||||
@ -378,24 +391,31 @@ def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike) -> A
|
||||
raise ValueError("mode must be one of ['full', 'same', 'valid']")
|
||||
|
||||
result = lax.conv_general_dilated(x[None, None, :], y[None, None, :], (1,),
|
||||
padding, precision=precision)
|
||||
padding, precision=precision,
|
||||
preferred_element_type=preferred_element_type)
|
||||
return result[0, 0, out_order]
|
||||
|
||||
|
||||
@util._wraps(np.convolve, lax_description=_PRECISION_DOC)
|
||||
@partial(jit, static_argnames=('mode', 'precision'))
|
||||
@util._wraps(np.convolve, lax_description=_PRECISION_DOC,
|
||||
extra_params=_PREFERRED_ELEMENT_TYPE_DESCRIPTION)
|
||||
@partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type'))
|
||||
def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *,
|
||||
precision: PrecisionLike = None) -> Array:
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: Optional[dtype] = None) -> Array:
|
||||
util.check_arraylike("convolve", a, v)
|
||||
return _conv(asarray(a), asarray(v), mode, 'convolve', precision)
|
||||
return _conv(asarray(a), asarray(v), mode=mode, op='convolve',
|
||||
precision=precision, preferred_element_type=preferred_element_type)
|
||||
|
||||
|
||||
@util._wraps(np.correlate, lax_description=_PRECISION_DOC)
|
||||
@partial(jit, static_argnames=('mode', 'precision'))
|
||||
@util._wraps(np.correlate, lax_description=_PRECISION_DOC,
|
||||
extra_params=_PREFERRED_ELEMENT_TYPE_DESCRIPTION)
|
||||
@partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type'))
|
||||
def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *,
|
||||
precision: PrecisionLike = None) -> Array:
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: Optional[dtype] = None) -> Array:
|
||||
util.check_arraylike("correlate", a, v)
|
||||
return _conv(asarray(a), asarray(v), mode, 'correlate', precision)
|
||||
return _conv(asarray(a), asarray(v), mode=mode, op='correlate',
|
||||
precision=precision, preferred_element_type=preferred_element_type)
|
||||
|
||||
|
||||
@util._wraps(np.histogram_bin_edges)
|
||||
|
@ -1894,12 +1894,35 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
|
||||
precision = lax.Precision.HIGHEST if jtu.device_under_test() == "tpu" else None
|
||||
np_fun = partial(np_op, mode=mode)
|
||||
jnp_fun = partial(jnp_op, mode=mode, precision=precision)
|
||||
def np_fun(x, y):
|
||||
return np_op(x, y, mode=mode).astype(dtypes.to_inexact_dtype(dtype))
|
||||
tol = {np.float16: 2e-1, np.float32: 1e-2, np.float64: 1e-14,
|
||||
np.complex128: 1e-14}
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
|
||||
tol=tol)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
mode=['full', 'same', 'valid'],
|
||||
op=['convolve', 'correlate'],
|
||||
dtype=number_dtypes,
|
||||
xshape=one_dim_array_shapes,
|
||||
yshape=one_dim_array_shapes,
|
||||
)
|
||||
@jtu.skip_on_devices("gpu", "tpu", "rocm") # backends don't support all dtypes.
|
||||
def testConvolutionsPreferredElementType(self, xshape, yshape, dtype, mode, op):
|
||||
jnp_op = getattr(jnp, op)
|
||||
np_op = getattr(np, op)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
|
||||
precision = lax.Precision.HIGHEST if jtu.device_under_test() == "tpu" else None
|
||||
jnp_fun = partial(jnp_op, mode=mode, precision=precision,
|
||||
preferred_element_type=dtype)
|
||||
def np_fun(x, y):
|
||||
return np_op(x, y, mode=mode).astype(dtype)
|
||||
tol = {np.float16: 2e-1, np.float32: 1e-2, np.float64: 1e-14,
|
||||
np.complex128: 1e-14}
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
|
Loading…
x
Reference in New Issue
Block a user