mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 19:06:07 +00:00
184 lines
6.9 KiB
Python
184 lines
6.9 KiB
Python
# Copyright 2019 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from collections.abc import Callable, Sequence
|
|
import functools
|
|
import itertools
|
|
import operator
|
|
|
|
from jax._src import api
|
|
from jax._src import util
|
|
from jax import lax
|
|
import jax.numpy as jnp
|
|
from jax._src.typing import ArrayLike, Array
|
|
from jax._src.util import safe_zip as zip
|
|
|
|
|
|
def _nonempty_prod(arrs: Sequence[Array]) -> Array:
|
|
return functools.reduce(operator.mul, arrs)
|
|
|
|
def _nonempty_sum(arrs: Sequence[Array]) -> Array:
|
|
return functools.reduce(operator.add, arrs)
|
|
|
|
def _mirror_index_fixer(index: Array, size: int) -> Array:
|
|
s = size - 1 # Half-wavelength of triangular wave
|
|
# Scaled, integer-valued version of the triangular wave |x - round(x)|
|
|
return jnp.abs((index + s) % (2 * s) - s)
|
|
|
|
def _reflect_index_fixer(index: Array, size: int) -> Array:
|
|
return jnp.floor_divide(_mirror_index_fixer(2*index+1, 2*size+1) - 1, 2)
|
|
|
|
_INDEX_FIXERS: dict[str, Callable[[Array, int], Array]] = {
|
|
'constant': lambda index, size: index,
|
|
'nearest': lambda index, size: jnp.clip(index, 0, size - 1),
|
|
'wrap': lambda index, size: index % size,
|
|
'mirror': _mirror_index_fixer,
|
|
'reflect': _reflect_index_fixer,
|
|
}
|
|
|
|
|
|
def _round_half_away_from_zero(a: Array) -> Array:
|
|
return a if jnp.issubdtype(a.dtype, jnp.integer) else lax.round(a)
|
|
|
|
|
|
def _nearest_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]:
|
|
index = _round_half_away_from_zero(coordinate).astype(jnp.int32)
|
|
weight = coordinate.dtype.type(1)
|
|
return [(index, weight)]
|
|
|
|
|
|
def _linear_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]:
|
|
lower = jnp.floor(coordinate)
|
|
upper_weight = coordinate - lower
|
|
lower_weight = 1 - upper_weight
|
|
index = lower.astype(jnp.int32)
|
|
return [(index, lower_weight), (index + 1, upper_weight)]
|
|
|
|
|
|
@functools.partial(api.jit, static_argnums=(2, 3, 4))
|
|
def _map_coordinates(input: ArrayLike, coordinates: Sequence[ArrayLike],
|
|
order: int, mode: str, cval: ArrayLike) -> Array:
|
|
input_arr = jnp.asarray(input)
|
|
coordinate_arrs = [jnp.asarray(c) for c in coordinates]
|
|
cval = jnp.asarray(cval, input_arr.dtype)
|
|
|
|
if len(coordinates) != input_arr.ndim:
|
|
raise ValueError('coordinates must be a sequence of length input.ndim, but '
|
|
'{} != {}'.format(len(coordinates), input_arr.ndim))
|
|
|
|
index_fixer = _INDEX_FIXERS.get(mode)
|
|
if index_fixer is None:
|
|
raise NotImplementedError(
|
|
'jax.scipy.ndimage.map_coordinates does not yet support mode {}. '
|
|
'Currently supported modes are {}.'.format(mode, set(_INDEX_FIXERS)))
|
|
|
|
if mode == 'constant':
|
|
is_valid = lambda index, size: (0 <= index) & (index < size)
|
|
else:
|
|
is_valid = lambda index, size: True
|
|
|
|
if order == 0:
|
|
interp_fun = _nearest_indices_and_weights
|
|
elif order == 1:
|
|
interp_fun = _linear_indices_and_weights
|
|
else:
|
|
raise NotImplementedError(
|
|
'jax.scipy.ndimage.map_coordinates currently requires order<=1')
|
|
|
|
valid_1d_interpolations = []
|
|
for coordinate, size in zip(coordinate_arrs, input_arr.shape):
|
|
interp_nodes = interp_fun(coordinate)
|
|
valid_interp = []
|
|
for index, weight in interp_nodes:
|
|
fixed_index = index_fixer(index, size)
|
|
valid = is_valid(index, size)
|
|
valid_interp.append((fixed_index, valid, weight))
|
|
valid_1d_interpolations.append(valid_interp)
|
|
|
|
outputs = []
|
|
for items in itertools.product(*valid_1d_interpolations):
|
|
indices, validities, weights = util.unzip3(items)
|
|
if all(valid is True for valid in validities):
|
|
# fast path
|
|
contribution = input_arr[indices]
|
|
else:
|
|
all_valid = functools.reduce(operator.and_, validities)
|
|
contribution = jnp.where(all_valid, input_arr[indices], cval)
|
|
outputs.append(_nonempty_prod(weights) * contribution) # type: ignore
|
|
result = _nonempty_sum(outputs)
|
|
if jnp.issubdtype(input_arr.dtype, jnp.integer):
|
|
result = _round_half_away_from_zero(result)
|
|
return result.astype(input_arr.dtype)
|
|
|
|
|
|
"""
|
|
Only nearest neighbor (``order=0``), linear interpolation (``order=1``) and
|
|
modes ``'constant'``, ``'nearest'``, ``'wrap'`` ``'mirror'`` and ``'reflect'`` are currently supported.
|
|
|
|
"""
|
|
|
|
def map_coordinates(
|
|
input: ArrayLike, coordinates: Sequence[ArrayLike], order: int,
|
|
mode: str = 'constant', cval: ArrayLike = 0.0,
|
|
):
|
|
"""
|
|
Map the input array to new coordinates using interpolation.
|
|
|
|
JAX implementation of :func:`scipy.ndimage.map_coordinates`
|
|
|
|
Given an input array and a set of coordinates, this function returns the
|
|
interpolated values of the input array at those coordinates.
|
|
|
|
Args:
|
|
input: N-dimensional input array from which values are interpolated.
|
|
coordinates: length-N sequence of arrays specifying the coordinates
|
|
at which to evaluate the interpolated values
|
|
order: The order of interpolation. JAX supports the following:
|
|
|
|
* 0: Nearest-neighbor
|
|
* 1: Linear
|
|
|
|
mode: Points outside the boundaries of the input are filled according to the given mode.
|
|
JAX supports one of ``('constant', 'nearest', 'mirror', 'wrap', 'reflect')``. Note the
|
|
``'wrap'`` mode in JAX behaves as ``'grid-wrap'`` mode in SciPy, and ``'constant'``
|
|
mode in JAX behaves as ``'grid-constant'`` mode in SciPy. This discrepancy was caused
|
|
by a former bug in those modes in SciPy (https://github.com/scipy/scipy/issues/2640),
|
|
which was first fixed in JAX by changing the behavior of the existing modes, and later
|
|
on fixed in SciPy, by adding modes with new names, rather than fixing the existing
|
|
ones, for backwards compatibility reasons. Default is 'constant'.
|
|
cval: Value used for points outside the boundaries of the input if ``mode='constant'``
|
|
Default is 0.0.
|
|
|
|
Returns:
|
|
The interpolated values at the specified coordinates.
|
|
|
|
Examples:
|
|
>>> input = jnp.arange(12.0).reshape(3, 4)
|
|
>>> input
|
|
Array([[ 0., 1., 2., 3.],
|
|
[ 4., 5., 6., 7.],
|
|
[ 8., 9., 10., 11.]], dtype=float32)
|
|
>>> coordinates = [jnp.array([0.5, 1.5]),
|
|
... jnp.array([1.5, 2.5])]
|
|
>>> jax.scipy.ndimage.map_coordinates(input, coordinates, order=1)
|
|
Array([3.5, 8.5], dtype=float32)
|
|
|
|
Note:
|
|
Interpolation near boundaries differs from the scipy function, because JAX
|
|
fixed an outstanding bug; see https://github.com/jax-ml/jax/issues/11097.
|
|
This function interprets the ``mode`` argument as documented by SciPy, but
|
|
not as implemented by SciPy.
|
|
"""
|
|
return _map_coordinates(input, coordinates, order, mode, cval)
|