[MLIR] Change signature of lowering rules.

Refactoring only, no functional changes intended.

Previously the MLIR lowering rule signature was

```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```

where `ctx` was a module-wide context.

Change it to

```
def rule(ctx, *args, **jaxpr_params)
```

where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.

This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.

PiperOrigin-RevId: 416698663
This commit is contained in:
Peter Hawkins 2021-12-15 19:06:26 -08:00 committed by jax authors
parent 46c2839258
commit a87b21148c
12 changed files with 220 additions and 204 deletions

View File

@ -738,7 +738,7 @@ ad.deflinear2(device_put_p, lambda cotangent, _, **kwargs: [cotangent])
masking.defvectorized(device_put_p)
batching.defvectorized(device_put_p)
def _device_put_lowering(ctx, avals_in, avals_out, x, *, device):
def _device_put_lowering(ctx, x, *, device):
return [x]

View File

@ -640,7 +640,7 @@ def _pred_bcast_select_mhlo(
return mhlo.SelectOp(bcast_pred, x, y).results
def _while_lowering(ctx, avals_in, avals_out, *args, cond_jaxpr,
def _while_lowering(ctx, *args, cond_jaxpr,
body_jaxpr, cond_nconsts, body_nconsts):
pred_aval = cond_jaxpr.out_avals[0]
batched = bool(pred_aval.shape)
@ -651,7 +651,7 @@ def _while_lowering(ctx, avals_in, avals_out, *args, cond_jaxpr,
# computation), we build XLA computations that handle the tuple munging before
# generating a Call into the computations formed from the jaxprs.
loop_carry_types = _map(mlir.aval_to_ir_types, avals_in)
loop_carry_types = _map(mlir.aval_to_ir_types, ctx.avals_in)
flat_loop_carry_types = util.flatten(loop_carry_types)
loop_carry_tuple_type = ir.TupleType.get_tuple(flat_loop_carry_types)
@ -668,15 +668,17 @@ def _while_lowering(ctx, avals_in, avals_out, *args, cond_jaxpr,
for i, input_type in enumerate(flat_loop_carry_types)]
cond_args = util.unflatten(flat_cond_args, _map(len, loop_carry_types))
x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts])
cond_ctx = ctx.replace(
name_stack=xla.extend_name_stack(ctx.name_stack, 'cond'))
cond_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack, 'cond'))
(pred,), = mlir.jaxpr_subcomp(
cond_ctx, cond_jaxpr.jaxpr, _map(mlir.ir_constants, cond_jaxpr.consts),
*(x + z))
if batched:
pred_ctx = mlir.LoweringRuleContext(
module_context=ctx.module_context, avals_in=[pred_aval],
avals_out=[pred_aval.update(shape=())])
pred, = lax._unary_reduce_lower(
mhlo.OrOp, lambda dtype: np.array(False, dtype), ctx, [pred_aval],
[pred_aval.update(shape=())], pred,
mhlo.OrOp, lambda dtype: np.array(False, dtype), pred_ctx, pred,
axes=tuple(range(len(pred_aval.shape))))
mhlo.ReturnOp([pred])
@ -689,14 +691,15 @@ def _while_lowering(ctx, avals_in, avals_out, *args, cond_jaxpr,
for i, input_type in enumerate(flat_loop_carry_types)]
body_args = util.unflatten(flat_body_args, _map(len, loop_carry_types))
x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts])
body_ctx = ctx.replace(
name_stack=xla.extend_name_stack(ctx.name_stack, 'body'))
body_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack, 'body'))
new_z = mlir.jaxpr_subcomp(
body_ctx, body_jaxpr.jaxpr, _map(mlir.ir_constants, body_jaxpr.consts),
*(y + z))
if batched:
body_pred_ctx = ctx.replace(
name_stack=xla.extend_name_stack(ctx.name_stack, 'body_pred'))
body_pred_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
'body_pred'))
(body_pred,), = mlir.jaxpr_subcomp(
body_pred_ctx, cond_jaxpr.jaxpr,
_map(mlir.ir_constants, cond_jaxpr.consts), *(x + z))
@ -1318,11 +1321,11 @@ xla.register_translation(cond_p, _cond_translation_rule, initial_style=True)
core.custom_typechecks[cond_p] = _cond_typecheck
pe.partial_eval_jaxpr_custom_rules[cond_p] = pe.partial_eval_jaxpr_custom_rule_not_implemented
def _cond_lowering(ctx, avals_in, avals_out, index, *args, branches, linear):
def _cond_lowering(ctx, index, *args, branches, linear):
del linear # Unused.
arg_avals = avals_in[1:]
arg_avals = ctx.avals_in[1:]
input_types = _map(mlir.aval_to_ir_types, arg_avals)
output_types = _map(mlir.aval_to_ir_types, avals_out)
output_types = _map(mlir.aval_to_ir_types, ctx.avals_out)
flat_input_types = util.flatten(input_types)
flat_output_types = util.flatten(output_types)
input_tuple_type = ir.TupleType.get_tuple(flat_input_types)
@ -1340,8 +1343,8 @@ def _cond_lowering(ctx, avals_in, avals_out, index, *args, branches, linear):
mlir.i32_attr(i)).result
for i, input_type in enumerate(flat_input_types)]
unflattened_args = util.unflatten(args, _map(len, input_types))
out_vals = mlir.jaxpr_subcomp(ctx, jaxpr.jaxpr, jaxpr.consts,
*unflattened_args)
out_vals = mlir.jaxpr_subcomp(ctx.module_context, jaxpr.jaxpr,
jaxpr.consts, *unflattened_args)
out = mhlo.TupleOp(output_tuple_type, util.flatten(out_vals)).results
mhlo.ReturnOp(out)

View File

