[jax2tf] lax.reduce_window (enable_xla=False): bug fix and improvements.

* Fixes https://github.com/google/jax/issues/11804: we only supported `lax.reduce_window` without batch and channel dimensions, which is wrong. This is supported, and in fact something that most users use (this case is actually not explained in the [operational semantics for XLA::ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow)). I have fixed this and clarified a number of test cases with batch and channel dimensions.

* Also, @sdenton4 gave a failing example in a Colab using polymorphic dimensions. I've added this as a test case to make sure it works now.

* Adds support for explicit padding using the existing padding logic from convolutions.

* Fixes https://github.com/google/jax/issues/11874: we were not handling SAME padding for `lax.add` correctly, since we used `tf.nn.avg_pool`, which does not include non-padding tokens (see issue for more details). I resolved it by adding manual padding and added some additional tests for this.

* Ensures we call eval_shape on a shape containing polynomials before calling a TF op.

* Fixes https://github.com/google/jax/issues/11929#issuecomment-1216261697: we weren't running any of the shape_poly_test.py tests for `enable_xla=False`.

PiperOrigin-RevId: 467879449
This commit is contained in:
Marc van Zee 2022-08-16 03:01:17 -07:00 committed by jax authors
parent da168a100a
commit df5f3c556c
7 changed files with 214 additions and 87 deletions

View File

