Generalize reduce-window padding to support (lo, hi) pairs. (#3728)

* Generalize reduce-window padding to support (lo, hi) pairs, as XLA does..

This turns out to simplify the code slightly, too.

* Fix select_and_gather_add batching rule and test.

* Fix documentation text to refer to ReduceWindowWithGeneralPadding.
This commit is contained in:
Peter Hawkins 2020-07-13 09:49:52 -04:00 committed by GitHub
parent a9da06ce75
commit 71253ac4c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 44 additions and 80 deletions

View File

@ -791,9 +791,6 @@ def _reduce_window(jax_f, reducer, init_val, operand, window_dimensions,
# TODO(tomhennigan): tf2xla should have a shape inference function.
out_shape = _reduce_window_shape(jax_f, operand, window_dimensions,
window_strides, padding)
padding = lax.padtype_to_pads(_get_shape_from_tensor_or_array(operand),
window_dimensions,
window_strides, padding)
a = tf.constant(0, operand.dtype)
reducer_fn = reducer.get_concrete_function(a, a)
out = tfxla.reduce_window(operand, tf.constant(init_val, operand.dtype),

View File

@ -174,8 +174,10 @@ def _pooling_layer(reducer, init_val, rescaler=None):
strides = strides[:i] + (1,) + strides[i:]
def init_fun(rng, input_shape):
padding_vals = lax.padtype_to_pads(input_shape, window_shape,
strides, padding)
out_shape = lax.reduce_window_shape_tuple(input_shape, window_shape,
strides, padding)
strides, padding_vals)
return out_shape, ()
def apply_fun(params, inputs, **kwargs):
out = lax.reduce_window(inputs, init_val, reducer, window_shape,

View File

@ -1115,11 +1115,14 @@ def _reduce_and(operand: Array, axes: Sequence[int]) -> Array:
def reduce_window(operand: Array, init_value: Array, computation: Callable,
window_dimensions: Shape, window_strides: Sequence[int],
padding: str) -> Array:
"""Wraps XLA's `ReduceWindow
padding: Union[str, Sequence[Tuple[int, int]]]) -> Array:
"""Wraps XLA's `ReduceWindowWithGeneralPadding
<https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_
operator.
"""
if isinstance(padding, str):
padding = padtype_to_pads(operand.shape, window_dimensions,
window_strides, padding)
monoid_reducer = _get_monoid_window_reducer(computation, init_value)
if monoid_reducer:
return monoid_reducer(operand, window_dimensions, window_strides, padding)
@ -1142,63 +1145,67 @@ def _get_monoid_window_reducer(monoid_op: Callable, x: Array) -> Optional[Callab
return None
def _reduce_window_sum(operand: Array, window_dimensions: Shape,
window_strides: Sequence[int], padding: str) -> Array:
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]]) -> Array:
return reduce_window_sum_p.bind(
operand, window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=padding)
window_strides=tuple(window_strides), padding=tuple(padding))
def _reduce_window_prod(operand: Array, window_dimensions: Shape,
window_strides: Sequence[int], padding: str) -> Array:
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]]) -> Array:
init_value = _const(operand, 1)
jaxpr, consts = _reduction_jaxpr(mul, _abstractify(init_value))
return reduce_window_p.bind(
operand, init_value, jaxpr=jaxpr, consts=consts,
window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=padding)
window_strides=tuple(window_strides), padding=tuple(padding))
def _reduce_window_max(operand: Array, window_dimensions: Shape,
window_strides: Sequence[int], padding: str) -> Array:
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]]) -> Array:
return reduce_window_max_p.bind(
operand, window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=padding)
window_strides=tuple(window_strides), padding=tuple(padding))
def _reduce_window_min(operand: Array, window_dimensions: Shape,
window_strides: Sequence[int], padding: str) -> Array:
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]]) -> Array:
return reduce_window_min_p.bind(
operand, window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=padding)
window_strides=tuple(window_strides), padding=tuple(padding))
def _select_and_scatter(operand: Array, select: Callable,
window_dimensions: Shape, window_strides: Sequence[int],
padding: str, source: Array, init_value: Array,
scatter: Callable) -> Array:
padding: Sequence[Tuple[int, int]], source: Array,
init_value: Array, scatter: Callable) -> Array:
select_jaxpr, select_consts = _reduction_jaxpr(select, _abstractify(init_value))
scatter_jaxpr, scatter_consts = _reduction_jaxpr(scatter, _abstractify(init_value))
return select_and_scatter_p.bind(
operand, source, init_value, select_jaxpr=select_jaxpr,
select_consts=select_consts, scatter_jaxpr=scatter_jaxpr,
scatter_consts=scatter_consts, window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=padding)
window_strides=tuple(window_strides), padding=tuple(padding))
def _select_and_scatter_add(source: Array, operand: Array,
select_prim: core.Primitive,
window_dimensions: Shape,
window_strides: Sequence[int],
padding: str) -> Array:
padding: Sequence[Tuple[int, int]]) -> Array:
return select_and_scatter_add_p.bind(
source, operand, select_prim=select_prim,
window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=padding)
window_strides=tuple(window_strides), padding=tuple(padding))
def _select_and_gather_add(tangents: Array, operand: Array,
select_prim: core.Primitive,
window_dimensions: Shape,
window_strides: Sequence[int],
padding: str) -> Array:
padding: Sequence[Tuple[int, int]]) -> Array:
return select_and_gather_add_p.bind(
tangents, operand, select_prim=select_prim,
window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=padding)
window_strides=tuple(window_strides), padding=tuple(padding))
def cumsum(operand: Array, axis: int) -> Array:
"""Computes a cumulative sum along `axis`."""
@ -4450,12 +4457,9 @@ def _reduce_window_shape_rule(operand, init_value, *, jaxpr, consts,
def _reduce_window_translation_rule(c, operand, init_value, *, jaxpr, consts,
window_dimensions, window_strides, padding):
xla_computation = _reduction_computation(c, jaxpr, consts, init_value)
pads = xc.window_padding_type_to_pad_values(
padding, c.get_shape(operand).dimensions(), window_dimensions,
window_strides)
return xops.ReduceWindowWithGeneralPadding(
operand, init_value, xla_computation, window_dimensions,
window_strides, (), (), pads)
window_strides, (), (), padding)
def _generic_reduce_window_batch_rule(
batched_args, batch_dims, *, jaxpr, consts, window_dimensions,
@ -4494,29 +4498,24 @@ def _reduce_window_sum_translation_rule(c, operand, *, window_dimensions,
window_strides, padding):
dtype = c.get_shape(operand).numpy_dtype()
scalar = ShapedArray((), dtype)
pads = xc.window_padding_type_to_pad_values(
padding, c.get_shape(operand).dimensions(), window_dimensions,
window_strides)
return xops.ReduceWindowWithGeneralPadding(
operand, xb.constant(c, onp.array(0, dtype)),
xla.primitive_subcomputation(add_p, scalar, scalar), window_dimensions,
window_strides, (), (), pads)
window_strides, (), (), padding)
def _reduce_window_sum_transpose_rule(cotangent, operand, *, window_dimensions,
window_strides, padding):
assert ad.is_undefined_primal(operand)
input_shape = operand.aval.shape
in_pads = padtype_to_pads(input_shape, window_dimensions, window_strides,
padding)
ones = [1] * len(input_shape)
pads = _conv_general_vjp_lhs_padding(
input_shape, window_dimensions, window_strides, cotangent.shape, in_pads,
input_shape, window_dimensions, window_strides, cotangent.shape, padding,
ones, ones)
padding_config = [(lo, hi, stride - 1)
for (lo, hi), stride in zip(pads, window_strides)]
pad_cotangent = pad(cotangent, _zero(cotangent), padding_config)
result = _reduce_window_sum(pad_cotangent, window_dimensions, ones,
xla_client.PaddingType.VALID)
[(0, 0)] * len(input_shape))
assert result.shape == input_shape
return [result]
@ -4529,10 +4528,8 @@ def _reduce_window_batch_rule(reduce_window, batched_args, bdims, *,
window_dimensions = \
window_dimensions[:bdim] + (1,) + window_dimensions[bdim:]
window_strides = window_strides[:bdim] + (1,) + window_strides[bdim:]
operand = reduce_window(
operand, window_dimensions, window_strides, padding)
padding = padding[:bdim] + ((0, 0),) + padding[bdim:]
operand = reduce_window(operand, window_dimensions, window_strides, padding)
return operand, bdim
reduce_window_sum_p = standard_primitive(
@ -4546,13 +4543,10 @@ def _reduce_window_chooser_translation_rule(
prim, identity, c, operand, *, window_dimensions, window_strides, padding):
dtype = c.get_shape(operand).numpy_dtype()
scalar = ShapedArray((), dtype)
pads = xc.window_padding_type_to_pad_values(
padding, c.get_shape(operand).dimensions(), window_dimensions,
window_strides)
return xops.ReduceWindowWithGeneralPadding(
operand, xb.constant(c, identity(dtype)),
xla.primitive_subcomputation(prim, scalar, scalar), window_dimensions,
window_strides, (), (), pads)
window_strides, (), (), padding)
def _reduce_window_chooser_jvp_rule(prim, g, operand, *, window_dimensions,
window_strides, padding):
@ -4580,8 +4574,7 @@ def _common_reduce_window_shape_rule(operand, window_dimensions,
def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
padding):
pads = padtype_to_pads(operand_shape, window_dimensions, window_strides, padding)
operand_padded = onp.add(operand_shape, onp.add(*zip(*pads)))
operand_padded = onp.add(operand_shape, onp.add(*zip(*padding)))
t = onp.floor_divide(
onp.subtract(operand_padded, window_dimensions), window_strides) + 1
return tuple(t)
@ -4624,11 +4617,8 @@ def _select_and_scatter_translation(
scatter_consts, window_dimensions, window_strides, padding):
select = _reduction_computation(c, select_jaxpr, select_consts, init_value)
scatter = _reduction_computation(c, scatter_jaxpr, scatter_consts, init_value)
pads = xc.window_padding_type_to_pad_values(
padding, c.get_shape(operand).dimensions(), window_dimensions,
window_strides)
return xops.SelectAndScatterWithGeneralPadding(
operand, select, window_dimensions, window_strides, pads, source,
operand, select, window_dimensions, window_strides, padding, source,
init_value, scatter)
select_and_scatter_p = standard_primitive(
@ -4649,11 +4639,8 @@ def _select_and_scatter_add_translation(
select = xla.primitive_subcomputation(select_prim, scalar, scalar)
scatter = xla.primitive_subcomputation(add_p, scalar, scalar)
zero = xb.constant(c, onp.array(0, dtype))
pads = xc.window_padding_type_to_pad_values(
padding, c.get_shape(operand).dimensions(), window_dimensions,
window_strides)
return xops.SelectAndScatterWithGeneralPadding(
operand, select, window_dimensions, window_strides, pads, source, zero,
operand, select, window_dimensions, window_strides, padding, source, zero,
scatter)
def _select_and_scatter_add_jvp(
@ -4834,12 +4821,9 @@ def _select_and_gather_add_translation(
assert select_prim is ge_p or select_prim is le_p, select_prim
init = -onp.inf if select_prim is ge_p else onp.inf
pads = xc.window_padding_type_to_pad_values(
padding, c.get_shape(operand).dimensions(), window_dimensions,
window_strides)
out = xops.ReduceWindowWithGeneralPadding(
pack(operand, tangents), pack(const(c, dtype, init), const(c, dtype, 0)),
reducer(), window_dimensions, window_strides, (), (), pads)
reducer(), window_dimensions, window_strides, (), (), padding)
return snd(out)
def _select_and_gather_add_jvp(
@ -4878,6 +4862,7 @@ def _select_and_gather_add_batching_rule(
x = batching.bdim_at_front(x, x_bdim, size)
window_dimensions = (1,) + window_dimensions
window_strides = (1,) + window_strides
padding = ((0, 0),) + padding
out = _select_and_gather_add(t, x, select_prim, window_dimensions,
window_strides, padding)
return (out, 0)
@ -4979,36 +4964,15 @@ def _cumred_tpu_translation_rule(window_reduce: Callable, x, *,
axis: int):
# On TPU, an implementation using reduce_window is handled specially by the
# compiler and is efficient. On other backends, it is O(n^2).
if window_reduce is _reduce_window_max:
if onp.issubdtype(x.dtype, onp.integer):
unit = onp.iinfo(x.dtype).min
elif onp.issubdtype(x.dtype, onp.bool_):
unit = False
else: # inexact
unit = -onp.inf
elif window_reduce is _reduce_window_min:
if onp.issubdtype(x.dtype, onp.integer):
unit = onp.iinfo(x.dtype).max
elif onp.issubdtype(x.dtype, onp.bool_):
unit = True
else: # inexact
unit = onp.inf
elif window_reduce is _reduce_window_sum:
unit = 0
elif window_reduce is _reduce_window_prod:
unit = 1
else:
raise ValueError("Unknown type of reducer, get {}".format(window_reduce))
n = x.shape[axis]
if n == 0:
return x
padding = [(0, 0, 0)] * x.ndim
padding[axis] = (n - 1, 0, 0)
x = pad(x, _const(x, unit), padding)
padding = [(0, 0)] * x.ndim
padding[axis] = (n - 1, 0)
strides = [1] * x.ndim
window_dims = [1] * x.ndim
window_dims[axis] = n
return window_reduce(x, window_dims, strides, xla_client.PaddingType.VALID)
return window_reduce(x, window_dims, strides, padding)
def _cumred_batch_rule(prim, batched_args, batch_dims, *, axis: int):
operand, = batched_args

View File

@ -566,8 +566,9 @@ class LaxVmapTest(jtu.JaxTestCase):
[(1, 2, 2, 1), (1, 1, 1, 1)]))
def fun(operand, tangents):
pads = lax.padtype_to_pads(operand.shape, dims, strides, padding)
return lax._select_and_gather_add(operand, tangents, lax.ge_p, dims,
strides, padding)
strides, pads)
for shape, dims, strides in all_configs:
for bdims in all_bdims(shape, shape):