@ -723,12 +723,12 @@ def _complex_mul(mul, x, y):
_real_dtype = lambda dtype: np.finfo(dtype).dtype
def _conv_general_dilated_lower(
ctx, avals_in, avals_out, lhs, rhs, *, window_strides, padding,
ctx, lhs, rhs, *, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count,
batch_group_count, precision, preferred_element_type,
expand_complex_convolutions=False, **unused_kwargs):
lhs_aval, rhs_aval = avals_in
aval_out, = avals_out
lhs_aval, rhs_aval = ctx.avals_in
aval_out, = ctx.avals_out
assert isinstance(dimension_numbers, ConvDimensionNumbers)
dtype = lhs_aval.dtype
if expand_complex_convolutions and np.issubdtype(dtype, np.complexfloating):
@ -746,7 +746,7 @@ def _conv_general_dilated_lower(
batch_group_count=batch_group_count, precision=precision,
preferred_element_type=preferred_element_type)),
multiple_results=False)
return complex_conv(ctx, avals_in, avals_out, lhs, rhs)
return complex_conv(ctx, lhs, rhs)
lhs_spec, rhs_spec, out_spec = dimension_numbers
dnums = mhlo.ConvDimensionNumbers.get(

View File

@ -1444,9 +1444,7 @@ def broadcast_mhlo(
out.append(arg)
return out
def _nary_lower_mhlo(op: Callable, ctx: mlir.LoweringContext,
avals_in: Sequence[core.ShapedArray],
avals_out: Sequence[core.ShapedArray],
def _nary_lower_mhlo(op: Callable, ctx,
*args: Union[ir.Value, Sequence[ir.Value]],
explicit_type=False, **params):
"""Lowers an elementwise operator to its MHLO/CHLO equivalent.
@ -1456,8 +1454,8 @@ def _nary_lower_mhlo(op: Callable, ctx: mlir.LoweringContext,
provided?
"""
del params
aval_out, = avals_out
broadcasted_args = broadcast_mhlo(aval_out, avals_in, args)
aval_out, = ctx.avals_out
broadcasted_args = broadcast_mhlo(aval_out, ctx.avals_in, args)
if explicit_type:
return op(mlir.aval_to_ir_type(aval_out), *broadcasted_args).results
else:
@ -1494,8 +1492,8 @@ def _sign_translation_rule(ctx, avals_in, avals_out, x):
sign_p = standard_unop(_num, 'sign', translation_rule=_sign_translation_rule)
ad.defjvp_zero(sign_p)
def _sign_lower_mhlo(ctx, avals_in, avals_out, x):
x_aval, = avals_in
def _sign_lower_mhlo(ctx, x):
x_aval, = ctx.avals_in
if dtypes.issubdtype(x_aval.dtype, np.unsignedinteger):
return mhlo.SelectOp(
mhlo.CompareOp(
@ -1546,14 +1544,14 @@ round_p = standard_unop(_float, 'round')
xla.register_translation(round_p, _round_translation_rule)
ad.defjvp_zero(round_p)
def _round_lower(ctx, avals_in, avals_out, x, *, rounding_method):
def _round_lower(ctx, x, *, rounding_method):
if rounding_method is RoundingMethod.AWAY_FROM_ZERO:
return mhlo.RoundOp(x).results
else:
assert rounding_method is RoundingMethod.TO_NEAREST_EVEN
round_nearest = mlir.lower_fun(_round_to_nearest_even,
multiple_results=False)
return round_nearest(ctx, avals_in, avals_out, x)
return round_nearest(ctx, x)
mlir.register_lowering(round_p, _round_lower)
is_finite_p = unop(_fixed_dtype(np.bool_), _float, 'is_finite')
@ -2113,10 +2111,11 @@ ad.defjvp_zero(shift_right_logical_p)
mlir.register_lowering(shift_right_logical_p,
partial(_nary_lower_mhlo, mhlo.ShiftRightLogicalOp))
def _compare_lower_mhlo(direction: str, ctx, avals_in, avals_out, x, y):
x_aval, y_aval = avals_in
aval_out, = avals_out
x, y = broadcast_mhlo(aval_out.update(dtype=x_aval.dtype), avals_in, (x, y))
def _compare_lower_mhlo(direction: str, ctx, x, y):
x_aval, y_aval = ctx.avals_in
aval_out, = ctx.avals_out
x, y = broadcast_mhlo(aval_out.update(dtype=x_aval.dtype), ctx.avals_in,
(x, y))
if dtypes.issubdtype(x_aval.dtype, np.inexact):
compare_type = "FLOAT"
elif dtypes.issubdtype(x_aval.dtype, np.signedinteger):
@ -2222,10 +2221,9 @@ pe.forwarding_rules[convert_element_type_p] = _convert_elt_type_fwd_rule
def _real_dtype(dtype): return np.finfo(dtype).dtype
def _convert_element_type_lower(ctx, avals_in, avals_out, operand, *,
new_dtype, weak_type):
aval_in, = avals_in
aval_out, = avals_out
def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type):
aval_in, = ctx.avals_in
aval_out, = ctx.avals_out
if (dtypes.issubdtype(aval_in.dtype, np.complexfloating) and
not dtypes.issubdtype(new_dtype, np.complexfloating)):
operand = mhlo.RealOp(operand).result
@ -2260,9 +2258,8 @@ ad.defjvp_zero(bitcast_convert_type_p)
batching.defvectorized(bitcast_convert_type_p)
masking.defvectorized(bitcast_convert_type_p)
def _bitcast_convert_type_lower(ctx, avals_in, avals_out, operand, *,
new_dtype):
aval_out, = avals_out
def _bitcast_convert_type_lower(ctx, operand, *, new_dtype):
aval_out, = ctx.avals_out
return mhlo.BitcastConvertOp(mlir.aval_to_ir_type(aval_out), operand).results
mlir.register_lowering(bitcast_convert_type_p, _bitcast_convert_type_lower)
@ -2522,15 +2519,15 @@ def precision_attr(precision: PrecisionType) -> ir.ArrayAttr:
precision = (precision, precision)
return ir.ArrayAttr.get([ir.StringAttr.get(str(p)) for p in precision])
def _dot_general_lower(ctx, avals_in, avals_out, lhs, rhs, *, dimension_numbers,
def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
precision, preferred_element_type: Optional[np.dtype]):
del preferred_element_type # Implied by the output aval
lhs_aval, rhs_aval = avals_in
aval_out, = avals_out
lhs_aval, rhs_aval = ctx.avals_in
aval_out, = ctx.avals_out
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
# TODO(b/195364460): Work around slow XLA/CPU implementation of float16 matmul
if ctx.platform == "cpu":
if ctx.module_context.platform == "cpu":
if lhs_aval.dtype == np.float16:
f32 = mlir.dtype_to_ir_type(np.dtype(np.float32))
lhs = mhlo.ConvertOp(ir.RankedTensorType.get(lhs_aval.shape, f32),
@ -2614,10 +2611,9 @@ ad.deflinear2(broadcast_in_dim_p, _broadcast_in_dim_transpose_rule)
batching.primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule
pe.forwarding_rules[broadcast_in_dim_p] = _broadcast_in_dim_fwd_rule
def _broadcast_in_dim_lower(ctx, avals_in, avals_out, x, *, shape,
broadcast_dimensions):
def _broadcast_in_dim_lower(ctx, x, *, shape, broadcast_dimensions):
del shape
aval_out, = avals_out
aval_out, = ctx.avals_out
return mhlo.BroadcastInDimOp(
mlir.aval_to_ir_type(aval_out), x,
mlir.dense_int_elements(broadcast_dimensions)
@ -2760,7 +2756,7 @@ ad.deflinear2(concatenate_p, _concatenate_transpose_rule)
ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule
batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule
def _concatenate_lower(ctx, avals_in, avals_out, *xs, dimension):
def _concatenate_lower(ctx, *xs, dimension):
return mhlo.ConcatenateOp(xs, mlir.i64_attr(dimension)).results
mlir.register_lowering(concatenate_p, _concatenate_lower)
@ -2852,8 +2848,8 @@ ad.deflinear2(pad_p, _pad_transpose)
batching.primitive_batchers[pad_p] = _pad_batch_rule
masking.masking_rules[pad_p] = _pad_masking_rule
def _pad_lower(ctx, avals_in, avals_out, x, padding_value, *, padding_config):
aval_out, = avals_out
def _pad_lower(ctx, x, padding_value, *, padding_config):
aval_out, = ctx.avals_out
low, high, interior = util.unzip3(padding_config)
return mhlo.PadOp(mlir.aval_to_ir_type(aval_out), x, padding_value,
mlir.dense_int_elements(low),
@ -2910,9 +2906,9 @@ squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule,
ad.deflinear2(squeeze_p, _squeeze_transpose_rule)
batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule
def _squeeze_lower(ctx, avals_in, avals_out, operand, *, dimensions):
def _squeeze_lower(ctx, operand, *, dimensions):
del dimensions # Implied by the output aval.
aval_out, = avals_out
aval_out, = ctx.avals_out
return mhlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), operand).results
mlir.register_lowering(squeeze_p, _squeeze_lower)
@ -3014,9 +3010,9 @@ ad.deflinear2(reshape_p, _reshape_transpose_rule)
batching.primitive_batchers[reshape_p] = _reshape_batch_rule
masking.masking_rules[reshape_p] = _reshape_masking_rule
def _reshape_lower(ctx, avals_in, avals_out, x, *, new_sizes, dimensions):
aval_in, = avals_in
aval_out, = avals_out
def _reshape_lower(ctx, x, *, new_sizes, dimensions):
aval_in, = ctx.avals_in
aval_out, = ctx.avals_out
if dimensions is not None:
aval = core.ShapedArray(np.take(aval_in.shape, dimensions), aval_in.dtype)
x = mhlo.TransposeOp(mlir.aval_to_ir_type(aval), x,
@ -3045,7 +3041,7 @@ rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev')
ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)])
batching.primitive_batchers[rev_p] = _rev_batch_rule
def _rev_lower(ctx, avals_in, avals_out, x, *, dimensions):
def _rev_lower(ctx, x, *, dimensions):
return mhlo.ReverseOp(x, mlir.dense_int_elements(dimensions)).results
mlir.register_lowering(rev_p, _rev_lower)
@ -3076,8 +3072,8 @@ ad.deflinear2(transpose_p,
batching.primitive_batchers[transpose_p] = _transpose_batch_rule
masking.masking_rules[transpose_p] = _transpose_masking_rule
def _transpose_lower(ctx, avals_in, avals_out, x, *, permutation):
aval_out, = avals_out
def _transpose_lower(ctx, x, *, permutation):
aval_out, = ctx.avals_out
return mhlo.TransposeOp(mlir.aval_to_ir_type(aval_out), x,
mlir.dense_int_elements(permutation)).results
mlir.register_lowering(transpose_p, _transpose_lower)
@ -3338,18 +3334,17 @@ xla.register_translation(reduce_p, _reduce_translation_rule)
batching.primitive_batchers[reduce_p] = _reduce_batch_rule
ad.primitive_jvps[reduce_p] = _reduce_jvp_rule
def _reduce_lower(ctx, avals_in, avals_out, *values, computation, jaxpr,
consts, dimensions):
assert all(isinstance(x, core.ShapedArray) for x in avals_in), avals_in
def _reduce_lower(ctx, *values, computation, jaxpr, consts, dimensions):
assert all(isinstance(x, core.ShapedArray) for x in ctx.avals_in), ctx.avals_in
operands, init_values = util.split_list(values, [len(values) // 2])
init_value_avals = avals_in[len(values) // 2:]
op = mhlo.ReduceOp([mlir.aval_to_ir_type(aval) for aval in avals_out],
init_value_avals = ctx.avals_in[len(values) // 2:]
op = mhlo.ReduceOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
operands, init_values, mlir.dense_int_elements(dimensions))
ir_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
reducer = op.regions[0].blocks.append(*(ir_types + ir_types))
with ir.InsertionPoint(reducer):
ctx = ctx.replace(name_stack='')
out_nodes = mlir.jaxpr_subcomp(ctx, jaxpr, consts,
reducer_ctx = ctx.module_context.replace(name_stack='')
out_nodes = mlir.jaxpr_subcomp(reducer_ctx, jaxpr, consts,
*([a] for a in reducer.arguments))
mhlo.ReturnOp(util.flatten(out_nodes))
return op.results
@ -3566,9 +3561,8 @@ reduce_and_p = standard_primitive(_reduce_logical_shape_rule, _fixed_dtype(np.bo
batching.defreducer(reduce_and_p)
def _unary_reduce_lower(reducer, unit_factory, ctx, avals_in, avals_out, x, *,
axes):
aval_out, = avals_out
def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes):
aval_out, = ctx.avals_out
dtype = aval_out.dtype
op = mhlo.ReduceOp([mlir.aval_to_ir_type(aval_out)], [x],
mlir.ir_constants(unit_factory(aval_out.dtype)),
@ -3613,9 +3607,8 @@ reduce_precision_p = standard_primitive(
batching.defvectorized(reduce_precision_p)
masking.defvectorized(reduce_precision_p)
def _reduce_precision_lower(ctx, avals_in, avals_out, operand, *, exponent_bits,
mantissa_bits):
aval_out, = avals_out
def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits):
aval_out, = ctx.avals_out
return mhlo.ReducePrecisionOp(mlir.aval_to_ir_type(aval_out), operand,
mlir.i32_attr(exponent_bits),
mlir.i32_attr(mantissa_bits)).results
@ -3754,22 +3747,24 @@ ad.primitive_jvps[sort_p] = _sort_jvp
batching.primitive_batchers[sort_p] = _sort_batch_rule
def _sort_lower(ctx, avals_in, avals_out, *operands, dimension, is_stable,
num_keys):
assert all(isinstance(x, core.ShapedArray) for x in avals_in), avals_in
sort = mhlo.SortOp([mlir.aval_to_ir_type(aval) for aval in avals_out],
def _sort_lower(ctx, *operands, dimension, is_stable, num_keys):
assert all(isinstance(x, core.ShapedArray) for x in ctx.avals_in), ctx.avals_in
sort = mhlo.SortOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
mlir.flatten_lowering_ir_args(operands),
mlir.i64_attr(dimension), ir.BoolAttr.get(is_stable))
scalar_avals = [aval.update(shape=()) for aval in avals_in]
scalar_avals = [aval.update(shape=()) for aval in ctx.avals_in]
scalar_types = safe_map(mlir.aval_to_ir_type, scalar_avals)
comparator = sort.comparator.blocks.append(
*util.flatten(zip(scalar_types, scalar_types)))
with ir.InsertionPoint(comparator):
lower_comparator = mlir.lower_fun(partial(_sort_lt_comparator),
multiple_results=False)
out = lower_comparator(ctx, util.flatten(zip(scalar_avals, scalar_avals)),
[core.ShapedArray((), np.bool_)],
*[[a] for a in comparator.arguments],
sub_ctx = mlir.LoweringRuleContext(
module_context = ctx.module_context,
avals_in=util.flatten(zip(scalar_avals, scalar_avals)),
avals_out=[core.ShapedArray((), np.bool_)])
out = lower_comparator(sub_ctx, *[[a] for a in comparator.arguments],
num_keys=num_keys)
mhlo.ReturnOp(util.flatten(out))
return sort.results
@ -3870,8 +3865,8 @@ create_token_p.def_abstract_eval(lambda *_: abstract_token)
xla.register_translation(create_token_p,
lambda ctx, *_: [xops.CreateToken(ctx.builder)])
def _create_token_lowering(ctx, avals_in, avals_out, *operands):
aval_out, = avals_out
def _create_token_lowering(ctx, *operands):
aval_out, = ctx.avals_out
return mhlo.CreateTokenOp(mlir.aval_to_ir_type(aval_out)).results
mlir.register_lowering(create_token_p, _create_token_lowering)
@ -3897,8 +3892,8 @@ after_all_p.def_impl(partial(xla.apply_primitive, after_all_p))
after_all_p.def_abstract_eval(_after_all_abstract_eval)
xla.register_translation(after_all_p, _after_all_translation_rule)
def _after_all_lowering(ctx, avals_in, avals_out, *operands):
aval_out, = avals_out
def _after_all_lowering(ctx, *operands):
aval_out, = ctx.avals_out
return mhlo.AfterAllOp(mlir.aval_to_ir_type(aval_out), operands).results
mlir.register_lowering(after_all_p, _after_all_lowering)
@ -3956,8 +3951,8 @@ infeed_p.def_abstract_eval(_infeed_abstract_eval)
xla.register_translation(infeed_p, _infeed_translation_rule)
def _infeed_lowering(ctx, avals_in, avals_out, token, *, shapes, partitions):
output_types = safe_map(mlir.aval_to_ir_types, avals_out[:-1])
def _infeed_lowering(ctx, token, *, shapes, partitions):
output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out[:-1])
flat_output_types = util.flatten(output_types)
output_tuple_type = ir.TupleType.get_tuple(flat_output_types)
# TODO(phawkins): verify `shapes` have a major-to-minor layout.
@ -4020,9 +4015,9 @@ outfeed_p.def_abstract_eval(_outfeed_abstract_eval)
xla.register_translation(outfeed_p, _outfeed_translation_rule)
def _outfeed_lowering(ctx, avals_in, avals_out, token, *xs, partitions):
token_aval = avals_in[0]
xs_avals = avals_in[1:]
def _outfeed_lowering(ctx, token, *xs, partitions):
token_aval = ctx.avals_in[0]
xs_avals = ctx.avals_in[1:]
input_types = map(mlir.aval_to_ir_types, xs_avals)
flat_input_types = util.flatten(input_types)
input_tuple_type = ir.TupleType.get_tuple(flat_input_types)
@ -4070,10 +4065,10 @@ rng_uniform_p.def_impl(partial(xla.apply_primitive, rng_uniform_p))
rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval)
xla.register_translation(rng_uniform_p, _rng_uniform_translation_rule)
def _rng_uniform_lowering(ctx, avals_in, avals_out, a, b, *, shape):
aval_out, = avals_out
def _rng_uniform_lowering(ctx, a, b, *, shape):
aval_out, = ctx.avals_out
shape, = mlir.ir_constants(np.array(aval_out.shape, np.int64),
canonicalize_types=False)
canonicalize_types=False)
return mhlo.RngUniformOp(a, b, shape).results
mlir.register_lowering(rng_uniform_p, _rng_uniform_lowering)
@ -4177,9 +4172,9 @@ iota_p.def_impl(partial(xla.apply_primitive, iota_p))
iota_p.def_abstract_eval(_iota_abstract_eval)
xla.register_translation(iota_p, _iota_translation_rule)
def _iota_lower(ctx, avals_in, avals_out, *, dtype, shape, dimension):
def _iota_lower(ctx, *, dtype, shape, dimension):
del dtype, shape
aval_out, = avals_out
aval_out, = ctx.avals_out
return mhlo.IotaOp(mlir.aval_to_ir_type(aval_out),
mlir.i64_attr(dimension)).results
mlir.register_lowering(iota_p, _iota_lower)

View File

@ -747,9 +747,8 @@ ad.deflinear2(slice_p, _slice_transpose_rule)
batching.primitive_batchers[slice_p] = _slice_batching_rule
masking.masking_rules[slice_p] = _slice_masking_rule
def _slice_lower(ctx, avals_in, avals_out, x, *, start_indices,
limit_indices, strides):
aval_out, = avals_out
def _slice_lower(ctx, x, *, start_indices, limit_indices, strides):
aval_out, = ctx.avals_out
strides = strides or [1] * len(start_indices)
return mhlo.SliceOp(x,
mlir.dense_int_elements(start_indices),
@ -848,9 +847,8 @@ ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp # TODO
ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule
batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule
def _dynamic_slice_lower(ctx, avals_in, avals_out, x, *start_indices,
slice_sizes):
aval_out, = avals_out
def _dynamic_slice_lower(ctx, x, *start_indices, slice_sizes):
aval_out, = ctx.avals_out
return mhlo.DynamicSliceOp(mlir.aval_to_ir_type(aval_out), x,
start_indices,
mlir.dense_int_elements(slice_sizes)).results
@ -947,9 +945,8 @@ ad.primitive_transposes[dynamic_update_slice_p] = \
batching.primitive_batchers[dynamic_update_slice_p] = \
_dynamic_update_slice_batching_rule
def _dynamic_update_slice_lower(ctx, avals_in, avals_out, x, update,
*start_indices):
aval_out, = avals_out
def _dynamic_update_slice_lower(ctx, x, update, *start_indices):
aval_out, = ctx.avals_out
return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update,
start_indices).results
@ -1280,14 +1277,14 @@ batching.primitive_batchers[gather_p] = _gather_batching_rule
def _gather_lower(ctx, avals_in, avals_out, operand, indices, *,
def _gather_lower(ctx, operand, indices, *,
dimension_numbers, slice_sizes, unique_indices,
indices_are_sorted, mode, fill_value):
aval_out, = avals_out
aval_out, = ctx.avals_out
if mode == GatherScatterMode.FILL_OR_DROP:
gather_fill_fn = mlir.lower_fun(_gather_fill, multiple_results=False)
return gather_fill_fn(
ctx, avals_in, avals_out, operand, indices,
ctx, operand, indices,
dimension_numbers=dimension_numbers, slice_sizes=slice_sizes,
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
fill_value=fill_value, output_shape=aval_out.shape)
@ -1296,7 +1293,7 @@ def _gather_lower(ctx, avals_in, avals_out, operand, indices, *,
GatherScatterMode.CLIP), mode
dnums = mhlo.GatherDimensionNumbers.get(
collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims),
index_vector_dim=len(avals_in[1].shape) - 1,
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
offset_dims=list(dimension_numbers.offset_dims),
start_index_map=list(dimension_numbers.start_index_map))
return mhlo.GatherOp(operand, indices, dnums,
@ -1945,30 +1942,30 @@ batching.primitive_batchers[scatter_p] = (
def _scatter_lower(ctx, avals_in, avals_out, operand, indices, updates, *,
def _scatter_lower(ctx, operand, indices, updates, *,
update_jaxpr, update_consts, dimension_numbers,
indices_are_sorted, unique_indices, mode):
if mode == GatherScatterMode.CLIP:
clip_fn = mlir.lower_fun(_clamp_scatter_indices, multiple_results=False)
(indices,), = clip_fn(ctx, avals_in, None, operand, indices, updates,
dnums=dimension_numbers)
(indices,), = clip_fn(ctx.replace(avals_out=None), operand, indices,
updates, dnums=dimension_numbers)
aval_out, = avals_out
aval_out, = ctx.avals_out
dnums = dimension_numbers
scatter_dnums = mhlo.ScatterDimensionNumbers.get(
update_window_dims=list(dnums.update_window_dims),
inserted_window_dims=list(dnums.inserted_window_dims),
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
index_vector_dim=len(avals_in[1].shape) - 1)
index_vector_dim=len(ctx.avals_in[1].shape) - 1)
op = mhlo.ScatterOp(mlir.aval_to_ir_type(aval_out), operand, indices, updates,
scatter_dnums, ir.BoolAttr.get(indices_are_sorted),
ir.BoolAttr.get(unique_indices))
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), aval_out.dtype))
update = op.update_computation.blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(update):
ctx = ctx.replace(name_stack='')
update_ctx = ctx.module_context.replace(name_stack='')
out_nodes = mlir.jaxpr_subcomp(
ctx, update_jaxpr, update_consts,
update_ctx, update_jaxpr, update_consts,
(update.arguments[0],), (update.arguments[1],))
mhlo.ReturnOp(util.flatten(out_nodes))
return op.results
@ -1982,12 +1979,12 @@ mlir.register_lowering(scatter_max_p, _scatter_lower)
def _real_dtype(dtype): return np.finfo(dtype).dtype
def _scatter_add_lower_gpu(ctx, avals_in, avals_out, operand, indices, updates,
def _scatter_add_lower_gpu(ctx, operand, indices, updates,
*, update_jaxpr, update_consts, dimension_numbers,
indices_are_sorted, unique_indices, mode):
operand_aval_in, _, updates_aval_in = avals_in
operand_aval_in, _, updates_aval_in = ctx.avals_in
if operand_aval_in.dtype != np.complex128:
return _scatter_lower(ctx, avals_in, avals_out, operand, indices, updates,
return _scatter_lower(ctx, operand, indices, updates,
update_jaxpr=update_jaxpr,
update_consts=update_consts,
dimension_numbers=dimension_numbers,
@ -1996,16 +1993,16 @@ def _scatter_add_lower_gpu(ctx, avals_in, avals_out, operand, indices, updates,
if mode == GatherScatterMode.CLIP:
clip_fn = mlir.lower_fun(_clamp_scatter_indices, multiple_results=False)
(indices,), = clip_fn(ctx, avals_in, None, operand, indices, updates,
(indices,), = clip_fn(ctx, ctx.avals_in, None, operand, indices, updates,
dnums=dimension_numbers)
aval_out, = avals_out
aval_out, = ctx.avals_out
dnums = dimension_numbers
scatter_dnums = mhlo.ScatterDimensionNumbers.get(
update_window_dims=list(dnums.update_window_dims),
inserted_window_dims=list(dnums.inserted_window_dims),
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
index_vector_dim=len(avals_in[1].shape) - 1)
index_vector_dim=len(ctx.avals_in[1].shape) - 1)
real_dtype = _real_dtype(aval_out.dtype)
operand_type_part = mlir.aval_to_ir_type(
core.ShapedArray(aval_out.shape, real_dtype))

View File

@ -308,14 +308,14 @@ reduce_window_p.def_abstract_eval(_reduce_window_abstract_eval_rule)
batching.primitive_batchers[reduce_window_p] = _generic_reduce_window_batch_rule
xla.register_translation(reduce_window_p, _reduce_window_translation_rule)
def _generic_reduce_window_lower(ctx, avals_in, avals_out, *args, jaxpr, consts,
def _generic_reduce_window_lower(ctx, *args, jaxpr, consts,
window_dimensions, window_strides, padding,
base_dilation, window_dilation):
operands, init_values = util.split_list(args, [len(args) // 2])
_, init_value_avals = util.split_list(avals_in, [len(operands)])
_, init_value_avals = util.split_list(ctx.avals_in, [len(operands)])
scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
rw = mhlo.ReduceWindowOp(
map(mlir.aval_to_ir_type, avals_out), operands, init_values,
map(mlir.aval_to_ir_type, ctx.avals_out), operands, init_values,
mlir.dense_int_elements(window_dimensions),
mlir.dense_int_elements(window_strides),
mlir.dense_int_elements(base_dilation),
@ -323,7 +323,7 @@ def _generic_reduce_window_lower(ctx, avals_in, avals_out, *args, jaxpr, consts,
ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
with ir.InsertionPoint(reducer):
out_nodes = mlir.jaxpr_subcomp(ctx, jaxpr, consts,
out_nodes = mlir.jaxpr_subcomp(ctx.module_context, jaxpr, consts,
*([a] for a in reducer.arguments))
mhlo.ReturnOp(util.flatten(out_nodes))
return rw.results
@ -488,10 +488,10 @@ batching.primitive_batchers[reduce_window_min_p] = partial(
def _reduce_window_lower(
reduce_op, init_value, ctx, avals_in, avals_out, operand, *,
reduce_op, init_value, ctx, operand, *,
window_dimensions, window_strides, padding, base_dilation, window_dilation):
aval_out, = avals_out
operand_aval, = avals_in
aval_out, = ctx.avals_out
operand_aval, = ctx.avals_in
scalar_aval = operand_aval.update(shape=())
scalar_type = mlir.aval_to_ir_type(scalar_aval)
rw = mhlo.ReduceWindowOp(
@ -545,11 +545,11 @@ select_and_scatter_p = lax.standard_primitive(
_select_and_scatter_translation)
def _select_and_scatter_lower(
ctx, avals_in, avals_out, operand, source, init_value, *, select_jaxpr,
ctx, operand, source, init_value, *, select_jaxpr,
select_consts, scatter_jaxpr, scatter_consts, window_dimensions,
window_strides, padding):
operand_aval, source_aval, init_value_aval = avals_in
aval_out, = avals_out
operand_aval, source_aval, init_value_aval = ctx.avals_in
aval_out, = ctx.avals_out
scalar_aval = operand_aval.update(shape=())
scalar_type = mlir.aval_to_ir_type(scalar_aval)
op = mhlo.SelectAndScatterOp(
@ -559,12 +559,14 @@ def _select_and_scatter_lower(
ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
select = op.select.blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(select):
out_nodes = mlir.jaxpr_subcomp(ctx, select_jaxpr, select_consts,
out_nodes = mlir.jaxpr_subcomp(ctx.module_context, select_jaxpr,
select_consts,
*([a] for a in select.arguments))
mhlo.ReturnOp(util.flatten(out_nodes))
scatter = op.scatter.blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(scatter):
out_nodes = mlir.jaxpr_subcomp(ctx, scatter_jaxpr, scatter_consts,
out_nodes = mlir.jaxpr_subcomp(ctx.module_context, scatter_jaxpr,
scatter_consts,
*([a] for a in scatter.arguments))
mhlo.ReturnOp(util.flatten(out_nodes))
return op.results
@ -679,8 +681,9 @@ xla.register_translation(
platform='gpu')
def _select_and_scatter_add_impl(source, operand, *, select_prim, window_dimensions,
window_strides, padding, expand_padding):
def _select_and_scatter_add_impl(source, operand, *,
select_prim, window_dimensions, window_strides,
padding, expand_padding):
dtype = source.dtype
select = lambda x, y: select_prim.bind(x, y)
scatter = lax.bitwise_or if dtype == np.bool_ else lax.add

View File

@ -1303,14 +1303,14 @@ def _xmap_lowering_rule(*args, **kwargs):
return _xmap_lowering_rule_replica(*args, **kwargs)
mlir.register_lowering(xmap_p, _xmap_lowering_rule)
def _xmap_lowering_rule_replica(ctx, avals_in, avals_out, *in_nodes,
call_jaxpr, name,
in_axes, out_axes, donated_invars,
global_axis_sizes,
spmd_in_axes, spmd_out_axes,
positional_semantics,
axis_resources, resource_env, backend):
xla.check_backend_matches(backend, ctx.platform)
def _xmap_lowering_rule_replica(ctx, *in_nodes,
call_jaxpr, name,
in_axes, out_axes, donated_invars,
global_axis_sizes,
spmd_in_axes, spmd_out_axes,
positional_semantics,
axis_resources, resource_env, backend):
xla.check_backend_matches(backend, ctx.module_context.platform)
# The only way for any of those two assertions to be violated is when xmap
# is using the SPMD lowering, but then this rule shouldn't even trigger.
assert positional_semantics == _PositionalSemantics.LOCAL
@ -1346,30 +1346,34 @@ def _xmap_lowering_rule_replica(ctx, avals_in, avals_out, *in_nodes,
assert not consts
tiled_ins = (
mlir.lower_fun(
partial(_tile, in_axes=arg_in_axes, axis_sizes=local_mesh_shape),
multiple_results=False)(ctx, [aval], None, in_node)[0]
if v.aval is not core.abstract_unit else in_node
for v, aval, in_node, arg_in_axes
in zip(call_jaxpr.invars, avals_in, in_nodes, mesh_in_axes))
mlir.lower_fun(partial(_tile, in_axes=arg_in_axes,
axis_sizes=local_mesh_shape),
multiple_results=False)(
mlir.LoweringRuleContext(module_context=ctx.module_context,
avals_in=[aval], avals_out=None),
in_node)[0]
if v.aval is not core.abstract_unit else in_node
for v, aval, in_node, arg_in_axes
in zip(call_jaxpr.invars, ctx.avals_in, in_nodes, mesh_in_axes))
# NOTE: We don't extend the resource env with the mesh shape, because those
# resources are already in scope! It's the outermost xmap that introduces
# them!
# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
sub_ctx = ctx.replace(
name_stack=xla.extend_name_stack(ctx.name_stack,
sub_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
xla.wrap_name(name, 'xmap')))
tiled_outs = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr, (), *tiled_ins)
outs = [
mlir.lower_fun(
partial(_untile, out_axes=ans_out_axes, axis_sizes=local_mesh_shape,
platform=ctx.platform),
platform=ctx.module_context.platform),
multiple_results=False)(
ctx, [vectorized_outvar.aval], None, tiled_out
)[0]
mlir.LoweringRuleContext(module_context=ctx.module_context,
avals_in=[vectorized_outvar.aval],
avals_out=None), tiled_out)[0]
if v.aval is not core.abstract_unit else tiled_out
for v, vectorized_outvar, tiled_out, ans_out_axes
in zip(call_jaxpr.outvars, vectorized_jaxpr.outvars, tiled_outs,
@ -1377,12 +1381,12 @@ def _xmap_lowering_rule_replica(ctx, avals_in, avals_out, *in_nodes,
return outs
def _xmap_lowering_rule_spmd(ctx, avals_in, avals_out, *global_in_nodes,
def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
call_jaxpr, name, in_axes, out_axes,
donated_invars, global_axis_sizes, spmd_in_axes,
spmd_out_axes, positional_semantics,
axis_resources, resource_env, backend):
xla.check_backend_matches(backend, ctx.platform)
xla.check_backend_matches(backend, ctx.module_context.platform)
plan = EvaluationPlan.from_axis_resources(axis_resources, resource_env, global_axis_sizes)
resource_call_jaxpr = plan.subst_axes_with_resources(call_jaxpr)
@ -1409,7 +1413,7 @@ def _xmap_lowering_rule_spmd(ctx, avals_in, avals_out, *global_in_nodes,
# NOTE: We don't extend the resource env with the mesh shape, because those
# resources are already in scope! It's the outermost xmap that introduces
# them!
global_in_avals = avals_in
global_in_avals = ctx.avals_in
vectorized_jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(f, global_in_avals)
assert not consts
@ -1422,8 +1426,8 @@ def _xmap_lowering_rule_spmd(ctx, avals_in, avals_out, *global_in_nodes,
# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
sub_ctx = ctx.replace(
name_stack=xla.extend_name_stack(ctx.name_stack,
sub_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
xla.wrap_name(name, 'xmap')))
global_out_nodes = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr, (),
*sharded_global_in_nodes)

View File

@ -866,9 +866,9 @@ def _sharding_constraint_translation_rule(ctx, avals_in, avals_out, x_node, *,
ctx.builder, x_node, get_aval_sharding_proto(aval, axis_resources, mesh))]
xla.register_translation(sharding_constraint_p, _sharding_constraint_translation_rule)
def _sharding_constraint_mhlo_lowering(ctx, avals_in, avals_out, x_node, *,
axis_resources, resource_env):
aval, = avals_in
def _sharding_constraint_mhlo_lowering(ctx, x_node, *, axis_resources,
resource_env):
aval, = ctx.avals_in
mesh = resource_env.physical_mesh
return [mlir.wrap_with_sharding_op(
x_node, get_aval_sharding_proto(aval, axis_resources, mesh))]

View File

@ -280,7 +280,8 @@ def _source_info_to_location(
# Translation rules
@dataclasses.dataclass
class LoweringContext:
class ModuleContext:
"""Module-wide context information for MLIR lowering."""
context: ir.Context
module: ir.Module
ip: ir.InsertionPoint
@ -309,11 +310,19 @@ class LoweringContext:
def replace(self, **kw): return dataclasses.replace(self, **kw)
@dataclasses.dataclass
class LoweringRuleContext:
"""Per-rule context information for MLIR lowering."""
module_context: ModuleContext
avals_in: Sequence[core.AbstractValue]
avals_out: Any
def replace(self, **kw): return dataclasses.replace(self, **kw)
if not MYPY:
class LoweringRule(Protocol):
def __call__(self, ctx: LoweringContext,
avals_in: Sequence[core.AbstractValue],
avals_out: Sequence[core.AbstractValue],
def __call__(self, ctx: LoweringRuleContext,
*args: Union[ir.Value, Sequence[ir.Value]],
**kw) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]:
"""Converts a JAX primitive invocation into MLIR."""
@ -367,7 +376,7 @@ def lower_jaxpr_to_module(
warnings.warn("Some donated buffers were not usable: {}".format(
", ".join(unused_donations)))
ctx = LoweringContext(platform, axis_env, name_stack)
ctx = ModuleContext(platform, axis_env, name_stack)
with ctx.context, ir.Location.unknown(ctx.context):
# Some clients expect modules to have unique names, e.g., in trace data.
# This may or may not be a reasonable assumption.
@ -407,7 +416,7 @@ def _set_up_aliases(avals_in, avals_out, donated_args):
return input_output_aliases, out_donated_args
def lower_jaxpr_to_fun(
ctx: LoweringContext, name: str, jaxpr: core.ClosedJaxpr, *,
ctx: ModuleContext, name: str, jaxpr: core.ClosedJaxpr, *,
public: bool = False, replace_units_with_dummy: bool = False,
replace_tokens_with_dummy: bool = False,
replicated_args: Optional[Sequence[bool]] = None,
@ -530,7 +539,7 @@ def lower_jaxpr_to_fun(
return func_op
def jaxpr_subcomp(ctx: LoweringContext, jaxpr: core.Jaxpr,
def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
consts: Sequence[Sequence[ir.Value]],
*args: Sequence[ir.Value]) -> Sequence[Sequence[ir.Value]]:
"""Lowers a jaxpr into mHLO, inlined into an existing function.
@ -578,8 +587,10 @@ def jaxpr_subcomp(ctx: LoweringContext, jaxpr: core.Jaxpr,
f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
f"found for platform {ctx.platform}")
ans = rule(ctx, map(aval, eqn.invars), map(aval, eqn.outvars),
*map(_unwrap_singleton_ir_values, in_nodes),
rule_ctx = LoweringRuleContext(
module_context=ctx, avals_in=map(aval, eqn.invars),
avals_out=map(aval, eqn.outvars))
ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
**eqn.params)
try:
@ -605,15 +616,16 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
The returned function does not use `avals_out`, so callers may pass any value
as `avals_out`."""
def f_lowered(ctx, avals_in, avals_out, *args, **params):
def f_lowered(ctx, *args, **params):
if multiple_results:
f = fun
else:
f = lambda *args, **kw: (fun(*args, **kw),)
wrapped_fun = lu.wrap_init(f, params)
with core.extend_axis_env_nd(zip(ctx.axis_env.names, ctx.axis_env.sizes)):
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals_in)
return jaxpr_subcomp(ctx, jaxpr, _ir_consts(consts),
axis_env = ctx.module_context.axis_env
with core.extend_axis_env_nd(zip(axis_env.names, axis_env.sizes)):
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
return jaxpr_subcomp(ctx.module_context, jaxpr, _ir_consts(consts),
*map(wrap_singleton_ir_values, args))
return f_lowered
@ -634,19 +646,20 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
flatten_lowering_ir_args(args))
return util.unflatten(call.results, map(len, output_types))
def _xla_call_lower(ctx, avals_in, avals_out, *args,
def _xla_call_lower(ctx, *args,
backend=None, name, call_jaxpr, donated_invars, inline=None,
device=None):
del device, donated_invars, inline # Ignored.
return _call_lowering(f"jit_{name}", xla.wrap_name(name, "jit"), call_jaxpr,
backend, ctx, avals_in, avals_out, *args)
backend, ctx.module_context, ctx.avals_in, ctx.avals_out,
*args)
register_lowering(xla.xla_call_p, _xla_call_lower)
def _named_call_lowering(ctx, avals_in, avals_out, *args, name, backend=None,
def _named_call_lowering(ctx, *args, name, backend=None,
call_jaxpr):
return _call_lowering(name, name, call_jaxpr, backend, ctx, avals_in,
avals_out, *args)
return _call_lowering(name, name, call_jaxpr, backend, ctx.module_context,
ctx.avals_in, ctx.avals_out, *args)
register_lowering(core.named_call_p, _named_call_lowering)
register_lowering(core.call_p, partial(_named_call_lowering, name="core_call"))
@ -658,18 +671,17 @@ def full_like_aval(value, aval: core.ShapedArray) -> ir.Value:
return mhlo.BroadcastOp(aval_to_ir_type(aval), zero,
dense_int_elements(aval.shape)).result
def zeros_like_lowering(ctx, avals_in, avals_out, x):
aval, = avals_in
def zeros_like_lowering(ctx, x):
aval, = ctx.avals_in
assert isinstance(aval, core.ShapedArray), aval
return [full_like_aval(0, aval)]
register_lowering(ad_util.zeros_like_p, zeros_like_lowering)
def add_jaxvals_lowering(ctx, avals_in, avals_out, x, y):
def add_jaxvals_lowering(ctx, x, y):
return mhlo.AddOp(x, y).results
register_lowering(ad_util.add_jaxvals_p, add_jaxvals_lowering)
register_lowering(ad_util.stop_gradient_p,
lambda ctx, avals_in, avals_out, x: [x])
register_lowering(ad_util.stop_gradient_p, lambda ctx, x: [x])
def _minmax_mhlo(op, cmp, x, y):
@ -735,23 +747,24 @@ def set_sharding(op, sharding_proto: xc.OpSharding):
# MLIR lowerings for lax primitives
def xla_fallback_lowering(prim: core.Primitive, ctx: LoweringContext,
avals_in, avals_out, *args, **params):
def xla_fallback_lowering(prim: core.Primitive, ctx: LoweringRuleContext, *args,
**params):
module_ctx = ctx.module_context
xla_computation = xla.primitive_subcomputation(
ctx.platform, ctx.axis_env, prim, *avals_in, **params)
module_ctx.platform, module_ctx.axis_env, prim, *ctx.avals_in, **params)
submodule_str = xc._xla.mlir.xla_computation_to_mlir_module(xla_computation)
submodule = ir.Module.parse(submodule_str)
callee_name = None
for op in submodule.body.operations:
ctx.module.body.append(op)
module_ctx.module.body.append(op)
if op.name.value == "main":
op.attributes["sym_name"] = ir.StringAttr.get(f"xla_fallback_{prim.name}")
callee_name = ir.StringAttr(ctx.symbol_table.insert(op)).value
callee_name = ir.StringAttr(module_ctx.symbol_table.insert(op)).value
op.attributes["sym_visibility"] = ir.StringAttr.get("private")
else:
ctx.symbol_table.insert(op)
module_ctx.symbol_table.insert(op)
output_types = map(aval_to_ir_types, avals_out)
output_types = map(aval_to_ir_types, ctx.avals_out)
flat_output_types = util.flatten(output_types)
output_type = (ir.TupleType.get_tuple(flat_output_types)
if prim.multiple_results else flat_output_types[0])

View File

@ -1718,19 +1718,20 @@ def _mhlo_unshard(aval, axis_env, out_axis, xs, platform):
raise TypeError(aval)
def _pmap_lowering(ctx, avals_in, avals_out, *in_nodes, axis_name,
def _pmap_lowering(ctx, *in_nodes, axis_name,
axis_size, global_axis_size, devices, name,
call_jaxpr, backend=None, in_axes, out_axes,
donated_invars, global_arg_shapes):
del donated_invars # Unused.
xla.check_backend_matches(backend, ctx.platform)
xla.check_backend_matches(backend, ctx.module_context.platform)
# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
if ctx.axis_env.names and devices is not None:
if ctx.module_context.axis_env.names and devices is not None:
raise ValueError("Nested pmap with explicit devices argument.")
if global_axis_size is None:
global_axis_size = axis_size
new_env = xla.extend_axis_env(ctx.axis_env, axis_name, global_axis_size)
new_env = xla.extend_axis_env(ctx.module_context.axis_env, axis_name,
global_axis_size)
# Shard the in_nodes that are mapped
in_avals = [v.aval for v in call_jaxpr.invars]
in_nodes_sharded = (
@ -1739,14 +1740,15 @@ def _pmap_lowering(ctx, avals_in, avals_out, *in_nodes, axis_name,
for aval, in_node, in_axis in zip(in_avals, in_nodes, in_axes))
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
sub_ctx = ctx.replace(
sub_ctx = ctx.module_context.replace(
axis_env=new_env,
name_stack=xla.extend_name_stack(ctx.name_stack,
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
util.wrap_name(name, 'pmap')))
sharded_outs = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, (),
*in_nodes_sharded)
out_avals = [v.aval for v in call_jaxpr.outvars]
outs = [_mhlo_unshard(aval, new_env, out_axis, shard, platform=ctx.platform)
outs = [_mhlo_unshard(aval, new_env, out_axis, shard,
platform=ctx.module_context.platform)
for aval, out_axis, shard in zip(out_avals, out_axes, sharded_outs)]
return outs

View File

@ -214,7 +214,7 @@ def _sharded_jit_translation_rule(ctx, avals_in, avals_out, *in_nodes,
xops.Call(ctx.builder, subc, list(in_nodes)))
def _sharded_jit_lowering(ctx, avals_in, avals_out, *in_nodes,
def _sharded_jit_lowering(ctx, *in_nodes,
in_parts, out_parts_thunk, nparts,
name, call_jaxpr, local_in_parts,
local_out_parts_thunk, local_nparts):
@ -233,12 +233,12 @@ def _sharded_jit_lowering(ctx, avals_in, avals_out, *in_nodes,
else:
args.append(ns)
sub_ctx = ctx.replace(
sub_ctx = ctx.module_context.replace(
name_stack=extend_name_stack(wrap_name(name, "sharded_jit")))
fn = mlir.lower_jaxpr_to_fun(sub_ctx, f"sharded_jit_{name}",
core.ClosedJaxpr(call_jaxpr, ()))
output_types = safe_map(mlir.aval_to_ir_types, avals_out)
output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out)
flat_output_types = util.flatten(output_types)
call = std.CallOp(flat_output_types,
ir.FlatSymbolRefAttr.get(fn.name.value),
@ -472,8 +472,7 @@ ad.deflinear2(sharding_constraint_p,
xla.register_translation(sharding_constraint_p,
_sharding_constraint_translation_rule)
def _sharding_constraint_lowering(ctx, avals_in, avals_out, x_node,
partitions):
def _sharding_constraint_lowering(ctx, x_node, partitions):
return [mlir.wrap_with_sharding_op(x_node, xla.sharding_to_proto(partitions))]
mlir.register_lowering(sharding_constraint_p, _sharding_constraint_lowering)

View File

@ -170,7 +170,7 @@ def _sp_indices_translation_rule(ctx, avals_in, avals_out, data, indices):
# because it leads to infinite recursion.
xla.register_translation(sp_indices_p, _sp_indices_translation_rule)
def _sp_indices_mhlo_lowering(ctx, avals_in, avals_out, data_and_indices):
def _sp_indices_mhlo_lowering(ctx, data_and_indices):
return [data_and_indices[1]]
mlir.register_lowering(sp_indices_p, _sp_indices_mhlo_lowering)
@ -192,7 +192,7 @@ def _sp_data_translation_rule(ctx, avals_in, avals_out, data, indices):
# because it leads to infinite recursion.
xla.register_translation(sp_data_p, _sp_data_translation_rule)
def _sp_data_mhlo_lowering(ctx, avals_in, avals_out, data_and_indices):
def _sp_data_mhlo_lowering(ctx, data_and_indices):
return [data_and_indices[0]]
mlir.register_lowering(sp_data_p, _sp_data_mhlo_lowering)