mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Support extrapolation in jnp.interp
Fixes https://github.com/google/jax/issues/14858
This commit is contained in:
parent
c2b15a1eb8
commit
4009005f0c
@ -940,11 +940,9 @@ def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike
|
||||
return lax.eq(a, b)
|
||||
|
||||
|
||||
@util._wraps(np.interp)
|
||||
@jit
|
||||
def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
|
||||
left: Optional[ArrayLike] = None,
|
||||
right: Optional[ArrayLike] = None,
|
||||
def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
|
||||
left: Union[ArrayLike, str, None] = None,
|
||||
right: Union[ArrayLike, str, None] = None,
|
||||
period: Optional[ArrayLike] = None) -> Array:
|
||||
util.check_arraylike("interp", x, xp, fp)
|
||||
if shape(xp) != shape(fp) or ndim(xp) != 1:
|
||||
@ -953,6 +951,21 @@ def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
|
||||
fp_arr, = util.promote_dtypes_inexact(fp)
|
||||
del x, xp, fp
|
||||
|
||||
if isinstance(left, str):
|
||||
if left != 'extrapolate':
|
||||
raise ValueError("the only valid string value of `left` is "
|
||||
f"'extrapolate', but got: {left!r}")
|
||||
extrapolate_left = True
|
||||
else:
|
||||
extrapolate_left = False
|
||||
if isinstance(right, str):
|
||||
if right != 'extrapolate':
|
||||
raise ValueError("the only valid string value of `right` is "
|
||||
f"'extrapolate', but got: {right!r}")
|
||||
extrapolate_right = True
|
||||
else:
|
||||
extrapolate_right = False
|
||||
|
||||
if dtypes.issubdtype(x_arr.dtype, np.complexfloating):
|
||||
raise ValueError("jnp.interp: complex x values not supported.")
|
||||
|
||||
@ -975,15 +988,40 @@ def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
|
||||
dx0 = lax.abs(dx) <= epsilon # Prevent NaN gradients when `dx` is small.
|
||||
f = where(dx0, fp_arr[i - 1], fp_arr[i - 1] + (delta / where(dx0, 1, dx)) * df)
|
||||
|
||||
left_arr: ArrayLike = fp_arr[0] if left is None else left
|
||||
right_arr: ArrayLike = fp_arr[-1] if right is None else right
|
||||
if not extrapolate_left:
|
||||
assert not isinstance(left, str)
|
||||
left_arr: ArrayLike = fp_arr[0] if left is None else left
|
||||
if period is None:
|
||||
f = where(x_arr < xp_arr[0], left_arr, f)
|
||||
if not extrapolate_right:
|
||||
assert not isinstance(right, str)
|
||||
right_arr: ArrayLike = fp_arr[-1] if right is None else right
|
||||
if period is None:
|
||||
f = where(x_arr > xp_arr[-1], right_arr, f)
|
||||
|
||||
if period is None:
|
||||
f = where(x_arr < xp_arr[0], left_arr, f)
|
||||
f = where(x_arr > xp_arr[-1], right_arr, f)
|
||||
return f
|
||||
|
||||
|
||||
@util._wraps(np.interp,
|
||||
lax_description=_dedent("""
|
||||
In addition to constant interpolation supported by NumPy, jnp.interp also
|
||||
supports left='extrapolate' and right='extrpolate' to indicate linear
|
||||
extrpolation instead."""))
|
||||
def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
|
||||
left: Union[ArrayLike, str, None] = None,
|
||||
right: Union[ArrayLike, str, None] = None,
|
||||
period: Optional[ArrayLike] = None) -> Array:
|
||||
static_argnames = []
|
||||
if isinstance(left, str) or left is None:
|
||||
static_argnames.append('left')
|
||||
if isinstance(right, str) or right is None:
|
||||
static_argnames.append('right')
|
||||
if period is None:
|
||||
static_argnames.append('period')
|
||||
jitted_interp = jit(_interp, static_argnames=static_argnames)
|
||||
return jitted_interp(x, xp, fp, left, right, period)
|
||||
|
||||
|
||||
@overload
|
||||
def where(condition: ArrayLike, x: Literal[None] = None, y: Literal[None] = None, *,
|
||||
size: Optional[int] = None,
|
||||
|
@ -2176,6 +2176,45 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
rtol=3e-3, atol=1e-3)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product([
|
||||
dict(x=0.5, left='extrapolate', expected=5),
|
||||
dict(x=1.5, left='extrapolate', expected=15),
|
||||
dict(x=3.5, left='extrapolate', expected=30),
|
||||
dict(x=3.9, right='extrapolate', expected=39),
|
||||
])
|
||||
def testInterpExtrapoate(self, x, expected, **kwargs):
|
||||
xp = jnp.array([1.0, 2.0, 3.0])
|
||||
fp = jnp.array([10.0, 20.0, 30.0])
|
||||
actual = jnp.interp(x, xp, fp, **kwargs)
|
||||
self.assertAlmostEqual(actual, expected)
|
||||
|
||||
def testInterpErrors(self):
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
'xp and fp must be one-dimensional arrays of equal size'
|
||||
):
|
||||
jnp.interp(0.0, jnp.arange(2.0), jnp.arange(3.0))
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"the only valid string value of `left` is 'extrapolate', but got: 'interpolate'"
|
||||
):
|
||||
jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), left='interpolate')
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"the only valid string value of `right` is 'extrapolate', but got: 'interpolate'"
|
||||
):
|
||||
jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), right='interpolate')
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"jnp.interp: complex x values not supported."
|
||||
):
|
||||
jnp.interp(1j, 1j * np.arange(3.0), np.arange(3.0))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"period must be a scalar; got"
|
||||
):
|
||||
jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), period=np.array([1.0]))
|
||||
|
||||
@jtu.sample_product(
|
||||
period=[None, 0.59],
|
||||
left=[None, 0],
|
||||
|
Loading…
x
Reference in New Issue
Block a user