rocm_jax/jax/_src/scipy/ndimage.py
2024-09-20 07:52:33 -07:00

184 lines
6.9 KiB
Python

# Copyright 2019 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable, Sequence
import functools
import itertools
import operator
from jax._src import api
from jax._src import util
from jax import lax
import jax.numpy as jnp
from jax._src.typing import ArrayLike, Array
from jax._src.util import safe_zip as zip
def _nonempty_prod(arrs: Sequence[Array]) -> Array:
return functools.reduce(operator.mul, arrs)
def _nonempty_sum(arrs: Sequence[Array]) -> Array:
return functools.reduce(operator.add, arrs)
def _mirror_index_fixer(index: Array, size: int) -> Array:
s = size - 1 # Half-wavelength of triangular wave
# Scaled, integer-valued version of the triangular wave |x - round(x)|
return jnp.abs((index + s) % (2 * s) - s)
def _reflect_index_fixer(index: Array, size: int) -> Array:
return jnp.floor_divide(_mirror_index_fixer(2*index+1, 2*size+1) - 1, 2)
_INDEX_FIXERS: dict[str, Callable[[Array, int], Array]] = {
'constant': lambda index, size: index,
'nearest': lambda index, size: jnp.clip(index, 0, size - 1),
'wrap': lambda index, size: index % size,
'mirror': _mirror_index_fixer,
'reflect': _reflect_index_fixer,
}
def _round_half_away_from_zero(a: Array) -> Array:
return a if jnp.issubdtype(a.dtype, jnp.integer) else lax.round(a)
def _nearest_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]:
index = _round_half_away_from_zero(coordinate).astype(jnp.int32)
weight = coordinate.dtype.type(1)
return [(index, weight)]
def _linear_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]:
lower = jnp.floor(coordinate)
upper_weight = coordinate - lower
lower_weight = 1 - upper_weight
index = lower.astype(jnp.int32)
return [(index, lower_weight), (index + 1, upper_weight)]
@functools.partial(api.jit, static_argnums=(2, 3, 4))
def _map_coordinates(input: ArrayLike, coordinates: Sequence[ArrayLike],
order: int, mode: str, cval: ArrayLike) -> Array:
input_arr = jnp.asarray(input)
coordinate_arrs = [jnp.asarray(c) for c in coordinates]
cval = jnp.asarray(cval, input_arr.dtype)
if len(coordinates) != input_arr.ndim:
raise ValueError('coordinates must be a sequence of length input.ndim, but '
'{} != {}'.format(len(coordinates), input_arr.ndim))
index_fixer = _INDEX_FIXERS.get(mode)
if index_fixer is None:
raise NotImplementedError(
'jax.scipy.ndimage.map_coordinates does not yet support mode {}. '
'Currently supported modes are {}.'.format(mode, set(_INDEX_FIXERS)))
if mode == 'constant':
is_valid = lambda index, size: (0 <= index) & (index < size)
else:
is_valid = lambda index, size: True
if order == 0:
interp_fun = _nearest_indices_and_weights
elif order == 1:
interp_fun = _linear_indices_and_weights
else:
raise NotImplementedError(
'jax.scipy.ndimage.map_coordinates currently requires order<=1')
valid_1d_interpolations = []
for coordinate, size in zip(coordinate_arrs, input_arr.shape):
interp_nodes = interp_fun(coordinate)
valid_interp = []
for index, weight in interp_nodes:
fixed_index = index_fixer(index, size)
valid = is_valid(index, size)
valid_interp.append((fixed_index, valid, weight))
valid_1d_interpolations.append(valid_interp)
outputs = []
for items in itertools.product(*valid_1d_interpolations):
indices, validities, weights = util.unzip3(items)
if all(valid is True for valid in validities):
# fast path
contribution = input_arr[indices]
else:
all_valid = functools.reduce(operator.and_, validities)
contribution = jnp.where(all_valid, input_arr[indices], cval)
outputs.append(_nonempty_prod(weights) * contribution) # type: ignore
result = _nonempty_sum(outputs)
if jnp.issubdtype(input_arr.dtype, jnp.integer):
result = _round_half_away_from_zero(result)
return result.astype(input_arr.dtype)
"""
Only nearest neighbor (``order=0``), linear interpolation (``order=1``) and
modes ``'constant'``, ``'nearest'``, ``'wrap'`` ``'mirror'`` and ``'reflect'`` are currently supported.
"""
def map_coordinates(
input: ArrayLike, coordinates: Sequence[ArrayLike], order: int,
mode: str = 'constant', cval: ArrayLike = 0.0,
):
"""
Map the input array to new coordinates using interpolation.
JAX implementation of :func:`scipy.ndimage.map_coordinates`
Given an input array and a set of coordinates, this function returns the
interpolated values of the input array at those coordinates.
Args:
input: N-dimensional input array from which values are interpolated.
coordinates: length-N sequence of arrays specifying the coordinates
at which to evaluate the interpolated values
order: The order of interpolation. JAX supports the following:
* 0: Nearest-neighbor
* 1: Linear
mode: Points outside the boundaries of the input are filled according to the given mode.
JAX supports one of ``('constant', 'nearest', 'mirror', 'wrap', 'reflect')``. Note the
``'wrap'`` mode in JAX behaves as ``'grid-wrap'`` mode in SciPy, and ``'constant'``
mode in JAX behaves as ``'grid-constant'`` mode in SciPy. This discrepancy was caused
by a former bug in those modes in SciPy (https://github.com/scipy/scipy/issues/2640),
which was first fixed in JAX by changing the behavior of the existing modes, and later
on fixed in SciPy, by adding modes with new names, rather than fixing the existing
ones, for backwards compatibility reasons. Default is 'constant'.
cval: Value used for points outside the boundaries of the input if ``mode='constant'``
Default is 0.0.
Returns:
The interpolated values at the specified coordinates.
Examples:
>>> input = jnp.arange(12.0).reshape(3, 4)
>>> input
Array([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]], dtype=float32)
>>> coordinates = [jnp.array([0.5, 1.5]),
... jnp.array([1.5, 2.5])]
>>> jax.scipy.ndimage.map_coordinates(input, coordinates, order=1)
Array([3.5, 8.5], dtype=float32)
Note:
Interpolation near boundaries differs from the scipy function, because JAX
fixed an outstanding bug; see https://github.com/jax-ml/jax/issues/11097.
This function interprets the ``mode`` argument as documented by SciPy, but
not as implemented by SciPy.
"""
return _map_coordinates(input, coordinates, order, mode, cval)