[shape_poly] Refactor arange and image_resize for shape polymorphism

Bug: 8367

Small refactoring to jax.image.resize to make it compatible with
shape polymorphismin jax2tf. In the process added also support for
jnp.arange([dim_poly]). Note that the underlying lax.iota already
supported shape polymorphism.
This commit is contained in:
George Necula 2021-11-05 17:03:46 +02:00
parent c37c62cfc3
commit 75155f5eda
5 changed files with 78 additions and 34 deletions

View File

@ -44,7 +44,8 @@ def _fill_triangle_kernel(x):
return jnp.maximum(0, 1 - jnp.abs(x))
def compute_weight_mat(input_size: int, output_size: int, scale,
def compute_weight_mat(input_size: core.DimSize,
output_size: core.DimSize, scale,
translation,
kernel: Callable,
antialias: bool):
@ -70,13 +71,15 @@ def compute_weight_mat(input_size: int, output_size: int, scale,
# Zero out weights where the sample location is completely outside the input
# range.
# Note sample_f has already had the 0.5 removed, hence the weird range below.
input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5
return jnp.where(
jnp.logical_and(sample_f >= -0.5,
sample_f <= input_size - 0.5)[jnp.newaxis, :], weights, 0)
sample_f <= input_size_minus_0_5)[jnp.newaxis, :], weights, 0)
def _scale_and_translate(x, output_shape, spatial_dims, scale, translation,
kernel, antialias, precision):
def _scale_and_translate(x, output_shape: core.Shape,
spatial_dims: Sequence[int], scale, translation,
kernel, antialias: bool, precision):
input_shape = x.shape
assert len(input_shape) == len(output_shape)
assert len(spatial_dims) == len(scale)
@ -134,7 +137,7 @@ _kernels = {
# scale and translation here are scalar elements of an np.array, what is the
# correct type annotation?
def scale_and_translate(image, shape: Sequence[int],
def scale_and_translate(image, shape: core.Shape,
spatial_dims: Sequence[int],
scale, translation,
method: Union[str, ResizeMethod],
@ -221,22 +224,24 @@ def scale_and_translate(image, shape: Sequence[int],
kernel, antialias, precision)
def _resize_nearest(x, output_shape):
def _resize_nearest(x, output_shape: core.Shape):
input_shape = x.shape
assert len(input_shape) == len(output_shape)
spatial_dims, = np.nonzero(np.not_equal(input_shape, output_shape))
spatial_dims = tuple(i for i in range(len(input_shape))
if not core.symbolic_equal_dim(input_shape[i], output_shape[i]))
for d in spatial_dims:
m = input_shape[d]
n = output_shape[d]
offsets = (np.arange(n) + 0.5) * m / n
offsets = (jnp.arange(n) + 0.5) * core.dimension_as_value(m) / core.dimension_as_value(n)
offsets = jnp.floor(offsets).astype(np.int32)
indices = [slice(None)] * len(input_shape)
indices[d] = np.floor(offsets).astype(np.int32)
indices[d] = offsets
x = x[tuple(indices)]
return x
@partial(jit, static_argnums=(1, 2, 3, 4))
def _resize(image, shape: Sequence[int], method: Union[str, ResizeMethod],
def _resize(image, shape: core.Shape, method: Union[str, ResizeMethod],
antialias: bool, precision):
if len(shape) != image.ndim:
msg = ('shape must have length equal to the number of dimensions of x; '
@ -254,15 +259,16 @@ def _resize(image, shape: Sequence[int], method: Union[str, ResizeMethod],
# Skip dimensions that have scale=1 and translation=0, this is only possible
# since all of the current resize methods (kernels) are interpolating, so the
# output = input under an identity warp.
spatial_dims = tuple(np.nonzero(np.not_equal(image.shape, shape))[0])
scale = [1.0 if shape[d] == 0 else float(shape[d]) / image.shape[d]
spatial_dims = tuple(i for i in range(len(shape))
if not core.symbolic_equal_dim(image.shape[i], shape[i]))
scale = [1.0 if core.symbolic_equal_dim(shape[d], 0) else core.dimension_as_value(shape[d]) / core.dimension_as_value(image.shape[d])
for d in spatial_dims]
return _scale_and_translate(image, shape, spatial_dims,
scale, [0.] * len(spatial_dims), kernel,
antialias, precision)
def resize(image, shape: Sequence[int], method: Union[str, ResizeMethod],
def resize(image, shape: core.Shape, method: Union[str, ResizeMethod],
antialias: bool = True,
precision = lax.Precision.HIGHEST):
"""Image resize.

View File

@ -3755,20 +3755,22 @@ def identity(n, dtype=None):
@_wraps(np.arange)
def arange(start, stop=None, step=None, dtype=None):
def arange(start: core.DimSize, stop: Optional[core.DimSize]=None,
step: Optional[core.DimSize]=None, dtype=None):
lax._check_user_dtype_supported(dtype, "arange")
require = partial(core.concrete_or_error, _np_asarray)
msg = "It arose in jax.numpy.arange argument `{}`.".format
dtype = dtype or _dtype(start, *(x for x in [stop, step] if x is not None))
if stop is None and step is None:
start = require(start, msg("stop"))
dtype = dtype or _dtype(start)
return lax.iota(dtype, np.ceil(start).astype(int)) # avoids materializing
if not core.is_dim_size(start):
start = require(start, msg("stop"))
start = np.ceil(start).astype(int)
return lax.iota(dtype, start)
else:
start = require(start, msg("start"))
stop = None if stop is None else require(stop, msg("stop"))
step = None if step is None else require(step, msg("step"))
if dtype is None:
dtype = _dtype(start, *(x for x in [stop, step] if x is not None))
return array(np.arange(start, stop=stop, step=step, dtype=dtype))

View File

@ -1376,6 +1376,14 @@ def _dim_handler_and_canonical(*dlist: DimSize) -> Tuple[DimensionHandler, Tuple
raise ValueError(msg)
return next(iter(special_handlers), _dimension_handler_int), tuple(canonical)
def is_dim_size(v: Any) -> bool:
"""Checks if a value is a DimSize."""
try:
handler, _ = _dim_handler_and_canonical(v)
return True
except TypeError:
return False
def is_constant_dim(d: DimSize) -> bool:
handler, ds = _dim_handler_and_canonical(d)
return handler.is_constant(*ds)

View File

@ -41,10 +41,12 @@ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Uni
import jax
from jax._src.numpy import lax_numpy
from jax._src import dtypes
import opt_einsum
from jax import config
from jax import core
import numpy as np
DimSize = core.DimSize
@ -448,6 +450,7 @@ class DimensionHandlerPoly(core.DimensionHandler):
return _dim_as_value(d)
core._SPECIAL_DIMENSION_HANDLERS[_DimPolynomial] = DimensionHandlerPoly()
dtypes.python_scalar_dtypes[_DimPolynomial] = dtypes.python_scalar_dtypes[int]
def _einsum_contract_path(*operands, **kwargs):
"""Like opt_einsum.contract_path, with support for DimPolynomial shapes.

View File

@ -1149,6 +1149,28 @@ _POLY_SHAPE_TEST_HARNESSES = [
jax.grad(lambda x: jnp.sum(jnp.sum(x, axis=0, keepdims=0) + x)),
[RandArg((3, 4), _f32)],
poly_axes=[0]),
_make_harness("arange", "start",
lambda op: jnp.arange(2 * op.shape[0], dtype=_f32),
[RandArg((3,), _f32)],
poly_axes=[0],
enable_and_diable_xla=True),
_make_harness("arange", "start_no_dtype",
lambda op: jnp.arange(op.shape[0]),
[RandArg((3,), _f32)],
poly_axes=[0],
enable_and_diable_xla=True),
# Reduce the poly dimension
_make_harness("argmax", "0",
lambda op: lax.argmax(op, axis=0, index_dtype=np.int32),
[RandArg((3, 4, 5), _f32)],
poly_axes=[0],
enable_and_diable_xla=True),
# Reduce the non-poly dimension
_make_harness("argmax", "1",
lambda op: lax.argmax(op, axis=1, index_dtype=np.int32),
[RandArg((3, 4, 5), _f32)],
poly_axes=[0],
enable_and_diable_xla=True),
[
_make_harness("average",
f"axis={axis}_weights=None",
@ -1165,18 +1187,6 @@ _POLY_SHAPE_TEST_HARNESSES = [
poly_axes=[0, 0])
for axis in [None, 0, 1]
],
# Reduce the poly dimension
_make_harness("argmax", "0",
lambda op: lax.argmax(op, axis=0, index_dtype=np.int32),
[RandArg((3, 4, 5), _f32)],
poly_axes=[0],
enable_and_diable_xla=True),
# Reduce the non-poly dimension
_make_harness("argmax", "1",
lambda op: lax.argmax(op, axis=1, index_dtype=np.int32),
[RandArg((3, 4, 5), _f32)],
poly_axes=[0],
enable_and_diable_xla=True),
_make_harness("broadcast_to", "",
lambda x: jnp.broadcast_to(x, [x.shape[0], x.shape[0], 4]),
[RandArg((3, 4), _f32)],
@ -1391,8 +1401,23 @@ _POLY_SHAPE_TEST_HARNESSES = [
[RandArg((3, 4), _f32)],
poly_axes=[0], enable_and_diable_xla=True,
expect_error=(IndexError, "Array slice indices must have static")),
_make_harness("index_in_dim", "idx=pos",
lambda x: lax.index_in_dim(x, 0, axis=0, keepdims=False),
_make_harness("image_resize", "linear_0",
lambda x: jax.image.resize(x, (x.shape[0], 2 * x.shape[1], 2 * x.shape[2], x.shape[3]),
method="linear"),
[RandArg((3, 16, 32, 3), _f32)],
poly_axes=[(1, 2)]),
_make_harness("image_resize", "linear_to_fixed_dim",
lambda x: jax.image.resize(x, (x.shape[0], 64, 64, x.shape[3]),
method="linear"),
[RandArg((3, 16, 32, 3), _f32)],
poly_axes=[(1, 2)]),
_make_harness("image_resize", "nearest_0",
lambda x: jax.image.resize(x, (x.shape[0], 2 * x.shape[1], 2 * x.shape[2], x.shape[3]),
method="nearest"),
[RandArg((3, 5, 7, 3), _f32)],
poly_axes=[(1, 2)]),
_make_harness("index_in_dim", "0",
lambda x: lax.index_in_dim(x, -1, axis=0, keepdims=False),
[RandArg((3, 4), _f32)],
poly_axes=[0]),
_make_harness("index_in_dim", "idx=neg",
@ -1734,7 +1759,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
# to parameterized below.
@primitive_harness.parameterized(
_flatten_harnesses(_POLY_SHAPE_TEST_HARNESSES),
#one_containing="getitem_op=poly_idx=slice-None-1"
#one_containing="arange_start_no_dtype"
)
def test_prim(self, harness: Harness):
args = harness.dyn_args_maker(self.rng())