Support extrapolation in jnp.interp

Fixes https://github.com/google/jax/issues/14858
This commit is contained in:
Stephan Hoyer 2023-03-19 17:21:32 -07:00
parent c2b15a1eb8
commit 4009005f0c
2 changed files with 87 additions and 10 deletions

View File

@ -940,11 +940,9 @@ def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike
return lax.eq(a, b)
@util._wraps(np.interp)
@jit
def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
left: Optional[ArrayLike] = None,
right: Optional[ArrayLike] = None,
def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
left: Union[ArrayLike, str, None] = None,
right: Union[ArrayLike, str, None] = None,
period: Optional[ArrayLike] = None) -> Array:
util.check_arraylike("interp", x, xp, fp)
if shape(xp) != shape(fp) or ndim(xp) != 1:
@ -953,6 +951,21 @@ def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
fp_arr, = util.promote_dtypes_inexact(fp)
del x, xp, fp
if isinstance(left, str):
if left != 'extrapolate':
raise ValueError("the only valid string value of `left` is "
f"'extrapolate', but got: {left!r}")
extrapolate_left = True
else:
extrapolate_left = False
if isinstance(right, str):
if right != 'extrapolate':
raise ValueError("the only valid string value of `right` is "
f"'extrapolate', but got: {right!r}")
extrapolate_right = True
else:
extrapolate_right = False
if dtypes.issubdtype(x_arr.dtype, np.complexfloating):
raise ValueError("jnp.interp: complex x values not supported.")
@ -975,15 +988,40 @@ def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
dx0 = lax.abs(dx) <= epsilon # Prevent NaN gradients when `dx` is small.
f = where(dx0, fp_arr[i - 1], fp_arr[i - 1] + (delta / where(dx0, 1, dx)) * df)
left_arr: ArrayLike = fp_arr[0] if left is None else left
right_arr: ArrayLike = fp_arr[-1] if right is None else right
if not extrapolate_left:
assert not isinstance(left, str)
left_arr: ArrayLike = fp_arr[0] if left is None else left
if period is None:
f = where(x_arr < xp_arr[0], left_arr, f)
if not extrapolate_right:
assert not isinstance(right, str)
right_arr: ArrayLike = fp_arr[-1] if right is None else right
if period is None:
f = where(x_arr > xp_arr[-1], right_arr, f)
if period is None:
f = where(x_arr < xp_arr[0], left_arr, f)
f = where(x_arr > xp_arr[-1], right_arr, f)
return f
@util._wraps(np.interp,
lax_description=_dedent("""
In addition to constant interpolation supported by NumPy, jnp.interp also
supports left='extrapolate' and right='extrpolate' to indicate linear
extrpolation instead."""))
def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
left: Union[ArrayLike, str, None] = None,
right: Union[ArrayLike, str, None] = None,
period: Optional[ArrayLike] = None) -> Array:
static_argnames = []
if isinstance(left, str) or left is None:
static_argnames.append('left')
if isinstance(right, str) or right is None:
static_argnames.append('right')
if period is None:
static_argnames.append('period')
jitted_interp = jit(_interp, static_argnames=static_argnames)
return jitted_interp(x, xp, fp, left, right, period)
@overload
def where(condition: ArrayLike, x: Literal[None] = None, y: Literal[None] = None, *,
size: Optional[int] = None,

View File

@ -2176,6 +2176,45 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
rtol=3e-3, atol=1e-3)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product([
dict(x=0.5, left='extrapolate', expected=5),
dict(x=1.5, left='extrapolate', expected=15),
dict(x=3.5, left='extrapolate', expected=30),
dict(x=3.9, right='extrapolate', expected=39),
])
def testInterpExtrapoate(self, x, expected, **kwargs):
xp = jnp.array([1.0, 2.0, 3.0])
fp = jnp.array([10.0, 20.0, 30.0])
actual = jnp.interp(x, xp, fp, **kwargs)
self.assertAlmostEqual(actual, expected)
def testInterpErrors(self):
with self.assertRaisesWithLiteralMatch(
ValueError,
'xp and fp must be one-dimensional arrays of equal size'
):
jnp.interp(0.0, jnp.arange(2.0), jnp.arange(3.0))
with self.assertRaisesWithLiteralMatch(
ValueError,
"the only valid string value of `left` is 'extrapolate', but got: 'interpolate'"
):
jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), left='interpolate')
with self.assertRaisesWithLiteralMatch(
ValueError,
"the only valid string value of `right` is 'extrapolate', but got: 'interpolate'"
):
jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), right='interpolate')
with self.assertRaisesWithLiteralMatch(
ValueError,
"jnp.interp: complex x values not supported."
):
jnp.interp(1j, 1j * np.arange(3.0), np.arange(3.0))
with self.assertRaisesRegex(
ValueError,
"period must be a scalar; got"
):
jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), period=np.array([1.0]))
@jtu.sample_product(
period=[None, 0.59],
left=[None, 0],