2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
2021-11-18 16:31:06 -05:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2023-07-21 14:20:39 -04:00
|
|
|
from collections.abc import Sequence
|
2021-11-18 16:31:06 -05:00
|
|
|
from functools import partial
|
2023-07-21 14:20:39 -04:00
|
|
|
from typing import Any, Callable, Optional, Union
|
2021-11-18 16:31:06 -05:00
|
|
|
import warnings
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
2023-02-06 22:51:50 -08:00
|
|
|
from jax import tree_util
|
2021-11-18 16:31:06 -05:00
|
|
|
|
|
|
|
from jax._src import ad_util
|
2022-12-16 20:59:41 -08:00
|
|
|
from jax._src import core
|
2023-03-31 08:50:59 -07:00
|
|
|
from jax._src import dispatch
|
2021-11-18 16:31:06 -05:00
|
|
|
from jax._src import dtypes
|
2022-12-16 20:59:41 -08:00
|
|
|
from jax._src import util
|
|
|
|
from jax._src.core import ShapedArray, ConcreteArray
|
2023-02-06 22:51:50 -08:00
|
|
|
from jax._src.interpreters import ad
|
2023-02-09 15:11:20 -08:00
|
|
|
from jax._src.interpreters import batching
|
|
|
|
from jax._src.interpreters import mlir
|
2022-12-16 20:59:41 -08:00
|
|
|
from jax._src.lax import lax
|
|
|
|
from jax._src.lax import convolution
|
|
|
|
from jax._src.lax import slicing
|
2021-11-23 18:57:45 -08:00
|
|
|
from jax._src.lib.mlir import ir
|
2022-12-15 20:59:34 -08:00
|
|
|
from jax._src.lib.mlir.dialects import hlo
|
2022-10-31 15:08:19 -07:00
|
|
|
from jax._src.numpy.ufuncs import logaddexp
|
2021-11-18 16:31:06 -05:00
|
|
|
|
2021-11-22 13:20:55 -08:00
|
|
|
map = util.safe_map
|
|
|
|
zip = util.safe_zip
|
2021-11-18 16:31:06 -05:00
|
|
|
|
|
|
|
Array = Any
|
|
|
|
|
|
|
|
|
2021-11-22 13:20:55 -08:00
|
|
|
def reduce_window(operand, init_value, computation: Callable,
|
2021-11-18 16:31:06 -05:00
|
|
|
window_dimensions: core.Shape, window_strides: Sequence[int],
|
2023-06-23 15:11:37 -07:00
|
|
|
padding: Union[str, Sequence[tuple[int, int]]],
|
2021-11-18 16:31:06 -05:00
|
|
|
base_dilation: Optional[Sequence[int]] = None,
|
|
|
|
window_dilation: Optional[Sequence[int]] = None) -> Array:
|
|
|
|
"""Wraps XLA's `ReduceWindowWithGeneralPadding
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_
|
|
|
|
operator.
|
|
|
|
"""
|
2021-11-22 13:20:55 -08:00
|
|
|
flat_operands, operand_tree = tree_util.tree_flatten(operand)
|
|
|
|
flat_init_values, init_value_tree = tree_util.tree_flatten(init_value)
|
|
|
|
if operand_tree != init_value_tree:
|
|
|
|
raise ValueError('Operands must have the same tree structure as '
|
|
|
|
f'init_values: {operand_tree} vs. {init_value_tree}')
|
|
|
|
if len(flat_operands) == 0:
|
|
|
|
raise ValueError('reduce_window must have at least one operand.')
|
|
|
|
if len(flat_operands) != len(flat_init_values):
|
|
|
|
raise ValueError('Must have same total number of operands as init_values: '
|
|
|
|
f' {len(flat_operands)} vs. {len(flat_init_values)}')
|
2021-11-18 16:31:06 -05:00
|
|
|
if isinstance(padding, str):
|
2021-11-23 12:35:23 -08:00
|
|
|
dilated_window_dims = (
|
|
|
|
window_dimensions if window_dilation is None else
|
|
|
|
lax._dilate_shape(window_dimensions, window_dilation))
|
2021-11-22 13:20:55 -08:00
|
|
|
padding = tuple(lax.padtype_to_pads(
|
|
|
|
flat_operands[0].shape, dilated_window_dims, window_strides, padding))
|
2021-11-18 16:31:06 -05:00
|
|
|
else:
|
|
|
|
padding = tuple(padding)
|
|
|
|
if base_dilation is None:
|
|
|
|
base_dilation = (1,) * len(window_dimensions)
|
|
|
|
if window_dilation is None:
|
|
|
|
window_dilation = (1,) * len(window_dimensions)
|
2021-11-22 13:20:55 -08:00
|
|
|
monoid_reducer = _get_monoid_window_reducer(computation, flat_init_values)
|
2021-11-18 16:31:06 -05:00
|
|
|
if monoid_reducer:
|
|
|
|
return monoid_reducer(operand, window_dimensions, window_strides, padding,
|
|
|
|
base_dilation, window_dilation)
|
|
|
|
else:
|
2021-11-22 13:20:55 -08:00
|
|
|
flat_init_avals = map(lax._abstractify, flat_init_values)
|
|
|
|
jaxpr, consts, out_tree = lax._variadic_reduction_jaxpr(
|
|
|
|
computation, tuple(flat_init_avals), init_value_tree)
|
|
|
|
if operand_tree != out_tree:
|
|
|
|
raise ValueError(
|
|
|
|
'reduce_window output must have the same tree structure as the operands'
|
|
|
|
f' {operand_tree} vs. {out_tree}')
|
|
|
|
out_flat = reduce_window_p.bind(
|
2022-01-03 01:52:33 -05:00
|
|
|
*flat_operands, *flat_init_values, jaxpr=jaxpr, consts=consts,
|
2021-11-18 16:31:06 -05:00
|
|
|
window_dimensions=tuple(window_dimensions),
|
|
|
|
window_strides=tuple(window_strides), padding=padding,
|
|
|
|
base_dilation=tuple(base_dilation),
|
|
|
|
window_dilation=tuple(window_dilation))
|
2021-11-22 13:20:55 -08:00
|
|
|
return tree_util.tree_unflatten(out_tree, out_flat)
|
2021-11-18 16:31:06 -05:00
|
|
|
|
2021-11-22 13:20:55 -08:00
|
|
|
def _get_monoid_window_reducer(monoid_op: Callable,
|
|
|
|
xs: Sequence[Array]) -> Optional[Callable]:
|
|
|
|
if len(xs) != 1:
|
|
|
|
return None
|
|
|
|
x, = xs
|
2021-11-18 16:31:06 -05:00
|
|
|
aval = core.get_aval(x)
|
|
|
|
if (type(aval) is ConcreteArray) and aval.shape == ():
|
|
|
|
if monoid_op is lax.add:
|
|
|
|
return aval.val == 0 and _reduce_window_sum
|
|
|
|
elif monoid_op is lax.max:
|
2021-11-23 12:35:23 -08:00
|
|
|
return (aval.val == lax._get_max_identity(aval.dtype)
|
|
|
|
and _reduce_window_max)
|
2021-11-18 16:31:06 -05:00
|
|
|
elif monoid_op is lax.min:
|
2021-11-23 12:35:23 -08:00
|
|
|
return (aval.val == lax._get_min_identity(aval.dtype)
|
|
|
|
and _reduce_window_min)
|
2021-11-18 16:31:06 -05:00
|
|
|
return None
|
|
|
|
|
|
|
|
def _reduce_window_sum(operand: Array, window_dimensions: core.Shape,
|
|
|
|
window_strides: Sequence[int],
|
2023-06-23 15:11:37 -07:00
|
|
|
padding: Sequence[tuple[int, int]],
|
2021-11-18 16:31:06 -05:00
|
|
|
base_dilation: Optional[Sequence[int]] = None,
|
|
|
|
window_dilation: Optional[Sequence[int]] = None) -> Array:
|
|
|
|
if base_dilation is None:
|
|
|
|
base_dilation = (1,) * len(window_dimensions)
|
|
|
|
if window_dilation is None:
|
|
|
|
window_dilation = (1,) * len(window_dimensions)
|
|
|
|
return reduce_window_sum_p.bind(
|
|
|
|
operand, window_dimensions=tuple(window_dimensions),
|
|
|
|
window_strides=tuple(window_strides), padding=tuple(padding),
|
|
|
|
base_dilation=tuple(base_dilation),
|
|
|
|
window_dilation=tuple(window_dilation))
|
|
|
|
|
|
|
|
def _reduce_window_prod(operand: Array, window_dimensions: core.Shape,
|
|
|
|
window_strides: Sequence[int],
|
2023-06-23 15:11:37 -07:00
|
|
|
padding: Sequence[tuple[int, int]],
|
2021-11-18 16:31:06 -05:00
|
|
|
base_dilation: Optional[Sequence[int]] = None,
|
|
|
|
window_dilation: Optional[Sequence[int]] = None) -> Array:
|
|
|
|
init_value = lax._const(operand, 1)
|
|
|
|
jaxpr, consts = lax._reduction_jaxpr(lax.mul, lax._abstractify(init_value))
|
|
|
|
if base_dilation is None:
|
|
|
|
base_dilation = (1,) * len(window_dimensions)
|
|
|
|
if window_dilation is None:
|
|
|
|
window_dilation = (1,) * len(window_dimensions)
|
2021-11-22 13:20:55 -08:00
|
|
|
out, = reduce_window_p.bind(
|
2021-11-18 16:31:06 -05:00
|
|
|
operand, init_value, jaxpr=jaxpr, consts=consts,
|
|
|
|
window_dimensions=tuple(window_dimensions),
|
|
|
|
window_strides=tuple(window_strides), padding=tuple(padding),
|
|
|
|
base_dilation=tuple(base_dilation),
|
|
|
|
window_dilation=tuple(window_dilation))
|
2021-11-22 13:20:55 -08:00
|
|
|
return out
|
2021-11-18 16:31:06 -05:00
|
|
|
|
|
|
|
def _reduce_window_max(operand: Array, window_dimensions: core.Shape,
|
|
|
|
window_strides: Sequence[int],
|
2023-06-23 15:11:37 -07:00
|
|
|
padding: Sequence[tuple[int, int]],
|
2021-11-18 16:31:06 -05:00
|
|
|
base_dilation: Optional[Sequence[int]] = None,
|
|
|
|
window_dilation: Optional[Sequence[int]] = None) -> Array:
|
|
|
|
if base_dilation is None:
|
|
|
|
base_dilation = (1,) * len(window_dimensions)
|
|
|
|
if window_dilation is None:
|
|
|
|
window_dilation = (1,) * len(window_dimensions)
|
|
|
|
return reduce_window_max_p.bind(
|
|
|
|
operand, window_dimensions=tuple(window_dimensions),
|
|
|
|
window_strides=tuple(window_strides), padding=tuple(padding),
|
|
|
|
base_dilation=tuple(base_dilation),
|
|
|
|
window_dilation=tuple(window_dilation))
|
|
|
|
|
|
|
|
def _reduce_window_min(operand: Array, window_dimensions: core.Shape,
|
|
|
|
window_strides: Sequence[int],
|
2023-06-23 15:11:37 -07:00
|
|
|
padding: Sequence[tuple[int, int]],
|
2021-11-18 16:31:06 -05:00
|
|
|
base_dilation: Optional[Sequence[int]] = None,
|
|
|
|
window_dilation: Optional[Sequence[int]] = None) -> Array:
|
|
|
|
if base_dilation is None:
|
|
|
|
base_dilation = (1,) * len(window_dimensions)
|
|
|
|
if window_dilation is None:
|
|
|
|
window_dilation = (1,) * len(window_dimensions)
|
|
|
|
return reduce_window_min_p.bind(
|
|
|
|
operand, window_dimensions=tuple(window_dimensions),
|
|
|
|
window_strides=tuple(window_strides), padding=tuple(padding),
|
|
|
|
base_dilation=tuple(base_dilation),
|
|
|
|
window_dilation=tuple(window_dilation))
|
|
|
|
|
2022-10-31 15:08:19 -07:00
|
|
|
def _reduce_window_logaddexp(
|
|
|
|
operand: Array, window_dimensions: core.Shape,
|
|
|
|
window_strides: Sequence[int],
|
2023-06-23 15:11:37 -07:00
|
|
|
padding: Sequence[tuple[int, int]],
|
2022-10-31 15:08:19 -07:00
|
|
|
base_dilation: Optional[Sequence[int]] = None,
|
|
|
|
window_dilation: Optional[Sequence[int]] = None) -> Array:
|
|
|
|
init_value = lax._const(operand, -np.inf)
|
|
|
|
jaxpr, consts = lax._reduction_jaxpr(logaddexp, lax._abstractify(init_value))
|
|
|
|
if base_dilation is None:
|
|
|
|
base_dilation = (1,) * len(window_dimensions)
|
|
|
|
if window_dilation is None:
|
|
|
|
window_dilation = (1,) * len(window_dimensions)
|
|
|
|
out, = reduce_window_p.bind(
|
|
|
|
operand, init_value, jaxpr=jaxpr, consts=consts,
|
|
|
|
window_dimensions=tuple(window_dimensions),
|
|
|
|
window_strides=tuple(window_strides), padding=tuple(padding),
|
|
|
|
base_dilation=tuple(base_dilation),
|
|
|
|
window_dilation=tuple(window_dilation))
|
|
|
|
return out
|
|
|
|
|
2021-11-18 16:31:06 -05:00
|
|
|
def _select_and_scatter(operand: Array, select: Callable,
|
|
|
|
window_dimensions: core.Shape,
|
|
|
|
window_strides: Sequence[int],
|
2023-06-23 15:11:37 -07:00
|
|
|
padding: Sequence[tuple[int, int]], source: Array,
|
2021-11-18 16:31:06 -05:00
|
|
|
init_value: Array, scatter: Callable) -> Array:
|
|
|
|
select_jaxpr, select_consts = lax._reduction_jaxpr(
|
|
|
|
select, lax._abstractify(init_value))
|
|
|
|
scatter_jaxpr, scatter_consts = lax._reduction_jaxpr(
|
|
|
|
scatter, lax._abstractify(init_value))
|
|
|
|
return select_and_scatter_p.bind(
|
|
|
|
operand, source, init_value, select_jaxpr=select_jaxpr,
|
|
|
|
select_consts=select_consts, scatter_jaxpr=scatter_jaxpr,
|
|
|
|
scatter_consts=scatter_consts, window_dimensions=tuple(window_dimensions),
|
|
|
|
window_strides=tuple(window_strides), padding=tuple(padding))
|
|
|
|
|
|
|
|
def _select_and_scatter_add(source: Array, operand: Array,
|
|
|
|
select_prim: core.Primitive,
|
|
|
|
window_dimensions: core.Shape,
|
|
|
|
window_strides: Sequence[int],
|
2023-06-23 15:11:37 -07:00
|
|
|
padding: Sequence[tuple[int, int]]) -> Array:
|
2021-11-18 16:31:06 -05:00
|
|
|
return select_and_scatter_add_p.bind(
|
|
|
|
source, operand, select_prim=select_prim,
|
|
|
|
window_dimensions=tuple(window_dimensions),
|
|
|
|
window_strides=tuple(window_strides), padding=tuple(padding))
|
|
|
|
|
|
|
|
def _select_and_gather_add(tangents: Array, operand: Array,
|
|
|
|
select_prim: core.Primitive,
|
|
|
|
window_dimensions: core.Shape,
|
|
|
|
window_strides: Sequence[int],
|
2023-06-23 15:11:37 -07:00
|
|
|
padding: Sequence[tuple[int, int]],
|
2021-11-18 16:31:06 -05:00
|
|
|
base_dilation: Sequence[int],
|
|
|
|
window_dilation: Sequence[int]) -> Array:
|
2021-11-23 12:35:23 -08:00
|
|
|
"""Extracts the tangent corresponding to the minimum or maximum element in
|
|
|
|
each window of the `operand` array.
|
2021-11-18 16:31:06 -05:00
|
|
|
|
|
|
|
Wraps XLA's `ReduceWindow
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_
|
2021-11-22 13:20:55 -08:00
|
|
|
operator, which applies a reduction function to all elements in each window of
|
|
|
|
the input multi-dimensional array. In this case, the input multi-dimensional
|
|
|
|
array is built by packing each element in the `operand` array with its
|
|
|
|
corresponding element in the `tangents` array.
|
2021-11-18 16:31:06 -05:00
|
|
|
|
|
|
|
Args:
|
|
|
|
tangents: an array
|
|
|
|
operand: an array with the same shape as `tangents`
|
|
|
|
select_prim: a reduction function (restricted to `ge_p` and `le_p`)
|
|
|
|
window_dimensions: an array of integers for window dimension values
|
|
|
|
window_strides: an array of integers for window stride values
|
|
|
|
base_dilation: an array of integers for base dilation values
|
|
|
|
window_dilation: an array of integers for window dilation values
|
|
|
|
|
|
|
|
Returns:
|
2021-11-22 13:20:55 -08:00
|
|
|
An array containing the elements in `tangents` corresponding to the output
|
|
|
|
of the reduction of `operand` fin each window.
|
2021-11-18 16:31:06 -05:00
|
|
|
"""
|
|
|
|
return select_and_gather_add_p.bind(
|
|
|
|
tangents, operand, select_prim=select_prim,
|
|
|
|
window_dimensions=tuple(window_dimensions),
|
|
|
|
window_strides=tuple(window_strides), padding=tuple(padding),
|
|
|
|
base_dilation=tuple(base_dilation),
|
|
|
|
window_dilation=tuple(window_dilation))
|
|
|
|
|
|
|
|
|
2021-11-22 13:20:55 -08:00
|
|
|
def _reduce_window_abstract_eval_rule(
|
|
|
|
*avals, jaxpr, consts, window_dimensions, window_strides, padding,
|
|
|
|
base_dilation, window_dilation):
|
|
|
|
operand_avals, init_val_avals = util.split_list(avals, [len(avals) // 2])
|
|
|
|
if any(o.dtype != iv.dtype for o, iv in zip(operand_avals, init_val_avals)):
|
|
|
|
msg = ("reduce_window got inconsistent dtypes for operands and init_values:"
|
|
|
|
" got operand dtypes {} and init_value dtypes {}.")
|
|
|
|
raise TypeError(msg.format([o.dtype for o in operand_avals],
|
|
|
|
[iv.dtype for iv in init_val_avals]))
|
|
|
|
if any(len(v.shape) != 0 for v in init_val_avals):
|
|
|
|
msg = ("reduce_window expected init_values to be scalars but init_values "
|
|
|
|
"have shapes {}.")
|
|
|
|
raise TypeError(msg.format([v.shape for v in init_val_avals]))
|
|
|
|
out_shape = _common_reduce_window_shape_rule(
|
|
|
|
operand_avals[0], window_dimensions, window_strides, padding,
|
|
|
|
base_dilation, window_dilation)
|
|
|
|
return tuple(ShapedArray(out_shape, op.dtype) for op in operand_avals)
|
|
|
|
|
2021-11-18 16:31:06 -05:00
|
|
|
def _generic_reduce_window_batch_rule(
|
|
|
|
batched_args, batch_dims, *, jaxpr, consts, window_dimensions,
|
|
|
|
window_strides, padding, base_dilation, window_dilation):
|
2021-11-22 13:20:55 -08:00
|
|
|
num_operands = len(batched_args) // 2
|
|
|
|
operands, init_values = util.split_list(batched_args, [num_operands])
|
|
|
|
operand_bdims, init_value_bdims = util.split_list(batch_dims, [num_operands])
|
|
|
|
|
|
|
|
if any(init_bdim is not None for init_bdim in init_value_bdims):
|
2021-11-18 16:31:06 -05:00
|
|
|
raise NotImplementedError("reduce_window batching is not implemented for "
|
|
|
|
"initial values")
|
|
|
|
|
2021-11-22 13:20:55 -08:00
|
|
|
size = next(x.shape[ax] for x, ax in zip(operands, operand_bdims)
|
|
|
|
if ax is not None)
|
|
|
|
operands = [batching.bdim_at_front(arg, bdim, size)
|
|
|
|
for arg, bdim in zip(operands, operand_bdims)]
|
|
|
|
window_dimensions = (1,) + window_dimensions
|
|
|
|
window_strides = (1,) + window_strides
|
|
|
|
padding = ((0, 0),) + padding
|
|
|
|
base_dilation = (1,) + base_dilation
|
|
|
|
window_dilation = (1,) + window_dilation
|
|
|
|
outs = reduce_window_p.bind(
|
|
|
|
*(operands + init_values), jaxpr=jaxpr, consts=consts,
|
|
|
|
window_dimensions=window_dimensions, window_strides=window_strides,
|
|
|
|
padding=padding, base_dilation=base_dilation,
|
2021-11-18 16:31:06 -05:00
|
|
|
window_dilation=window_dilation)
|
2021-11-22 13:20:55 -08:00
|
|
|
return outs, (0,) * num_operands
|
2021-11-18 16:31:06 -05:00
|
|
|
|
|
|
|
|
2021-11-22 13:20:55 -08:00
|
|
|
reduce_window_p = core.Primitive('reduce_window')
|
|
|
|
reduce_window_p.multiple_results = True
|
2023-03-31 08:50:59 -07:00
|
|
|
reduce_window_p.def_impl(partial(dispatch.apply_primitive, reduce_window_p))
|
2021-11-22 13:20:55 -08:00
|
|
|
reduce_window_p.def_abstract_eval(_reduce_window_abstract_eval_rule)
|
2021-11-18 16:31:06 -05:00
|
|
|
batching.primitive_batchers[reduce_window_p] = _generic_reduce_window_batch_rule
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _generic_reduce_window_lower(ctx, *args, jaxpr, consts,
|
2021-11-23 18:57:45 -08:00
|
|
|
window_dimensions, window_strides, padding,
|
|
|
|
base_dilation, window_dilation):
|
|
|
|
operands, init_values = util.split_list(args, [len(args) // 2])
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
_, init_value_avals = util.split_list(ctx.avals_in, [len(operands)])
|
2023-06-17 10:33:29 -07:00
|
|
|
|
|
|
|
def reducer_body(reducer: ir.Block) -> Sequence[ir.Value]:
|
2022-04-19 10:45:09 -07:00
|
|
|
if jaxpr.effects:
|
|
|
|
raise NotImplementedError('Cannot lower effectful `reduce_window`.')
|
|
|
|
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr,
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
mlir.TokenSet(), consts, *([a] for a in reducer.arguments),
|
|
|
|
dim_var_values=ctx.dim_var_values)
|
2023-06-17 10:33:29 -07:00
|
|
|
return util.flatten(out_nodes)
|
|
|
|
|
|
|
|
return mlir.reduce_window(
|
|
|
|
ctx,
|
|
|
|
reducer_name="generic_reduce_window_reducer",
|
|
|
|
reducer_body=reducer_body,
|
|
|
|
operands=operands,
|
|
|
|
init_values=init_values, init_values_avals=init_value_avals,
|
|
|
|
out_avals=ctx.avals_out,
|
|
|
|
window_dimensions=window_dimensions, window_strides=window_strides,
|
|
|
|
base_dilation=base_dilation, window_dilation=window_dilation,
|
|
|
|
padding=padding)
|
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
|
|
mlir.register_lowering(reduce_window_p, _generic_reduce_window_lower)
|
|
|
|
|
|
|
|
|
2021-11-18 16:31:06 -05:00
|
|
|
def _reduce_window_sum_shape_rule(operand, *, window_dimensions, window_strides,
|
|
|
|
padding, base_dilation, window_dilation):
|
|
|
|
if not dtypes.issubdtype(operand.dtype, np.number):
|
|
|
|
msg = "operand to reduce_window_sum must have a number dtype, got {}"
|
|
|
|
raise TypeError(msg.format(np.dtype(operand.dtype).name))
|
|
|
|
return _common_reduce_window_shape_rule(operand, window_dimensions,
|
2021-11-22 13:20:55 -08:00
|
|
|
window_strides, padding,
|
|
|
|
base_dilation, window_dilation)
|
2021-11-18 16:31:06 -05:00
|
|
|
|
|
|
|
def _reduce_window_sum_transpose_rule(cotangent, operand, *, window_dimensions,
|
|
|
|
window_strides, padding, base_dilation,
|
|
|
|
window_dilation):
|
|
|
|
assert ad.is_undefined_primal(operand)
|
|
|
|
input_shape = operand.aval.shape
|
2021-11-23 12:35:23 -08:00
|
|
|
pads = convolution._conv_general_vjp_lhs_padding(
|
2021-11-18 16:31:06 -05:00
|
|
|
input_shape, window_dimensions, window_strides, cotangent.shape, padding,
|
|
|
|
base_dilation, window_dilation)
|
|
|
|
ones = [1] * len(input_shape)
|
|
|
|
padding_config = [(lo, hi, stride - 1)
|
|
|
|
for (lo, hi), stride in zip(pads, window_strides)]
|
|
|
|
pad_cotangent = lax.pad(cotangent, lax._zero(cotangent), padding_config)
|
|
|
|
result = _reduce_window_sum(pad_cotangent, window_dimensions, base_dilation,
|
|
|
|
[(0, 0)] * len(input_shape),
|
|
|
|
base_dilation=ones,
|
|
|
|
window_dilation=window_dilation)
|
|
|
|
assert result.shape == input_shape, (result.shape, input_shape)
|
|
|
|
return [result]
|
|
|
|
|
|
|
|
def _reduce_window_batch_rule(reduce_window, batched_args, bdims, *,
|
|
|
|
window_dimensions, window_strides, padding,
|
|
|
|
base_dilation, window_dilation):
|
|
|
|
operand, = batched_args
|
|
|
|
bdim, = bdims
|
|
|
|
|
|
|
|
if bdim is not None:
|
|
|
|
window_dimensions = \
|
|
|
|
window_dimensions[:bdim] + (1,) + window_dimensions[bdim:]
|
|
|
|
window_strides = window_strides[:bdim] + (1,) + window_strides[bdim:]
|
|
|
|
padding = padding[:bdim] + ((0, 0),) + padding[bdim:]
|
|
|
|
base_dilation = base_dilation[:bdim] + (1,) + base_dilation[bdim:]
|
|
|
|
window_dilation = window_dilation[:bdim] + (1,) + window_dilation[bdim:]
|
|
|
|
|
|
|
|
operand = reduce_window(operand, window_dimensions, window_strides, padding,
|
|
|
|
base_dilation, window_dilation)
|
|
|
|
return operand, bdim
|
|
|
|
|
|
|
|
reduce_window_sum_p = lax.standard_primitive(
|
2022-04-18 08:28:08 -07:00
|
|
|
_reduce_window_sum_shape_rule, lax._input_dtype, 'reduce_window_sum')
|
2021-11-18 16:31:06 -05:00
|
|
|
ad.deflinear2(reduce_window_sum_p, _reduce_window_sum_transpose_rule)
|
|
|
|
batching.primitive_batchers[reduce_window_sum_p] = partial(
|
|
|
|
_reduce_window_batch_rule, _reduce_window_sum)
|
|
|
|
|
|
|
|
def _reduce_window_chooser_jvp_rule(prim, g, operand, *, window_dimensions,
|
|
|
|
window_strides, padding, base_dilation,
|
|
|
|
window_dilation):
|
|
|
|
assert prim is lax.max_p or prim is lax.min_p
|
|
|
|
select_prim = lax.ge_p if prim is lax.max_p else lax.le_p
|
|
|
|
return _select_and_gather_add(g, operand, select_prim, window_dimensions,
|
|
|
|
window_strides, padding, base_dilation,
|
|
|
|
window_dilation)
|
|
|
|
|
|
|
|
|
|
|
|
def _common_reduce_window_shape_rule(operand, window_dimensions,
|
|
|
|
window_strides, padding, base_dilation,
|
|
|
|
window_dilation):
|
|
|
|
lax._check_shapelike("reduce_window", "window_dimensions", window_dimensions,
|
|
|
|
non_zero_shape=True)
|
|
|
|
lax._check_shapelike("reduce_window", "window_strides", window_strides,
|
|
|
|
non_zero_shape=True)
|
|
|
|
lax._check_shapelike("reduce_window", "base_dilation", base_dilation)
|
|
|
|
lax._check_shapelike("reduce_window", "window_dilation", window_dilation)
|
|
|
|
if operand.ndim != len(window_dimensions):
|
|
|
|
msg = ("reduce_window got the wrong number of window_dimensions for "
|
|
|
|
"operand: got operand shape {} with window_dimensions {}.")
|
|
|
|
raise TypeError(msg.format(operand.shape, window_dimensions))
|
|
|
|
if len(window_strides) != len(window_dimensions):
|
|
|
|
msg = ("reduce_window got inconsistent window_strides and "
|
|
|
|
"window_dimensions: got window_strides {} and window_dimensions {}.")
|
|
|
|
raise TypeError(msg.format(window_strides, window_dimensions))
|
|
|
|
if len(base_dilation) != len(window_dimensions):
|
|
|
|
msg = ("reduce_window got inconsistent base_dilation and "
|
|
|
|
"window_dimensions: got base_dilation {} and window_dimensions {}.")
|
|
|
|
raise TypeError(msg.format(base_dilation, window_dimensions))
|
|
|
|
if len(window_dilation) != len(window_dimensions):
|
|
|
|
msg = ("reduce_window got inconsistent window_dilation and "
|
|
|
|
"window_dimensions: got window_dilation {} and window_dimensions "
|
|
|
|
"{}.")
|
|
|
|
raise TypeError(msg.format(window_dilation, window_dimensions))
|
|
|
|
|
|
|
|
return reduce_window_shape_tuple(operand.shape, window_dimensions,
|
|
|
|
window_strides, padding, base_dilation,
|
|
|
|
window_dilation)
|
|
|
|
|
|
|
|
def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
|
|
|
|
padding, base_dilation=None,
|
|
|
|
window_dilation=None):
|
|
|
|
if base_dilation is not None:
|
2021-11-23 12:35:23 -08:00
|
|
|
operand_shape = lax._dilate_shape(operand_shape, base_dilation)
|
2021-11-18 16:31:06 -05:00
|
|
|
if window_dilation is not None:
|
2021-11-23 12:35:23 -08:00
|
|
|
window_dimensions = lax._dilate_shape(window_dimensions, window_dilation)
|
2023-07-11 14:03:52 +01:00
|
|
|
operand_padded = tuple(d + pl + ph for d, (pl, ph) in zip(operand_shape, padding))
|
|
|
|
return tuple(map(core.stride_dim, operand_padded, window_dimensions, window_strides))
|
2021-11-18 16:31:06 -05:00
|
|
|
|
|
|
|
reduce_window_max_p = lax.standard_primitive(
|
2022-04-18 08:28:08 -07:00
|
|
|
_common_reduce_window_shape_rule, lax._input_dtype, 'reduce_window_max')
|
2021-11-18 16:31:06 -05:00
|
|
|
ad.defjvp(reduce_window_max_p, partial(_reduce_window_chooser_jvp_rule,
|
|
|
|
lax.max_p))
|
|
|
|
batching.primitive_batchers[reduce_window_max_p] = partial(
|
|
|
|
_reduce_window_batch_rule, _reduce_window_max)
|
|
|
|
|
|
|
|
reduce_window_min_p = lax.standard_primitive(
|
2022-04-18 08:28:08 -07:00
|
|
|
_common_reduce_window_shape_rule, lax._input_dtype, 'reduce_window_min')
|
2021-11-18 16:31:06 -05:00
|
|
|
ad.defjvp(reduce_window_min_p, partial(_reduce_window_chooser_jvp_rule,
|
|
|
|
lax.min_p))
|
|
|
|
|
|
|
|
_reduce_window_min_batch_rule = partial(_reduce_window_batch_rule,
|
|
|
|
_reduce_window_min)
|
|
|
|
batching.primitive_batchers[reduce_window_min_p] = partial(
|
|
|
|
_reduce_window_batch_rule, _reduce_window_min)
|
|
|
|
|
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
def _reduce_window_lower(
|
2023-06-17 10:33:29 -07:00
|
|
|
reduce_op,
|
|
|
|
init_value, ctx, operand, *,
|
|
|
|
window_dimensions, window_strides, padding, base_dilation,
|
|
|
|
window_dilation):
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
operand_aval, = ctx.avals_in
|
2021-11-23 18:57:45 -08:00
|
|
|
scalar_aval = operand_aval.update(shape=())
|
2023-06-17 10:33:29 -07:00
|
|
|
|
|
|
|
return mlir.reduce_window(ctx,
|
|
|
|
reducer_name=f"reduce_window_{scalar_aval.dtype}_reducer",
|
|
|
|
reducer_body=lambda reducer: reduce_op(*reducer.arguments),
|
|
|
|
operands=[operand],
|
|
|
|
init_values=[mlir.full_like_aval(ctx, init_value(scalar_aval.dtype),
|
|
|
|
scalar_aval)],
|
|
|
|
init_values_avals=[scalar_aval],
|
|
|
|
out_avals=ctx.avals_out,
|
|
|
|
window_dimensions=window_dimensions,
|
|
|
|
window_strides=window_strides, base_dilation=base_dilation,
|
|
|
|
window_dilation=window_dilation, padding=padding)
|
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
|
|
mlir.register_lowering(reduce_window_sum_p, partial(
|
2022-12-15 20:59:34 -08:00
|
|
|
_reduce_window_lower, hlo.AddOp, lambda _: 0))
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(reduce_window_min_p, partial(
|
2022-12-15 20:59:34 -08:00
|
|
|
_reduce_window_lower, mlir.min_hlo, lax._get_min_identity))
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(reduce_window_max_p, partial(
|
2022-12-15 20:59:34 -08:00
|
|
|
_reduce_window_lower, mlir.max_hlo, lax._get_max_identity))
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
2021-11-18 16:31:06 -05:00
|
|
|
def _select_and_scatter_shape_rule(
|
|
|
|
operand, source, init_value, *, select_jaxpr, select_consts, scatter_jaxpr,
|
|
|
|
scatter_consts, window_dimensions, window_strides, padding):
|
|
|
|
lax._check_shapelike("select_and_scatter", "window_dimensions",
|
|
|
|
window_dimensions)
|
|
|
|
lax._check_shapelike("select_and_scatter", "window_strides", window_strides)
|
|
|
|
if len(window_dimensions) != len(window_strides):
|
|
|
|
msg = ("select_and_scatter got inconsistent window_strides and "
|
|
|
|
"window_dimensions: got window_strides {} and window_dimensions {}.")
|
|
|
|
raise TypeError(msg.format(window_strides, window_dimensions))
|
|
|
|
return operand.shape
|
|
|
|
|
|
|
|
select_and_scatter_p = lax.standard_primitive(
|
2022-04-18 08:28:08 -07:00
|
|
|
_select_and_scatter_shape_rule, lax._input_dtype, 'select_and_scatter')
|
2021-11-18 16:31:06 -05:00
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
def _select_and_scatter_lower(
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
ctx, operand, source, init_value, *, select_jaxpr,
|
2021-11-23 18:57:45 -08:00
|
|
|
select_consts, scatter_jaxpr, scatter_consts, window_dimensions,
|
|
|
|
window_strides, padding):
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
operand_aval, source_aval, init_value_aval = ctx.avals_in
|
|
|
|
aval_out, = ctx.avals_out
|
2021-11-23 18:57:45 -08:00
|
|
|
scalar_aval = operand_aval.update(shape=())
|
|
|
|
scalar_type = mlir.aval_to_ir_type(scalar_aval)
|
2022-12-15 20:59:34 -08:00
|
|
|
op = hlo.SelectAndScatterOp(
|
2022-05-24 04:32:15 -07:00
|
|
|
mlir.aval_to_ir_type(aval_out),
|
|
|
|
operand,
|
|
|
|
source,
|
|
|
|
init_value,
|
|
|
|
window_dimensions=mlir.dense_int_elements(window_dimensions),
|
|
|
|
window_strides=mlir.dense_int_elements(window_strides),
|
2022-06-01 10:23:42 -07:00
|
|
|
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
|
|
|
|
shape=(len(padding), 2)))
|
2021-11-23 18:57:45 -08:00
|
|
|
select = op.select.blocks.append(scalar_type, scalar_type)
|
|
|
|
with ir.InsertionPoint(select):
|
2022-04-19 10:45:09 -07:00
|
|
|
if select_jaxpr.effects:
|
|
|
|
raise NotImplementedError('Cannot lower effectful `select`.')
|
|
|
|
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, select_jaxpr,
|
|
|
|
mlir.TokenSet(), select_consts,
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
*([a] for a in select.arguments),
|
|
|
|
dim_var_values=ctx.dim_var_values)
|
2022-12-15 20:59:34 -08:00
|
|
|
hlo.ReturnOp(util.flatten(out_nodes))
|
2021-11-23 18:57:45 -08:00
|
|
|
scatter = op.scatter.blocks.append(scalar_type, scalar_type)
|
|
|
|
with ir.InsertionPoint(scatter):
|
2022-04-19 10:45:09 -07:00
|
|
|
if scatter_jaxpr.effects:
|
|
|
|
raise NotImplementedError('Cannot lower effectful `scatter`.')
|
|
|
|
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, scatter_jaxpr,
|
|
|
|
mlir.TokenSet(), scatter_consts,
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
*([a] for a in scatter.arguments),
|
|
|
|
dim_var_values=ctx.dim_var_values)
|
2022-12-15 20:59:34 -08:00
|
|
|
hlo.ReturnOp(util.flatten(out_nodes))
|
2021-11-23 18:57:45 -08:00
|
|
|
return op.results
|
|
|
|
|
|
|
|
mlir.register_lowering(select_and_scatter_p, _select_and_scatter_lower)
|
2021-11-18 16:31:06 -05:00
|
|
|
|
|
|
|
def _select_and_scatter_add_shape_rule(
|
|
|
|
source, operand, *, select_prim, window_dimensions, window_strides,
|
|
|
|
padding):
|
|
|
|
return operand.shape
|
|
|
|
|
|
|
|
def _select_and_scatter_add_jvp(
|
|
|
|
primals, tangents, *, select_prim, window_dimensions, window_strides,
|
|
|
|
padding):
|
|
|
|
source, operand = primals
|
|
|
|
g_source, g_operand = tangents
|
|
|
|
val_out = _select_and_scatter_add(
|
|
|
|
source, operand, select_prim, window_dimensions, window_strides,
|
|
|
|
padding)
|
|
|
|
del g_operand
|
|
|
|
if type(g_source) is ad_util.Zero:
|
|
|
|
tangent_out = ad_util.Zero.from_value(val_out)
|
|
|
|
else:
|
|
|
|
tangent_out = _select_and_scatter_add(
|
|
|
|
g_source, operand, select_prim, window_dimensions,
|
|
|
|
window_strides, padding)
|
|
|
|
return val_out, tangent_out
|
|
|
|
|
|
|
|
def _select_and_scatter_add_transpose(
|
|
|
|
t, source, operand, *, select_prim, window_dimensions, window_strides,
|
|
|
|
padding):
|
|
|
|
assert ad.is_undefined_primal(source) and not ad.is_undefined_primal(operand)
|
|
|
|
if type(t) is ad_util.Zero:
|
|
|
|
return [ad_util.Zero(source.aval), None]
|
|
|
|
ones = (1,) * len(window_dimensions)
|
|
|
|
source_t = _select_and_gather_add(t, operand, select_prim, window_dimensions,
|
|
|
|
window_strides, padding, ones, ones)
|
|
|
|
return [source_t, None]
|
|
|
|
|
|
|
|
def _select_and_scatter_add_batch_rule(
|
|
|
|
batched_args, batch_dims, *, select_prim, window_dimensions, window_strides,
|
|
|
|
padding):
|
|
|
|
source, operand = batched_args
|
|
|
|
s_bdim, o_bdim = batch_dims
|
|
|
|
size = next(a.shape[bdim] for a, bdim in zip(batched_args, batch_dims)
|
|
|
|
if bdim is not None)
|
|
|
|
source = batching.bdim_at_front(source, s_bdim, size)
|
|
|
|
operand = batching.bdim_at_front(operand, o_bdim, size)
|
|
|
|
|
|
|
|
window_dimensions = (1,) + window_dimensions
|
|
|
|
window_strides = (1,) + window_strides
|
|
|
|
padding = ((0, 0),) + padding
|
|
|
|
out = _select_and_scatter_add(source, operand, select_prim, window_dimensions,
|
|
|
|
window_strides, padding)
|
|
|
|
return out, 0
|
|
|
|
|
|
|
|
select_and_scatter_add_p = lax.standard_primitive(
|
2021-11-23 12:35:23 -08:00
|
|
|
_select_and_scatter_add_shape_rule, lax._input_dtype,
|
2022-04-18 08:28:08 -07:00
|
|
|
'select_and_scatter_add')
|
2021-11-18 16:31:06 -05:00
|
|
|
|
|
|
|
ad.primitive_transposes[select_and_scatter_add_p] = \
|
|
|
|
_select_and_scatter_add_transpose
|
|
|
|
ad.primitive_jvps[select_and_scatter_add_p] = _select_and_scatter_add_jvp
|
|
|
|
batching.primitive_batchers[select_and_scatter_add_p] = \
|
|
|
|
_select_and_scatter_add_batch_rule
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _select_and_scatter_add_impl(source, operand, *,
|
|
|
|
select_prim, window_dimensions, window_strides,
|
|
|
|
padding, expand_padding):
|
2021-11-23 18:57:45 -08:00
|
|
|
dtype = source.dtype
|
|
|
|
select = lambda x, y: select_prim.bind(x, y)
|
|
|
|
scatter = lax.bitwise_or if dtype == np.bool_ else lax.add
|
|
|
|
if expand_padding:
|
|
|
|
operand_shape = operand.shape
|
|
|
|
original_padding = padding
|
|
|
|
identity = (lax._get_max_identity if select_prim is lax.ge_p
|
|
|
|
else lax._get_min_identity)
|
|
|
|
pads = [(lo, hi, 0) for (lo, hi) in padding]
|
|
|
|
operand = lax.pad(operand, identity(dtype), pads)
|
|
|
|
padding = [(0, 0) for _ in padding]
|
|
|
|
out = _select_and_scatter(
|
|
|
|
operand, select, window_dimensions, window_strides, padding, source,
|
|
|
|
lax._zero(operand), scatter)
|
|
|
|
if expand_padding:
|
|
|
|
start_indices = [lo for (lo, hi) in original_padding]
|
|
|
|
stop_indices = [lo + d for ((lo, hi), d) in zip(original_padding,
|
|
|
|
operand_shape)]
|
|
|
|
out = slicing.slice(out, start_indices, stop_indices)
|
|
|
|
return out
|
|
|
|
|
|
|
|
mlir.register_lowering(select_and_scatter_add_p, mlir.lower_fun(
|
|
|
|
partial(_select_and_scatter_add_impl, expand_padding=False),
|
|
|
|
multiple_results=False))
|
2022-04-18 08:28:08 -07:00
|
|
|
# TODO(b/161704903): workaround for XLA/CPU crash.
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(select_and_scatter_add_p, mlir.lower_fun(
|
|
|
|
partial(_select_and_scatter_add_impl, expand_padding=True),
|
|
|
|
multiple_results=False), platform='cpu')
|
2022-04-18 08:28:08 -07:00
|
|
|
# TODO(b/182390722): workaround for XLA/GPU crash.
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(select_and_scatter_add_p, mlir.lower_fun(
|
|
|
|
partial(_select_and_scatter_add_impl, expand_padding=True),
|
|
|
|
multiple_results=False), platform='gpu')
|
|
|
|
|
|
|
|
|
2021-11-18 16:31:06 -05:00
|
|
|
def _select_and_gather_add_shape_rule(
|
|
|
|
tangents, operand, *, select_prim, window_dimensions, window_strides,
|
|
|
|
padding, base_dilation, window_dilation):
|
|
|
|
if tangents.shape != operand.shape:
|
|
|
|
msg = ("select_and_gather_add tangents and operand shapes must match, "
|
|
|
|
"got {} and {}.")
|
|
|
|
raise TypeError(msg.format(tangents.shape, operand.shape))
|
|
|
|
return _common_reduce_window_shape_rule(
|
|
|
|
operand, window_dimensions, window_strides, padding, base_dilation,
|
|
|
|
window_dilation)
|
|
|
|
|
2022-05-09 08:14:56 -07:00
|
|
|
def _select_and_gather_add_lowering(
|
2022-12-17 22:43:58 +02:00
|
|
|
ctx: mlir.LoweringRuleContext,
|
|
|
|
tangents, operand, *, select_prim,
|
2021-11-18 16:31:06 -05:00
|
|
|
window_dimensions, window_strides, padding, base_dilation, window_dilation,
|
|
|
|
max_bits=64):
|
2022-05-09 08:14:56 -07:00
|
|
|
_, operand_aval, = ctx.avals_in
|
|
|
|
out_aval, = ctx.avals_out
|
2023-04-10 10:15:08 -07:00
|
|
|
assert isinstance(operand_aval, core.ShapedArray), operand_aval
|
2021-11-18 16:31:06 -05:00
|
|
|
dtype = operand_aval.dtype
|
2022-05-09 08:14:56 -07:00
|
|
|
etype = mlir.dtype_to_ir_type(dtype)
|
2021-11-18 16:31:06 -05:00
|
|
|
nbits = dtypes.finfo(dtype).bits
|
|
|
|
|
|
|
|
assert nbits <= max_bits
|
|
|
|
double_word_reduction = nbits * 2 <= max_bits
|
|
|
|
|
2023-08-17 06:43:31 -07:00
|
|
|
const = lambda dtype, x: mlir.ir_constant(np.array(x, dtype=dtype))
|
2022-05-09 08:14:56 -07:00
|
|
|
|
2022-12-17 22:43:58 +02:00
|
|
|
def _broadcast_scalar_const(x, aval_out):
|
|
|
|
return mlir.broadcast_in_dim(ctx, const(aval_out.dtype, x),
|
|
|
|
aval_out,
|
|
|
|
broadcast_dimensions=())
|
2021-11-18 16:31:06 -05:00
|
|
|
|
|
|
|
if double_word_reduction:
|
|
|
|
# TODO(b/73062247): XLA doesn't yet implement ReduceWindow on tuples, so
|
|
|
|
# we implement a pair-wise ReduceWindow by packing two k-bit values into
|
|
|
|
# 2k-bit unsigned integer using bit tricks.
|
|
|
|
word_dtype = lax._UINT_DTYPES[nbits]
|
|
|
|
double_word_dtype = lax._UINT_DTYPES[nbits * 2]
|
2022-12-17 22:43:58 +02:00
|
|
|
word_type = mlir.dtype_to_ir_type(word_dtype) # type: ignore
|
|
|
|
# Packs two values into a double_word_type.
|
|
|
|
def pack(a, b, ab_aval):
|
|
|
|
word_type_ab_aval = ab_aval.update(dtype=word_dtype)
|
|
|
|
double_word_type_ab_aval = ab_aval.update(dtype=double_word_dtype)
|
|
|
|
a = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), a)
|
|
|
|
b = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), b)
|
|
|
|
a = hlo.ConvertOp(mlir.aval_to_ir_type(double_word_type_ab_aval), a)
|
|
|
|
b = hlo.ConvertOp(mlir.aval_to_ir_type(double_word_type_ab_aval), b)
|
2022-12-15 20:59:34 -08:00
|
|
|
a = hlo.ShiftLeftOp(a,
|
2022-12-17 22:43:58 +02:00
|
|
|
_broadcast_scalar_const(nbits, double_word_type_ab_aval))
|
2022-12-15 20:59:34 -08:00
|
|
|
return hlo.OrOp(a, b)
|
2021-11-18 16:31:06 -05:00
|
|
|
|
2022-12-17 22:43:58 +02:00
|
|
|
# Unpacks the first element of a double_word_type.
|
2022-05-09 08:14:56 -07:00
|
|
|
def fst(t):
|
2022-12-17 22:43:58 +02:00
|
|
|
assert not ir.RankedTensorType(t.type).shape
|
2022-12-15 20:59:34 -08:00
|
|
|
st = hlo.ShiftRightLogicalOp(t, const(double_word_dtype, nbits))
|
|
|
|
return hlo.BitcastConvertOp(
|
2022-12-17 22:43:58 +02:00
|
|
|
ir.RankedTensorType.get([], etype),
|
|
|
|
hlo.ConvertOp(ir.RankedTensorType.get([], word_type), st)).result
|
2021-11-18 16:31:06 -05:00
|
|
|
|
2022-12-17 22:43:58 +02:00
|
|
|
# Unpacks the second element of a double_word_type.
|
|
|
|
def snd(t, t_aval):
|
2022-12-15 20:59:34 -08:00
|
|
|
return hlo.BitcastConvertOp(
|
2022-12-17 22:43:58 +02:00
|
|
|
mlir.aval_to_ir_type(t_aval.update(dtype=dtype)),
|
|
|
|
hlo.ConvertOp(mlir.aval_to_ir_type(t_aval.update(dtype=word_dtype)), t)).result
|
2021-11-18 16:31:06 -05:00
|
|
|
|
|
|
|
else:
|
|
|
|
# The double-word trick above only works if we have a sufficiently large
|
|
|
|
# type. As an alternative, we can pack two half words into a single word,
|
|
|
|
# at the cost of precision.
|
|
|
|
# TODO(b/73062247): add support for tuple reductions and remove this case.
|
|
|
|
warnings.warn("Using reduced precision for gradient of reduce-window "
|
|
|
|
"min/max operator to work around missing XLA support for "
|
|
|
|
"pair-reductions. This is likely from a second or "
|
|
|
|
"higher derivative of a max-pooling operation.")
|
|
|
|
r_nbits = nbits // 2
|
|
|
|
# Drop/round the bottom mantissa bits.
|
|
|
|
nexp = dtypes.finfo(dtype).nexp
|
|
|
|
nmant = r_nbits - nexp - 1
|
|
|
|
|
|
|
|
double_word_dtype = word_dtype = lax._UINT_DTYPES[nbits]
|
|
|
|
|
2022-12-17 22:43:58 +02:00
|
|
|
# Packs two values into a double_word_type.
|
|
|
|
def pack(a, b, ab_aval):
|
|
|
|
word_type_ab_aval = ab_aval.update(dtype=word_dtype)
|
2022-12-15 20:59:34 -08:00
|
|
|
a = hlo.ReducePrecisionOp(a, exponent_bits=mlir.i32_attr(nexp),
|
|
|
|
mantissa_bits=mlir.i32_attr(nmant))
|
|
|
|
b = hlo.ReducePrecisionOp(b, exponent_bits=mlir.i32_attr(nexp),
|
|
|
|
mantissa_bits=mlir.i32_attr(nmant))
|
2022-12-17 22:43:58 +02:00
|
|
|
a = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), a)
|
|
|
|
b = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), b)
|
2022-12-15 20:59:34 -08:00
|
|
|
b = hlo.ShiftRightLogicalOp(
|
2022-12-17 22:43:58 +02:00
|
|
|
b, _broadcast_scalar_const(r_nbits, word_type_ab_aval))
|
2022-12-15 20:59:34 -08:00
|
|
|
return hlo.OrOp(a, b)
|
2021-11-18 16:31:06 -05:00
|
|
|
|
2022-12-17 22:43:58 +02:00
|
|
|
# Unpacks the first element of a double_word_type.
|
2022-05-09 08:14:56 -07:00
|
|
|
def fst(t):
|
2022-12-17 22:43:58 +02:00
|
|
|
assert not ir.RankedTensorType(t.type).shape
|
2022-12-15 20:59:34 -08:00
|
|
|
st = hlo.AndOp(t, const(word_dtype, ((1 << r_nbits) - 1) << r_nbits))
|
|
|
|
return hlo.BitcastConvertOp(ir.RankedTensorType.get([], etype),
|
|
|
|
st).result
|
2021-11-18 16:31:06 -05:00
|
|
|
|
2022-12-17 22:43:58 +02:00
|
|
|
# Unpacks the second element of a double_word_type.
|
|
|
|
def snd(t, t_aval):
|
2022-12-15 20:59:34 -08:00
|
|
|
return hlo.BitcastConvertOp(
|
2022-12-17 22:43:58 +02:00
|
|
|
mlir.aval_to_ir_type(t_aval.update(dtype=dtype)),
|
|
|
|
hlo.ShiftLeftOp(t, _broadcast_scalar_const(r_nbits, t_aval.update(dtype=word_dtype)))
|
2022-05-09 08:14:56 -07:00
|
|
|
).result
|
2021-11-18 16:31:06 -05:00
|
|
|
|
|
|
|
assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim
|
|
|
|
init = -np.inf if select_prim is lax.ge_p else np.inf
|
2022-12-17 22:43:58 +02:00
|
|
|
double_word_out_aval = out_aval.update(dtype=double_word_dtype)
|
2023-06-17 10:33:29 -07:00
|
|
|
|
|
|
|
def reducer_body(reducer: ir.Block) -> Sequence[ir.Value]:
|
2022-05-09 08:14:56 -07:00
|
|
|
x, y = reducer.arguments
|
|
|
|
assert select_prim is lax.ge_p or select_prim is lax.le_p
|
2022-12-17 22:43:58 +02:00
|
|
|
cmp_op = "GE" if select_prim is lax.ge_p else "LE"
|
|
|
|
out = hlo.SelectOp(mlir.compare_hlo(fst(x), fst(y), cmp_op), x, y)
|
2023-06-17 10:33:29 -07:00
|
|
|
return out
|
|
|
|
|
|
|
|
res, = mlir.reduce_window(ctx,
|
|
|
|
reducer_name="reduce_window_select_and_gather_add",
|
|
|
|
reducer_body=reducer_body,
|
|
|
|
operands=[pack(operand, tangents, operand_aval)],
|
|
|
|
init_values=[pack(const(dtype, init), const(dtype, 0), core.ShapedArray((), dtype))],
|
|
|
|
init_values_avals=[core.ShapedArray((), double_word_dtype)],
|
|
|
|
out_avals=[double_word_out_aval],
|
|
|
|
window_dimensions=window_dimensions,
|
|
|
|
window_strides=window_strides,
|
|
|
|
base_dilation=base_dilation,
|
|
|
|
window_dilation=window_dilation,
|
|
|
|
padding=padding)
|
|
|
|
return [snd(res, double_word_out_aval)]
|
2021-11-18 16:31:06 -05:00
|
|
|
|
|
|
|
# TODO(phawkins): use this translation rule on all platforms.
|
2021-11-22 13:49:14 -08:00
|
|
|
def _select_and_gather_add_using_variadic_reducewindow(
|
|
|
|
tangents, operand, *, select_prim, window_dimensions, window_strides,
|
|
|
|
padding, base_dilation, window_dilation):
|
|
|
|
def reducer(x, y):
|
|
|
|
kx, vx = x
|
|
|
|
ky, vy = y
|
|
|
|
which = select_prim.bind(kx, ky)
|
|
|
|
return (lax.select(which, kx, ky), lax.select(which, vx, vy))
|
2021-11-18 16:31:06 -05:00
|
|
|
|
|
|
|
assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim
|
|
|
|
init = -np.inf if select_prim is lax.ge_p else np.inf
|
2021-11-22 13:49:14 -08:00
|
|
|
_, out = reduce_window(
|
|
|
|
(operand, tangents),
|
|
|
|
(np.array(init, dtype=operand.dtype), np.array(0, dtype=operand.dtype)),
|
|
|
|
reducer, window_dimensions, window_strides, padding, base_dilation,
|
|
|
|
window_dilation)
|
|
|
|
return out
|
2021-11-18 16:31:06 -05:00
|
|
|
|
|
|
|
def _select_and_gather_add_jvp(
|
|
|
|
primals, tangents, *, select_prim, window_dimensions, window_strides,
|
|
|
|
padding, base_dilation, window_dilation):
|
|
|
|
source, operand = primals
|
|
|
|
g_source, g_operand = tangents
|
|
|
|
val_out = _select_and_gather_add(
|
|
|
|
source, operand, select_prim, window_dimensions, window_strides,
|
|
|
|
padding, base_dilation, window_dilation)
|
|
|
|
del g_operand
|
|
|
|
if type(g_source) is ad_util.Zero:
|
|
|
|
tangent_out = ad_util.Zero.from_value(val_out)
|
|
|
|
else:
|
|
|
|
tangent_out = _select_and_gather_add(
|
|
|
|
g_source, operand, select_prim, window_dimensions,
|
|
|
|
window_strides, padding, base_dilation, window_dilation)
|
|
|
|
return val_out, tangent_out
|
|
|
|
|
|
|
|
def _select_and_gather_add_transpose(
|
|
|
|
t, tangents, operand, *, select_prim, window_dimensions, window_strides,
|
|
|
|
padding, base_dilation, window_dilation):
|
|
|
|
assert select_prim in (lax.le_p, lax.ge_p)
|
2021-11-23 12:35:23 -08:00
|
|
|
assert (ad.is_undefined_primal(tangents) and
|
|
|
|
not ad.is_undefined_primal(operand))
|
2021-11-18 16:31:06 -05:00
|
|
|
if any(d != 1 for d in window_dilation):
|
|
|
|
msg = ("VJP not implemented for select_and_gather (MaxPool) with window "
|
|
|
|
"dilation, got window_dilation={}.")
|
|
|
|
raise NotImplementedError(msg.format(window_dilation))
|
|
|
|
if type(t) is ad_util.Zero:
|
|
|
|
return [ad_util.Zero(tangents.aval), None]
|
|
|
|
has_base_dilation = any(d != 1 for d in base_dilation)
|
|
|
|
if has_base_dilation:
|
2021-11-23 12:35:23 -08:00
|
|
|
select_identity = (lax._get_max_identity if select_prim is lax.ge_p
|
|
|
|
else lax._get_min_identity)
|
2021-11-18 16:31:06 -05:00
|
|
|
operand = lax.pad(operand, select_identity(operand.dtype),
|
|
|
|
tuple((0, 0, d - 1) for d in base_dilation))
|
|
|
|
result = _select_and_scatter_add(t, operand, select_prim, window_dimensions,
|
|
|
|
window_strides, padding)
|
|
|
|
if has_base_dilation:
|
2021-11-23 16:34:33 -08:00
|
|
|
result = slicing.slice(result, (0,) * len(result.shape), result.shape,
|
|
|
|
base_dilation)
|
2021-11-18 16:31:06 -05:00
|
|
|
return [result, None]
|
|
|
|
|
|
|
|
def _select_and_gather_add_batching_rule(
|
|
|
|
batched_args, batch_dims, *, select_prim, window_dimensions, window_strides,
|
|
|
|
padding, base_dilation, window_dilation):
|
|
|
|
t, x = batched_args
|
|
|
|
t_bdim, x_bdim = batch_dims
|
|
|
|
size = next(a.shape[bdim] for a, bdim in zip(batched_args, batch_dims)
|
|
|
|
if bdim is not None)
|
|
|
|
t = batching.bdim_at_front(t, t_bdim, size)
|
|
|
|
x = batching.bdim_at_front(x, x_bdim, size)
|
|
|
|
window_dimensions = (1,) + window_dimensions
|
|
|
|
window_strides = (1,) + window_strides
|
|
|
|
padding = ((0, 0),) + padding
|
|
|
|
base_dilation = (1,) + base_dilation
|
|
|
|
window_dilation = (1,) + window_dilation
|
|
|
|
out = _select_and_gather_add(t, x, select_prim, window_dimensions,
|
|
|
|
window_strides, padding, base_dilation,
|
|
|
|
window_dilation)
|
|
|
|
return (out, 0)
|
|
|
|
|
|
|
|
|
|
|
|
select_and_gather_add_p = lax.standard_primitive(
|
2021-11-23 12:35:23 -08:00
|
|
|
_select_and_gather_add_shape_rule, lax._input_dtype,
|
2022-04-18 08:28:08 -07:00
|
|
|
'select_and_gather_add')
|
2021-11-18 16:31:06 -05:00
|
|
|
ad.primitive_jvps[select_and_gather_add_p] = _select_and_gather_add_jvp
|
|
|
|
ad.primitive_transposes[select_and_gather_add_p] = \
|
|
|
|
_select_and_gather_add_transpose
|
|
|
|
batching.primitive_batchers[select_and_gather_add_p] = \
|
|
|
|
_select_and_gather_add_batching_rule
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
|
|
mlir.register_lowering(select_and_gather_add_p, mlir.lower_fun(
|
|
|
|
_select_and_gather_add_using_variadic_reducewindow,
|
|
|
|
multiple_results=False))
|
|
|
|
|
2022-04-18 08:28:08 -07:00
|
|
|
# TODO(b/183233858): use variadic reducewindow on GPU, when implemented.
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(
|
|
|
|
select_and_gather_add_p,
|
2022-05-09 08:14:56 -07:00
|
|
|
_select_and_gather_add_lowering,
|
2021-11-23 18:57:45 -08:00
|
|
|
platform="gpu")
|