From 75155f5eda4b298cbb9ac28a9ff0a2bd0176c24b Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 5 Nov 2021 17:03:46 +0200 Subject: [PATCH] [shape_poly] Refactor arange and image_resize for shape polymorphism Bug: 8367 Small refactoring to jax.image.resize to make it compatible with shape polymorphismin jax2tf. In the process added also support for jnp.arange([dim_poly]). Note that the underlying lax.iota already supported shape polymorphism. --- jax/_src/image/scale.py | 32 ++++++----- jax/_src/numpy/lax_numpy.py | 14 +++-- jax/core.py | 8 +++ jax/experimental/jax2tf/shape_poly.py | 3 + .../jax2tf/tests/shape_poly_test.py | 55 ++++++++++++++----- 5 files changed, 78 insertions(+), 34 deletions(-) diff --git a/jax/_src/image/scale.py b/jax/_src/image/scale.py index fa86e4b1c..491c51287 100644 --- a/jax/_src/image/scale.py +++ b/jax/_src/image/scale.py @@ -44,7 +44,8 @@ def _fill_triangle_kernel(x): return jnp.maximum(0, 1 - jnp.abs(x)) -def compute_weight_mat(input_size: int, output_size: int, scale, +def compute_weight_mat(input_size: core.DimSize, + output_size: core.DimSize, scale, translation, kernel: Callable, antialias: bool): @@ -70,13 +71,15 @@ def compute_weight_mat(input_size: int, output_size: int, scale, # Zero out weights where the sample location is completely outside the input # range. # Note sample_f has already had the 0.5 removed, hence the weird range below. + input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5 return jnp.where( jnp.logical_and(sample_f >= -0.5, - sample_f <= input_size - 0.5)[jnp.newaxis, :], weights, 0) + sample_f <= input_size_minus_0_5)[jnp.newaxis, :], weights, 0) -def _scale_and_translate(x, output_shape, spatial_dims, scale, translation, - kernel, antialias, precision): +def _scale_and_translate(x, output_shape: core.Shape, + spatial_dims: Sequence[int], scale, translation, + kernel, antialias: bool, precision): input_shape = x.shape assert len(input_shape) == len(output_shape) assert len(spatial_dims) == len(scale) @@ -134,7 +137,7 @@ _kernels = { # scale and translation here are scalar elements of an np.array, what is the # correct type annotation? -def scale_and_translate(image, shape: Sequence[int], +def scale_and_translate(image, shape: core.Shape, spatial_dims: Sequence[int], scale, translation, method: Union[str, ResizeMethod], @@ -221,22 +224,24 @@ def scale_and_translate(image, shape: Sequence[int], kernel, antialias, precision) -def _resize_nearest(x, output_shape): +def _resize_nearest(x, output_shape: core.Shape): input_shape = x.shape assert len(input_shape) == len(output_shape) - spatial_dims, = np.nonzero(np.not_equal(input_shape, output_shape)) + spatial_dims = tuple(i for i in range(len(input_shape)) + if not core.symbolic_equal_dim(input_shape[i], output_shape[i])) for d in spatial_dims: m = input_shape[d] n = output_shape[d] - offsets = (np.arange(n) + 0.5) * m / n + offsets = (jnp.arange(n) + 0.5) * core.dimension_as_value(m) / core.dimension_as_value(n) + offsets = jnp.floor(offsets).astype(np.int32) indices = [slice(None)] * len(input_shape) - indices[d] = np.floor(offsets).astype(np.int32) + indices[d] = offsets x = x[tuple(indices)] return x @partial(jit, static_argnums=(1, 2, 3, 4)) -def _resize(image, shape: Sequence[int], method: Union[str, ResizeMethod], +def _resize(image, shape: core.Shape, method: Union[str, ResizeMethod], antialias: bool, precision): if len(shape) != image.ndim: msg = ('shape must have length equal to the number of dimensions of x; ' @@ -254,15 +259,16 @@ def _resize(image, shape: Sequence[int], method: Union[str, ResizeMethod], # Skip dimensions that have scale=1 and translation=0, this is only possible # since all of the current resize methods (kernels) are interpolating, so the # output = input under an identity warp. - spatial_dims = tuple(np.nonzero(np.not_equal(image.shape, shape))[0]) - scale = [1.0 if shape[d] == 0 else float(shape[d]) / image.shape[d] + spatial_dims = tuple(i for i in range(len(shape)) + if not core.symbolic_equal_dim(image.shape[i], shape[i])) + scale = [1.0 if core.symbolic_equal_dim(shape[d], 0) else core.dimension_as_value(shape[d]) / core.dimension_as_value(image.shape[d]) for d in spatial_dims] return _scale_and_translate(image, shape, spatial_dims, scale, [0.] * len(spatial_dims), kernel, antialias, precision) -def resize(image, shape: Sequence[int], method: Union[str, ResizeMethod], +def resize(image, shape: core.Shape, method: Union[str, ResizeMethod], antialias: bool = True, precision = lax.Precision.HIGHEST): """Image resize. diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 8d080b734..105e9b9db 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3755,20 +3755,22 @@ def identity(n, dtype=None): @_wraps(np.arange) -def arange(start, stop=None, step=None, dtype=None): +def arange(start: core.DimSize, stop: Optional[core.DimSize]=None, + step: Optional[core.DimSize]=None, dtype=None): lax._check_user_dtype_supported(dtype, "arange") require = partial(core.concrete_or_error, _np_asarray) msg = "It arose in jax.numpy.arange argument `{}`.".format + dtype = dtype or _dtype(start, *(x for x in [stop, step] if x is not None)) if stop is None and step is None: - start = require(start, msg("stop")) - dtype = dtype or _dtype(start) - return lax.iota(dtype, np.ceil(start).astype(int)) # avoids materializing + if not core.is_dim_size(start): + start = require(start, msg("stop")) + start = np.ceil(start).astype(int) + + return lax.iota(dtype, start) else: start = require(start, msg("start")) stop = None if stop is None else require(stop, msg("stop")) step = None if step is None else require(step, msg("step")) - if dtype is None: - dtype = _dtype(start, *(x for x in [stop, step] if x is not None)) return array(np.arange(start, stop=stop, step=step, dtype=dtype)) diff --git a/jax/core.py b/jax/core.py index e6264d004..c143d02b6 100644 --- a/jax/core.py +++ b/jax/core.py @@ -1376,6 +1376,14 @@ def _dim_handler_and_canonical(*dlist: DimSize) -> Tuple[DimensionHandler, Tuple raise ValueError(msg) return next(iter(special_handlers), _dimension_handler_int), tuple(canonical) +def is_dim_size(v: Any) -> bool: + """Checks if a value is a DimSize.""" + try: + handler, _ = _dim_handler_and_canonical(v) + return True + except TypeError: + return False + def is_constant_dim(d: DimSize) -> bool: handler, ds = _dim_handler_and_canonical(d) return handler.is_constant(*ds) diff --git a/jax/experimental/jax2tf/shape_poly.py b/jax/experimental/jax2tf/shape_poly.py index 5cde46667..8e74a17ec 100644 --- a/jax/experimental/jax2tf/shape_poly.py +++ b/jax/experimental/jax2tf/shape_poly.py @@ -41,10 +41,12 @@ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Uni import jax from jax._src.numpy import lax_numpy +from jax._src import dtypes import opt_einsum from jax import config from jax import core + import numpy as np DimSize = core.DimSize @@ -448,6 +450,7 @@ class DimensionHandlerPoly(core.DimensionHandler): return _dim_as_value(d) core._SPECIAL_DIMENSION_HANDLERS[_DimPolynomial] = DimensionHandlerPoly() +dtypes.python_scalar_dtypes[_DimPolynomial] = dtypes.python_scalar_dtypes[int] def _einsum_contract_path(*operands, **kwargs): """Like opt_einsum.contract_path, with support for DimPolynomial shapes. diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 14d8dc329..0c184a567 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -1149,6 +1149,28 @@ _POLY_SHAPE_TEST_HARNESSES = [ jax.grad(lambda x: jnp.sum(jnp.sum(x, axis=0, keepdims=0) + x)), [RandArg((3, 4), _f32)], poly_axes=[0]), + _make_harness("arange", "start", + lambda op: jnp.arange(2 * op.shape[0], dtype=_f32), + [RandArg((3,), _f32)], + poly_axes=[0], + enable_and_diable_xla=True), + _make_harness("arange", "start_no_dtype", + lambda op: jnp.arange(op.shape[0]), + [RandArg((3,), _f32)], + poly_axes=[0], + enable_and_diable_xla=True), + # Reduce the poly dimension + _make_harness("argmax", "0", + lambda op: lax.argmax(op, axis=0, index_dtype=np.int32), + [RandArg((3, 4, 5), _f32)], + poly_axes=[0], + enable_and_diable_xla=True), + # Reduce the non-poly dimension + _make_harness("argmax", "1", + lambda op: lax.argmax(op, axis=1, index_dtype=np.int32), + [RandArg((3, 4, 5), _f32)], + poly_axes=[0], + enable_and_diable_xla=True), [ _make_harness("average", f"axis={axis}_weights=None", @@ -1165,18 +1187,6 @@ _POLY_SHAPE_TEST_HARNESSES = [ poly_axes=[0, 0]) for axis in [None, 0, 1] ], - # Reduce the poly dimension - _make_harness("argmax", "0", - lambda op: lax.argmax(op, axis=0, index_dtype=np.int32), - [RandArg((3, 4, 5), _f32)], - poly_axes=[0], - enable_and_diable_xla=True), - # Reduce the non-poly dimension - _make_harness("argmax", "1", - lambda op: lax.argmax(op, axis=1, index_dtype=np.int32), - [RandArg((3, 4, 5), _f32)], - poly_axes=[0], - enable_and_diable_xla=True), _make_harness("broadcast_to", "", lambda x: jnp.broadcast_to(x, [x.shape[0], x.shape[0], 4]), [RandArg((3, 4), _f32)], @@ -1391,8 +1401,23 @@ _POLY_SHAPE_TEST_HARNESSES = [ [RandArg((3, 4), _f32)], poly_axes=[0], enable_and_diable_xla=True, expect_error=(IndexError, "Array slice indices must have static")), - _make_harness("index_in_dim", "idx=pos", - lambda x: lax.index_in_dim(x, 0, axis=0, keepdims=False), + _make_harness("image_resize", "linear_0", + lambda x: jax.image.resize(x, (x.shape[0], 2 * x.shape[1], 2 * x.shape[2], x.shape[3]), + method="linear"), + [RandArg((3, 16, 32, 3), _f32)], + poly_axes=[(1, 2)]), + _make_harness("image_resize", "linear_to_fixed_dim", + lambda x: jax.image.resize(x, (x.shape[0], 64, 64, x.shape[3]), + method="linear"), + [RandArg((3, 16, 32, 3), _f32)], + poly_axes=[(1, 2)]), + _make_harness("image_resize", "nearest_0", + lambda x: jax.image.resize(x, (x.shape[0], 2 * x.shape[1], 2 * x.shape[2], x.shape[3]), + method="nearest"), + [RandArg((3, 5, 7, 3), _f32)], + poly_axes=[(1, 2)]), + _make_harness("index_in_dim", "0", + lambda x: lax.index_in_dim(x, -1, axis=0, keepdims=False), [RandArg((3, 4), _f32)], poly_axes=[0]), _make_harness("index_in_dim", "idx=neg", @@ -1734,7 +1759,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase): # to parameterized below. @primitive_harness.parameterized( _flatten_harnesses(_POLY_SHAPE_TEST_HARNESSES), - #one_containing="getitem_op=poly_idx=slice-None-1" + #one_containing="arange_start_no_dtype" ) def test_prim(self, harness: Harness): args = harness.dyn_args_maker(self.rng())