mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Implement the select_and_gather_add translation rule via lower_fun.
This allows us to share the logic in the MLIR lowering. PiperOrigin-RevId: 411639693
This commit is contained in:
parent
4679f455f9
commit
361b7367cc
@ -707,30 +707,23 @@ def _select_and_gather_add_translation(
|
||||
return [snd(out)]
|
||||
|
||||
# TODO(phawkins): use this translation rule on all platforms.
|
||||
def _select_and_gather_add_translation_using_variadic_reducewindow(
|
||||
ctx, avals_in, avals_out, tangents, operand, *, select_prim,
|
||||
window_dimensions, window_strides, padding, base_dilation, window_dilation):
|
||||
c = ctx.builder
|
||||
tangents_aval, operand_aval = avals_in
|
||||
dtype = operand_aval.dtype
|
||||
|
||||
const = lambda c, dtype, x: xops.Constant(c, np.array(x, dtype=dtype))
|
||||
|
||||
def reducer():
|
||||
c = xc.XlaBuilder("select_and_gather_pair_reducer")
|
||||
shape = xla_client.Shape.array_shape(np.dtype(dtype), ())
|
||||
kx, vx, ky, vy = (xb.parameter(c, i, shape) for i in range(4))
|
||||
which = (xops.Ge if select_prim is lax.ge_p else xops.Le)(kx, ky)
|
||||
xops.Tuple(c, [xops.Select(which, kx, ky), xops.Select(which, vx, vy)])
|
||||
return c.build()
|
||||
def _select_and_gather_add_using_variadic_reducewindow(
|
||||
tangents, operand, *, select_prim, window_dimensions, window_strides,
|
||||
padding, base_dilation, window_dilation):
|
||||
def reducer(x, y):
|
||||
kx, vx = x
|
||||
ky, vy = y
|
||||
which = select_prim.bind(kx, ky)
|
||||
return (lax.select(which, kx, ky), lax.select(which, vx, vy))
|
||||
|
||||
assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim
|
||||
init = -np.inf if select_prim is lax.ge_p else np.inf
|
||||
out = xops.ReduceWindowWithGeneralPadding(
|
||||
[operand, tangents], [const(c, dtype, init), const(c, dtype, 0)],
|
||||
reducer(), window_dimensions, window_strides, base_dilation,
|
||||
window_dilation, padding)
|
||||
return [xops.GetTupleElement(out, 1)]
|
||||
_, out = reduce_window(
|
||||
(operand, tangents),
|
||||
(np.array(init, dtype=operand.dtype), np.array(0, dtype=operand.dtype)),
|
||||
reducer, window_dimensions, window_strides, padding, base_dilation,
|
||||
window_dilation)
|
||||
return out
|
||||
|
||||
def _select_and_gather_add_jvp(
|
||||
primals, tangents, *, select_prim, window_dimensions, window_strides,
|
||||
@ -795,7 +788,8 @@ def _select_and_gather_add_batching_rule(
|
||||
|
||||
select_and_gather_add_p = lax.standard_primitive(
|
||||
_select_and_gather_add_shape_rule, _input_dtype, 'select_and_gather_add',
|
||||
_select_and_gather_add_translation_using_variadic_reducewindow)
|
||||
xla.lower_fun(_select_and_gather_add_using_variadic_reducewindow,
|
||||
new_style=True, multiple_results=False))
|
||||
ad.primitive_jvps[select_and_gather_add_p] = _select_and_gather_add_jvp
|
||||
ad.primitive_transposes[select_and_gather_add_p] = \
|
||||
_select_and_gather_add_transpose
|
||||
|
@ -889,6 +889,35 @@ def apply_primitive(prim, *args, **params):
|
||||
|
||||
# Translation rules for lax primitives
|
||||
|
||||
def _fallback_lowering(prim: core.Primitive, ctx: LoweringContext,
|
||||
avals_in, avals_out, *args, **params):
|
||||
xla_computation = xla.primitive_subcomputation(ctx.platform, prim, *avals_in,
|
||||
**params)
|
||||
submodule_str = xc._xla.mlir.xla_computation_to_mlir_module(xla_computation)
|
||||
submodule = ir.Module.parse(submodule_str)
|
||||
callee_name = None
|
||||
for op in submodule.body.operations:
|
||||
ctx.module.body.append(op)
|
||||
if op.name.value == "main":
|
||||
callee_name = ir.StringAttr(ctx.symbol_table.insert(op)).value
|
||||
op.attributes["sym_visibility"] = ir.StringAttr.get("private")
|
||||
else:
|
||||
ctx.symbol_table.insert(op)
|
||||
|
||||
output_types = map(aval_to_ir_types, avals_out)
|
||||
flat_output_types = util.flatten(output_types)
|
||||
output_type = (ir.TupleType.get_tuple(flat_output_types)
|
||||
if prim.multiple_results else flat_output_types[0])
|
||||
|
||||
call = std.CallOp([output_type], ir.FlatSymbolRefAttr.get(callee_name),
|
||||
_flatten_lowering_ir_args(args)).result
|
||||
if not prim.multiple_results:
|
||||
return [call]
|
||||
flat_results = [mhlo.GetTupleElementOp(typ, call, _i32_attr(i)).result
|
||||
for i, typ in enumerate(flat_output_types)]
|
||||
return util.unflatten(flat_results, map(len, output_types))
|
||||
|
||||
|
||||
def _broadcast(aval_out: core.ShapedArray, avals: Sequence[core.ShapedArray],
|
||||
args: Sequence[ir.Value]) -> Sequence[ir.Value]:
|
||||
out = []
|
||||
@ -1527,6 +1556,12 @@ def _select_and_scatter_lower(
|
||||
|
||||
translations[lax_windowed_reductions.select_and_scatter_p] = _select_and_scatter_lower
|
||||
|
||||
translations[lax_windowed_reductions.select_and_gather_add_p] = lower_fun(
|
||||
lax_windowed_reductions._select_and_gather_add_using_variadic_reducewindow,
|
||||
multiple_results=False)
|
||||
|
||||
platform_specific_translations["gpu"][lax_windowed_reductions.select_and_gather_add_p] = (
|
||||
partial(_fallback_lowering, lax_windowed_reductions.select_and_gather_add_p))
|
||||
|
||||
def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions,
|
||||
window_strides, padding, expand_padding):
|
||||
@ -1541,9 +1576,9 @@ def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions,
|
||||
pads = [(lo, hi, 0) for (lo, hi) in padding]
|
||||
operand = lax.pad(operand, identity(dtype), pads)
|
||||
padding = [(0, 0) for _ in padding]
|
||||
out = lax._select_and_scatter(operand, select, window_dimensions,
|
||||
window_strides, padding, source,
|
||||
lax._zero(operand), scatter)
|
||||
out = lax_windowed_reductions._select_and_scatter(
|
||||
operand, select, window_dimensions, window_strides, padding, source,
|
||||
lax._zero(operand), scatter)
|
||||
if expand_padding:
|
||||
start_indices = [lo for (lo, hi) in original_padding]
|
||||
stop_indices = [lo + d for ((lo, hi), d) in zip(original_padding,
|
||||
@ -1927,33 +1962,6 @@ def _remat_lowering(ctx, avals_in, avals_out, *args,
|
||||
|
||||
translations[pe.remat_call_p] = _remat_lowering
|
||||
|
||||
def _fallback_lowering(prim: core.Primitive, ctx: LoweringContext,
|
||||
avals_in, avals_out, *args, **params):
|
||||
xla_computation = xla.primitive_subcomputation(ctx.platform, prim, *avals_in,
|
||||
**params)
|
||||
submodule_str = xc._xla.mlir.xla_computation_to_mlir_module(xla_computation)
|
||||
submodule = ir.Module.parse(submodule_str)
|
||||
callee_name = None
|
||||
for op in submodule.body.operations:
|
||||
ctx.module.body.append(op)
|
||||
if op.name.value == "main":
|
||||
callee_name = ir.StringAttr(ctx.symbol_table.insert(op)).value
|
||||
op.attributes["sym_visibility"] = ir.StringAttr.get("private")
|
||||
else:
|
||||
ctx.symbol_table.insert(op)
|
||||
|
||||
output_types = map(aval_to_ir_types, avals_out)
|
||||
flat_output_types = util.flatten(output_types)
|
||||
output_type = (ir.TupleType.get_tuple(flat_output_types)
|
||||
if prim.multiple_results else flat_output_types[0])
|
||||
|
||||
call = std.CallOp([output_type], ir.FlatSymbolRefAttr.get(callee_name),
|
||||
_flatten_lowering_ir_args(args)).result
|
||||
if not prim.multiple_results:
|
||||
return [call]
|
||||
flat_results = [mhlo.GetTupleElementOp(typ, call, _i32_attr(i)).result
|
||||
for i, typ in enumerate(flat_output_types)]
|
||||
return util.unflatten(flat_results, map(len, output_types))
|
||||
|
||||
def add_fallback_lowering(prim: core.Primitive):
|
||||
translations[prim] = partial(_fallback_lowering, prim)
|
||||
@ -1990,6 +1998,5 @@ map(add_fallback_lowering, [
|
||||
lax.top_k_p,
|
||||
|
||||
# TODO(phawkins): implement these lax ops:
|
||||
lax_windowed_reductions.select_and_gather_add_p,
|
||||
lax.rng_bit_generator_p,
|
||||
])
|
||||
|
Loading…
x
Reference in New Issue
Block a user