mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
implement jax.numpy.interp (#3949)
This commit is contained in:
parent
d4d7323a57
commit
d6f131aaee
@ -149,6 +149,7 @@ Not every function in NumPy is implemented; contributions are welcome!
|
||||
in1d
|
||||
indices
|
||||
inner
|
||||
interp
|
||||
isclose
|
||||
iscomplex
|
||||
isfinite
|
||||
|
@ -37,7 +37,8 @@ from .lax_numpy import (
|
||||
fmod, frexp, full, full_like, function, gcd, geomspace, gradient, greater,
|
||||
greater_equal, hamming, hanning, heaviside, histogram, histogram_bin_edges,
|
||||
hsplit, hstack, hypot, identity, iinfo, imag,
|
||||
indices, inexact, in1d, inf, inner, int16, int32, int64, int8, int_, integer, intersect1d,
|
||||
indices, inexact, in1d, inf, inner, int16, int32, int64, int8, int_, integer,
|
||||
interp, intersect1d,
|
||||
isclose, iscomplex, iscomplexobj, isfinite, isin, isinf, isnan, isneginf,
|
||||
isposinf, isreal, isrealobj, isscalar, issubdtype, issubsctype, iterable,
|
||||
ix_, kaiser, kron, lcm, ldexp, left_shift, less, less_equal, lexsort, linspace,
|
||||
|
@ -1266,6 +1266,32 @@ else:
|
||||
def _maybe_numpy_1_13_isclose_behavior(a, out):
|
||||
return out
|
||||
|
||||
@_wraps(np.interp)
|
||||
def interp(x, xp, fp, left=None, right=None, period=None):
|
||||
if shape(xp) != shape(fp) or ndim(xp) != 1:
|
||||
raise ValueError("xp and fp must be one-dimensional arrays of equal size")
|
||||
x, xp, fp = map(asarray, _promote_dtypes_inexact(x, xp, fp))
|
||||
if period is not None:
|
||||
if period == 0:
|
||||
raise ValueError(f"period must be a non-zero value; got {period}")
|
||||
period = abs(period)
|
||||
x = x % period
|
||||
xp = xp % period
|
||||
xp, fp = lax.sort_key_val(xp, fp)
|
||||
xp = concatenate([xp[-1:] - period, xp, xp[:1] + period])
|
||||
fp = concatenate([fp[-1:], fp, fp[:1]])
|
||||
|
||||
i = clip(searchsorted(xp, x, side='right'), 1, len(xp) - 1)
|
||||
df = fp[i] - fp[i - 1]
|
||||
dx = xp[i] - xp[i - 1]
|
||||
delta = x - xp[i - 1]
|
||||
f = where((dx == 0) | (x == xp[i]), fp[i], fp[i - 1] + delta * (df / dx))
|
||||
|
||||
if period is None:
|
||||
f = where(x < xp[0], fp[0] if left is None else left, f)
|
||||
f = where(x > xp[-1], fp[-1] if right is None else right, f)
|
||||
return f
|
||||
|
||||
|
||||
@_wraps(np.in1d, lax_description="""
|
||||
In the JAX version, the `assume_unique` argument is not referenced.
|
||||
|
@ -1780,6 +1780,32 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_period={}_left={}_right={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), period, left, right),
|
||||
"shape": shape, "dtype": dtype,
|
||||
"period": period, "left": left, "right": right}
|
||||
for shape in nonempty_shapes
|
||||
for period in [None, 0.59]
|
||||
for left in [None, 0]
|
||||
for right in [None, 1]
|
||||
for dtype in default_dtypes
|
||||
# following types lack precision for meaningful tests
|
||||
if dtype not in [np.int8, np.int16, np.float16, jnp.bfloat16]
|
||||
))
|
||||
def testInterp(self, shape, dtype, period, left, right):
|
||||
rng = jtu.rand_default(self.rng(), scale=10)
|
||||
kwds = dict(period=period, left=left, right=right)
|
||||
np_fun = partial(np.interp, **kwds)
|
||||
jnp_fun = partial(jnp.interp, **kwds)
|
||||
args_maker = lambda: [rng(shape, dtype), np.sort(rng((20,), dtype)), np.linspace(0, 1, 20)]
|
||||
|
||||
# skip numpy comparison for integer types with period specified, because numpy
|
||||
# uses an unstable sort and so results differ for duplicate values.
|
||||
if not (period and np.issubdtype(dtype, np.integer)):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol={np.float32: 2E-4})
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_x1={}_x2={}_x1_rng={}".format(
|
||||
jtu.format_shape_dtype_string(x1_shape, x1_dtype),
|
||||
|
Loading…
x
Reference in New Issue
Block a user