@ -150,20 +150,29 @@ function `lax.reduce_window_p` with the following conditions:
We provide partial support for all these ops, with the following limitations:
* `computation` should be one of `lax.min`, `lax.max`, or `lax.add`.
* For `lax.min` and `lax.max`, dtypes `np.bool`, `np.uint32`, `np.uint64`,
`np.complex64`, and `np.complex128` are not supported.
* Additionally, for `lax.min`, dtypes `np.uint8` and `np.uint16` are not
supported.
* For `lax.add`, only dtypes `np.float16`, `np.float32`, and `np.float64` are
supported.
* We support at most 2 spatial dimension.
* Base dilations other than `(1,) * len(operand)` are not supported.
* `padding` should either be `VALID` or `SAME`.
* Using `lax.add` on TPU may give very large deviations. This is due to the way
the conversion is implemented (first take the average over the window and then
multiply by window size). This gives large deviations on TPU due to the fact
that it uses `bfloat16` for computations.
* `computation` should be one of `lax.min`, `lax.max`, or `lax.add`.
* For `lax.min` and `lax.max`, dtypes `np.bool`, `np.uint32`, `np.uint64`,
`np.complex64`, and `np.complex128` are not supported.
* Additionally, for `lax.min`, dtypes `np.uint8` and `np.uint16` are not
supported.
* For `lax.add`, only dtypes `np.float16`, `np.float32`, and `np.float64` are
supported.
* We support at most 2 spatial dimension.
* Base dilations other than `(1,) * len(operand)` are not supported.
* `padding` should either be `VALID` or `SAME`.
* We compute `lax.reduce_window_sum_p` by calling `tf.nn.avg_pool` (through
`tf.nn.pool`), and then multiplying the result by
`np.prod(window_dimensions)`. If you are using an NN library that implements
`avg_pool` using `lax.reduce_window` (such as Flax's
[pooling.py](https://github.com/google/flax/blob/main/flax/linen/pooling.py)),
this is usually implemented by dividing the result with
`np.prod(window_dimensions)`. So when converting this function, the
resulting computation for `avg_pool` is `(tf.nn.avg_pool(xs) *
np.prod(window)) / np.prod(window)`. This is redundant and can be optimized.
* Using `lax.add` on TPU may give very large deviations. This is due to the
way the conversion is implemented (first take the average over the window
and then multiply by window size). This gives large deviations on TPU due to
the fact that it uses `bfloat16` for computations.
We implement all reductions using the Tensorflow function
[tf.nn.pool](https://www.tensorflow.org/api_docs/python/tf/nn/pool).

View File

@ -120,10 +120,11 @@ def pads_to_padtype(in_shape, window_shape, window_strides, padding) -> str:
def _pad_spatial_dims(x, x_shape, padding):
"""Pads `x` using `padding`, which specifies padding for the spatial dimensions."""
# Add empty padding for batch and feature dimensions.
no_pad = ((0, 0),)
padding = tuple(padding)
padding = no_pad + padding + no_pad
if len(padding) == len(x_shape) - 2:
# If necessary, add empty padding for batch and feature dimensions.
no_pad = ((0, 0),)
padding = no_pad + padding + no_pad
x = tf.pad(x, padding)
assert len(x.shape) == len(padding)
x_shape = tuple(p0 + xs + p1 for xs, (p0, p1) in zip(x_shape, padding))
@ -517,11 +518,9 @@ tf_impl_no_xla[lax.argmin_p] = partial(_argminmax, True)
tf_impl_no_xla[lax.argmax_p] = partial(_argminmax, False)
def _reduce_monoid(operand, window_dimensions, window_strides, padding,
base_dilation, window_dilation, computation_name,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
dtype = operand.dtype
def _validate_reduce_window_inputs(operand_shape, computation_name, dtype,
window_dimensions, window_strides,
base_dilation, window_dilation):
if computation_name not in ["min", "max", "add"]:
raise _reduce_error("Reduction function should be either min, max, or add.")
if computation_name in ["min", "max"] and dtype in [
@ -540,35 +539,117 @@ def _reduce_monoid(operand, window_dimensions, window_strides, padding,
raise _reduce_error("Add pooling does not support operands of type "
f"{dtype}")
# In presence of shape polymorphism, operand.shape may contain None. The
# actual dimension polynomial shapes are in _in_avals.
operand_shape = _in_avals[0].shape
if (len(operand_shape) != len(window_dimensions) != len(window_strides) !=
len(window_dilation)):
raise _reduce_error("Input shapes, window dimensions, window stride "
"dimensions, and window dilation dimensions should "
"match.")
has_only_spatial_dims = True
if len(operand_shape) > 4:
raise _reduce_error("Only 1D or 2D input are supported.")
if len(operand_shape) > 2:
# operand_shape = (batch, spatial_dims, ..., channel).
has_only_spatial_dims = False
for name, value in [("window_dimensions", window_dimensions),
("window_strides", window_strides),
("window_dilation", window_dilation)]:
if value[0] != value[-1] != 1:
raise _reduce_error("Only 1D or 2D input are supported, expected "
f"{name}=(1, spatial_dims, ..., 1), but got "
f"{value}")
if list(base_dilation) != [1] * len(operand_shape):
# TODO(marcvanzee): Add support for base dilations. We can do this using
# a scatter on operand.
raise _reduce_error("Unimplemented support for base dilation.")
return has_only_spatial_dims
def _padding_reduce_window(operand, operand_shape, computation_name,
window_dimensions, window_strides, padding):
padding_type = pads_to_padtype(operand_shape, window_dimensions,
window_strides, padding)
if padding_type == "EXPLICIT":
# TODO(marcvanzee): Add support for explicit padding. This can be done
# similarly like we did for convolutions.
raise _reduce_error("Only 'VALID' and 'SAME' padding are currently "
"supported.")
def tf_pool(op, pooling_type):
# Add batch and channel dimensions, these are expected by TF.
op = tf.reshape(op, (1,) + operand_shape + (1,))
op = tf.nn.pool(
input=op,
# https://github.com/google/jax/issues/11874.
needs_manual_padding = (
padding_type == "SAME" and computation_name == "add" and
window_dimensions != [1] * len(operand_shape))
if needs_manual_padding or padding_type == "EXPLICIT":
operand, operand_shape = _pad_spatial_dims(operand, operand_shape, padding)
padding_type = "VALID"
return operand, operand_shape, padding_type
def _reshape_reduce_window(operand, operand_shape, window_dimensions,
window_strides, window_dilation, *,
has_only_spatial_dims):
# Reshape inputs so they are accepted by tf.nn.pool, which expects batch and
# channel dimensions for operand but not for any of the other inputs.
if has_only_spatial_dims: # len(operand_shape) <= 2
# Call eval_shape on a shape that may contain polynomials, otherwise TF does
# not know what to do with polynomials in the shape.
operand_shape = jax2tf._eval_shape(operand_shape)
# Add batch and channel dimensions to operand.
operand = tf.reshape(operand, (1,) + operand_shape + (1,))
else:
# This branch assumes operand.shape = (batch, spatial_dims, ..., channel),
# and dimensions, strides, dilation are all (1, spatial_values, ..., 1).
# Input validation for this is done in _validate_reduce_window_inputs.
window_dimensions = window_dimensions[1:-1]
window_strides = window_strides[1:-1]
window_dilation = window_dilation[1:-1]
return operand, window_dimensions, window_strides, window_dilation
def _reduce_monoid(operand, window_dimensions, window_strides, padding,
base_dilation, window_dilation, computation_name,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
dtype = operand.dtype
# In presence of shape polymorphism, operand.shape may contain None. The
# actual dimension polynomial shapes are in _in_avals.
operand_shape = _in_avals[0].shape
# TODO(marcvanzee): Put reduce_window arguments into dataclass, similar to
# Gather, to simplify function calls.
has_only_spatial_dims = _validate_reduce_window_inputs(
operand_shape, computation_name, dtype, window_dimensions, window_strides,
base_dilation, window_dilation)
operand, operand_shape, padding_type = _padding_reduce_window(
operand, operand_shape, computation_name, window_dimensions,
window_strides, padding)
operand, window_dimensions, window_strides, dilations = _reshape_reduce_window(
operand,
operand_shape,
window_dimensions,
window_strides,
window_dilation,
has_only_spatial_dims=has_only_spatial_dims)
def tf_pool(inputs, pooling_type):
result = tf.nn.pool(
inputs,
window_shape=window_dimensions,
pooling_type=pooling_type,
padding=padding_type,
strides=window_strides,
dilations=window_dilation)
op = tf.reshape(op, jax2tf._aval_to_tf_shape(_out_aval))
return op
dilations=dilations)
if has_only_spatial_dims:
# If the input only had spatial dimensions we need to contract the batch
# and channel dimensions before returning the output.
result = tf.squeeze(result, [0, -1])
jax2tf._assert_matching_abstract_shape(result, _out_aval.shape)
return result
negate = lambda x: tf.multiply(x, tf.constant(-1, dtype))
if computation_name == "max":
@ -577,8 +658,8 @@ def _reduce_monoid(operand, window_dimensions, window_strides, padding,
return negate(tf_pool(negate(operand), "MAX"))
elif computation_name == "add":
# TODO(marcvanzee): This may give very large deviations on TPU when using
# floats as inputs. We should think of a different implementation if users
# run into this often.
# floats as inputs. Alternatively, we could implement this using a
# convolution with an all-1's kernel.
return tf.multiply(tf_pool(operand, "AVG"), np.prod(window_dimensions))

View File

@ -124,25 +124,20 @@ class Jax2TfLimitation(primitive_harness.Limitation):
# We keep here the explicit set of groups for which we don't have limitations
harness_groups_no_limitations = {
"abs", "add", "add_any", "and", "atan2",
"bitcast_convert_type", "broadcast", "broadcast_in_dim", "cbrt", "ceil",
"clamp", "concatenate", "cos", "cosh", "complex", "conj",
"convert_element_type",
"cummax", "cummin", "device_put", "dynamic_slice",
"dynamic_update_slice", "exp", "eq", "floor", "gather", "ge", "gt",
"imag",
"iota", "is_finite", "le", "lt", "log", "mul", "ne", "neg", "not",
"or", "pad", "population_count",
"abs", "add", "add_any", "and", "atan2", "bitcast_convert_type",
"broadcast", "broadcast_in_dim", "cbrt", "ceil", "clamp", "concatenate",
"cos", "cosh", "complex", "conj", "convert_element_type", "cummax",
"cummin", "device_put", "dynamic_slice", "dynamic_update_slice", "exp",
"eq", "floor", "gather", "ge", "gt", "imag", "iota", "is_finite", "le",
"lt", "log", "mul", "ne", "neg", "not", "or", "pad", "population_count",
"random_categorical", "random_split", "random_uniform", "random_randint",
"reduce",
"reduce_and", "reduce_prod", "reduce_or", "reduce_sum",
"reduce_window_mul", "reduce_window_min",
"reduce_window_max",
"real", "reshape", "rev", "rsqrt", "scatter_max", "scatter_min",
"select_n", "select_and_scatter_add",
"shift_left", "shift_right_logical", "shift_right_arithmetic", "sign",
"sin", "sinh", "slice", "sqrt", "squeeze", "stop_gradient", "sub",
"tie_in", "transpose", "xor", "zeros_like"
"reduce", "reduce_and", "reduce_prod", "reduce_or", "reduce_sum",
"reduce_window_mul", "reduce_window_min", "reduce_window_max", "real",
"reshape", "rev", "rsqrt", "scatter_max", "scatter_min", "select_n",
"select_and_scatter_add", "shift_left", "shift_right_logical",
"shift_right_arithmetic", "sign", "sin", "sinh", "slice", "sqrt",
"squeeze", "stop_gradient", "sub", "tie_in", "transpose", "xor",
"zeros_like"
}
@classmethod
@ -910,6 +905,15 @@ class Jax2TfLimitation(primitive_harness.Limitation):
@classmethod
def reduce_window_add(cls, harness: primitive_harness.Harness):
return [
Jax2TfLimitation(
"Small deviations on GPU for large inputs and enable_xla=False",
dtypes=[np.float32],
devices="gpu",
modes=("eager", "graph", "compiled"),
expect_tf_error=False,
skip_comparison=False,
enabled=not harness.params["enable_xla"],
tol=3e-5),
Jax2TfLimitation(
"Large deviations on TPU for enable_xla=False",
dtypes=[np.float16, np.float32],

View File

@ -828,7 +828,7 @@ for prim in [lax.argmin_p, lax.argmax_p]:
for index_dtype in jtu.dtypes.all_integer + jtu.dtypes.all_unsigned:
_make_argminmax_harness(prim, "index_dtype", index_dtype=index_dtype)
# Some special cases, with equal elements and NaN
# Some special cases, with equal elements and NaN
for name, operand in [
("nan_0", np.array([np.nan, np.nan, 2., -2., -np.nan, -np.nan], np.float32)),
("nan_1", np.array([np.nan, -np.nan, 2., -2.], np.float32)),
@ -2488,19 +2488,22 @@ _make_reduce_window_harness("base_dilation", base_dilation=(1, 2),
requires_xla=True)
# Validate window_dilation
_make_reduce_window_harness("window_dilation", window_dilation=(1, 2))
# Validate squeezing behavior and dimensions in tf.nn.max_pool
for shape, window_dimensions in [
((2,), (2,)), # 1 spatial dimension, left and right squeeze
((2, 1), (2, 1)), # 1 spatial dimension, left squeeze
((1, 2), (1, 2)), # 1 spatial dimension, right squeeze
((1, 2, 1), (1, 2, 1)), # 1 spatial dimension no squeeze
((2, 4), (2, 2)), # 2 spatial dimensions, left and right squeeze
((2, 4, 3), (2, 2, 2)), # 3 spatial dimensions, left and right squeeze
((1, 4, 3, 2, 1), (1, 2, 2, 2, 1)) # 3 spatial dimensions, no squeeze
# Validate batch and channel dimensions behavior. lax.reduce_window accepts
# inputs that either have or do not have batch and channel dimensions.
# N=batch, DHW=spatial, C=channel.
# Without XLA only supports 1D/2D reductions.
for shape, window_dimensions, requires_xla in [
((2,), (2,), False), # W
((2, 1), (2, 1), False), # WC
((1, 2), (1, 2), False), # NW
((1, 2, 1), (1, 2, 1), False), # NWC
((2, 4), (2, 2), False), # HW
((1, 2, 4, 1), (1, 2, 2, 1), False), # NHWC
((2, 4, 3), (2, 2, 2), True), # DHW
((1, 4, 3, 2, 1), (1, 2, 2, 2, 1), True) # NDHWC
]:
requires_xla = len(shape) > 2 # Without XLA only supports 1D/2D reductions.
_make_reduce_window_harness(
"squeeze_dim",
"batch_channel_dims",
computation=lax.max,
shape=shape,
dtype=np.float32,
@ -2512,17 +2515,31 @@ for shape, window_dimensions in [
window_dimensions=window_dimensions,
requires_xla=requires_xla)
# This corresponds to SAME padding.
_make_reduce_window_harness(
"same_padding",
shape=(112, 112),
init_value=-np.inf,
computation=lax.max,
window_dimensions=(3, 3),
window_strides=(2, 2),
padding="SAME")
for computation, id_value in [(lax.max, _get_max_identity(np.float32)),
(lax.min, _get_min_identity(np.float32)),
(lax.add, 0.)]:
_make_reduce_window_harness(
"same_padding",
shape=(112, 112),
init_value=id_value,
computation=computation,
window_dimensions=(3, 3),
window_strides=(2, 2),
padding="SAME")
# A few additional test cases for manual padding, which is applied when calling
# reduce_window with lax.add, SAME padding and window_dimensions != (1, 1, ...).
for window_dimensions, window_strides in [((2, 2), (1, 1)), ((3, 3), (2, 2)),
((13, 13), (5, 6))]:
_make_reduce_window_harness(
"manual_padding",
shape=(12, 12),
init_value=0.,
computation=lax.add,
window_dimensions=window_dimensions,
window_strides=window_strides,
padding="SAME")
# b/240647139
_make_reduce_window_harness(
"init_value_1d",
shape=(1, 16000),

View File

@ -102,7 +102,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
@primitive_harness.parameterized(
primitive_harness.all_harnesses,
include_jax_unimpl=False,
#one_containing="reduce_window_max",
#one_containing="reduce_window_add_same_padding",
)
@jtu.ignore_warning(
category=UserWarning, message="Using reduced precision for gradient.*")

View File

@ -1163,7 +1163,8 @@ def _make_harness(group_name: str, name: str,
match the expected exception string.
enable_and_disable_xla=True means that we generate two harnesses,
one with enable_xla=False.
one with enable_xla=False and one with enable_xal=True. Otherwise we create
only one harness with enable_xla=True.
"""
if enable_and_disable_xla:
return [
@ -1189,7 +1190,8 @@ def _make_harness(group_name: str, name: str,
dtype=np.float32,
poly_axes=poly_axes, check_result=check_result,
skip_jax_run=skip_jax_run, expect_error=expect_error,
tol=tol)
tol=tol,
**params)
_f32 = np.float32
@ -1637,13 +1639,26 @@ _POLY_SHAPE_TEST_HARNESSES = [
lambda x: lax.reduce_window(x, np.array(1., _f32), lax.min,
(2, 2), (1, 1), "VALID"),
[RandArg((3, 8), _f32)],
poly_axes=[0]),
poly_axes=[0],
enable_and_disable_xla=True),
_make_harness("reduce_window", "add",
# x.shape = (b, 8)
lambda x: lax.reduce_window(x, 0, lax.add, (2, 2), (1, 1),
"VALID"),
[RandArg((3, 8), _f32)],
poly_axes=[0]),
poly_axes=[0],
enable_and_disable_xla=True),
# https://github.com/google/jax/issues/11804
# Use the reshape trick to simulate a polymorphic dimension of 16*b.
# (See test "conv_general_dilated.1d_1" above for more details.)
_make_harness("reduce_window", "add",
# x.shape = (1, 16*b, 1)
lambda x: lax.reduce_window(
jnp.reshape(x, (1, -1, 1)),
0., lax.add, (1, 4, 1), (1, 2, 1), "SAME"),
[RandArg((1, 128, 16), _f32)],
poly_axes=[1],
enable_and_disable_xla=True),
# TODO(necula): not yet supported, but also unlikely to come up.
# _make_harness("random_uniform", "odd",
# lambda key, a: jax.random.uniform(key, (2 * a.shape[0] + 1, a.shape[1]),
@ -1897,7 +1912,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
# to parameterized below.
@primitive_harness.parameterized(
_flatten_harnesses(_POLY_SHAPE_TEST_HARNESSES),
#one_containing="reduce_window_add",
#one_containing="reduce_window_add_noxla_poly_axes=[1]",
)
def test_prim(self, harness: Harness):
_test_one_harness(self, harness)

View File

@ -392,7 +392,8 @@ class JaxToTfTestCase(jtu.JaxTestCase):
"""
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes,
enable_xla=enable_xla)
f_tf_func = tf.function(f_tf, autograph=False, input_signature=input_signature)
f_tf_func = tf.function(
f_tf, autograph=False, input_signature=input_signature)
concrete_f_tf = f_tf_func.get_concrete_function(*input_signature)
if expected_output_signature:
# Strangely, output_shapes can be a single shape for a function with a