mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Make lax.reduce_window variadic.
This is similar to the support in lax.reduce(), where the operands and init_values become pytrees. This is a strict superset of the current API, so users should not need updates. Variadic lax.reduce_window() is only supported on CPU and TPU at the moment, not GPU. PiperOrigin-RevId: 411632993
This commit is contained in:
parent
ad6ce74d67
commit
5415306257
@ -1424,8 +1424,8 @@ def argmax(operand: Array, axis: int,
|
||||
return argmax_p.bind(operand, axes=(axis,),
|
||||
index_dtype=dtypes.canonicalize_dtype(index_dtype))
|
||||
|
||||
def reduce(operands: Array, init_values: Array, computation: Callable,
|
||||
dimensions: Sequence[int]) -> Array:
|
||||
def reduce(operands, init_values, computation: Callable,
|
||||
dimensions: Sequence[int]):
|
||||
"""Wraps XLA's `Reduce
|
||||
<https://www.tensorflow.org/xla/operation_semantics#reduce>`_
|
||||
operator.
|
||||
@ -1483,7 +1483,8 @@ def _variadic_reduction_jaxpr(computation, flat_avals, aval_tree):
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_comp, tuple(flat_in_avals))
|
||||
return jaxpr, tuple(consts), out_tree()
|
||||
|
||||
def _get_monoid_reducer(monoid_op: Callable, xs: Array) -> Optional[Callable]:
|
||||
def _get_monoid_reducer(monoid_op: Callable,
|
||||
xs: Sequence[Array]) -> Optional[Callable]:
|
||||
if len(xs) != 1:
|
||||
return None
|
||||
x, = xs
|
||||
|
@ -24,6 +24,7 @@ from jax.interpreters import xla
|
||||
|
||||
from jax import core
|
||||
from jax.core import (ShapedArray, ConcreteArray)
|
||||
from jax import tree_util
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import dtypes
|
||||
@ -36,7 +37,10 @@ from jax._src.lax.lax import (
|
||||
)
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
import jax._src.util as util
|
||||
|
||||
map = util.safe_map
|
||||
zip = util.safe_zip
|
||||
|
||||
xb = xla_bridge
|
||||
xc = xla_client
|
||||
@ -45,7 +49,7 @@ xops = xla_client.ops
|
||||
Array = Any
|
||||
|
||||
|
||||
def reduce_window(operand: Array, init_value: Array, computation: Callable,
|
||||
def reduce_window(operand, init_value, computation: Callable,
|
||||
window_dimensions: core.Shape, window_strides: Sequence[int],
|
||||
padding: Union[str, Sequence[Tuple[int, int]]],
|
||||
base_dilation: Optional[Sequence[int]] = None,
|
||||
@ -54,31 +58,52 @@ def reduce_window(operand: Array, init_value: Array, computation: Callable,
|
||||
<https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_
|
||||
operator.
|
||||
"""
|
||||
flat_operands, operand_tree = tree_util.tree_flatten(operand)
|
||||
flat_init_values, init_value_tree = tree_util.tree_flatten(init_value)
|
||||
if operand_tree != init_value_tree:
|
||||
raise ValueError('Operands must have the same tree structure as '
|
||||
f'init_values: {operand_tree} vs. {init_value_tree}')
|
||||
if len(flat_operands) == 0:
|
||||
raise ValueError('reduce_window must have at least one operand.')
|
||||
if len(flat_operands) != len(flat_init_values):
|
||||
raise ValueError('Must have same total number of operands as init_values: '
|
||||
f' {len(flat_operands)} vs. {len(flat_init_values)}')
|
||||
if isinstance(padding, str):
|
||||
dilated_window_dims = (window_dimensions if window_dilation is None else
|
||||
_dilate_shape(window_dimensions, window_dilation))
|
||||
padding = tuple(lax.padtype_to_pads(operand.shape, dilated_window_dims,
|
||||
window_strides, padding))
|
||||
padding = tuple(lax.padtype_to_pads(
|
||||
flat_operands[0].shape, dilated_window_dims, window_strides, padding))
|
||||
else:
|
||||
padding = tuple(padding)
|
||||
if base_dilation is None:
|
||||
base_dilation = (1,) * len(window_dimensions)
|
||||
if window_dilation is None:
|
||||
window_dilation = (1,) * len(window_dimensions)
|
||||
monoid_reducer = _get_monoid_window_reducer(computation, init_value)
|
||||
monoid_reducer = _get_monoid_window_reducer(computation, flat_init_values)
|
||||
if monoid_reducer:
|
||||
return monoid_reducer(operand, window_dimensions, window_strides, padding,
|
||||
base_dilation, window_dilation)
|
||||
else:
|
||||
jaxpr, consts = lax._reduction_jaxpr(computation, lax._abstractify(init_value))
|
||||
return reduce_window_p.bind(
|
||||
operand, init_value, jaxpr=jaxpr, consts=consts,
|
||||
flat_init_avals = map(lax._abstractify, flat_init_values)
|
||||
jaxpr, consts, out_tree = lax._variadic_reduction_jaxpr(
|
||||
computation, tuple(flat_init_avals), init_value_tree)
|
||||
if operand_tree != out_tree:
|
||||
raise ValueError(
|
||||
'reduce_window output must have the same tree structure as the operands'
|
||||
f' {operand_tree} vs. {out_tree}')
|
||||
out_flat = reduce_window_p.bind(
|
||||
*(flat_operands + flat_init_values), jaxpr=jaxpr, consts=consts,
|
||||
window_dimensions=tuple(window_dimensions),
|
||||
window_strides=tuple(window_strides), padding=padding,
|
||||
base_dilation=tuple(base_dilation),
|
||||
window_dilation=tuple(window_dilation))
|
||||
return tree_util.tree_unflatten(out_tree, out_flat)
|
||||
|
||||
def _get_monoid_window_reducer(monoid_op: Callable, x: Array) -> Optional[Callable]:
|
||||
def _get_monoid_window_reducer(monoid_op: Callable,
|
||||
xs: Sequence[Array]) -> Optional[Callable]:
|
||||
if len(xs) != 1:
|
||||
return None
|
||||
x, = xs
|
||||
aval = core.get_aval(x)
|
||||
if (type(aval) is ConcreteArray) and aval.shape == ():
|
||||
if monoid_op is lax.add:
|
||||
@ -115,12 +140,13 @@ def _reduce_window_prod(operand: Array, window_dimensions: core.Shape,
|
||||
base_dilation = (1,) * len(window_dimensions)
|
||||
if window_dilation is None:
|
||||
window_dilation = (1,) * len(window_dimensions)
|
||||
return reduce_window_p.bind(
|
||||
out, = reduce_window_p.bind(
|
||||
operand, init_value, jaxpr=jaxpr, consts=consts,
|
||||
window_dimensions=tuple(window_dimensions),
|
||||
window_strides=tuple(window_strides), padding=tuple(padding),
|
||||
base_dilation=tuple(base_dilation),
|
||||
window_dilation=tuple(window_dilation))
|
||||
return out
|
||||
|
||||
def _reduce_window_max(operand: Array, window_dimensions: core.Shape,
|
||||
window_strides: Sequence[int],
|
||||
@ -189,10 +215,10 @@ def _select_and_gather_add(tangents: Array, operand: Array,
|
||||
|
||||
Wraps XLA's `ReduceWindow
|
||||
<https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_
|
||||
operator, which applies a reduction function to all elements in each window of the
|
||||
input multi-dimensional array. In this case, the input multi-dimensional array is
|
||||
built by packing each element in the `operand` array with its corresponding
|
||||
element in the `tangents` array.
|
||||
operator, which applies a reduction function to all elements in each window of
|
||||
the input multi-dimensional array. In this case, the input multi-dimensional
|
||||
array is built by packing each element in the `operand` array with its
|
||||
corresponding element in the `tangents` array.
|
||||
|
||||
Args:
|
||||
tangents: an array
|
||||
@ -204,8 +230,8 @@ def _select_and_gather_add(tangents: Array, operand: Array,
|
||||
window_dilation: an array of integers for window dilation values
|
||||
|
||||
Returns:
|
||||
An array containing the elements in `tangents` corresponding to the output of the
|
||||
reduction of `operand` fin each window.
|
||||
An array containing the elements in `tangents` corresponding to the output
|
||||
of the reduction of `operand` fin each window.
|
||||
"""
|
||||
return select_and_gather_add_p.bind(
|
||||
tangents, operand, select_prim=select_prim,
|
||||
@ -215,56 +241,70 @@ def _select_and_gather_add(tangents: Array, operand: Array,
|
||||
window_dilation=tuple(window_dilation))
|
||||
|
||||
|
||||
def _reduce_window_shape_rule(operand, init_value, *, jaxpr, consts,
|
||||
window_dimensions, window_strides, padding,
|
||||
base_dilation, window_dilation):
|
||||
if operand.dtype != init_value.dtype:
|
||||
msg = ("reduce_window got inconsistent dtypes for operand and init_value: "
|
||||
" got operand dtype {} and init_value dtype {}.")
|
||||
raise TypeError(msg.format(operand.dtype, init_value.dtype))
|
||||
if init_value.shape != ():
|
||||
msg = ("reduce_window expected init_value to be a scalar but init_value "
|
||||
"has shape {}.")
|
||||
raise TypeError(msg.format(init_value.shape))
|
||||
return _common_reduce_window_shape_rule(
|
||||
operand, window_dimensions, window_strides, padding, base_dilation,
|
||||
window_dilation)
|
||||
def _reduce_window_abstract_eval_rule(
|
||||
*avals, jaxpr, consts, window_dimensions, window_strides, padding,
|
||||
base_dilation, window_dilation):
|
||||
operand_avals, init_val_avals = util.split_list(avals, [len(avals) // 2])
|
||||
if any(o.dtype != iv.dtype for o, iv in zip(operand_avals, init_val_avals)):
|
||||
msg = ("reduce_window got inconsistent dtypes for operands and init_values:"
|
||||
" got operand dtypes {} and init_value dtypes {}.")
|
||||
raise TypeError(msg.format([o.dtype for o in operand_avals],
|
||||
[iv.dtype for iv in init_val_avals]))
|
||||
if any(len(v.shape) != 0 for v in init_val_avals):
|
||||
msg = ("reduce_window expected init_values to be scalars but init_values "
|
||||
"have shapes {}.")
|
||||
raise TypeError(msg.format([v.shape for v in init_val_avals]))
|
||||
out_shape = _common_reduce_window_shape_rule(
|
||||
operand_avals[0], window_dimensions, window_strides, padding,
|
||||
base_dilation, window_dilation)
|
||||
return tuple(ShapedArray(out_shape, op.dtype) for op in operand_avals)
|
||||
|
||||
def _reduce_window_translation_rule(ctx, avals_in, avals_out, operand,
|
||||
init_value, *, jaxpr, consts,
|
||||
window_dimensions, window_strides, padding,
|
||||
base_dilation, window_dilation):
|
||||
xla_computation = lax._reduction_computation(ctx, jaxpr, consts, init_value)
|
||||
return [xops.ReduceWindowWithGeneralPadding(
|
||||
operand, init_value, xla_computation, window_dimensions,
|
||||
window_strides, base_dilation, window_dilation, padding)]
|
||||
def _reduce_window_translation_rule(ctx, avals_in, avals_out, *args, jaxpr,
|
||||
consts, window_dimensions, window_strides,
|
||||
padding, base_dilation, window_dilation):
|
||||
operands, init_values = util.split_list(args, [len(args) // 2])
|
||||
xla_computation = lax._reduction_computation(ctx, jaxpr, consts, init_values,
|
||||
singleton=False)
|
||||
return xla.xla_destructure(ctx.builder, xops.ReduceWindowWithGeneralPadding(
|
||||
operands, init_values, xla_computation, window_dimensions,
|
||||
window_strides, base_dilation, window_dilation, padding))
|
||||
|
||||
def _generic_reduce_window_batch_rule(
|
||||
batched_args, batch_dims, *, jaxpr, consts, window_dimensions,
|
||||
window_strides, padding, base_dilation, window_dilation):
|
||||
num_operands = len(batched_args) // 2
|
||||
operands, init_values = util.split_list(batched_args, [num_operands])
|
||||
operand_bdims, init_value_bdims = util.split_list(batch_dims, [num_operands])
|
||||
|
||||
operand, init = batched_args
|
||||
bdim, init_bdim = batch_dims
|
||||
if init_bdim is not None:
|
||||
if any(init_bdim is not None for init_bdim in init_value_bdims):
|
||||
raise NotImplementedError("reduce_window batching is not implemented for "
|
||||
"initial values")
|
||||
|
||||
def reduce_window(x, window_dimensions, window_strides, padding, base_dilation,
|
||||
window_dilation):
|
||||
return reduce_window_p.bind(
|
||||
x, init, jaxpr=jaxpr, consts=consts, window_dimensions=window_dimensions,
|
||||
window_strides=window_strides, padding=padding, base_dilation=base_dilation,
|
||||
size = next(x.shape[ax] for x, ax in zip(operands, operand_bdims)
|
||||
if ax is not None)
|
||||
operands = [batching.bdim_at_front(arg, bdim, size)
|
||||
for arg, bdim in zip(operands, operand_bdims)]
|
||||
window_dimensions = (1,) + window_dimensions
|
||||
window_strides = (1,) + window_strides
|
||||
padding = ((0, 0),) + padding
|
||||
base_dilation = (1,) + base_dilation
|
||||
window_dilation = (1,) + window_dilation
|
||||
outs = reduce_window_p.bind(
|
||||
*(operands + init_values), jaxpr=jaxpr, consts=consts,
|
||||
window_dimensions=window_dimensions, window_strides=window_strides,
|
||||
padding=padding, base_dilation=base_dilation,
|
||||
window_dilation=window_dilation)
|
||||
return _reduce_window_batch_rule(
|
||||
reduce_window, (operand,), (bdim,), window_dimensions=window_dimensions,
|
||||
window_strides=window_strides, padding=padding, base_dilation=base_dilation,
|
||||
window_dilation=window_dilation)
|
||||
return outs, (0,) * num_operands
|
||||
|
||||
|
||||
reduce_window_p = lax.standard_primitive(
|
||||
_reduce_window_shape_rule, _input_dtype, 'reduce_window',
|
||||
_reduce_window_translation_rule)
|
||||
reduce_window_p = core.Primitive('reduce_window')
|
||||
reduce_window_p.multiple_results = True
|
||||
reduce_window_p.def_impl(partial(xla.apply_primitive, reduce_window_p))
|
||||
reduce_window_p.def_abstract_eval(_reduce_window_abstract_eval_rule)
|
||||
batching.primitive_batchers[reduce_window_p] = _generic_reduce_window_batch_rule
|
||||
|
||||
xla.register_translation(reduce_window_p, _reduce_window_translation_rule)
|
||||
|
||||
def _reduce_window_sum_shape_rule(operand, *, window_dimensions, window_strides,
|
||||
padding, base_dilation, window_dilation):
|
||||
@ -272,8 +312,8 @@ def _reduce_window_sum_shape_rule(operand, *, window_dimensions, window_strides,
|
||||
msg = "operand to reduce_window_sum must have a number dtype, got {}"
|
||||
raise TypeError(msg.format(np.dtype(operand.dtype).name))
|
||||
return _common_reduce_window_shape_rule(operand, window_dimensions,
|
||||
window_strides, padding, base_dilation,
|
||||
window_dilation)
|
||||
window_strides, padding,
|
||||
base_dilation, window_dilation)
|
||||
|
||||
def _reduce_window_sum_translation_rule(ctx, avals_in, avals_out, operand, *,
|
||||
window_dimensions, window_strides,
|
||||
@ -436,8 +476,10 @@ def _select_and_scatter_translation(
|
||||
ctx, avals_in, avals_out, operand, source, init_value, *, select_jaxpr,
|
||||
select_consts, scatter_jaxpr, scatter_consts, window_dimensions,
|
||||
window_strides, padding):
|
||||
select = lax._reduction_computation(ctx, select_jaxpr, select_consts, init_value)
|
||||
scatter = lax._reduction_computation(ctx, scatter_jaxpr, scatter_consts, init_value)
|
||||
select = lax._reduction_computation(ctx, select_jaxpr, select_consts,
|
||||
init_value)
|
||||
scatter = lax._reduction_computation(ctx, scatter_jaxpr, scatter_consts,
|
||||
init_value)
|
||||
return [xops.SelectAndScatterWithGeneralPadding(
|
||||
operand, select, window_dimensions, window_strides, padding, source,
|
||||
init_value, scatter)]
|
||||
@ -462,7 +504,8 @@ def _select_and_scatter_add_translation(
|
||||
select = xla.primitive_subcomputation(
|
||||
ctx.platform, select_prim, scalar, scalar)
|
||||
scatter = xla.primitive_subcomputation(
|
||||
ctx.platform, lax.or_p if dtype == np.bool_ else lax.add_p, scalar, scalar)
|
||||
ctx.platform, lax.or_p if dtype == np.bool_ else lax.add_p, scalar,
|
||||
scalar)
|
||||
zero = xla.pyval_to_ir_constant(c, np.array(0, dtype))
|
||||
# TODO(b/161704903): remove this workaround when XLA:CPU bug is fixed.
|
||||
expand_padding = (expand_padding and
|
||||
@ -599,11 +642,13 @@ def _select_and_gather_add_translation(
|
||||
# Unpacks the first element of a tuple.
|
||||
def fst(c, t):
|
||||
st = xops.ShiftRightLogical(t, const(c, double_word_dtype, nbits))
|
||||
return xops.BitcastConvertType(xops.ConvertElementType(st, word_type), etype)
|
||||
return xops.BitcastConvertType(xops.ConvertElementType(st, word_type),
|
||||
etype)
|
||||
|
||||
# Unpacks the second element of a tuple.
|
||||
def snd(t):
|
||||
return xops.BitcastConvertType(xops.ConvertElementType(t, word_type), etype)
|
||||
return xops.BitcastConvertType(xops.ConvertElementType(t, word_type),
|
||||
etype)
|
||||
|
||||
else:
|
||||
# The double-word trick above only works if we have a sufficiently large
|
||||
@ -638,8 +683,8 @@ def _select_and_gather_add_translation(
|
||||
|
||||
# Unpacks the second element of a tuple.
|
||||
def snd(t):
|
||||
return xops.BitcastConvertType(xops.ShiftLeft(t, const(c, word_dtype, r_nbits)),
|
||||
etype)
|
||||
return xops.BitcastConvertType(
|
||||
xops.ShiftLeft(t, const(c, word_dtype, r_nbits)), etype)
|
||||
|
||||
def reducer():
|
||||
c = xc.XlaBuilder("select_and_gather_pair_reducer")
|
||||
|
@ -1741,14 +1741,14 @@ def _common_reduce_window(operand, init_val, reducer, window_dimensions,
|
||||
return out
|
||||
|
||||
|
||||
def _reduce_window(operand, init_value, *, jaxpr, consts, window_dimensions,
|
||||
def _reduce_window(*args, jaxpr, consts, window_dimensions,
|
||||
window_strides, padding, base_dilation, window_dilation,
|
||||
_in_avals, _out_aval):
|
||||
"""TensorFlow implementation of reduce_window.
|
||||
|
||||
Args:
|
||||
operand: N dimensional array containing elements of type T
|
||||
init_value: starting value of the reduction
|
||||
operands: N dimensional arrays containing elements of type T
|
||||
init_values: starting values of the reduction
|
||||
jaxpr: the jaxpr corresponding to the reduction function
|
||||
consts: the constants associated with jaxpr.
|
||||
window_dimensions: array of integers for window dimension values
|
||||
@ -1761,15 +1761,20 @@ def _reduce_window(operand, init_value, *, jaxpr, consts, window_dimensions,
|
||||
The reduced operand.
|
||||
"""
|
||||
assert len(consts) == 0, "Reduction computation cannot have constants"
|
||||
operands, init_values = util.split_list(args, [len(args) // 2])
|
||||
|
||||
if len(operands) != 1:
|
||||
raise NotImplementedError("jax2tf does not support variadic reduce_window")
|
||||
|
||||
def reducer(arg1: TfVal, arg2: TfVal) -> TfVal:
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
res, = _interpret_jaxpr(closed_jaxpr, arg1, arg2, extra_name_stack=None)
|
||||
return res
|
||||
|
||||
return _common_reduce_window(operand, init_value, reducer, window_dimensions,
|
||||
window_strides, padding, base_dilation,
|
||||
window_dilation, _in_avals, _out_aval)
|
||||
return (_common_reduce_window(operands[0], init_values[0], reducer,
|
||||
window_dimensions, window_strides, padding,
|
||||
base_dilation, window_dilation, _in_avals,
|
||||
_out_aval[0]),)
|
||||
|
||||
|
||||
|
||||
|
@ -36,6 +36,7 @@ from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lax import control_flow
|
||||
from jax._src.lax import windowed_reductions as lax_windowed_reductions
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import builtin
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
@ -1450,27 +1451,26 @@ translations[lax.argmax_p] = lower_fun(
|
||||
multiple_results=False)
|
||||
|
||||
|
||||
def _generic_reduce_window_lower(ctx, avals_in, avals_out, operand, init_value,
|
||||
*, jaxpr, consts, window_dimensions,
|
||||
window_strides, padding, base_dilation,
|
||||
window_dilation):
|
||||
aval_out, = avals_out
|
||||
operand_aval, scalar_aval = avals_in
|
||||
scalar_type = aval_to_ir_type(scalar_aval)
|
||||
def _generic_reduce_window_lower(ctx, avals_in, avals_out, *args, jaxpr, consts,
|
||||
window_dimensions, window_strides, padding,
|
||||
base_dilation, window_dilation):
|
||||
operands, init_values = util.split_list(args, [len(args) // 2])
|
||||
_, init_value_avals = util.split_list(avals_in, [len(operands)])
|
||||
scalar_types = [aval_to_ir_type(aval) for aval in init_value_avals]
|
||||
rw = mhlo.ReduceWindowOp(
|
||||
aval_to_ir_types(aval_out), [operand], [init_value],
|
||||
map(aval_to_ir_type, avals_out), operands, init_values,
|
||||
_dense_int_elements(window_dimensions),
|
||||
_dense_int_elements(window_strides), _dense_int_elements(base_dilation),
|
||||
_dense_int_elements(window_dilation),
|
||||
ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
|
||||
reducer = rw.regions[0].blocks.append(scalar_type, scalar_type)
|
||||
reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
|
||||
with ir.InsertionPoint(reducer):
|
||||
out_nodes = jaxpr_subcomp(ctx, jaxpr, consts,
|
||||
*([a] for a in reducer.arguments))
|
||||
mhlo.ReturnOp(util.flatten(out_nodes))
|
||||
return rw.results
|
||||
|
||||
translations[lax.reduce_window_p] = _generic_reduce_window_lower
|
||||
translations[lax_windowed_reductions.reduce_window_p] = _generic_reduce_window_lower
|
||||
|
||||
|
||||
def _reduce_window_lower(
|
||||
@ -1492,11 +1492,11 @@ def _reduce_window_lower(
|
||||
mhlo.ReturnOp(reduce_op(*reducer.arguments))
|
||||
return rw.results
|
||||
|
||||
translations[lax.reduce_window_sum_p] = partial(
|
||||
translations[lax_windowed_reductions.reduce_window_sum_p] = partial(
|
||||
_reduce_window_lower, mhlo.AddOp, lambda _: 0)
|
||||
translations[lax.reduce_window_min_p] = partial(
|
||||
translations[lax_windowed_reductions.reduce_window_min_p] = partial(
|
||||
_reduce_window_lower, mhlo.MinOp, lax._get_min_identity)
|
||||
translations[lax.reduce_window_max_p] = partial(
|
||||
translations[lax_windowed_reductions.reduce_window_max_p] = partial(
|
||||
_reduce_window_lower, mhlo.MaxOp, lax._get_max_identity)
|
||||
|
||||
|
||||
@ -1525,7 +1525,7 @@ def _select_and_scatter_lower(
|
||||
mhlo.ReturnOp(util.flatten(out_nodes))
|
||||
return op.results
|
||||
|
||||
translations[lax.select_and_scatter_p] = _select_and_scatter_lower
|
||||
translations[lax_windowed_reductions.select_and_scatter_p] = _select_and_scatter_lower
|
||||
|
||||
|
||||
def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions,
|
||||
@ -1551,13 +1551,13 @@ def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions,
|
||||
out = lax.slice(out, start_indices, stop_indices)
|
||||
return out
|
||||
|
||||
translations[lax.select_and_scatter_add_p] = lower_fun(
|
||||
translations[lax_windowed_reductions.select_and_scatter_add_p] = lower_fun(
|
||||
partial(_select_and_scatter_add, expand_padding=False),
|
||||
multiple_results=False)
|
||||
platform_specific_translations['cpu'][lax.select_and_scatter_add_p] = lower_fun(
|
||||
platform_specific_translations['cpu'][lax_windowed_reductions.select_and_scatter_add_p] = lower_fun(
|
||||
partial(_select_and_scatter_add, expand_padding=True),
|
||||
multiple_results=False)
|
||||
platform_specific_translations['gpu'][lax.select_and_scatter_add_p] = lower_fun(
|
||||
platform_specific_translations['gpu'][lax_windowed_reductions.select_and_scatter_add_p] = lower_fun(
|
||||
partial(_select_and_scatter_add, expand_padding=True),
|
||||
multiple_results=False)
|
||||
|
||||
@ -1826,10 +1826,14 @@ def _add_cumulative_reduce(prim, reducer, tpu_reduce_window_fn):
|
||||
partial(control_flow._cumred_tpu_translation_rule, tpu_reduce_window_fn),
|
||||
multiple_results=False)
|
||||
|
||||
_add_cumulative_reduce(control_flow.cumsum_p, lax.add, lax._reduce_window_sum)
|
||||
_add_cumulative_reduce(control_flow.cumprod_p, lax.mul, lax._reduce_window_prod)
|
||||
_add_cumulative_reduce(control_flow.cummin_p, lax.min, lax._reduce_window_min)
|
||||
_add_cumulative_reduce(control_flow.cummax_p, lax.max, lax._reduce_window_max)
|
||||
_add_cumulative_reduce(control_flow.cumsum_p, lax.add,
|
||||
lax_windowed_reductions._reduce_window_sum)
|
||||
_add_cumulative_reduce(control_flow.cumprod_p, lax.mul,
|
||||
lax_windowed_reductions._reduce_window_prod)
|
||||
_add_cumulative_reduce(control_flow.cummin_p, lax.min,
|
||||
lax_windowed_reductions._reduce_window_min)
|
||||
_add_cumulative_reduce(control_flow.cummax_p, lax.max,
|
||||
lax_windowed_reductions._reduce_window_max)
|
||||
|
||||
translations[custom_derivatives.custom_jvp_call_jaxpr_p] = lower_fun(
|
||||
custom_derivatives._custom_jvp_call_jaxpr_impl, multiple_results=True)
|
||||
@ -1986,6 +1990,6 @@ map(add_fallback_lowering, [
|
||||
lax.top_k_p,
|
||||
|
||||
# TODO(phawkins): implement these lax ops:
|
||||
lax.select_and_gather_add_p,
|
||||
lax_windowed_reductions.select_and_gather_add_p,
|
||||
lax.rng_bit_generator_p,
|
||||
])
|
||||
|
@ -1696,6 +1696,62 @@ class LaxTest(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": ("_shape={}_dims={}_strides={}_padding={}"
|
||||
"_basedilation={}_windowdilation={}")
|
||||
.format(jtu.format_shape_dtype_string(shape, dtype),
|
||||
dims, strides, padding, base_dilation, window_dilation),
|
||||
"dtype": dtype, "shape": shape,
|
||||
"dims": dims, "strides": strides, "padding": padding,
|
||||
"base_dilation": base_dilation, "window_dilation": window_dilation}
|
||||
for dtype in [np.float32]
|
||||
for shape, dims, strides, padding, base_dilation, window_dilation in (
|
||||
itertools.chain(
|
||||
itertools.product(
|
||||
[(4, 6)],
|
||||
[(2, 1), (1, 2)],
|
||||
[(1, 1), (2, 1), (1, 2)],
|
||||
["VALID", "SAME", [(0, 3), (1, 2)]],
|
||||
[(1, 1), (2, 3)],
|
||||
[(1, 1), (1, 2)]),
|
||||
itertools.product(
|
||||
[(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
|
||||
[(1, 2, 2, 1), (1, 1, 1, 1)],
|
||||
["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]],
|
||||
[(1, 1, 1, 1), (2, 1, 3, 2)],
|
||||
[(1, 1, 1, 1), (1, 2, 2, 1)])))))
|
||||
# TODO(b/183233858): variadic reduce-window is not implemented on XLA:GPU
|
||||
@jtu.skip_on_devices("gpu")
|
||||
def testReduceWindowVariadic(self, dtype, shape, dims, strides, padding,
|
||||
base_dilation, window_dilation):
|
||||
if (jtu.device_under_test() == "tpu" and
|
||||
any(d != 1 for d in window_dilation)):
|
||||
raise SkipTest("TPU support missing for arbitrary window dilation.")
|
||||
rng = jtu.rand_small(self.rng())
|
||||
init_values = (np.asarray(0, dtype=dtype), np.array(-np.inf, dtype=dtype))
|
||||
|
||||
def reducer(xs, ys):
|
||||
x1, x2 = xs
|
||||
y1, y2 = ys
|
||||
return (x1 + y1, lax.max(x2, y2))
|
||||
|
||||
def fun(*operands):
|
||||
return lax.reduce_window(operands, init_values, reducer, dims, strides,
|
||||
padding, base_dilation, window_dilation)
|
||||
|
||||
def reference_fun(*operands):
|
||||
return [
|
||||
lax_reference.reduce_window(operand, init_val, op, dims, strides,
|
||||
padding, base_dilation)
|
||||
for operand, init_val, op in zip(operands, init_values,
|
||||
[np.add, np.maximum])]
|
||||
|
||||
args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)]
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
if all(d == 1 for d in window_dilation):
|
||||
self._CheckAgainstNumpy(reference_fun, fun, args_maker)
|
||||
|
||||
|
||||
def testReduceWindowFailures(self):
|
||||
def empty_window_test():
|
||||
return lax.reduce_window(np.ones((1,)), 0., lax.add, padding='VALID',
|
||||
@ -1711,9 +1767,8 @@ class LaxTest(jtu.JaxTestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Invalid return type from reduction function: <class 'list'>\n"
|
||||
"Reduction functions should only return an array.\n"
|
||||
"Full return value: .*"):
|
||||
"reduce_window output must have the same tree structure as the "
|
||||
"operands.*"):
|
||||
return lax.reduce_window(
|
||||
np.ones((1,)), 0., lambda x, y: [x + y],
|
||||
padding='VALID', window_dimensions=(1,), window_strides=(1,))
|
||||
@ -2363,8 +2418,8 @@ class LaxTest(jtu.JaxTestCase):
|
||||
, "window_dilation": (1, 1)
|
||||
}
|
||||
|
||||
msg = (r"reduce_window expected init_value to be a scalar but init_value "
|
||||
r"has shape \(1,\).")
|
||||
msg = (r"reduce_window expected init_values to be scalars but init_values "
|
||||
r"have shapes \[\(1,\)\].")
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
lax.reduce_window(**args)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user