implement jax.numpy.interp (#3949)

This commit is contained in:
Jake Vanderplas 2020-08-04 12:39:04 -07:00 committed by GitHub
parent d4d7323a57
commit d6f131aaee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 55 additions and 1 deletions

View File

@ -149,6 +149,7 @@ Not every function in NumPy is implemented; contributions are welcome!
in1d
indices
inner
interp
isclose
iscomplex
isfinite

View File

@ -37,7 +37,8 @@ from .lax_numpy import (
fmod, frexp, full, full_like, function, gcd, geomspace, gradient, greater,
greater_equal, hamming, hanning, heaviside, histogram, histogram_bin_edges,
hsplit, hstack, hypot, identity, iinfo, imag,
indices, inexact, in1d, inf, inner, int16, int32, int64, int8, int_, integer, intersect1d,
indices, inexact, in1d, inf, inner, int16, int32, int64, int8, int_, integer,
interp, intersect1d,
isclose, iscomplex, iscomplexobj, isfinite, isin, isinf, isnan, isneginf,
isposinf, isreal, isrealobj, isscalar, issubdtype, issubsctype, iterable,
ix_, kaiser, kron, lcm, ldexp, left_shift, less, less_equal, lexsort, linspace,

View File

@ -1266,6 +1266,32 @@ else:
def _maybe_numpy_1_13_isclose_behavior(a, out):
return out
@_wraps(np.interp)
def interp(x, xp, fp, left=None, right=None, period=None):
if shape(xp) != shape(fp) or ndim(xp) != 1:
raise ValueError("xp and fp must be one-dimensional arrays of equal size")
x, xp, fp = map(asarray, _promote_dtypes_inexact(x, xp, fp))
if period is not None:
if period == 0:
raise ValueError(f"period must be a non-zero value; got {period}")
period = abs(period)
x = x % period
xp = xp % period
xp, fp = lax.sort_key_val(xp, fp)
xp = concatenate([xp[-1:] - period, xp, xp[:1] + period])
fp = concatenate([fp[-1:], fp, fp[:1]])
i = clip(searchsorted(xp, x, side='right'), 1, len(xp) - 1)
df = fp[i] - fp[i - 1]
dx = xp[i] - xp[i - 1]
delta = x - xp[i - 1]
f = where((dx == 0) | (x == xp[i]), fp[i], fp[i - 1] + delta * (df / dx))
if period is None:
f = where(x < xp[0], fp[0] if left is None else left, f)
f = where(x > xp[-1], fp[-1] if right is None else right, f)
return f
@_wraps(np.in1d, lax_description="""
In the JAX version, the `assume_unique` argument is not referenced.

View File

@ -1780,6 +1780,32 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_period={}_left={}_right={}".format(
jtu.format_shape_dtype_string(shape, dtype), period, left, right),
"shape": shape, "dtype": dtype,
"period": period, "left": left, "right": right}
for shape in nonempty_shapes
for period in [None, 0.59]
for left in [None, 0]
for right in [None, 1]
for dtype in default_dtypes
# following types lack precision for meaningful tests
if dtype not in [np.int8, np.int16, np.float16, jnp.bfloat16]
))
def testInterp(self, shape, dtype, period, left, right):
rng = jtu.rand_default(self.rng(), scale=10)
kwds = dict(period=period, left=left, right=right)
np_fun = partial(np.interp, **kwds)
jnp_fun = partial(jnp.interp, **kwds)
args_maker = lambda: [rng(shape, dtype), np.sort(rng((20,), dtype)), np.linspace(0, 1, 20)]
# skip numpy comparison for integer types with period specified, because numpy
# uses an unstable sort and so results differ for duplicate values.
if not (period and np.issubdtype(dtype, np.integer)):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol={np.float32: 2E-4})
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_x1={}_x2={}_x1_rng={}".format(
jtu.format_shape_dtype_string(x1_shape, x1_dtype),