2020-07-10 09:57:59 -04:00
|
|
|
|
# Copyright 2020 Google LLC
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
2020-07-10 10:32:13 -04:00
|
|
|
|
from functools import partial
|
2020-07-10 09:57:59 -04:00
|
|
|
|
import enum
|
2020-09-08 13:51:19 -07:00
|
|
|
|
from typing import Callable, Sequence, Union
|
2020-07-10 09:57:59 -04:00
|
|
|
|
|
2021-09-30 12:36:47 -07:00
|
|
|
|
from jax import core
|
2020-07-10 10:32:13 -04:00
|
|
|
|
from jax import jit
|
2020-07-10 09:57:59 -04:00
|
|
|
|
from jax import lax
|
|
|
|
|
from jax import numpy as jnp
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
2020-09-08 13:51:19 -07:00
|
|
|
|
def _fill_lanczos_kernel(radius, x):
|
|
|
|
|
y = radius * jnp.sin(np.pi * x) * jnp.sin(np.pi * x / radius)
|
2021-05-20 21:49:41 -06:00
|
|
|
|
# out = y / (np.pi ** 2 * x ** 2) where x >1e-3, 1 otherwise
|
|
|
|
|
out = jnp.where(x > 1e-3, jnp.divide(y, jnp.where(x != 0, np.pi**2 * x**2, 1)), 1)
|
2020-09-08 13:51:19 -07:00
|
|
|
|
return jnp.where(x > radius, 0., out)
|
2020-07-10 09:57:59 -04:00
|
|
|
|
|
|
|
|
|
|
2020-09-08 13:51:19 -07:00
|
|
|
|
def _fill_keys_cubic_kernel(x):
|
2020-07-10 09:57:59 -04:00
|
|
|
|
# http://ieeexplore.ieee.org/document/1163711/
|
|
|
|
|
# R. G. Keys. Cubic convolution interpolation for digital image processing.
|
|
|
|
|
# IEEE Transactions on Acoustics, Speech, and Signal Processing,
|
|
|
|
|
# 29(6):1153–1160, 1981.
|
2020-09-08 13:51:19 -07:00
|
|
|
|
out = ((1.5 * x - 2.5) * x) * x + 1.
|
2021-07-29 09:51:41 -04:00
|
|
|
|
out = jnp.where(x >= 1., ((-0.5 * x + 2.5) * x - 4.) * x + 2., out)
|
2020-09-08 13:51:19 -07:00
|
|
|
|
return jnp.where(x >= 2., 0., out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _fill_triangle_kernel(x):
|
|
|
|
|
return jnp.maximum(0, 1 - jnp.abs(x))
|
|
|
|
|
|
|
|
|
|
|
2021-11-05 17:03:46 +02:00
|
|
|
|
def compute_weight_mat(input_size: core.DimSize,
|
|
|
|
|
output_size: core.DimSize, scale,
|
2020-09-08 13:51:19 -07:00
|
|
|
|
translation,
|
|
|
|
|
kernel: Callable,
|
|
|
|
|
antialias: bool):
|
2020-07-10 09:57:59 -04:00
|
|
|
|
inv_scale = 1. / scale
|
|
|
|
|
# When downsampling the kernel should be scaled since we want to low pass
|
|
|
|
|
# filter and interpolate, but when upsampling it should not be since we only
|
|
|
|
|
# want to interpolate.
|
2020-09-21 16:20:17 -07:00
|
|
|
|
kernel_scale = jnp.maximum(inv_scale, 1.) if antialias else 1.
|
2020-07-10 09:57:59 -04:00
|
|
|
|
|
2020-12-08 13:03:30 -08:00
|
|
|
|
sample_f = ((jnp.arange(output_size) + 0.5) * inv_scale -
|
2020-09-08 13:51:19 -07:00
|
|
|
|
translation * inv_scale - 0.5)
|
|
|
|
|
x = (
|
|
|
|
|
jnp.abs(sample_f[jnp.newaxis, :] -
|
|
|
|
|
jnp.arange(input_size, dtype=sample_f.dtype)[:, jnp.newaxis]) /
|
|
|
|
|
kernel_scale)
|
|
|
|
|
weights = kernel(x)
|
|
|
|
|
|
|
|
|
|
total_weight_sum = jnp.sum(weights, axis=0, keepdims=True)
|
2021-05-20 21:49:41 -06:00
|
|
|
|
weights = jnp.where(
|
|
|
|
|
jnp.abs(total_weight_sum) > 1000. * np.finfo(np.float32).eps,
|
|
|
|
|
jnp.divide(weights, jnp.where(total_weight_sum != 0, total_weight_sum, 1)),
|
|
|
|
|
0)
|
2020-09-08 13:51:19 -07:00
|
|
|
|
# Zero out weights where the sample location is completely outside the input
|
|
|
|
|
# range.
|
|
|
|
|
# Note sample_f has already had the 0.5 removed, hence the weird range below.
|
2021-11-05 17:03:46 +02:00
|
|
|
|
input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5
|
2020-09-08 13:51:19 -07:00
|
|
|
|
return jnp.where(
|
|
|
|
|
jnp.logical_and(sample_f >= -0.5,
|
2021-11-05 17:03:46 +02:00
|
|
|
|
sample_f <= input_size_minus_0_5)[jnp.newaxis, :], weights, 0)
|
2020-07-10 09:57:59 -04:00
|
|
|
|
|
|
|
|
|
|
2021-11-05 17:03:46 +02:00
|
|
|
|
def _scale_and_translate(x, output_shape: core.Shape,
|
|
|
|
|
spatial_dims: Sequence[int], scale, translation,
|
|
|
|
|
kernel, antialias: bool, precision):
|
2020-07-10 09:57:59 -04:00
|
|
|
|
input_shape = x.shape
|
|
|
|
|
assert len(input_shape) == len(output_shape)
|
2020-09-21 16:20:17 -07:00
|
|
|
|
assert len(spatial_dims) == len(scale)
|
|
|
|
|
assert len(spatial_dims) == len(translation)
|
2020-07-10 09:57:59 -04:00
|
|
|
|
if len(spatial_dims) == 0:
|
|
|
|
|
return x
|
|
|
|
|
contractions = []
|
2020-07-12 14:00:10 -04:00
|
|
|
|
in_indices = list(range(len(output_shape)))
|
2020-07-10 09:57:59 -04:00
|
|
|
|
out_indices = list(range(len(output_shape)))
|
|
|
|
|
for i, d in enumerate(spatial_dims):
|
|
|
|
|
m = input_shape[d]
|
|
|
|
|
n = output_shape[d]
|
2020-09-21 16:20:17 -07:00
|
|
|
|
w = compute_weight_mat(m, n, scale[i], translation[i],
|
2020-09-08 13:51:19 -07:00
|
|
|
|
kernel, antialias).astype(x.dtype)
|
2020-07-12 14:00:10 -04:00
|
|
|
|
contractions.append(w)
|
|
|
|
|
contractions.append([d, len(output_shape) + i])
|
2020-07-10 09:57:59 -04:00
|
|
|
|
out_indices[d] = len(output_shape) + i
|
|
|
|
|
contractions.append(out_indices)
|
2020-07-12 14:00:10 -04:00
|
|
|
|
return jnp.einsum(x, in_indices, *contractions, precision=precision)
|
2020-07-10 09:57:59 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResizeMethod(enum.Enum):
|
2020-07-13 20:27:12 -04:00
|
|
|
|
NEAREST = 0
|
2020-07-10 09:57:59 -04:00
|
|
|
|
LINEAR = 1
|
|
|
|
|
LANCZOS3 = 2
|
|
|
|
|
LANCZOS5 = 3
|
|
|
|
|
CUBIC = 4
|
2020-09-21 16:20:17 -07:00
|
|
|
|
# Caution: The current resize implementation assumes that the resize kernels
|
|
|
|
|
# are interpolating, i.e. for the identity warp the output equals the input.
|
|
|
|
|
# This is not true for, e.g. a Gaussian kernel, so if such kernels are added
|
|
|
|
|
# the implementation will need to be changed.
|
2020-07-10 09:57:59 -04:00
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def from_string(s: str):
|
2020-07-13 20:27:12 -04:00
|
|
|
|
if s == 'nearest':
|
|
|
|
|
return ResizeMethod.NEAREST
|
2020-07-10 09:57:59 -04:00
|
|
|
|
if s in ['linear', 'bilinear', 'trilinear', 'triangle']:
|
|
|
|
|
return ResizeMethod.LINEAR
|
|
|
|
|
elif s == 'lanczos3':
|
|
|
|
|
return ResizeMethod.LANCZOS3
|
|
|
|
|
elif s == 'lanczos5':
|
|
|
|
|
return ResizeMethod.LANCZOS5
|
|
|
|
|
elif s in ['cubic', 'bicubic', 'tricubic']:
|
|
|
|
|
return ResizeMethod.CUBIC
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f'Unknown resize method "{s}"')
|
|
|
|
|
|
2020-09-08 13:51:19 -07:00
|
|
|
|
_kernels = {
|
|
|
|
|
ResizeMethod.LINEAR: _fill_triangle_kernel,
|
|
|
|
|
ResizeMethod.LANCZOS3: lambda x: _fill_lanczos_kernel(3., x),
|
|
|
|
|
ResizeMethod.LANCZOS5: lambda x: _fill_lanczos_kernel(5., x),
|
|
|
|
|
ResizeMethod.CUBIC: _fill_keys_cubic_kernel
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# scale and translation here are scalar elements of an np.array, what is the
|
|
|
|
|
# correct type annotation?
|
2021-11-05 17:03:46 +02:00
|
|
|
|
def scale_and_translate(image, shape: core.Shape,
|
2020-09-21 16:20:17 -07:00
|
|
|
|
spatial_dims: Sequence[int],
|
2020-09-08 13:51:19 -07:00
|
|
|
|
scale, translation,
|
|
|
|
|
method: Union[str, ResizeMethod],
|
|
|
|
|
antialias: bool = True,
|
|
|
|
|
precision=lax.Precision.HIGHEST):
|
|
|
|
|
"""Apply a scale and translation to an image.
|
|
|
|
|
|
|
|
|
|
Generates a new image of shape 'shape' by resampling from the input image
|
|
|
|
|
using the sampling method corresponding to method. For 2D images, this
|
|
|
|
|
operation transforms a location in the input images, (x, y), to a location
|
2021-01-28 15:20:02 -08:00
|
|
|
|
in the output image according to::
|
|
|
|
|
|
2020-09-08 13:51:19 -07:00
|
|
|
|
(x * scale[1] + translation[1], y * scale[0] + translation[0])
|
2021-01-28 15:20:02 -08:00
|
|
|
|
|
2020-09-08 13:51:19 -07:00
|
|
|
|
(Note the _inverse_ warp is used to generate the sample locations.)
|
|
|
|
|
Assumes half-centered pixels, i.e the pixel at integer location row,col has
|
2020-09-21 16:20:17 -07:00
|
|
|
|
coordinates y, x = row + 0.5, col + 0.5.
|
2020-09-08 13:51:19 -07:00
|
|
|
|
Similarly for other input image dimensions.
|
|
|
|
|
|
|
|
|
|
If an output location(pixel) maps to an input sample location that is outside
|
|
|
|
|
the input boundaries then the value for the output location will be set to
|
|
|
|
|
zero.
|
|
|
|
|
|
|
|
|
|
The ``method`` argument expects one of the following resize methods:
|
|
|
|
|
|
2020-09-21 16:20:17 -07:00
|
|
|
|
``ResizeMethod.LINEAR``, ``"linear"``, ``"bilinear"``, ``"trilinear"``,
|
|
|
|
|
``"triangle"`` `Linear interpolation`_. If ``antialias`` is ``True``, uses a
|
|
|
|
|
triangular filter when downsampling.
|
2020-09-08 13:51:19 -07:00
|
|
|
|
|
|
|
|
|
``ResizeMethod.CUBIC``, ``"cubic"``, ``"bicubic"``, ``"tricubic"``
|
|
|
|
|
`Cubic interpolation`_, using the Keys cubic kernel.
|
|
|
|
|
|
|
|
|
|
``ResizeMethod.LANCZOS3``, ``"lanczos3"``
|
|
|
|
|
`Lanczos resampling`_, using a kernel of radius 3.
|
|
|
|
|
|
|
|
|
|
``ResizeMethod.LANCZOS5``, ``"lanczos5"``
|
|
|
|
|
`Lanczos resampling`_, using a kernel of radius 5.
|
|
|
|
|
|
|
|
|
|
.. _Linear interpolation: https://en.wikipedia.org/wiki/Bilinear_interpolation
|
|
|
|
|
.. _Cubic interpolation: https://en.wikipedia.org/wiki/Bicubic_interpolation
|
|
|
|
|
.. _Lanczos resampling: https://en.wikipedia.org/wiki/Lanczos_resampling
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
image: a JAX array.
|
|
|
|
|
shape: the output shape, as a sequence of integers with length equal to the
|
|
|
|
|
number of dimensions of `image`.
|
2020-09-21 16:20:17 -07:00
|
|
|
|
spatial_dims: A length K tuple specifying the spatial dimensions that the
|
|
|
|
|
passed scale and translation should be applied to.
|
2020-09-08 13:51:19 -07:00
|
|
|
|
scale: A [K] array with the same number of dimensions as image, containing
|
|
|
|
|
the scale to apply in each dimension.
|
|
|
|
|
translation: A [K] array with the same number of dimensions as image,
|
|
|
|
|
containing the translation to apply in each dimension.
|
|
|
|
|
method: the resizing method to use; either a ``ResizeMethod`` instance or a
|
|
|
|
|
string. Available methods are: LINEAR, LANCZOS3, LANCZOS5, CUBIC.
|
|
|
|
|
antialias: Should an antialiasing filter be used when downsampling? Defaults
|
|
|
|
|
to ``True``. Has no effect when upsampling.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The scale and translated image.
|
|
|
|
|
"""
|
2021-09-30 12:36:47 -07:00
|
|
|
|
shape = core.canonicalize_shape(shape)
|
2020-09-08 13:51:19 -07:00
|
|
|
|
if len(shape) != image.ndim:
|
|
|
|
|
msg = ('shape must have length equal to the number of dimensions of x; '
|
|
|
|
|
f' {shape} vs {image.shape}')
|
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
if isinstance(method, str):
|
|
|
|
|
method = ResizeMethod.from_string(method)
|
|
|
|
|
if method == ResizeMethod.NEAREST:
|
|
|
|
|
# Nearest neighbor is currently special-cased for straight resize, so skip
|
|
|
|
|
# for now.
|
|
|
|
|
raise ValueError('Nearest neighbor resampling is not currently supported '
|
|
|
|
|
'for scale_and_translate.')
|
|
|
|
|
assert isinstance(method, ResizeMethod)
|
|
|
|
|
|
|
|
|
|
kernel = _kernels[method]
|
|
|
|
|
if not jnp.issubdtype(image.dtype, jnp.inexact):
|
|
|
|
|
image = lax.convert_element_type(image, jnp.result_type(image, jnp.float32))
|
|
|
|
|
if not jnp.issubdtype(scale.dtype, jnp.inexact):
|
|
|
|
|
scale = lax.convert_element_type(scale, jnp.result_type(scale, jnp.float32))
|
|
|
|
|
if not jnp.issubdtype(translation.dtype, jnp.inexact):
|
|
|
|
|
translation = lax.convert_element_type(
|
|
|
|
|
translation, jnp.result_type(translation, jnp.float32))
|
2020-09-21 16:20:17 -07:00
|
|
|
|
return _scale_and_translate(image, shape, spatial_dims, scale, translation,
|
|
|
|
|
kernel, antialias, precision)
|
2020-09-08 13:51:19 -07:00
|
|
|
|
|
|
|
|
|
|
2021-11-05 17:03:46 +02:00
|
|
|
|
def _resize_nearest(x, output_shape: core.Shape):
|
2020-09-08 13:51:19 -07:00
|
|
|
|
input_shape = x.shape
|
|
|
|
|
assert len(input_shape) == len(output_shape)
|
2021-11-05 17:03:46 +02:00
|
|
|
|
spatial_dims = tuple(i for i in range(len(input_shape))
|
|
|
|
|
if not core.symbolic_equal_dim(input_shape[i], output_shape[i]))
|
2020-09-08 13:51:19 -07:00
|
|
|
|
for d in spatial_dims:
|
|
|
|
|
m = input_shape[d]
|
|
|
|
|
n = output_shape[d]
|
2021-11-05 17:03:46 +02:00
|
|
|
|
offsets = (jnp.arange(n) + 0.5) * core.dimension_as_value(m) / core.dimension_as_value(n)
|
2021-11-18 10:23:53 +02:00
|
|
|
|
# TODO(b/206898375): this computation produces the wrong result on
|
|
|
|
|
# CPU and GPU when using float64. Use float32 until the bug is fixed.
|
|
|
|
|
offsets = jnp.floor(offsets.astype(np.float32)).astype(np.int32)
|
2020-09-08 13:51:19 -07:00
|
|
|
|
indices = [slice(None)] * len(input_shape)
|
2021-11-05 17:03:46 +02:00
|
|
|
|
indices[d] = offsets
|
2020-09-08 13:51:19 -07:00
|
|
|
|
x = x[tuple(indices)]
|
|
|
|
|
return x
|
2020-07-10 09:57:59 -04:00
|
|
|
|
|
|
|
|
|
|
2020-07-12 14:00:10 -04:00
|
|
|
|
@partial(jit, static_argnums=(1, 2, 3, 4))
|
2021-11-05 17:03:46 +02:00
|
|
|
|
def _resize(image, shape: core.Shape, method: Union[str, ResizeMethod],
|
2020-07-12 14:00:10 -04:00
|
|
|
|
antialias: bool, precision):
|
2020-07-10 10:32:13 -04:00
|
|
|
|
if len(shape) != image.ndim:
|
|
|
|
|
msg = ('shape must have length equal to the number of dimensions of x; '
|
|
|
|
|
f' {shape} vs {image.shape}')
|
|
|
|
|
raise ValueError(msg)
|
2020-07-13 20:27:12 -04:00
|
|
|
|
if isinstance(method, str):
|
2020-07-20 21:15:40 +01:00
|
|
|
|
method = ResizeMethod.from_string(method)
|
|
|
|
|
if method == ResizeMethod.NEAREST:
|
2020-07-13 20:27:12 -04:00
|
|
|
|
return _resize_nearest(image, shape)
|
2020-07-20 21:15:40 +01:00
|
|
|
|
assert isinstance(method, ResizeMethod)
|
|
|
|
|
kernel = _kernels[method]
|
2020-09-08 13:51:19 -07:00
|
|
|
|
|
2020-07-10 10:32:13 -04:00
|
|
|
|
if not jnp.issubdtype(image.dtype, jnp.inexact):
|
|
|
|
|
image = lax.convert_element_type(image, jnp.result_type(image, jnp.float32))
|
2020-09-21 16:20:17 -07:00
|
|
|
|
# Skip dimensions that have scale=1 and translation=0, this is only possible
|
|
|
|
|
# since all of the current resize methods (kernels) are interpolating, so the
|
|
|
|
|
# output = input under an identity warp.
|
2021-11-05 17:03:46 +02:00
|
|
|
|
spatial_dims = tuple(i for i in range(len(shape))
|
|
|
|
|
if not core.symbolic_equal_dim(image.shape[i], shape[i]))
|
|
|
|
|
scale = [1.0 if core.symbolic_equal_dim(shape[d], 0) else core.dimension_as_value(shape[d]) / core.dimension_as_value(image.shape[d])
|
2021-08-11 11:20:20 -07:00
|
|
|
|
for d in spatial_dims]
|
2020-09-21 16:20:17 -07:00
|
|
|
|
return _scale_and_translate(image, shape, spatial_dims,
|
|
|
|
|
scale, [0.] * len(spatial_dims), kernel,
|
2020-07-12 14:00:10 -04:00
|
|
|
|
antialias, precision)
|
2020-07-10 10:32:13 -04:00
|
|
|
|
|
2020-09-08 13:51:19 -07:00
|
|
|
|
|
2021-11-05 17:03:46 +02:00
|
|
|
|
def resize(image, shape: core.Shape, method: Union[str, ResizeMethod],
|
2020-07-12 14:00:10 -04:00
|
|
|
|
antialias: bool = True,
|
|
|
|
|
precision = lax.Precision.HIGHEST):
|
2020-07-10 09:57:59 -04:00
|
|
|
|
"""Image resize.
|
|
|
|
|
|
|
|
|
|
The ``method`` argument expects one of the following resize methods:
|
|
|
|
|
|
2020-07-13 20:27:12 -04:00
|
|
|
|
``ResizeMethod.NEAREST``, ``"nearest"``
|
|
|
|
|
`Nearest neighbor interpolation`_. The values of ``antialias`` and
|
|
|
|
|
``precision`` are ignored.
|
|
|
|
|
|
2020-07-10 09:57:59 -04:00
|
|
|
|
``ResizeMethod.LINEAR``, ``"linear"``, ``"bilinear"``, ``"trilinear"``, ``"triangle"``
|
|
|
|
|
`Linear interpolation`_. If ``antialias`` is ``True``, uses a triangular
|
|
|
|
|
filter when downsampling.
|
|
|
|
|
|
|
|
|
|
``ResizeMethod.CUBIC``, ``"cubic"``, ``"bicubic"``, ``"tricubic"``
|
|
|
|
|
`Cubic interpolation`_, using the Keys cubic kernel.
|
|
|
|
|
|
|
|
|
|
``ResizeMethod.LANCZOS3``, ``"lanczos3"``
|
|
|
|
|
`Lanczos resampling`_, using a kernel of radius 3.
|
|
|
|
|
|
|
|
|
|
``ResizeMethod.LANCZOS5``, ``"lanczos5"``
|
|
|
|
|
`Lanczos resampling`_, using a kernel of radius 5.
|
|
|
|
|
|
2020-07-13 20:27:12 -04:00
|
|
|
|
.. _Nearest neighbor interpolation: https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
|
2020-07-10 09:57:59 -04:00
|
|
|
|
.. _Linear interpolation: https://en.wikipedia.org/wiki/Bilinear_interpolation
|
|
|
|
|
.. _Cubic interpolation: https://en.wikipedia.org/wiki/Bicubic_interpolation
|
|
|
|
|
.. _Lanczos resampling: https://en.wikipedia.org/wiki/Lanczos_resampling
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
image: a JAX array.
|
|
|
|
|
shape: the output shape, as a sequence of integers with length equal to
|
|
|
|
|
the number of dimensions of `image`. Note that :func:`resize` does not
|
|
|
|
|
distinguish spatial dimensions from batch or channel dimensions, so this
|
|
|
|
|
includes all dimensions of the image. To represent a batch or a channel
|
|
|
|
|
dimension, simply leave that element of the shape unchanged.
|
|
|
|
|
method: the resizing method to use; either a ``ResizeMethod`` instance or a
|
|
|
|
|
string. Available methods are: LINEAR, LANCZOS3, LANCZOS5, CUBIC.
|
|
|
|
|
antialias: should an antialiasing filter be used when downsampling? Defaults
|
|
|
|
|
to ``True``. Has no effect when upsampling.
|
|
|
|
|
Returns:
|
|
|
|
|
The resized image.
|
|
|
|
|
"""
|
2021-09-30 12:36:47 -07:00
|
|
|
|
return _resize(image, core.canonicalize_shape(shape), method, antialias,
|
|
|
|
|
precision)
|