mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Merge pull request #8625 from Edenhofer:regular_grid_interpolator
PiperOrigin-RevId: 423876494
This commit is contained in:
commit
16c809ce7f
13
jax/_src/scipy/interpolate/__init__.py
Normal file
13
jax/_src/scipy/interpolate/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
30
jax/_src/third_party/scipy/LICENSE.txt
vendored
Normal file
30
jax/_src/third_party/scipy/LICENSE.txt
vendored
Normal file
@ -0,0 +1,30 @@
|
||||
Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers.
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following
|
||||
disclaimer in the documentation and/or other materials provided
|
||||
with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
0
jax/_src/third_party/scipy/__init__.py
vendored
Normal file
0
jax/_src/third_party/scipy/__init__.py
vendored
Normal file
158
jax/_src/third_party/scipy/interpolate.py
vendored
Normal file
158
jax/_src/third_party/scipy/interpolate.py
vendored
Normal file
@ -0,0 +1,158 @@
|
||||
from itertools import product
|
||||
import scipy.interpolate as osp_interpolate
|
||||
|
||||
from jax._src.tree_util import register_pytree_node
|
||||
from jax._src.numpy.lax_numpy import (_check_arraylike, _promote_dtypes_inexact,
|
||||
asarray, broadcast_arrays, can_cast,
|
||||
empty, nan, searchsorted, where, zeros)
|
||||
from jax._src.numpy.util import _wraps
|
||||
|
||||
|
||||
def _ndim_coords_from_arrays(points, ndim=None):
|
||||
"""Convert a tuple of coordinate arrays to a (..., ndim)-shaped array."""
|
||||
if isinstance(points, tuple) and len(points) == 1:
|
||||
# handle argument tuple
|
||||
points = points[0]
|
||||
if isinstance(points, tuple):
|
||||
p = broadcast_arrays(*points)
|
||||
for p_other in p[1:]:
|
||||
if p_other.shape != p[0].shape:
|
||||
raise ValueError("coordinate arrays do not have the same shape")
|
||||
points = empty(p[0].shape + (len(points),), dtype=float)
|
||||
for j, item in enumerate(p):
|
||||
points = points.at[..., j].set(item)
|
||||
else:
|
||||
_check_arraylike("_ndim_coords_from_arrays", points)
|
||||
points = asarray(points) # SciPy: asanyarray(points)
|
||||
if points.ndim == 1:
|
||||
if ndim is None:
|
||||
points = points.reshape(-1, 1)
|
||||
else:
|
||||
points = points.reshape(-1, ndim)
|
||||
return points
|
||||
|
||||
|
||||
@_wraps(
|
||||
osp_interpolate.RegularGridInterpolator,
|
||||
lax_description="""
|
||||
In the JAX version, `bounds_error` defaults to and must always be `False` since no
|
||||
bound error may be raised under JIT.
|
||||
|
||||
Furthermore, in contrast to SciPy no input validation is performed.
|
||||
""")
|
||||
class RegularGridInterpolator:
|
||||
# Based on SciPy's implementation which in turn is originally based on an
|
||||
# implementation by Johannes Buchner
|
||||
|
||||
def __init__(self,
|
||||
points,
|
||||
values,
|
||||
method="linear",
|
||||
bounds_error=False,
|
||||
fill_value=nan):
|
||||
if method not in ("linear", "nearest"):
|
||||
raise ValueError(f"method {method!r} is not defined")
|
||||
self.method = method
|
||||
self.bounds_error = bounds_error
|
||||
if self.bounds_error:
|
||||
raise NotImplementedError("`bounds_error` takes no effect under JIT")
|
||||
|
||||
_check_arraylike("RegularGridInterpolator", values)
|
||||
if len(points) > values.ndim:
|
||||
ve = f"there are {len(points)} point arrays, but values has {values.ndim} dimensions"
|
||||
raise ValueError(ve)
|
||||
|
||||
values, = _promote_dtypes_inexact(values)
|
||||
|
||||
if fill_value is not None:
|
||||
_check_arraylike("RegularGridInterpolator", fill_value)
|
||||
fill_value = asarray(fill_value)
|
||||
if not can_cast(fill_value.dtype, values.dtype, casting='same_kind'):
|
||||
ve = "fill_value must be either 'None' or of a type compatible with values"
|
||||
raise ValueError(ve)
|
||||
self.fill_value = fill_value
|
||||
|
||||
# TODO: assert sanity of `points` similar to SciPy but in a JIT-able way
|
||||
_check_arraylike("RegularGridInterpolator", *points)
|
||||
self.grid = tuple(asarray(p) for p in points)
|
||||
self.values = values
|
||||
|
||||
@_wraps(osp_interpolate.RegularGridInterpolator.__call__, update_doc=False)
|
||||
def __call__(self, xi, method=None):
|
||||
method = self.method if method is None else method
|
||||
if method not in ("linear", "nearest"):
|
||||
raise ValueError(f"method {method!r} is not defined")
|
||||
|
||||
ndim = len(self.grid)
|
||||
xi = _ndim_coords_from_arrays(xi, ndim=ndim)
|
||||
if xi.shape[-1] != len(self.grid):
|
||||
raise ValueError("the requested sample points xi have dimension"
|
||||
f" {xi.shape[1]}, but this RegularGridInterpolator has"
|
||||
f" dimension {ndim}")
|
||||
|
||||
xi_shape = xi.shape
|
||||
xi = xi.reshape(-1, xi_shape[-1])
|
||||
|
||||
indices, norm_distances, out_of_bounds = self._find_indices(xi.T)
|
||||
if method == "linear":
|
||||
result = self._evaluate_linear(indices, norm_distances)
|
||||
elif method == "nearest":
|
||||
result = self._evaluate_nearest(indices, norm_distances)
|
||||
else:
|
||||
raise AssertionError("method must be bound")
|
||||
if not self.bounds_error and self.fill_value is not None:
|
||||
bc_shp = result.shape[:1] + (1,) * (result.ndim - 1)
|
||||
result = where(out_of_bounds.reshape(bc_shp), self.fill_value, result)
|
||||
|
||||
return result.reshape(xi_shape[:-1] + self.values.shape[ndim:])
|
||||
|
||||
def _evaluate_linear(self, indices, norm_distances):
|
||||
# slice for broadcasting over trailing dimensions in self.values
|
||||
vslice = (slice(None),) + (None,) * (self.values.ndim - len(indices))
|
||||
|
||||
# find relevant values
|
||||
# each i and i+1 represents a edge
|
||||
edges = product(*[[i, i + 1] for i in indices])
|
||||
values = asarray(0.)
|
||||
for edge_indices in edges:
|
||||
weight = asarray(1.)
|
||||
for ei, i, yi in zip(edge_indices, indices, norm_distances):
|
||||
weight *= where(ei == i, 1 - yi, yi)
|
||||
values += self.values[edge_indices] * weight[vslice]
|
||||
return values
|
||||
|
||||
def _evaluate_nearest(self, indices, norm_distances):
|
||||
idx_res = [
|
||||
where(yi <= .5, i, i + 1) for i, yi in zip(indices, norm_distances)
|
||||
]
|
||||
return self.values[tuple(idx_res)]
|
||||
|
||||
def _find_indices(self, xi):
|
||||
# find relevant edges between which xi are situated
|
||||
indices = []
|
||||
# compute distance to lower edge in unity units
|
||||
norm_distances = []
|
||||
# check for out of bounds xi
|
||||
out_of_bounds = zeros((xi.shape[1],), dtype=bool)
|
||||
# iterate through dimensions
|
||||
for x, g in zip(xi, self.grid):
|
||||
i = searchsorted(g, x) - 1
|
||||
i = where(i < 0, 0, i)
|
||||
i = where(i > g.size - 2, g.size - 2, i)
|
||||
indices.append(i)
|
||||
norm_distances.append((x - g[i]) / (g[i + 1] - g[i]))
|
||||
if not self.bounds_error:
|
||||
out_of_bounds += x < g[0]
|
||||
out_of_bounds += x > g[-1]
|
||||
return indices, norm_distances, out_of_bounds
|
||||
|
||||
|
||||
register_pytree_node(
|
||||
RegularGridInterpolator,
|
||||
lambda obj: ((obj.grid, obj.values, obj.fill_value),
|
||||
(obj.method, obj.bounds_error)),
|
||||
lambda aux, children: RegularGridInterpolator(
|
||||
*children[:2], # type: ignore[index]
|
||||
*aux,
|
||||
*children[2:]), # type: ignore[index]
|
||||
)
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
from jax.scipy import interpolate as interpolate
|
||||
from jax.scipy import linalg as linalg
|
||||
from jax.scipy import ndimage as ndimage
|
||||
from jax.scipy import signal as signal
|
||||
|
20
jax/scipy/interpolate/__init__.py
Normal file
20
jax/scipy/interpolate/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
|
||||
# Already deprecate namespaces that will be removed in SciPy v2.0.0
|
||||
|
||||
from jax._src.third_party.scipy.interpolate import (
|
||||
RegularGridInterpolator as RegularGridInterpolator)
|
71
tests/scipy_interpolate_test.py
Normal file
71
tests/scipy_interpolate_test.py
Normal file
@ -0,0 +1,71 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
import numpy as np
|
||||
|
||||
from jax._src import test_util as jtu
|
||||
import scipy.interpolate as sp_interp
|
||||
import jax.scipy.interpolate as jsp_interp
|
||||
|
||||
from jax.config import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class LaxBackedScipyInterpolateTests(jtu.JaxTestCase):
|
||||
"""Tests for LAX-backed scipy.interpolate implementations"""
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list({
|
||||
"testcase_name": f"_spaces={spaces}_method={method}",
|
||||
"spaces": spaces,
|
||||
"method": method
|
||||
}
|
||||
for spaces in (((0., 10., 10),), ((-15., 20., 12),
|
||||
(3., 4., 24)))
|
||||
for method in ("linear", "nearest")))
|
||||
def testRegularGridInterpolator(self, spaces, method):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = lambda init_args, call_args: sp_interp.RegularGridInterpolator(
|
||||
*init_args[:2], method, False, *init_args[2:])(*call_args)
|
||||
lax_fun = lambda init_args, call_args: jsp_interp.RegularGridInterpolator(
|
||||
*init_args[:2], method, False, *init_args[2:])(*call_args)
|
||||
|
||||
def args_maker():
|
||||
points = tuple(map(lambda x: np.linspace(*x), spaces))
|
||||
values = rng(reduce(operator.add, tuple(map(np.shape, points))), float)
|
||||
fill_value = np.nan
|
||||
|
||||
init_args = (points, values, fill_value)
|
||||
n_validation_points = 50
|
||||
valid_points = tuple(
|
||||
map(
|
||||
lambda x: np.linspace(x[0] - 0.2 * (x[1] - x[0]), x[1] + 0.2 *
|
||||
(x[1] - x[0]), n_validation_points),
|
||||
spaces))
|
||||
valid_points = np.squeeze(np.stack(valid_points, axis=1))
|
||||
call_args = (valid_points,)
|
||||
return init_args, call_args
|
||||
|
||||
self._CheckAgainstNumpy(
|
||||
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user