Merge pull request #8625 from Edenhofer:regular_grid_interpolator

PiperOrigin-RevId: 423876494
This commit is contained in:
jax authors 2022-01-24 12:02:34 -08:00
commit 16c809ce7f
7 changed files with 293 additions and 0 deletions

View File

@ -0,0 +1,13 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

30
jax/_src/third_party/scipy/LICENSE.txt vendored Normal file
View File

@ -0,0 +1,30 @@
Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following
disclaimer in the documentation and/or other materials provided
with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

View File

@ -0,0 +1,158 @@
from itertools import product
import scipy.interpolate as osp_interpolate
from jax._src.tree_util import register_pytree_node
from jax._src.numpy.lax_numpy import (_check_arraylike, _promote_dtypes_inexact,
asarray, broadcast_arrays, can_cast,
empty, nan, searchsorted, where, zeros)
from jax._src.numpy.util import _wraps
def _ndim_coords_from_arrays(points, ndim=None):
"""Convert a tuple of coordinate arrays to a (..., ndim)-shaped array."""
if isinstance(points, tuple) and len(points) == 1:
# handle argument tuple
points = points[0]
if isinstance(points, tuple):
p = broadcast_arrays(*points)
for p_other in p[1:]:
if p_other.shape != p[0].shape:
raise ValueError("coordinate arrays do not have the same shape")
points = empty(p[0].shape + (len(points),), dtype=float)
for j, item in enumerate(p):
points = points.at[..., j].set(item)
else:
_check_arraylike("_ndim_coords_from_arrays", points)
points = asarray(points) # SciPy: asanyarray(points)
if points.ndim == 1:
if ndim is None:
points = points.reshape(-1, 1)
else:
points = points.reshape(-1, ndim)
return points
@_wraps(
osp_interpolate.RegularGridInterpolator,
lax_description="""
In the JAX version, `bounds_error` defaults to and must always be `False` since no
bound error may be raised under JIT.
Furthermore, in contrast to SciPy no input validation is performed.
""")
class RegularGridInterpolator:
# Based on SciPy's implementation which in turn is originally based on an
# implementation by Johannes Buchner
def __init__(self,
points,
values,
method="linear",
bounds_error=False,
fill_value=nan):
if method not in ("linear", "nearest"):
raise ValueError(f"method {method!r} is not defined")
self.method = method
self.bounds_error = bounds_error
if self.bounds_error:
raise NotImplementedError("`bounds_error` takes no effect under JIT")
_check_arraylike("RegularGridInterpolator", values)
if len(points) > values.ndim:
ve = f"there are {len(points)} point arrays, but values has {values.ndim} dimensions"
raise ValueError(ve)
values, = _promote_dtypes_inexact(values)
if fill_value is not None:
_check_arraylike("RegularGridInterpolator", fill_value)
fill_value = asarray(fill_value)
if not can_cast(fill_value.dtype, values.dtype, casting='same_kind'):
ve = "fill_value must be either 'None' or of a type compatible with values"
raise ValueError(ve)
self.fill_value = fill_value
# TODO: assert sanity of `points` similar to SciPy but in a JIT-able way
_check_arraylike("RegularGridInterpolator", *points)
self.grid = tuple(asarray(p) for p in points)
self.values = values
@_wraps(osp_interpolate.RegularGridInterpolator.__call__, update_doc=False)
def __call__(self, xi, method=None):
method = self.method if method is None else method
if method not in ("linear", "nearest"):
raise ValueError(f"method {method!r} is not defined")
ndim = len(self.grid)
xi = _ndim_coords_from_arrays(xi, ndim=ndim)
if xi.shape[-1] != len(self.grid):
raise ValueError("the requested sample points xi have dimension"
f" {xi.shape[1]}, but this RegularGridInterpolator has"
f" dimension {ndim}")
xi_shape = xi.shape
xi = xi.reshape(-1, xi_shape[-1])
indices, norm_distances, out_of_bounds = self._find_indices(xi.T)
if method == "linear":
result = self._evaluate_linear(indices, norm_distances)
elif method == "nearest":
result = self._evaluate_nearest(indices, norm_distances)
else:
raise AssertionError("method must be bound")
if not self.bounds_error and self.fill_value is not None:
bc_shp = result.shape[:1] + (1,) * (result.ndim - 1)
result = where(out_of_bounds.reshape(bc_shp), self.fill_value, result)
return result.reshape(xi_shape[:-1] + self.values.shape[ndim:])
def _evaluate_linear(self, indices, norm_distances):
# slice for broadcasting over trailing dimensions in self.values
vslice = (slice(None),) + (None,) * (self.values.ndim - len(indices))
# find relevant values
# each i and i+1 represents a edge
edges = product(*[[i, i + 1] for i in indices])
values = asarray(0.)
for edge_indices in edges:
weight = asarray(1.)
for ei, i, yi in zip(edge_indices, indices, norm_distances):
weight *= where(ei == i, 1 - yi, yi)
values += self.values[edge_indices] * weight[vslice]
return values
def _evaluate_nearest(self, indices, norm_distances):
idx_res = [
where(yi <= .5, i, i + 1) for i, yi in zip(indices, norm_distances)
]
return self.values[tuple(idx_res)]
def _find_indices(self, xi):
# find relevant edges between which xi are situated
indices = []
# compute distance to lower edge in unity units
norm_distances = []
# check for out of bounds xi
out_of_bounds = zeros((xi.shape[1],), dtype=bool)
# iterate through dimensions
for x, g in zip(xi, self.grid):
i = searchsorted(g, x) - 1
i = where(i < 0, 0, i)
i = where(i > g.size - 2, g.size - 2, i)
indices.append(i)
norm_distances.append((x - g[i]) / (g[i + 1] - g[i]))
if not self.bounds_error:
out_of_bounds += x < g[0]
out_of_bounds += x > g[-1]
return indices, norm_distances, out_of_bounds
register_pytree_node(
RegularGridInterpolator,
lambda obj: ((obj.grid, obj.values, obj.fill_value),
(obj.method, obj.bounds_error)),
lambda aux, children: RegularGridInterpolator(
*children[:2], # type: ignore[index]
*aux,
*children[2:]), # type: ignore[index]
)

