mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[shape_poly] Refactor arange and image_resize for shape polymorphism
Bug: 8367 Small refactoring to jax.image.resize to make it compatible with shape polymorphismin jax2tf. In the process added also support for jnp.arange([dim_poly]). Note that the underlying lax.iota already supported shape polymorphism.
This commit is contained in:
parent
c37c62cfc3
commit
75155f5eda
@ -44,7 +44,8 @@ def _fill_triangle_kernel(x):
|
||||
return jnp.maximum(0, 1 - jnp.abs(x))
|
||||
|
||||
|
||||
def compute_weight_mat(input_size: int, output_size: int, scale,
|
||||
def compute_weight_mat(input_size: core.DimSize,
|
||||
output_size: core.DimSize, scale,
|
||||
translation,
|
||||
kernel: Callable,
|
||||
antialias: bool):
|
||||
@ -70,13 +71,15 @@ def compute_weight_mat(input_size: int, output_size: int, scale,
|
||||
# Zero out weights where the sample location is completely outside the input
|
||||
# range.
|
||||
# Note sample_f has already had the 0.5 removed, hence the weird range below.
|
||||
input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5
|
||||
return jnp.where(
|
||||
jnp.logical_and(sample_f >= -0.5,
|
||||
sample_f <= input_size - 0.5)[jnp.newaxis, :], weights, 0)
|
||||
sample_f <= input_size_minus_0_5)[jnp.newaxis, :], weights, 0)
|
||||
|
||||
|
||||
def _scale_and_translate(x, output_shape, spatial_dims, scale, translation,
|
||||
kernel, antialias, precision):
|
||||
def _scale_and_translate(x, output_shape: core.Shape,
|
||||
spatial_dims: Sequence[int], scale, translation,
|
||||
kernel, antialias: bool, precision):
|
||||
input_shape = x.shape
|
||||
assert len(input_shape) == len(output_shape)
|
||||
assert len(spatial_dims) == len(scale)
|
||||
@ -134,7 +137,7 @@ _kernels = {
|
||||
|
||||
# scale and translation here are scalar elements of an np.array, what is the
|
||||
# correct type annotation?
|
||||
def scale_and_translate(image, shape: Sequence[int],
|
||||
def scale_and_translate(image, shape: core.Shape,
|
||||
spatial_dims: Sequence[int],
|
||||
scale, translation,
|
||||
method: Union[str, ResizeMethod],
|
||||
@ -221,22 +224,24 @@ def scale_and_translate(image, shape: Sequence[int],
|
||||
kernel, antialias, precision)
|
||||
|
||||
|
||||
def _resize_nearest(x, output_shape):
|
||||
def _resize_nearest(x, output_shape: core.Shape):
|
||||
input_shape = x.shape
|
||||
assert len(input_shape) == len(output_shape)
|
||||
spatial_dims, = np.nonzero(np.not_equal(input_shape, output_shape))
|
||||
spatial_dims = tuple(i for i in range(len(input_shape))
|
||||
if not core.symbolic_equal_dim(input_shape[i], output_shape[i]))
|
||||
for d in spatial_dims:
|
||||
m = input_shape[d]
|
||||
n = output_shape[d]
|
||||
offsets = (np.arange(n) + 0.5) * m / n
|
||||
offsets = (jnp.arange(n) + 0.5) * core.dimension_as_value(m) / core.dimension_as_value(n)
|
||||
offsets = jnp.floor(offsets).astype(np.int32)
|
||||
indices = [slice(None)] * len(input_shape)
|
||||
indices[d] = np.floor(offsets).astype(np.int32)
|
||||
indices[d] = offsets
|
||||
x = x[tuple(indices)]
|
||||
return x
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(1, 2, 3, 4))
|
||||
def _resize(image, shape: Sequence[int], method: Union[str, ResizeMethod],
|
||||
def _resize(image, shape: core.Shape, method: Union[str, ResizeMethod],
|
||||
antialias: bool, precision):
|
||||
if len(shape) != image.ndim:
|
||||
msg = ('shape must have length equal to the number of dimensions of x; '
|
||||
@ -254,15 +259,16 @@ def _resize(image, shape: Sequence[int], method: Union[str, ResizeMethod],
|
||||
# Skip dimensions that have scale=1 and translation=0, this is only possible
|
||||
# since all of the current resize methods (kernels) are interpolating, so the
|
||||
# output = input under an identity warp.
|
||||
spatial_dims = tuple(np.nonzero(np.not_equal(image.shape, shape))[0])
|
||||
scale = [1.0 if shape[d] == 0 else float(shape[d]) / image.shape[d]
|
||||
spatial_dims = tuple(i for i in range(len(shape))
|
||||
if not core.symbolic_equal_dim(image.shape[i], shape[i]))
|
||||
scale = [1.0 if core.symbolic_equal_dim(shape[d], 0) else core.dimension_as_value(shape[d]) / core.dimension_as_value(image.shape[d])
|
||||
for d in spatial_dims]
|
||||
return _scale_and_translate(image, shape, spatial_dims,
|
||||
scale, [0.] * len(spatial_dims), kernel,
|
||||
antialias, precision)
|
||||
|
||||
|
||||
def resize(image, shape: Sequence[int], method: Union[str, ResizeMethod],
|
||||
def resize(image, shape: core.Shape, method: Union[str, ResizeMethod],
|
||||
antialias: bool = True,
|
||||
precision = lax.Precision.HIGHEST):
|
||||
"""Image resize.
|
||||
|
@ -3755,20 +3755,22 @@ def identity(n, dtype=None):
|
||||
|
||||
|
||||
@_wraps(np.arange)
|
||||
def arange(start, stop=None, step=None, dtype=None):
|
||||
def arange(start: core.DimSize, stop: Optional[core.DimSize]=None,
|
||||
step: Optional[core.DimSize]=None, dtype=None):
|
||||
lax._check_user_dtype_supported(dtype, "arange")
|
||||
require = partial(core.concrete_or_error, _np_asarray)
|
||||
msg = "It arose in jax.numpy.arange argument `{}`.".format
|
||||
dtype = dtype or _dtype(start, *(x for x in [stop, step] if x is not None))
|
||||
if stop is None and step is None:
|
||||
start = require(start, msg("stop"))
|
||||
dtype = dtype or _dtype(start)
|
||||
return lax.iota(dtype, np.ceil(start).astype(int)) # avoids materializing
|
||||
if not core.is_dim_size(start):
|
||||
start = require(start, msg("stop"))
|
||||
start = np.ceil(start).astype(int)
|
||||
|
||||
return lax.iota(dtype, start)
|
||||
else:
|
||||
start = require(start, msg("start"))
|
||||
stop = None if stop is None else require(stop, msg("stop"))
|
||||
step = None if step is None else require(step, msg("step"))
|
||||
if dtype is None:
|
||||
dtype = _dtype(start, *(x for x in [stop, step] if x is not None))
|
||||
return array(np.arange(start, stop=stop, step=step, dtype=dtype))
|
||||
|
||||
|
||||
|
@ -1376,6 +1376,14 @@ def _dim_handler_and_canonical(*dlist: DimSize) -> Tuple[DimensionHandler, Tuple
|
||||
raise ValueError(msg)
|
||||
return next(iter(special_handlers), _dimension_handler_int), tuple(canonical)
|
||||
|
||||
def is_dim_size(v: Any) -> bool:
|
||||
"""Checks if a value is a DimSize."""
|
||||
try:
|
||||
handler, _ = _dim_handler_and_canonical(v)
|
||||
return True
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
def is_constant_dim(d: DimSize) -> bool:
|
||||
handler, ds = _dim_handler_and_canonical(d)
|
||||
return handler.is_constant(*ds)
|
||||
|
@ -41,10 +41,12 @@ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Uni
|
||||
|
||||
import jax
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src import dtypes
|
||||
import opt_einsum
|
||||
from jax import config
|
||||
from jax import core
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
DimSize = core.DimSize
|
||||
@ -448,6 +450,7 @@ class DimensionHandlerPoly(core.DimensionHandler):
|
||||
return _dim_as_value(d)
|
||||
|
||||
core._SPECIAL_DIMENSION_HANDLERS[_DimPolynomial] = DimensionHandlerPoly()
|
||||
dtypes.python_scalar_dtypes[_DimPolynomial] = dtypes.python_scalar_dtypes[int]
|
||||
|
||||
def _einsum_contract_path(*operands, **kwargs):
|
||||
"""Like opt_einsum.contract_path, with support for DimPolynomial shapes.
|
||||
|
@ -1149,6 +1149,28 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
jax.grad(lambda x: jnp.sum(jnp.sum(x, axis=0, keepdims=0) + x)),
|
||||
[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0]),
|
||||
_make_harness("arange", "start",
|
||||
lambda op: jnp.arange(2 * op.shape[0], dtype=_f32),
|
||||
[RandArg((3,), _f32)],
|
||||
poly_axes=[0],
|
||||
enable_and_diable_xla=True),
|
||||
_make_harness("arange", "start_no_dtype",
|
||||
lambda op: jnp.arange(op.shape[0]),
|
||||
[RandArg((3,), _f32)],
|
||||
poly_axes=[0],
|
||||
enable_and_diable_xla=True),
|
||||
# Reduce the poly dimension
|
||||
_make_harness("argmax", "0",
|
||||
lambda op: lax.argmax(op, axis=0, index_dtype=np.int32),
|
||||
[RandArg((3, 4, 5), _f32)],
|
||||
poly_axes=[0],
|
||||
enable_and_diable_xla=True),
|
||||
# Reduce the non-poly dimension
|
||||
_make_harness("argmax", "1",
|
||||
lambda op: lax.argmax(op, axis=1, index_dtype=np.int32),
|
||||
[RandArg((3, 4, 5), _f32)],
|
||||
poly_axes=[0],
|
||||
enable_and_diable_xla=True),
|
||||
[
|
||||
_make_harness("average",
|
||||
f"axis={axis}_weights=None",
|
||||
@ -1165,18 +1187,6 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
poly_axes=[0, 0])
|
||||
for axis in [None, 0, 1]
|
||||
],
|
||||
# Reduce the poly dimension
|
||||
_make_harness("argmax", "0",
|
||||
lambda op: lax.argmax(op, axis=0, index_dtype=np.int32),
|
||||
[RandArg((3, 4, 5), _f32)],
|
||||
poly_axes=[0],
|
||||
enable_and_diable_xla=True),
|
||||
# Reduce the non-poly dimension
|
||||
_make_harness("argmax", "1",
|
||||
lambda op: lax.argmax(op, axis=1, index_dtype=np.int32),
|
||||
[RandArg((3, 4, 5), _f32)],
|
||||
poly_axes=[0],
|
||||
enable_and_diable_xla=True),
|
||||
_make_harness("broadcast_to", "",
|
||||
lambda x: jnp.broadcast_to(x, [x.shape[0], x.shape[0], 4]),
|
||||
[RandArg((3, 4), _f32)],
|
||||
@ -1391,8 +1401,23 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0], enable_and_diable_xla=True,
|
||||
expect_error=(IndexError, "Array slice indices must have static")),
|
||||
_make_harness("index_in_dim", "idx=pos",
|
||||
lambda x: lax.index_in_dim(x, 0, axis=0, keepdims=False),
|
||||
_make_harness("image_resize", "linear_0",
|
||||
lambda x: jax.image.resize(x, (x.shape[0], 2 * x.shape[1], 2 * x.shape[2], x.shape[3]),
|
||||
method="linear"),
|
||||
[RandArg((3, 16, 32, 3), _f32)],
|
||||
poly_axes=[(1, 2)]),
|
||||
_make_harness("image_resize", "linear_to_fixed_dim",
|
||||
lambda x: jax.image.resize(x, (x.shape[0], 64, 64, x.shape[3]),
|
||||
method="linear"),
|
||||
[RandArg((3, 16, 32, 3), _f32)],
|
||||
poly_axes=[(1, 2)]),
|
||||
_make_harness("image_resize", "nearest_0",
|
||||
lambda x: jax.image.resize(x, (x.shape[0], 2 * x.shape[1], 2 * x.shape[2], x.shape[3]),
|
||||
method="nearest"),
|
||||
[RandArg((3, 5, 7, 3), _f32)],
|
||||
poly_axes=[(1, 2)]),
|
||||
_make_harness("index_in_dim", "0",
|
||||
lambda x: lax.index_in_dim(x, -1, axis=0, keepdims=False),
|
||||
[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0]),
|
||||
_make_harness("index_in_dim", "idx=neg",
|
||||
@ -1734,7 +1759,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
# to parameterized below.
|
||||
@primitive_harness.parameterized(
|
||||
_flatten_harnesses(_POLY_SHAPE_TEST_HARNESSES),
|
||||
#one_containing="getitem_op=poly_idx=slice-None-1"
|
||||
#one_containing="arange_start_no_dtype"
|
||||
)
|
||||
def test_prim(self, harness: Harness):
|
||||
args = harness.dyn_args_maker(self.rng())
|
||||
|
Loading…
x
Reference in New Issue
Block a user