Merge pull request #16487 from jakevdp:convolve-dtype

PiperOrigin-RevId: 542929304
This commit is contained in:
jax authors 2023-06-23 12:32:36 -07:00
commit 01a16f5914
2 changed files with 58 additions and 15 deletions

View File

@ -349,11 +349,24 @@ def trunc(x: ArrayLike) -> Array:
return where(lax.lt(x, _lax_const(x, 0)), ufuncs.ceil(x), ufuncs.floor(x))
@partial(jit, static_argnums=(2, 3, 4))
def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike) -> Array:
_PREFERRED_ELEMENT_TYPE_DESCRIPTION = """
preferred_element_type : dtype, optional
If specified, accumulate results and return a result of the given data type.
If not specified, the function instead follows the numpy convention of always
accumulating results and returning an inexact dtype.
"""
@partial(jit, static_argnames=['mode', 'op', 'precision', 'preferred_element_type'])
def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike,
preferred_element_type: Optional[DTypeLike] = None) -> Array:
if ndim(x) != 1 or ndim(y) != 1:
raise ValueError(f"{op}() only support 1-dimensional inputs.")
x, y = util.promote_dtypes_inexact(x, y)
if preferred_element_type is None:
# if unspecified, promote to inexact following NumPy's default for convolutions.
x, y = util.promote_dtypes_inexact(x, y)
else:
# otherwise cast to same type but otherwise preserve input dtypes
x, y = util.promote_dtypes(x, y)
if len(x) == 0 or len(y) == 0:
raise ValueError(f"{op}: inputs cannot be empty, got shapes {x.shape} and {y.shape}.")
@ -378,24 +391,31 @@ def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike) -> A
raise ValueError("mode must be one of ['full', 'same', 'valid']")
result = lax.conv_general_dilated(x[None, None, :], y[None, None, :], (1,),
padding, precision=precision)
padding, precision=precision,
preferred_element_type=preferred_element_type)
return result[0, 0, out_order]
@util._wraps(np.convolve, lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('mode', 'precision'))
@util._wraps(np.convolve, lax_description=_PRECISION_DOC,
extra_params=_PREFERRED_ELEMENT_TYPE_DESCRIPTION)
@partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type'))
def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *,
precision: PrecisionLike = None) -> Array:
precision: PrecisionLike = None,
preferred_element_type: Optional[dtype] = None) -> Array:
util.check_arraylike("convolve", a, v)
return _conv(asarray(a), asarray(v), mode, 'convolve', precision)
return _conv(asarray(a), asarray(v), mode=mode, op='convolve',
precision=precision, preferred_element_type=preferred_element_type)
@util._wraps(np.correlate, lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('mode', 'precision'))
@util._wraps(np.correlate, lax_description=_PRECISION_DOC,
extra_params=_PREFERRED_ELEMENT_TYPE_DESCRIPTION)
@partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type'))
def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *,
precision: PrecisionLike = None) -> Array:
precision: PrecisionLike = None,
preferred_element_type: Optional[dtype] = None) -> Array:
util.check_arraylike("correlate", a, v)
return _conv(asarray(a), asarray(v), mode, 'correlate', precision)
return _conv(asarray(a), asarray(v), mode=mode, op='correlate',
precision=precision, preferred_element_type=preferred_element_type)
@util._wraps(np.histogram_bin_edges)

View File

@ -1894,12 +1894,35 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
precision = lax.Precision.HIGHEST if jtu.device_under_test() == "tpu" else None
np_fun = partial(np_op, mode=mode)
jnp_fun = partial(jnp_op, mode=mode, precision=precision)
def np_fun(x, y):
return np_op(x, y, mode=mode).astype(dtypes.to_inexact_dtype(dtype))
tol = {np.float16: 2e-1, np.float32: 1e-2, np.float64: 1e-14,
np.complex128: 1e-14}
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
tol=tol)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
mode=['full', 'same', 'valid'],
op=['convolve', 'correlate'],
dtype=number_dtypes,
xshape=one_dim_array_shapes,
yshape=one_dim_array_shapes,
)
@jtu.skip_on_devices("gpu", "tpu", "rocm") # backends don't support all dtypes.
def testConvolutionsPreferredElementType(self, xshape, yshape, dtype, mode, op):
jnp_op = getattr(jnp, op)
np_op = getattr(np, op)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
precision = lax.Precision.HIGHEST if jtu.device_under_test() == "tpu" else None
jnp_fun = partial(jnp_op, mode=mode, precision=precision,
preferred_element_type=dtype)
def np_fun(x, y):
return np_op(x, y, mode=mode).astype(dtype)
tol = {np.float16: 2e-1, np.float32: 1e-2, np.float64: 1e-14,
np.complex128: 1e-14}
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(