View File

@ -13,6 +13,7 @@
# limitations under the License.
# flake8: noqa: F401
from jax.scipy import interpolate as interpolate
from jax.scipy import linalg as linalg
from jax.scipy import ndimage as ndimage
from jax.scipy import signal as signal

View File

@ -0,0 +1,20 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa: F401
# Already deprecate namespaces that will be removed in SciPy v2.0.0
from jax._src.third_party.scipy.interpolate import (
RegularGridInterpolator as RegularGridInterpolator)

View File

@ -0,0 +1,71 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest, parameterized
import operator
from functools import reduce
import numpy as np
from jax._src import test_util as jtu
import scipy.interpolate as sp_interp
import jax.scipy.interpolate as jsp_interp
from jax.config import config
config.parse_flags_with_absl()
class LaxBackedScipyInterpolateTests(jtu.JaxTestCase):
"""Tests for LAX-backed scipy.interpolate implementations"""
@parameterized.named_parameters(
jtu.cases_from_list({
"testcase_name": f"_spaces={spaces}_method={method}",
"spaces": spaces,
"method": method
}
for spaces in (((0., 10., 10),), ((-15., 20., 12),
(3., 4., 24)))
for method in ("linear", "nearest")))
def testRegularGridInterpolator(self, spaces, method):
rng = jtu.rand_default(self.rng())
scipy_fun = lambda init_args, call_args: sp_interp.RegularGridInterpolator(
*init_args[:2], method, False, *init_args[2:])(*call_args)
lax_fun = lambda init_args, call_args: jsp_interp.RegularGridInterpolator(
*init_args[:2], method, False, *init_args[2:])(*call_args)
def args_maker():
points = tuple(map(lambda x: np.linspace(*x), spaces))
values = rng(reduce(operator.add, tuple(map(np.shape, points))), float)
fill_value = np.nan
init_args = (points, values, fill_value)
n_validation_points = 50
valid_points = tuple(
map(
lambda x: np.linspace(x[0] - 0.2 * (x[1] - x[0]), x[1] + 0.2 *
(x[1] - x[0]), n_validation_points),
spaces))
valid_points = np.squeeze(np.stack(valid_points, axis=1))
call_args = (valid_points,)
return init_args, call_args
self._CheckAgainstNumpy(
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14})
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())