Make lax.reduce_window variadic.

This is similar to the support in lax.reduce(), where the operands and init_values become pytrees. This is a strict superset of the current API, so users should not need updates.

Variadic lax.reduce_window() is only supported on CPU and TPU at the moment, not GPU.

PiperOrigin-RevId: 411632993
This commit is contained in:
Peter Hawkins 2021-11-22 13:20:55 -08:00 committed by jax authors
parent ad6ce74d67
commit 5415306257
5 changed files with 206 additions and 96 deletions

View File

@ -1424,8 +1424,8 @@ def argmax(operand: Array, axis: int,
return argmax_p.bind(operand, axes=(axis,),
index_dtype=dtypes.canonicalize_dtype(index_dtype))
def reduce(operands: Array, init_values: Array, computation: Callable,
dimensions: Sequence[int]) -> Array:
def reduce(operands, init_values, computation: Callable,
dimensions: Sequence[int]):
"""Wraps XLA's `Reduce
<https://www.tensorflow.org/xla/operation_semantics#reduce>`_
operator.
@ -1483,7 +1483,8 @@ def _variadic_reduction_jaxpr(computation, flat_avals, aval_tree):
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_comp, tuple(flat_in_avals))
return jaxpr, tuple(consts), out_tree()
def _get_monoid_reducer(monoid_op: Callable, xs: Array) -> Optional[Callable]:
def _get_monoid_reducer(monoid_op: Callable,
xs: Sequence[Array]) -> Optional[Callable]:
if len(xs) != 1:
return None
x, = xs

View File

@ -24,6 +24,7 @@ from jax.interpreters import xla
from jax import core
from jax.core import (ShapedArray, ConcreteArray)
from jax import tree_util
from jax._src import ad_util
from jax._src import dtypes
@ -36,7 +37,10 @@ from jax._src.lax.lax import (
)
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
import jax._src.util as util
map = util.safe_map
zip = util.safe_zip
xb = xla_bridge
xc = xla_client
@ -45,7 +49,7 @@ xops = xla_client.ops
Array = Any
def reduce_window(operand: Array, init_value: Array, computation: Callable,
def reduce_window(operand, init_value, computation: Callable,
window_dimensions: core.Shape, window_strides: Sequence[int],
padding: Union[str, Sequence[Tuple[int, int]]],
base_dilation: Optional[Sequence[int]] = None,
@ -54,31 +58,52 @@ def reduce_window(operand: Array, init_value: Array, computation: Callable,
<https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_
operator.
"""
flat_operands, operand_tree = tree_util.tree_flatten(operand)
flat_init_values, init_value_tree = tree_util.tree_flatten(init_value)
if operand_tree != init_value_tree:
raise ValueError('Operands must have the same tree structure as '
f'init_values: {operand_tree} vs. {init_value_tree}')
if len(flat_operands) == 0:
raise ValueError('reduce_window must have at least one operand.')
if len(flat_operands) != len(flat_init_values):
raise ValueError('Must have same total number of operands as init_values: '
f' {len(flat_operands)} vs. {len(flat_init_values)}')
if isinstance(padding, str):
dilated_window_dims = (window_dimensions if window_dilation is None else
_dilate_shape(window_dimensions, window_dilation))
padding = tuple(lax.padtype_to_pads(operand.shape, dilated_window_dims,
window_strides, padding))
padding = tuple(lax.padtype_to_pads(
flat_operands[0].shape, dilated_window_dims, window_strides, padding))
else:
padding = tuple(padding)
if base_dilation is None:
base_dilation = (1,) * len(window_dimensions)
if window_dilation is None:
window_dilation = (1,) * len(window_dimensions)
monoid_reducer = _get_monoid_window_reducer(computation, init_value)
monoid_reducer = _get_monoid_window_reducer(computation, flat_init_values)
if monoid_reducer:
return monoid_reducer(operand, window_dimensions, window_strides, padding,
base_dilation, window_dilation)
else:
jaxpr, consts = lax._reduction_jaxpr(computation, lax._abstractify(init_value))
return reduce_window_p.bind(
operand, init_value, jaxpr=jaxpr, consts=consts,
flat_init_avals = map(lax._abstractify, flat_init_values)
jaxpr, consts, out_tree = lax._variadic_reduction_jaxpr(
computation, tuple(flat_init_avals), init_value_tree)
if operand_tree != out_tree:
raise ValueError(
'reduce_window output must have the same tree structure as the operands'
f' {operand_tree} vs. {out_tree}')
out_flat = reduce_window_p.bind(
*(flat_operands + flat_init_values), jaxpr=jaxpr, consts=consts,
window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=padding,
base_dilation=tuple(base_dilation),
window_dilation=tuple(window_dilation))
return tree_util.tree_unflatten(out_tree, out_flat)
def _get_monoid_window_reducer(monoid_op: Callable, x: Array) -> Optional[Callable]:
def _get_monoid_window_reducer(monoid_op: Callable,
xs: Sequence[Array]) -> Optional[Callable]:
if len(xs) != 1:
return None
x, = xs
aval = core.get_aval(x)
if (type(aval) is ConcreteArray) and aval.shape == ():
if monoid_op is lax.add:
@ -115,12 +140,13 @@ def _reduce_window_prod(operand: Array, window_dimensions: core.Shape,
base_dilation = (1,) * len(window_dimensions)
if window_dilation is None:
window_dilation = (1,) * len(window_dimensions)
return reduce_window_p.bind(
out, = reduce_window_p.bind(
operand, init_value, jaxpr=jaxpr, consts=consts,
window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=tuple(padding),
base_dilation=tuple(base_dilation),
window_dilation=tuple(window_dilation))
return out
def _reduce_window_max(operand: Array, window_dimensions: core.Shape,
window_strides: Sequence[int],
@ -189,10 +215,10 @@ def _select_and_gather_add(tangents: Array, operand: Array,
Wraps XLA's `ReduceWindow
<https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_
operator, which applies a reduction function to all elements in each window of the
input multi-dimensional array. In this case, the input multi-dimensional array is
built by packing each element in the `operand` array with its corresponding
element in the `tangents` array.
operator, which applies a reduction function to all elements in each window of
the input multi-dimensional array. In this case, the input multi-dimensional
array is built by packing each element in the `operand` array with its
corresponding element in the `tangents` array.
Args:
tangents: an array
@ -204,8 +230,8 @@ def _select_and_gather_add(tangents: Array, operand: Array,
window_dilation: an array of integers for window dilation values
Returns:
An array containing the elements in `tangents` corresponding to the output of the
reduction of `operand` fin each window.
An array containing the elements in `tangents` corresponding to the output
of the reduction of `operand` fin each window.
"""
return select_and_gather_add_p.bind(
tangents, operand, select_prim=select_prim,
@ -215,56 +241,70 @@ def _select_and_gather_add(tangents: Array, operand: Array,
window_dilation=tuple(window_dilation))
def _reduce_window_shape_rule(operand, init_value, *, jaxpr, consts,
window_dimensions, window_strides, padding,
base_dilation, window_dilation):
if operand.dtype != init_value.dtype:
msg = ("reduce_window got inconsistent dtypes for operand and init_value: "
" got operand dtype {} and init_value dtype {}.")
raise TypeError(msg.format(operand.dtype, init_value.dtype))
if init_value.shape != ():
msg = ("reduce_window expected init_value to be a scalar but init_value "
"has shape {}.")
raise TypeError(msg.format(init_value.shape))
return _common_reduce_window_shape_rule(
operand, window_dimensions, window_strides, padding, base_dilation,
window_dilation)
def _reduce_window_abstract_eval_rule(
*avals, jaxpr, consts, window_dimensions, window_strides, padding,
base_dilation, window_dilation):
operand_avals, init_val_avals = util.split_list(avals, [len(avals) // 2])
if any(o.dtype != iv.dtype for o, iv in zip(operand_avals, init_val_avals)):
msg = ("reduce_window got inconsistent dtypes for operands and init_values:"
" got operand dtypes {} and init_value dtypes {}.")
raise TypeError(msg.format([o.dtype for o in operand_avals],
[iv.dtype for iv in init_val_avals]))
if any(len(v.shape) != 0 for v in init_val_avals):
msg = ("reduce_window expected init_values to be scalars but init_values "
"have shapes {}.")
raise TypeError(msg.format([v.shape for v in init_val_avals]))
out_shape = _common_reduce_window_shape_rule(
operand_avals[0], window_dimensions, window_strides, padding,
base_dilation, window_dilation)
return tuple(ShapedArray(out_shape, op.dtype) for op in operand_avals)
def _reduce_window_translation_rule(ctx, avals_in, avals_out, operand,
init_value, *, jaxpr, consts,
window_dimensions, window_strides, padding,
base_dilation, window_dilation):
xla_computation = lax._reduction_computation(ctx, jaxpr, consts, init_value)
return [xops.ReduceWindowWithGeneralPadding(
operand, init_value, xla_computation, window_dimensions,
window_strides, base_dilation, window_dilation, padding)]
def _reduce_window_translation_rule(ctx, avals_in, avals_out, *args, jaxpr,
consts, window_dimensions, window_strides,
padding, base_dilation, window_dilation):
operands, init_values = util.split_list(args, [len(args) // 2])
xla_computation = lax._reduction_computation(ctx, jaxpr, consts, init_values,
singleton=False)
return xla.xla_destructure(ctx.builder, xops.ReduceWindowWithGeneralPadding(
operands, init_values, xla_computation, window_dimensions,
window_strides, base_dilation, window_dilation, padding))
def _generic_reduce_window_batch_rule(
batched_args, batch_dims, *, jaxpr, consts, window_dimensions,
window_strides, padding, base_dilation, window_dilation):
num_operands = len(batched_args) // 2
operands, init_values = util.split_list(batched_args, [num_operands])
operand_bdims, init_value_bdims = util.split_list(batch_dims, [num_operands])
operand, init = batched_args
bdim, init_bdim = batch_dims
if init_bdim is not None:
if any(init_bdim is not None for init_bdim in init_value_bdims):
raise NotImplementedError("reduce_window batching is not implemented for "
"initial values")
def reduce_window(x, window_dimensions, window_strides, padding, base_dilation,
window_dilation):
return reduce_window_p.bind(
x, init, jaxpr=jaxpr, consts=consts, window_dimensions=window_dimensions,
window_strides=window_strides, padding=padding, base_dilation=base_dilation,
size = next(x.shape[ax] for x, ax in zip(operands, operand_bdims)
if ax is not None)
operands = [batching.bdim_at_front(arg, bdim, size)
for arg, bdim in zip(operands, operand_bdims)]
window_dimensions = (1,) + window_dimensions
window_strides = (1,) + window_strides
padding = ((0, 0),) + padding
base_dilation = (1,) + base_dilation
window_dilation = (1,) + window_dilation
outs = reduce_window_p.bind(
*(operands + init_values), jaxpr=jaxpr, consts=consts,
window_dimensions=window_dimensions, window_strides=window_strides,
padding=padding, base_dilation=base_dilation,
window_dilation=window_dilation)
return _reduce_window_batch_rule(
reduce_window, (operand,), (bdim,), window_dimensions=window_dimensions,
window_strides=window_strides, padding=padding, base_dilation=base_dilation,
window_dilation=window_dilation)
return outs, (0,) * num_operands
reduce_window_p = lax.standard_primitive(
_reduce_window_shape_rule, _input_dtype, 'reduce_window',
_reduce_window_translation_rule)
reduce_window_p = core.Primitive('reduce_window')
reduce_window_p.multiple_results = True
reduce_window_p.def_impl(partial(xla.apply_primitive, reduce_window_p))
reduce_window_p.def_abstract_eval(_reduce_window_abstract_eval_rule)
batching.primitive_batchers[reduce_window_p] = _generic_reduce_window_batch_rule
xla.register_translation(reduce_window_p, _reduce_window_translation_rule)
def _reduce_window_sum_shape_rule(operand, *, window_dimensions, window_strides,
padding, base_dilation, window_dilation):
@ -272,8 +312,8 @@ def _reduce_window_sum_shape_rule(operand, *, window_dimensions, window_strides,
msg = "operand to reduce_window_sum must have a number dtype, got {}"
raise TypeError(msg.format(np.dtype(operand.dtype).name))
return _common_reduce_window_shape_rule(operand, window_dimensions,
window_strides, padding, base_dilation,
window_dilation)
window_strides, padding,
base_dilation, window_dilation)
def _reduce_window_sum_translation_rule(ctx, avals_in, avals_out, operand, *,
window_dimensions, window_strides,
@ -436,8 +476,10 @@ def _select_and_scatter_translation(
ctx, avals_in, avals_out, operand, source, init_value, *, select_jaxpr,
select_consts, scatter_jaxpr, scatter_consts, window_dimensions,
window_strides, padding):
select = lax._reduction_computation(ctx, select_jaxpr, select_consts, init_value)
scatter = lax._reduction_computation(ctx, scatter_jaxpr, scatter_consts, init_value)
select = lax._reduction_computation(ctx, select_jaxpr, select_consts,
init_value)
scatter = lax._reduction_computation(ctx, scatter_jaxpr, scatter_consts,
init_value)
return [xops.SelectAndScatterWithGeneralPadding(
operand, select, window_dimensions, window_strides, padding, source,
init_value, scatter)]
@ -462,7 +504,8 @@ def _select_and_scatter_add_translation(
select = xla.primitive_subcomputation(
ctx.platform, select_prim, scalar, scalar)
scatter = xla.primitive_subcomputation(
ctx.platform, lax.or_p if dtype == np.bool_ else lax.add_p, scalar, scalar)
ctx.platform, lax.or_p if dtype == np.bool_ else lax.add_p, scalar,
scalar)
zero = xla.pyval_to_ir_constant(c, np.array(0, dtype))
# TODO(b/161704903): remove this workaround when XLA:CPU bug is fixed.
expand_padding = (expand_padding and
@ -599,11 +642,13 @@ def _select_and_gather_add_translation(
# Unpacks the first element of a tuple.
def fst(c, t):
st = xops.ShiftRightLogical(t, const(c, double_word_dtype, nbits))
return xops.BitcastConvertType(xops.ConvertElementType(st, word_type), etype)
return xops.BitcastConvertType(xops.ConvertElementType(st, word_type),
etype)
# Unpacks the second element of a tuple.
def snd(t):
return xops.BitcastConvertType(xops.ConvertElementType(t, word_type), etype)
return xops.BitcastConvertType(xops.ConvertElementType(t, word_type),
etype)
else:
# The double-word trick above only works if we have a sufficiently large
@ -638,8 +683,8 @@ def _select_and_gather_add_translation(
# Unpacks the second element of a tuple.
def snd(t):
return xops.BitcastConvertType(xops.ShiftLeft(t, const(c, word_dtype, r_nbits)),
etype)
return xops.BitcastConvertType(
xops.ShiftLeft(t, const(c, word_dtype, r_nbits)), etype)
def reducer():
c = xc.XlaBuilder("select_and_gather_pair_reducer")

View File

@ -1741,14 +1741,14 @@ def _common_reduce_window(operand, init_val, reducer, window_dimensions,
return out
def _reduce_window(operand, init_value, *, jaxpr, consts, window_dimensions,
def _reduce_window(*args, jaxpr, consts, window_dimensions,
window_strides, padding, base_dilation, window_dilation,
_in_avals, _out_aval):
"""TensorFlow implementation of reduce_window.
Args:
operand: N dimensional array containing elements of type T
init_value: starting value of the reduction
operands: N dimensional arrays containing elements of type T
init_values: starting values of the reduction
jaxpr: the jaxpr corresponding to the reduction function
consts: the constants associated with jaxpr.
window_dimensions: array of integers for window dimension values
@ -1761,15 +1761,20 @@ def _reduce_window(operand, init_value, *, jaxpr, consts, window_dimensions,
The reduced operand.
"""
assert len(consts) == 0, "Reduction computation cannot have constants"
operands, init_values = util.split_list(args, [len(args) // 2])
if len(operands) != 1:
raise NotImplementedError("jax2tf does not support variadic reduce_window")
def reducer(arg1: TfVal, arg2: TfVal) -> TfVal:
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
res, = _interpret_jaxpr(closed_jaxpr, arg1, arg2, extra_name_stack=None)
return res
return _common_reduce_window(operand, init_value, reducer, window_dimensions,
window_strides, padding, base_dilation,
window_dilation, _in_avals, _out_aval)
return (_common_reduce_window(operands[0], init_values[0], reducer,
window_dimensions, window_strides, padding,
base_dilation, window_dilation, _in_avals,
_out_aval[0]),)

View File

@ -36,6 +36,7 @@ from jax._src import dispatch
from jax._src import dtypes
from jax._src.lax import lax
from jax._src.lax import control_flow
from jax._src.lax import windowed_reductions as lax_windowed_reductions
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import builtin
from jax._src.lib.mlir.dialects import chlo
@ -1450,27 +1451,26 @@ translations[lax.argmax_p] = lower_fun(
multiple_results=False)
def _generic_reduce_window_lower(ctx, avals_in, avals_out, operand, init_value,
*, jaxpr, consts, window_dimensions,
window_strides, padding, base_dilation,
window_dilation):
aval_out, = avals_out
operand_aval, scalar_aval = avals_in
scalar_type = aval_to_ir_type(scalar_aval)
def _generic_reduce_window_lower(ctx, avals_in, avals_out, *args, jaxpr, consts,
window_dimensions, window_strides, padding,
base_dilation, window_dilation):
operands, init_values = util.split_list(args, [len(args) // 2])
_, init_value_avals = util.split_list(avals_in, [len(operands)])
scalar_types = [aval_to_ir_type(aval) for aval in init_value_avals]
rw = mhlo.ReduceWindowOp(
aval_to_ir_types(aval_out), [operand], [init_value],
map(aval_to_ir_type, avals_out), operands, init_values,
_dense_int_elements(window_dimensions),
_dense_int_elements(window_strides), _dense_int_elements(base_dilation),
_dense_int_elements(window_dilation),
ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
reducer = rw.regions[0].blocks.append(scalar_type, scalar_type)
reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
with ir.InsertionPoint(reducer):
out_nodes = jaxpr_subcomp(ctx, jaxpr, consts,
*([a] for a in reducer.arguments))
mhlo.ReturnOp(util.flatten(out_nodes))
return rw.results
translations[lax.reduce_window_p] = _generic_reduce_window_lower
translations[lax_windowed_reductions.reduce_window_p] = _generic_reduce_window_lower
def _reduce_window_lower(
@ -1492,11 +1492,11 @@ def _reduce_window_lower(
mhlo.ReturnOp(reduce_op(*reducer.arguments))
return rw.results
translations[lax.reduce_window_sum_p] = partial(
translations[lax_windowed_reductions.reduce_window_sum_p] = partial(
_reduce_window_lower, mhlo.AddOp, lambda _: 0)
translations[lax.reduce_window_min_p] = partial(
translations[lax_windowed_reductions.reduce_window_min_p] = partial(
_reduce_window_lower, mhlo.MinOp, lax._get_min_identity)
translations[lax.reduce_window_max_p] = partial(
translations[lax_windowed_reductions.reduce_window_max_p] = partial(
_reduce_window_lower, mhlo.MaxOp, lax._get_max_identity)
@ -1525,7 +1525,7 @@ def _select_and_scatter_lower(
mhlo.ReturnOp(util.flatten(out_nodes))
return op.results
translations[lax.select_and_scatter_p] = _select_and_scatter_lower
translations[lax_windowed_reductions.select_and_scatter_p] = _select_and_scatter_lower
def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions,
@ -1551,13 +1551,13 @@ def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions,
out = lax.slice(out, start_indices, stop_indices)
return out
translations[lax.select_and_scatter_add_p] = lower_fun(
translations[lax_windowed_reductions.select_and_scatter_add_p] = lower_fun(
partial(_select_and_scatter_add, expand_padding=False),
multiple_results=False)
platform_specific_translations['cpu'][lax.select_and_scatter_add_p] = lower_fun(
platform_specific_translations['cpu'][lax_windowed_reductions.select_and_scatter_add_p] = lower_fun(
partial(_select_and_scatter_add, expand_padding=True),
multiple_results=False)
platform_specific_translations['gpu'][lax.select_and_scatter_add_p] = lower_fun(
platform_specific_translations['gpu'][lax_windowed_reductions.select_and_scatter_add_p] = lower_fun(
partial(_select_and_scatter_add, expand_padding=True),
multiple_results=False)
@ -1826,10 +1826,14 @@ def _add_cumulative_reduce(prim, reducer, tpu_reduce_window_fn):
partial(control_flow._cumred_tpu_translation_rule, tpu_reduce_window_fn),
multiple_results=False)
_add_cumulative_reduce(control_flow.cumsum_p, lax.add, lax._reduce_window_sum)
_add_cumulative_reduce(control_flow.cumprod_p, lax.mul, lax._reduce_window_prod)
_add_cumulative_reduce(control_flow.cummin_p, lax.min, lax._reduce_window_min)
_add_cumulative_reduce(control_flow.cummax_p, lax.max, lax._reduce_window_max)
_add_cumulative_reduce(control_flow.cumsum_p, lax.add,
lax_windowed_reductions._reduce_window_sum)
_add_cumulative_reduce(control_flow.cumprod_p, lax.mul,
lax_windowed_reductions._reduce_window_prod)
_add_cumulative_reduce(control_flow.cummin_p, lax.min,
lax_windowed_reductions._reduce_window_min)
_add_cumulative_reduce(control_flow.cummax_p, lax.max,
lax_windowed_reductions._reduce_window_max)
translations[custom_derivatives.custom_jvp_call_jaxpr_p] = lower_fun(
custom_derivatives._custom_jvp_call_jaxpr_impl, multiple_results=True)
@ -1986,6 +1990,6 @@ map(add_fallback_lowering, [
lax.top_k_p,
# TODO(phawkins): implement these lax ops:
lax.select_and_gather_add_p,
lax_windowed_reductions.select_and_gather_add_p,
lax.rng_bit_generator_p,
])

View File

@ -1696,6 +1696,62 @@ class LaxTest(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
self._CompileAndCheck(fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": ("_shape={}_dims={}_strides={}_padding={}"
"_basedilation={}_windowdilation={}")
.format(jtu.format_shape_dtype_string(shape, dtype),
dims, strides, padding, base_dilation, window_dilation),
"dtype": dtype, "shape": shape,
"dims": dims, "strides": strides, "padding": padding,
"base_dilation": base_dilation, "window_dilation": window_dilation}
for dtype in [np.float32]
for shape, dims, strides, padding, base_dilation, window_dilation in (
itertools.chain(
itertools.product(
[(4, 6)],
[(2, 1), (1, 2)],
[(1, 1), (2, 1), (1, 2)],
["VALID", "SAME", [(0, 3), (1, 2)]],
[(1, 1), (2, 3)],
[(1, 1), (1, 2)]),
itertools.product(
[(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
[(1, 2, 2, 1), (1, 1, 1, 1)],
["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]],
[(1, 1, 1, 1), (2, 1, 3, 2)],
[(1, 1, 1, 1), (1, 2, 2, 1)])))))
# TODO(b/183233858): variadic reduce-window is not implemented on XLA:GPU
@jtu.skip_on_devices("gpu")
def testReduceWindowVariadic(self, dtype, shape, dims, strides, padding,
base_dilation, window_dilation):
if (jtu.device_under_test() == "tpu" and
any(d != 1 for d in window_dilation)):
raise SkipTest("TPU support missing for arbitrary window dilation.")
rng = jtu.rand_small(self.rng())
init_values = (np.asarray(0, dtype=dtype), np.array(-np.inf, dtype=dtype))
def reducer(xs, ys):
x1, x2 = xs
y1, y2 = ys
return (x1 + y1, lax.max(x2, y2))
def fun(*operands):
return lax.reduce_window(operands, init_values, reducer, dims, strides,
padding, base_dilation, window_dilation)
def reference_fun(*operands):
return [
lax_reference.reduce_window(operand, init_val, op, dims, strides,
padding, base_dilation)
for operand, init_val, op in zip(operands, init_values,
[np.add, np.maximum])]
args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)]
self._CompileAndCheck(fun, args_maker)
if all(d == 1 for d in window_dilation):
self._CheckAgainstNumpy(reference_fun, fun, args_maker)
def testReduceWindowFailures(self):
def empty_window_test():
return lax.reduce_window(np.ones((1,)), 0., lax.add, padding='VALID',
@ -1711,9 +1767,8 @@ class LaxTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError,
"Invalid return type from reduction function: <class 'list'>\n"
"Reduction functions should only return an array.\n"
"Full return value: .*"):
"reduce_window output must have the same tree structure as the "
"operands.*"):
return lax.reduce_window(
np.ones((1,)), 0., lambda x, y: [x + y],
padding='VALID', window_dimensions=(1,), window_strides=(1,))
@ -2363,8 +2418,8 @@ class LaxTest(jtu.JaxTestCase):
, "window_dilation": (1, 1)
}
msg = (r"reduce_window expected init_value to be a scalar but init_value "
r"has shape \(1,\).")
msg = (r"reduce_window expected init_values to be scalars but init_values "
r"have shapes \[\(1,\)\].")
with self.assertRaisesRegex(TypeError, msg):
lax.reduce_window(**args)