rocm_jax/jax/_src/lax/windowed_reductions.py

878 lines
39 KiB
Python
Raw Normal View History

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from functools import partial
from typing import Any, Callable, Optional, Union
import warnings
import numpy as np
from jax import tree_util
from jax._src import ad_util
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import util
from jax._src.core import ShapedArray, ConcreteArray
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lax import lax
from jax._src.lax import convolution
from jax._src.lax import slicing
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.ufuncs import logaddexp
map = util.safe_map
zip = util.safe_zip
Array = Any
def reduce_window(operand, init_value, computation: Callable,
window_dimensions: core.Shape, window_strides: Sequence[int],
padding: Union[str, Sequence[tuple[int, int]]],
base_dilation: Optional[Sequence[int]] = None,
window_dilation: Optional[Sequence[int]] = None) -> Array:
"""Wraps XLA's `ReduceWindowWithGeneralPadding
<https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_
operator.
"""
flat_operands, operand_tree = tree_util.tree_flatten(operand)
flat_init_values, init_value_tree = tree_util.tree_flatten(init_value)
if operand_tree != init_value_tree:
raise ValueError('Operands must have the same tree structure as '
f'init_values: {operand_tree} vs. {init_value_tree}')
if len(flat_operands) == 0:
raise ValueError('reduce_window must have at least one operand.')
if len(flat_operands) != len(flat_init_values):
raise ValueError('Must have same total number of operands as init_values: '
f' {len(flat_operands)} vs. {len(flat_init_values)}')
if isinstance(padding, str):
dilated_window_dims = (
window_dimensions if window_dilation is None else
lax._dilate_shape(window_dimensions, window_dilation))
padding = tuple(lax.padtype_to_pads(
flat_operands[0].shape, dilated_window_dims, window_strides, padding))
else:
padding = tuple(padding)
if base_dilation is None:
base_dilation = (1,) * len(window_dimensions)
if window_dilation is None:
window_dilation = (1,) * len(window_dimensions)
monoid_reducer = _get_monoid_window_reducer(computation, flat_init_values)
if monoid_reducer:
return monoid_reducer(operand, window_dimensions, window_strides, padding,
base_dilation, window_dilation)
else:
flat_init_avals = map(lax._abstractify, flat_init_values)
jaxpr, consts, out_tree = lax._variadic_reduction_jaxpr(
computation, tuple(flat_init_avals), init_value_tree)
if operand_tree != out_tree:
raise ValueError(
'reduce_window output must have the same tree structure as the operands'
f' {operand_tree} vs. {out_tree}')
out_flat = reduce_window_p.bind(
2022-01-03 01:52:33 -05:00
*flat_operands, *flat_init_values, jaxpr=jaxpr, consts=consts,
window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=padding,
base_dilation=tuple(base_dilation),
window_dilation=tuple(window_dilation))
return tree_util.tree_unflatten(out_tree, out_flat)
def _get_monoid_window_reducer(monoid_op: Callable,
xs: Sequence[Array]) -> Optional[Callable]:
if len(xs) != 1:
return None
x, = xs
aval = core.get_aval(x)
if (type(aval) is ConcreteArray) and aval.shape == ():
if monoid_op is lax.add:
return aval.val == 0 and _reduce_window_sum
elif monoid_op is lax.max:
return (aval.val == lax._get_max_identity(aval.dtype)
and _reduce_window_max)
elif monoid_op is lax.min:
return (aval.val == lax._get_min_identity(aval.dtype)
and _reduce_window_min)
return None
def _reduce_window_sum(operand: Array, window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[tuple[int, int]],
base_dilation: Optional[Sequence[int]] = None,
window_dilation: Optional[Sequence[int]] = None) -> Array:
if base_dilation is None:
base_dilation = (1,) * len(window_dimensions)
if window_dilation is None:
window_dilation = (1,) * len(window_dimensions)
return reduce_window_sum_p.bind(
operand, window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=tuple(padding),
base_dilation=tuple(base_dilation),
window_dilation=tuple(window_dilation))
def _reduce_window_prod(operand: Array, window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[tuple[int, int]],
base_dilation: Optional[Sequence[int]] = None,
window_dilation: Optional[Sequence[int]] = None) -> Array:
init_value = lax._const(operand, 1)
jaxpr, consts = lax._reduction_jaxpr(lax.mul, lax._abstractify(init_value))
if base_dilation is None:
base_dilation = (1,) * len(window_dimensions)
if window_dilation is None:
window_dilation = (1,) * len(window_dimensions)
out, = reduce_window_p.bind(
operand, init_value, jaxpr=jaxpr, consts=consts,
window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=tuple(padding),
base_dilation=tuple(base_dilation),
window_dilation=tuple(window_dilation))
return out
def _reduce_window_max(operand: Array, window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[tuple[int, int]],
base_dilation: Optional[Sequence[int]] = None,
window_dilation: Optional[Sequence[int]] = None) -> Array:
if base_dilation is None:
base_dilation = (1,) * len(window_dimensions)
if window_dilation is None:
window_dilation = (1,) * len(window_dimensions)
return reduce_window_max_p.bind(
operand, window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=tuple(padding),
base_dilation=tuple(base_dilation),
window_dilation=tuple(window_dilation))
def _reduce_window_min(operand: Array, window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[tuple[int, int]],
base_dilation: Optional[Sequence[int]] = None,
window_dilation: Optional[Sequence[int]] = None) -> Array:
if base_dilation is None:
base_dilation = (1,) * len(window_dimensions)
if window_dilation is None:
window_dilation = (1,) * len(window_dimensions)
return reduce_window_min_p.bind(
operand, window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=tuple(padding),
base_dilation=tuple(base_dilation),
window_dilation=tuple(window_dilation))
def _reduce_window_logaddexp(
operand: Array, window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[tuple[int, int]],
base_dilation: Optional[Sequence[int]] = None,
window_dilation: Optional[Sequence[int]] = None) -> Array:
init_value = lax._const(operand, -np.inf)
jaxpr, consts = lax._reduction_jaxpr(logaddexp, lax._abstractify(init_value))
if base_dilation is None:
base_dilation = (1,) * len(window_dimensions)
if window_dilation is None:
window_dilation = (1,) * len(window_dimensions)
out, = reduce_window_p.bind(
operand, init_value, jaxpr=jaxpr, consts=consts,
window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=tuple(padding),
base_dilation=tuple(base_dilation),
window_dilation=tuple(window_dilation))
return out
def _select_and_scatter(operand: Array, select: Callable,
window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[tuple[int, int]], source: Array,
init_value: Array, scatter: Callable) -> Array:
select_jaxpr, select_consts = lax._reduction_jaxpr(
select, lax._abstractify(init_value))
scatter_jaxpr, scatter_consts = lax._reduction_jaxpr(
scatter, lax._abstractify(init_value))
return select_and_scatter_p.bind(
operand, source, init_value, select_jaxpr=select_jaxpr,
select_consts=select_consts, scatter_jaxpr=scatter_jaxpr,
scatter_consts=scatter_consts, window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=tuple(padding))
def _select_and_scatter_add(source: Array, operand: Array,
select_prim: core.Primitive,
window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[tuple[int, int]]) -> Array:
return select_and_scatter_add_p.bind(
source, operand, select_prim=select_prim,
window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=tuple(padding))
def _select_and_gather_add(tangents: Array, operand: Array,
select_prim: core.Primitive,
window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[tuple[int, int]],
base_dilation: Sequence[int],
window_dilation: Sequence[int]) -> Array:
"""Extracts the tangent corresponding to the minimum or maximum element in
each window of the `operand` array.
Wraps XLA's `ReduceWindow
<https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_
operator, which applies a reduction function to all elements in each window of
the input multi-dimensional array. In this case, the input multi-dimensional
array is built by packing each element in the `operand` array with its
corresponding element in the `tangents` array.
Args:
tangents: an array
operand: an array with the same shape as `tangents`
select_prim: a reduction function (restricted to `ge_p` and `le_p`)
window_dimensions: an array of integers for window dimension values
window_strides: an array of integers for window stride values
base_dilation: an array of integers for base dilation values
window_dilation: an array of integers for window dilation values
Returns:
An array containing the elements in `tangents` corresponding to the output
of the reduction of `operand` fin each window.
"""
return select_and_gather_add_p.bind(
tangents, operand, select_prim=select_prim,
window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=tuple(padding),
base_dilation=tuple(base_dilation),
window_dilation=tuple(window_dilation))
def _reduce_window_abstract_eval_rule(
*avals, jaxpr, consts, window_dimensions, window_strides, padding,
base_dilation, window_dilation):
operand_avals, init_val_avals = util.split_list(avals, [len(avals) // 2])
if any(o.dtype != iv.dtype for o, iv in zip(operand_avals, init_val_avals)):
msg = ("reduce_window got inconsistent dtypes for operands and init_values:"
" got operand dtypes {} and init_value dtypes {}.")
raise TypeError(msg.format([o.dtype for o in operand_avals],
[iv.dtype for iv in init_val_avals]))
if any(len(v.shape) != 0 for v in init_val_avals):
msg = ("reduce_window expected init_values to be scalars but init_values "
"have shapes {}.")
raise TypeError(msg.format([v.shape for v in init_val_avals]))
out_shape = _common_reduce_window_shape_rule(
operand_avals[0], window_dimensions, window_strides, padding,
base_dilation, window_dilation)
return tuple(ShapedArray(out_shape, op.dtype) for op in operand_avals)
def _generic_reduce_window_batch_rule(
batched_args, batch_dims, *, jaxpr, consts, window_dimensions,
window_strides, padding, base_dilation, window_dilation):
num_operands = len(batched_args) // 2
operands, init_values = util.split_list(batched_args, [num_operands])
operand_bdims, init_value_bdims = util.split_list(batch_dims, [num_operands])
if any(init_bdim is not None for init_bdim in init_value_bdims):
raise NotImplementedError("reduce_window batching is not implemented for "
"initial values")
size = next(x.shape[ax] for x, ax in zip(operands, operand_bdims)
if ax is not None)
operands = [batching.bdim_at_front(arg, bdim, size)
for arg, bdim in zip(operands, operand_bdims)]
window_dimensions = (1,) + window_dimensions
window_strides = (1,) + window_strides
padding = ((0, 0),) + padding
base_dilation = (1,) + base_dilation
window_dilation = (1,) + window_dilation
outs = reduce_window_p.bind(
*(operands + init_values), jaxpr=jaxpr, consts=consts,
window_dimensions=window_dimensions, window_strides=window_strides,
padding=padding, base_dilation=base_dilation,
window_dilation=window_dilation)
return outs, (0,) * num_operands
reduce_window_p = core.Primitive('reduce_window')
reduce_window_p.multiple_results = True
reduce_window_p.def_impl(partial(dispatch.apply_primitive, reduce_window_p))
reduce_window_p.def_abstract_eval(_reduce_window_abstract_eval_rule)
batching.primitive_batchers[reduce_window_p] = _generic_reduce_window_batch_rule
def _generic_reduce_window_lower(ctx, *args, jaxpr, consts,
window_dimensions, window_strides, padding,
base_dilation, window_dilation):
operands, init_values = util.split_list(args, [len(args) // 2])
_, init_value_avals = util.split_list(ctx.avals_in, [len(operands)])
def reducer_body(reducer: ir.Block) -> Sequence[ir.Value]:
2022-04-19 10:45:09 -07:00
if jaxpr.effects:
raise NotImplementedError('Cannot lower effectful `reduce_window`.')
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr,
[jax2tf] An alternative support for shape polymorphism for native serialization. jax2tf already supports many cases of shape polymorphism, e.g., those where the shapes of all intermediates can be expressed as polynomials in the dimension variables in the input. We want to achieve the same same coverage, or more, while using StableHLO as the lowering format, rather than tf.Graph. For native serialization we will support two lowering implementations: * one is using the growing support in JAX for dynamic shapes, of which shape polymorphism is a special case. This implementation is enabled with the --jax_dynamic_shapes flag. At the moment, the JAX dynamic shapes support is still incomplete and over 300 jax2tf shape polymorphism tests fail. * a new one (added) here in which we form a Jaxpr using abstract values that express dimension sizes as dimension polynomials (as for the standard jax2tf). Then we lower the Jaxpr to StableHLO. This implementation is enabled when --jax_dynamic_shapes is off. With this implementation only 50 jax2tf tests fail (to be fixed separately). The key contribution here is to enable lowering a Jaxpr that contains dimension polynomials in some of the intermediate shapes. Many lowering rules already have some partial support for Jaxprs where the shapes contain `Var`s. To the extent possible, we try to write lowering rules that should cover both cases of dynamic shapes: Var or polynomials in shapes. The lowering convention is that at top level we collect the sorted list of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars. All IR functions will take N additional prefix arguments of int32 type containing the values of the dimension variables. This is stored as a list of `ir.Value` in `LoweringContext.dim_var_values`. Note that the Jaxprs are not changed to have extra Vars for the dimension variable values. An alternative implementation could work by transforming the Jaxpr to replace dimension polynomials into Vars. The key code pattern used in the lowering rule is:: if not core.is_constant_shape(shape): # Handles both Var, and polynomials shape = mlir.eval_dynamic_shape(ctx, shape) return mhlo.DynamicXXX(..., shape) else: return mhlo.XXX(..., shape) with `mlir.eval_dynamic_shape` handling both cases:: def eval_dynamic_shape(ctx, shape): if config.jax_dynamic_shapes: # Using Var return ... subst using ctx.axis_size_env ... else: # Using polynomials return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values In order to support the above some lowering functions need to take a LoweringContext parameter, e.g., mlir.broadcast_mhlo. I expect that the changes here will improve the --jax_dynamic_shapes coverage as well.
2022-11-28 13:16:07 +01:00
mlir.TokenSet(), consts, *([a] for a in reducer.arguments),
dim_var_values=ctx.dim_var_values)
return util.flatten(out_nodes)
return mlir.reduce_window(
ctx,
reducer_name="generic_reduce_window_reducer",
reducer_body=reducer_body,
operands=operands,
init_values=init_values, init_values_avals=init_value_avals,
out_avals=ctx.avals_out,
window_dimensions=window_dimensions, window_strides=window_strides,
base_dilation=base_dilation, window_dilation=window_dilation,
padding=padding)
mlir.register_lowering(reduce_window_p, _generic_reduce_window_lower)
def _reduce_window_sum_shape_rule(operand, *, window_dimensions, window_strides,
padding, base_dilation, window_dilation):
if not dtypes.issubdtype(operand.dtype, np.number):
msg = "operand to reduce_window_sum must have a number dtype, got {}"
raise TypeError(msg.format(np.dtype(operand.dtype).name))
return _common_reduce_window_shape_rule(operand, window_dimensions,
window_strides, padding,
base_dilation, window_dilation)
def _reduce_window_sum_transpose_rule(cotangent, operand, *, window_dimensions,
window_strides, padding, base_dilation,
window_dilation):
assert ad.is_undefined_primal(operand)
input_shape = operand.aval.shape
pads = convolution._conv_general_vjp_lhs_padding(
input_shape, window_dimensions, window_strides, cotangent.shape, padding,
base_dilation, window_dilation)
ones = [1] * len(input_shape)
padding_config = [(lo, hi, stride - 1)
for (lo, hi), stride in zip(pads, window_strides)]
pad_cotangent = lax.pad(cotangent, lax._zero(cotangent), padding_config)
result = _reduce_window_sum(pad_cotangent, window_dimensions, base_dilation,
[(0, 0)] * len(input_shape),
base_dilation=ones,
window_dilation=window_dilation)
assert result.shape == input_shape, (result.shape, input_shape)
return [result]
def _reduce_window_batch_rule(reduce_window, batched_args, bdims, *,
window_dimensions, window_strides, padding,
base_dilation, window_dilation):
operand, = batched_args
bdim, = bdims
if bdim is not None:
window_dimensions = \
window_dimensions[:bdim] + (1,) + window_dimensions[bdim:]
window_strides = window_strides[:bdim] + (1,) + window_strides[bdim:]
padding = padding[:bdim] + ((0, 0),) + padding[bdim:]
base_dilation = base_dilation[:bdim] + (1,) + base_dilation[bdim:]
window_dilation = window_dilation[:bdim] + (1,) + window_dilation[bdim:]
operand = reduce_window(operand, window_dimensions, window_strides, padding,
base_dilation, window_dilation)
return operand, bdim
reduce_window_sum_p = lax.standard_primitive(
_reduce_window_sum_shape_rule, lax._input_dtype, 'reduce_window_sum')
ad.deflinear2(reduce_window_sum_p, _reduce_window_sum_transpose_rule)
batching.primitive_batchers[reduce_window_sum_p] = partial(
_reduce_window_batch_rule, _reduce_window_sum)
def _reduce_window_chooser_jvp_rule(prim, g, operand, *, window_dimensions,
window_strides, padding, base_dilation,
window_dilation):
assert prim is lax.max_p or prim is lax.min_p
select_prim = lax.ge_p if prim is lax.max_p else lax.le_p
return _select_and_gather_add(g, operand, select_prim, window_dimensions,
window_strides, padding, base_dilation,
window_dilation)
def _common_reduce_window_shape_rule(operand, window_dimensions,
window_strides, padding, base_dilation,
window_dilation):
lax._check_shapelike("reduce_window", "window_dimensions", window_dimensions,
non_zero_shape=True)
lax._check_shapelike("reduce_window", "window_strides", window_strides,
non_zero_shape=True)
lax._check_shapelike("reduce_window", "base_dilation", base_dilation)
lax._check_shapelike("reduce_window", "window_dilation", window_dilation)
if operand.ndim != len(window_dimensions):
msg = ("reduce_window got the wrong number of window_dimensions for "
"operand: got operand shape {} with window_dimensions {}.")
raise TypeError(msg.format(operand.shape, window_dimensions))
if len(window_strides) != len(window_dimensions):
msg = ("reduce_window got inconsistent window_strides and "
"window_dimensions: got window_strides {} and window_dimensions {}.")
raise TypeError(msg.format(window_strides, window_dimensions))
if len(base_dilation) != len(window_dimensions):
msg = ("reduce_window got inconsistent base_dilation and "
"window_dimensions: got base_dilation {} and window_dimensions {}.")
raise TypeError(msg.format(base_dilation, window_dimensions))
if len(window_dilation) != len(window_dimensions):
msg = ("reduce_window got inconsistent window_dilation and "
"window_dimensions: got window_dilation {} and window_dimensions "
"{}.")
raise TypeError(msg.format(window_dilation, window_dimensions))
return reduce_window_shape_tuple(operand.shape, window_dimensions,
window_strides, padding, base_dilation,
window_dilation)
def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
padding, base_dilation=None,
window_dilation=None):
if base_dilation is not None:
operand_shape = lax._dilate_shape(operand_shape, base_dilation)
if window_dilation is not None:
window_dimensions = lax._dilate_shape(window_dimensions, window_dilation)
operand_padded = tuple(d + pl + ph for d, (pl, ph) in zip(operand_shape, padding))
return tuple(map(core.stride_dim, operand_padded, window_dimensions, window_strides))
reduce_window_max_p = lax.standard_primitive(
_common_reduce_window_shape_rule, lax._input_dtype, 'reduce_window_max')
ad.defjvp(reduce_window_max_p, partial(_reduce_window_chooser_jvp_rule,
lax.max_p))
batching.primitive_batchers[reduce_window_max_p] = partial(
_reduce_window_batch_rule, _reduce_window_max)
reduce_window_min_p = lax.standard_primitive(
_common_reduce_window_shape_rule, lax._input_dtype, 'reduce_window_min')
ad.defjvp(reduce_window_min_p, partial(_reduce_window_chooser_jvp_rule,
lax.min_p))
_reduce_window_min_batch_rule = partial(_reduce_window_batch_rule,
_reduce_window_min)
batching.primitive_batchers[reduce_window_min_p] = partial(
_reduce_window_batch_rule, _reduce_window_min)
def _reduce_window_lower(
reduce_op,
init_value, ctx, operand, *,
window_dimensions, window_strides, padding, base_dilation,
window_dilation):
operand_aval, = ctx.avals_in
scalar_aval = operand_aval.update(shape=())
return mlir.reduce_window(ctx,
reducer_name=f"reduce_window_{scalar_aval.dtype}_reducer",
reducer_body=lambda reducer: reduce_op(*reducer.arguments),
operands=[operand],
init_values=[mlir.full_like_aval(ctx, init_value(scalar_aval.dtype),
scalar_aval)],
init_values_avals=[scalar_aval],
out_avals=ctx.avals_out,
window_dimensions=window_dimensions,
window_strides=window_strides, base_dilation=base_dilation,
window_dilation=window_dilation, padding=padding)
mlir.register_lowering(reduce_window_sum_p, partial(
_reduce_window_lower, hlo.AddOp, lambda _: 0))
mlir.register_lowering(reduce_window_min_p, partial(
_reduce_window_lower, mlir.min_hlo, lax._get_min_identity))
mlir.register_lowering(reduce_window_max_p, partial(
_reduce_window_lower, mlir.max_hlo, lax._get_max_identity))
def _select_and_scatter_shape_rule(
operand, source, init_value, *, select_jaxpr, select_consts, scatter_jaxpr,
scatter_consts, window_dimensions, window_strides, padding):
lax._check_shapelike("select_and_scatter", "window_dimensions",
window_dimensions)
lax._check_shapelike("select_and_scatter", "window_strides", window_strides)
if len(window_dimensions) != len(window_strides):
msg = ("select_and_scatter got inconsistent window_strides and "
"window_dimensions: got window_strides {} and window_dimensions {}.")
raise TypeError(msg.format(window_strides, window_dimensions))
return operand.shape
select_and_scatter_p = lax.standard_primitive(
_select_and_scatter_shape_rule, lax._input_dtype, 'select_and_scatter')
def _select_and_scatter_lower(
ctx, operand, source, init_value, *, select_jaxpr,
select_consts, scatter_jaxpr, scatter_consts, window_dimensions,
window_strides, padding):
operand_aval, source_aval, init_value_aval = ctx.avals_in
aval_out, = ctx.avals_out
scalar_aval = operand_aval.update(shape=())
scalar_type = mlir.aval_to_ir_type(scalar_aval)
op = hlo.SelectAndScatterOp(
mlir.aval_to_ir_type(aval_out),
operand,
source,
init_value,
window_dimensions=mlir.dense_int_elements(window_dimensions),
window_strides=mlir.dense_int_elements(window_strides),
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
shape=(len(padding), 2)))
select = op.select.blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(select):
2022-04-19 10:45:09 -07:00
if select_jaxpr.effects:
raise NotImplementedError('Cannot lower effectful `select`.')
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, select_jaxpr,
mlir.TokenSet(), select_consts,
[jax2tf] An alternative support for shape polymorphism for native serialization. jax2tf already supports many cases of shape polymorphism, e.g., those where the shapes of all intermediates can be expressed as polynomials in the dimension variables in the input. We want to achieve the same same coverage, or more, while using StableHLO as the lowering format, rather than tf.Graph. For native serialization we will support two lowering implementations: * one is using the growing support in JAX for dynamic shapes, of which shape polymorphism is a special case. This implementation is enabled with the --jax_dynamic_shapes flag. At the moment, the JAX dynamic shapes support is still incomplete and over 300 jax2tf shape polymorphism tests fail. * a new one (added) here in which we form a Jaxpr using abstract values that express dimension sizes as dimension polynomials (as for the standard jax2tf). Then we lower the Jaxpr to StableHLO. This implementation is enabled when --jax_dynamic_shapes is off. With this implementation only 50 jax2tf tests fail (to be fixed separately). The key contribution here is to enable lowering a Jaxpr that contains dimension polynomials in some of the intermediate shapes. Many lowering rules already have some partial support for Jaxprs where the shapes contain `Var`s. To the extent possible, we try to write lowering rules that should cover both cases of dynamic shapes: Var or polynomials in shapes. The lowering convention is that at top level we collect the sorted list of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars. All IR functions will take N additional prefix arguments of int32 type containing the values of the dimension variables. This is stored as a list of `ir.Value` in `LoweringContext.dim_var_values`. Note that the Jaxprs are not changed to have extra Vars for the dimension variable values. An alternative implementation could work by transforming the Jaxpr to replace dimension polynomials into Vars. The key code pattern used in the lowering rule is:: if not core.is_constant_shape(shape): # Handles both Var, and polynomials shape = mlir.eval_dynamic_shape(ctx, shape) return mhlo.DynamicXXX(..., shape) else: return mhlo.XXX(..., shape) with `mlir.eval_dynamic_shape` handling both cases:: def eval_dynamic_shape(ctx, shape): if config.jax_dynamic_shapes: # Using Var return ... subst using ctx.axis_size_env ... else: # Using polynomials return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values In order to support the above some lowering functions need to take a LoweringContext parameter, e.g., mlir.broadcast_mhlo. I expect that the changes here will improve the --jax_dynamic_shapes coverage as well.
2022-11-28 13:16:07 +01:00
*([a] for a in select.arguments),
dim_var_values=ctx.dim_var_values)
hlo.ReturnOp(util.flatten(out_nodes))
scatter = op.scatter.blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(scatter):
2022-04-19 10:45:09 -07:00
if scatter_jaxpr.effects:
raise NotImplementedError('Cannot lower effectful `scatter`.')
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, scatter_jaxpr,
mlir.TokenSet(), scatter_consts,
[jax2tf] An alternative support for shape polymorphism for native serialization. jax2tf already supports many cases of shape polymorphism, e.g., those where the shapes of all intermediates can be expressed as polynomials in the dimension variables in the input. We want to achieve the same same coverage, or more, while using StableHLO as the lowering format, rather than tf.Graph. For native serialization we will support two lowering implementations: * one is using the growing support in JAX for dynamic shapes, of which shape polymorphism is a special case. This implementation is enabled with the --jax_dynamic_shapes flag. At the moment, the JAX dynamic shapes support is still incomplete and over 300 jax2tf shape polymorphism tests fail. * a new one (added) here in which we form a Jaxpr using abstract values that express dimension sizes as dimension polynomials (as for the standard jax2tf). Then we lower the Jaxpr to StableHLO. This implementation is enabled when --jax_dynamic_shapes is off. With this implementation only 50 jax2tf tests fail (to be fixed separately). The key contribution here is to enable lowering a Jaxpr that contains dimension polynomials in some of the intermediate shapes. Many lowering rules already have some partial support for Jaxprs where the shapes contain `Var`s. To the extent possible, we try to write lowering rules that should cover both cases of dynamic shapes: Var or polynomials in shapes. The lowering convention is that at top level we collect the sorted list of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars. All IR functions will take N additional prefix arguments of int32 type containing the values of the dimension variables. This is stored as a list of `ir.Value` in `LoweringContext.dim_var_values`. Note that the Jaxprs are not changed to have extra Vars for the dimension variable values. An alternative implementation could work by transforming the Jaxpr to replace dimension polynomials into Vars. The key code pattern used in the lowering rule is:: if not core.is_constant_shape(shape): # Handles both Var, and polynomials shape = mlir.eval_dynamic_shape(ctx, shape) return mhlo.DynamicXXX(..., shape) else: return mhlo.XXX(..., shape) with `mlir.eval_dynamic_shape` handling both cases:: def eval_dynamic_shape(ctx, shape): if config.jax_dynamic_shapes: # Using Var return ... subst using ctx.axis_size_env ... else: # Using polynomials return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values In order to support the above some lowering functions need to take a LoweringContext parameter, e.g., mlir.broadcast_mhlo. I expect that the changes here will improve the --jax_dynamic_shapes coverage as well.
2022-11-28 13:16:07 +01:00
*([a] for a in scatter.arguments),
dim_var_values=ctx.dim_var_values)
hlo.ReturnOp(util.flatten(out_nodes))
return op.results
mlir.register_lowering(select_and_scatter_p, _select_and_scatter_lower)
def _select_and_scatter_add_shape_rule(
source, operand, *, select_prim, window_dimensions, window_strides,
padding):
return operand.shape
def _select_and_scatter_add_jvp(
primals, tangents, *, select_prim, window_dimensions, window_strides,
padding):
source, operand = primals
g_source, g_operand = tangents
val_out = _select_and_scatter_add(
source, operand, select_prim, window_dimensions, window_strides,
padding)
del g_operand
if type(g_source) is ad_util.Zero:
tangent_out = ad_util.Zero.from_value(val_out)
else:
tangent_out = _select_and_scatter_add(
g_source, operand, select_prim, window_dimensions,
window_strides, padding)
return val_out, tangent_out
def _select_and_scatter_add_transpose(
t, source, operand, *, select_prim, window_dimensions, window_strides,
padding):
assert ad.is_undefined_primal(source) and not ad.is_undefined_primal(operand)
if type(t) is ad_util.Zero:
return [ad_util.Zero(source.aval), None]
ones = (1,) * len(window_dimensions)
source_t = _select_and_gather_add(t, operand, select_prim, window_dimensions,
window_strides, padding, ones, ones)
return [source_t, None]
def _select_and_scatter_add_batch_rule(
batched_args, batch_dims, *, select_prim, window_dimensions, window_strides,
padding):
source, operand = batched_args
s_bdim, o_bdim = batch_dims
size = next(a.shape[bdim] for a, bdim in zip(batched_args, batch_dims)
if bdim is not None)
source = batching.bdim_at_front(source, s_bdim, size)
operand = batching.bdim_at_front(operand, o_bdim, size)
window_dimensions = (1,) + window_dimensions
window_strides = (1,) + window_strides
padding = ((0, 0),) + padding
out = _select_and_scatter_add(source, operand, select_prim, window_dimensions,
window_strides, padding)
return out, 0
select_and_scatter_add_p = lax.standard_primitive(
_select_and_scatter_add_shape_rule, lax._input_dtype,
'select_and_scatter_add')
ad.primitive_transposes[select_and_scatter_add_p] = \
_select_and_scatter_add_transpose
ad.primitive_jvps[select_and_scatter_add_p] = _select_and_scatter_add_jvp
batching.primitive_batchers[select_and_scatter_add_p] = \
_select_and_scatter_add_batch_rule
def _select_and_scatter_add_impl(source, operand, *,
select_prim, window_dimensions, window_strides,
padding, expand_padding):
dtype = source.dtype
select = lambda x, y: select_prim.bind(x, y)
scatter = lax.bitwise_or if dtype == np.bool_ else lax.add
if expand_padding:
operand_shape = operand.shape
original_padding = padding
identity = (lax._get_max_identity if select_prim is lax.ge_p
else lax._get_min_identity)
pads = [(lo, hi, 0) for (lo, hi) in padding]
operand = lax.pad(operand, identity(dtype), pads)
padding = [(0, 0) for _ in padding]
out = _select_and_scatter(
operand, select, window_dimensions, window_strides, padding, source,
lax._zero(operand), scatter)
if expand_padding:
start_indices = [lo for (lo, hi) in original_padding]
stop_indices = [lo + d for ((lo, hi), d) in zip(original_padding,
operand_shape)]
out = slicing.slice(out, start_indices, stop_indices)
return out
mlir.register_lowering(select_and_scatter_add_p, mlir.lower_fun(
partial(_select_and_scatter_add_impl, expand_padding=False),
multiple_results=False))
# TODO(b/161704903): workaround for XLA/CPU crash.
mlir.register_lowering(select_and_scatter_add_p, mlir.lower_fun(
partial(_select_and_scatter_add_impl, expand_padding=True),
multiple_results=False), platform='cpu')
# TODO(b/182390722): workaround for XLA/GPU crash.
mlir.register_lowering(select_and_scatter_add_p, mlir.lower_fun(
partial(_select_and_scatter_add_impl, expand_padding=True),
multiple_results=False), platform='gpu')
def _select_and_gather_add_shape_rule(
tangents, operand, *, select_prim, window_dimensions, window_strides,
padding, base_dilation, window_dilation):
if tangents.shape != operand.shape:
msg = ("select_and_gather_add tangents and operand shapes must match, "
"got {} and {}.")
raise TypeError(msg.format(tangents.shape, operand.shape))
return _common_reduce_window_shape_rule(
operand, window_dimensions, window_strides, padding, base_dilation,
window_dilation)
def _select_and_gather_add_lowering(
ctx: mlir.LoweringRuleContext,
tangents, operand, *, select_prim,
window_dimensions, window_strides, padding, base_dilation, window_dilation,
max_bits=64):
_, operand_aval, = ctx.avals_in
out_aval, = ctx.avals_out
assert isinstance(operand_aval, core.ShapedArray), operand_aval
dtype = operand_aval.dtype
etype = mlir.dtype_to_ir_type(dtype)
nbits = dtypes.finfo(dtype).bits
assert nbits <= max_bits
double_word_reduction = nbits * 2 <= max_bits
const = lambda dtype, x: mlir.ir_constant(np.array(x, dtype=dtype))
def _broadcast_scalar_const(x, aval_out):
return mlir.broadcast_in_dim(ctx, const(aval_out.dtype, x),
aval_out,
broadcast_dimensions=())
if double_word_reduction:
# TODO(b/73062247): XLA doesn't yet implement ReduceWindow on tuples, so
# we implement a pair-wise ReduceWindow by packing two k-bit values into
# 2k-bit unsigned integer using bit tricks.
word_dtype = lax._UINT_DTYPES[nbits]
double_word_dtype = lax._UINT_DTYPES[nbits * 2]
word_type = mlir.dtype_to_ir_type(word_dtype) # type: ignore
# Packs two values into a double_word_type.
def pack(a, b, ab_aval):
word_type_ab_aval = ab_aval.update(dtype=word_dtype)
double_word_type_ab_aval = ab_aval.update(dtype=double_word_dtype)
a = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), a)
b = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), b)
a = hlo.ConvertOp(mlir.aval_to_ir_type(double_word_type_ab_aval), a)
b = hlo.ConvertOp(mlir.aval_to_ir_type(double_word_type_ab_aval), b)
a = hlo.ShiftLeftOp(a,
_broadcast_scalar_const(nbits, double_word_type_ab_aval))
return hlo.OrOp(a, b)
# Unpacks the first element of a double_word_type.
def fst(t):
assert not ir.RankedTensorType(t.type).shape
st = hlo.ShiftRightLogicalOp(t, const(double_word_dtype, nbits))
return hlo.BitcastConvertOp(
ir.RankedTensorType.get([], etype),
hlo.ConvertOp(ir.RankedTensorType.get([], word_type), st)).result
# Unpacks the second element of a double_word_type.
def snd(t, t_aval):
return hlo.BitcastConvertOp(
mlir.aval_to_ir_type(t_aval.update(dtype=dtype)),
hlo.ConvertOp(mlir.aval_to_ir_type(t_aval.update(dtype=word_dtype)), t)).result
else:
# The double-word trick above only works if we have a sufficiently large
# type. As an alternative, we can pack two half words into a single word,
# at the cost of precision.
# TODO(b/73062247): add support for tuple reductions and remove this case.
warnings.warn("Using reduced precision for gradient of reduce-window "
"min/max operator to work around missing XLA support for "
"pair-reductions. This is likely from a second or "
"higher derivative of a max-pooling operation.")
r_nbits = nbits // 2
# Drop/round the bottom mantissa bits.
nexp = dtypes.finfo(dtype).nexp
nmant = r_nbits - nexp - 1
double_word_dtype = word_dtype = lax._UINT_DTYPES[nbits]
# Packs two values into a double_word_type.
def pack(a, b, ab_aval):
word_type_ab_aval = ab_aval.update(dtype=word_dtype)
a = hlo.ReducePrecisionOp(a, exponent_bits=mlir.i32_attr(nexp),
mantissa_bits=mlir.i32_attr(nmant))
b = hlo.ReducePrecisionOp(b, exponent_bits=mlir.i32_attr(nexp),
mantissa_bits=mlir.i32_attr(nmant))
a = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), a)
b = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), b)
b = hlo.ShiftRightLogicalOp(
b, _broadcast_scalar_const(r_nbits, word_type_ab_aval))
return hlo.OrOp(a, b)
# Unpacks the first element of a double_word_type.
def fst(t):
assert not ir.RankedTensorType(t.type).shape
st = hlo.AndOp(t, const(word_dtype, ((1 << r_nbits) - 1) << r_nbits))
return hlo.BitcastConvertOp(ir.RankedTensorType.get([], etype),
st).result
# Unpacks the second element of a double_word_type.
def snd(t, t_aval):
return hlo.BitcastConvertOp(
mlir.aval_to_ir_type(t_aval.update(dtype=dtype)),
hlo.ShiftLeftOp(t, _broadcast_scalar_const(r_nbits, t_aval.update(dtype=word_dtype)))
).result
assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim
init = -np.inf if select_prim is lax.ge_p else np.inf
double_word_out_aval = out_aval.update(dtype=double_word_dtype)
def reducer_body(reducer: ir.Block) -> Sequence[ir.Value]:
x, y = reducer.arguments
assert select_prim is lax.ge_p or select_prim is lax.le_p
cmp_op = "GE" if select_prim is lax.ge_p else "LE"
out = hlo.SelectOp(mlir.compare_hlo(fst(x), fst(y), cmp_op), x, y)
return out
res, = mlir.reduce_window(ctx,
reducer_name="reduce_window_select_and_gather_add",
reducer_body=reducer_body,
operands=[pack(operand, tangents, operand_aval)],
init_values=[pack(const(dtype, init), const(dtype, 0), core.ShapedArray((), dtype))],
init_values_avals=[core.ShapedArray((), double_word_dtype)],
out_avals=[double_word_out_aval],
window_dimensions=window_dimensions,
window_strides=window_strides,
base_dilation=base_dilation,
window_dilation=window_dilation,
padding=padding)
return [snd(res, double_word_out_aval)]
# TODO(phawkins): use this translation rule on all platforms.
def _select_and_gather_add_using_variadic_reducewindow(
tangents, operand, *, select_prim, window_dimensions, window_strides,
padding, base_dilation, window_dilation):
def reducer(x, y):
kx, vx = x
ky, vy = y
which = select_prim.bind(kx, ky)
return (lax.select(which, kx, ky), lax.select(which, vx, vy))
assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim
init = -np.inf if select_prim is lax.ge_p else np.inf
_, out = reduce_window(
(operand, tangents),
(np.array(init, dtype=operand.dtype), np.array(0, dtype=operand.dtype)),
reducer, window_dimensions, window_strides, padding, base_dilation,
window_dilation)
return out
def _select_and_gather_add_jvp(
primals, tangents, *, select_prim, window_dimensions, window_strides,
padding, base_dilation, window_dilation):
source, operand = primals
g_source, g_operand = tangents
val_out = _select_and_gather_add(
source, operand, select_prim, window_dimensions, window_strides,
padding, base_dilation, window_dilation)
del g_operand
if type(g_source) is ad_util.Zero:
tangent_out = ad_util.Zero.from_value(val_out)
else:
tangent_out = _select_and_gather_add(
g_source, operand, select_prim, window_dimensions,
window_strides, padding, base_dilation, window_dilation)
return val_out, tangent_out
def _select_and_gather_add_transpose(
t, tangents, operand, *, select_prim, window_dimensions, window_strides,
padding, base_dilation, window_dilation):
assert select_prim in (lax.le_p, lax.ge_p)
assert (ad.is_undefined_primal(tangents) and
not ad.is_undefined_primal(operand))
if any(d != 1 for d in window_dilation):
msg = ("VJP not implemented for select_and_gather (MaxPool) with window "
"dilation, got window_dilation={}.")
raise NotImplementedError(msg.format(window_dilation))
if type(t) is ad_util.Zero:
return [ad_util.Zero(tangents.aval), None]
has_base_dilation = any(d != 1 for d in base_dilation)
if has_base_dilation:
select_identity = (lax._get_max_identity if select_prim is lax.ge_p
else lax._get_min_identity)
operand = lax.pad(operand, select_identity(operand.dtype),
tuple((0, 0, d - 1) for d in base_dilation))
result = _select_and_scatter_add(t, operand, select_prim, window_dimensions,
window_strides, padding)
if has_base_dilation:
result = slicing.slice(result, (0,) * len(result.shape), result.shape,
base_dilation)
return [result, None]
def _select_and_gather_add_batching_rule(
batched_args, batch_dims, *, select_prim, window_dimensions, window_strides,
padding, base_dilation, window_dilation):
t, x = batched_args
t_bdim, x_bdim = batch_dims
size = next(a.shape[bdim] for a, bdim in zip(batched_args, batch_dims)
if bdim is not None)
t = batching.bdim_at_front(t, t_bdim, size)
x = batching.bdim_at_front(x, x_bdim, size)
window_dimensions = (1,) + window_dimensions
window_strides = (1,) + window_strides
padding = ((0, 0),) + padding
base_dilation = (1,) + base_dilation
window_dilation = (1,) + window_dilation
out = _select_and_gather_add(t, x, select_prim, window_dimensions,
window_strides, padding, base_dilation,
window_dilation)
return (out, 0)
select_and_gather_add_p = lax.standard_primitive(
_select_and_gather_add_shape_rule, lax._input_dtype,
'select_and_gather_add')
ad.primitive_jvps[select_and_gather_add_p] = _select_and_gather_add_jvp
ad.primitive_transposes[select_and_gather_add_p] = \
_select_and_gather_add_transpose
batching.primitive_batchers[select_and_gather_add_p] = \
_select_and_gather_add_batching_rule
mlir.register_lowering(select_and_gather_add_p, mlir.lower_fun(
_select_and_gather_add_using_variadic_reducewindow,
multiple_results=False))
# TODO(b/183233858): use variadic reducewindow on GPU, when implemented.
mlir.register_lowering(
select_and_gather_add_p,
_select_and_gather_add_lowering,
platform="gpu")