Implement the select_and_gather_add translation rule via lower_fun.

This allows us to share the logic in the MLIR lowering.

PiperOrigin-RevId: 411639693
This commit is contained in:
Peter Hawkins 2021-11-22 13:49:14 -08:00 committed by jax authors
parent 4679f455f9
commit 361b7367cc
2 changed files with 54 additions and 53 deletions

View File

@ -707,30 +707,23 @@ def _select_and_gather_add_translation(
return [snd(out)]
# TODO(phawkins): use this translation rule on all platforms.
def _select_and_gather_add_translation_using_variadic_reducewindow(
ctx, avals_in, avals_out, tangents, operand, *, select_prim,
window_dimensions, window_strides, padding, base_dilation, window_dilation):
c = ctx.builder
tangents_aval, operand_aval = avals_in
dtype = operand_aval.dtype
const = lambda c, dtype, x: xops.Constant(c, np.array(x, dtype=dtype))
def reducer():
c = xc.XlaBuilder("select_and_gather_pair_reducer")
shape = xla_client.Shape.array_shape(np.dtype(dtype), ())
kx, vx, ky, vy = (xb.parameter(c, i, shape) for i in range(4))
which = (xops.Ge if select_prim is lax.ge_p else xops.Le)(kx, ky)
xops.Tuple(c, [xops.Select(which, kx, ky), xops.Select(which, vx, vy)])
return c.build()
def _select_and_gather_add_using_variadic_reducewindow(
tangents, operand, *, select_prim, window_dimensions, window_strides,
padding, base_dilation, window_dilation):
def reducer(x, y):
kx, vx = x
ky, vy = y
which = select_prim.bind(kx, ky)
return (lax.select(which, kx, ky), lax.select(which, vx, vy))
assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim
init = -np.inf if select_prim is lax.ge_p else np.inf
out = xops.ReduceWindowWithGeneralPadding(
[operand, tangents], [const(c, dtype, init), const(c, dtype, 0)],
reducer(), window_dimensions, window_strides, base_dilation,
window_dilation, padding)
return [xops.GetTupleElement(out, 1)]
_, out = reduce_window(
(operand, tangents),
(np.array(init, dtype=operand.dtype), np.array(0, dtype=operand.dtype)),
reducer, window_dimensions, window_strides, padding, base_dilation,
window_dilation)
return out
def _select_and_gather_add_jvp(
primals, tangents, *, select_prim, window_dimensions, window_strides,
@ -795,7 +788,8 @@ def _select_and_gather_add_batching_rule(
select_and_gather_add_p = lax.standard_primitive(
_select_and_gather_add_shape_rule, _input_dtype, 'select_and_gather_add',
_select_and_gather_add_translation_using_variadic_reducewindow)
xla.lower_fun(_select_and_gather_add_using_variadic_reducewindow,
new_style=True, multiple_results=False))
ad.primitive_jvps[select_and_gather_add_p] = _select_and_gather_add_jvp
ad.primitive_transposes[select_and_gather_add_p] = \
_select_and_gather_add_transpose

View File

@ -889,6 +889,35 @@ def apply_primitive(prim, *args, **params):
# Translation rules for lax primitives
def _fallback_lowering(prim: core.Primitive, ctx: LoweringContext,
avals_in, avals_out, *args, **params):
xla_computation = xla.primitive_subcomputation(ctx.platform, prim, *avals_in,
**params)
submodule_str = xc._xla.mlir.xla_computation_to_mlir_module(xla_computation)
submodule = ir.Module.parse(submodule_str)
callee_name = None
for op in submodule.body.operations:
ctx.module.body.append(op)
if op.name.value == "main":
callee_name = ir.StringAttr(ctx.symbol_table.insert(op)).value
op.attributes["sym_visibility"] = ir.StringAttr.get("private")
else:
ctx.symbol_table.insert(op)
output_types = map(aval_to_ir_types, avals_out)
flat_output_types = util.flatten(output_types)
output_type = (ir.TupleType.get_tuple(flat_output_types)
if prim.multiple_results else flat_output_types[0])
call = std.CallOp([output_type], ir.FlatSymbolRefAttr.get(callee_name),
_flatten_lowering_ir_args(args)).result
if not prim.multiple_results:
return [call]
flat_results = [mhlo.GetTupleElementOp(typ, call, _i32_attr(i)).result
for i, typ in enumerate(flat_output_types)]
return util.unflatten(flat_results, map(len, output_types))
def _broadcast(aval_out: core.ShapedArray, avals: Sequence[core.ShapedArray],
args: Sequence[ir.Value]) -> Sequence[ir.Value]:
out = []
@ -1527,6 +1556,12 @@ def _select_and_scatter_lower(
translations[lax_windowed_reductions.select_and_scatter_p] = _select_and_scatter_lower
translations[lax_windowed_reductions.select_and_gather_add_p] = lower_fun(
lax_windowed_reductions._select_and_gather_add_using_variadic_reducewindow,
multiple_results=False)
platform_specific_translations["gpu"][lax_windowed_reductions.select_and_gather_add_p] = (
partial(_fallback_lowering, lax_windowed_reductions.select_and_gather_add_p))
def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions,
window_strides, padding, expand_padding):
@ -1541,9 +1576,9 @@ def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions,
pads = [(lo, hi, 0) for (lo, hi) in padding]
operand = lax.pad(operand, identity(dtype), pads)
padding = [(0, 0) for _ in padding]
out = lax._select_and_scatter(operand, select, window_dimensions,
window_strides, padding, source,
lax._zero(operand), scatter)
out = lax_windowed_reductions._select_and_scatter(
operand, select, window_dimensions, window_strides, padding, source,
lax._zero(operand), scatter)
if expand_padding:
start_indices = [lo for (lo, hi) in original_padding]
stop_indices = [lo + d for ((lo, hi), d) in zip(original_padding,
@ -1927,33 +1962,6 @@ def _remat_lowering(ctx, avals_in, avals_out, *args,
translations[pe.remat_call_p] = _remat_lowering
def _fallback_lowering(prim: core.Primitive, ctx: LoweringContext,
avals_in, avals_out, *args, **params):
xla_computation = xla.primitive_subcomputation(ctx.platform, prim, *avals_in,
**params)
submodule_str = xc._xla.mlir.xla_computation_to_mlir_module(xla_computation)
submodule = ir.Module.parse(submodule_str)
callee_name = None
for op in submodule.body.operations:
ctx.module.body.append(op)
if op.name.value == "main":
callee_name = ir.StringAttr(ctx.symbol_table.insert(op)).value
op.attributes["sym_visibility"] = ir.StringAttr.get("private")
else:
ctx.symbol_table.insert(op)
output_types = map(aval_to_ir_types, avals_out)
flat_output_types = util.flatten(output_types)
output_type = (ir.TupleType.get_tuple(flat_output_types)
if prim.multiple_results else flat_output_types[0])
call = std.CallOp([output_type], ir.FlatSymbolRefAttr.get(callee_name),
_flatten_lowering_ir_args(args)).result
if not prim.multiple_results:
return [call]
flat_results = [mhlo.GetTupleElementOp(typ, call, _i32_attr(i)).result
for i, typ in enumerate(flat_output_types)]
return util.unflatten(flat_results, map(len, output_types))
def add_fallback_lowering(prim: core.Primitive):
translations[prim] = partial(_fallback_lowering, prim)
@ -1990,6 +1998,5 @@ map(add_fallback_lowering, [
lax.top_k_p,
# TODO(phawkins): implement these lax ops:
lax_windowed_reductions.select_and_gather_add_p,
lax.rng_bit_generator_p,
])