mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended. Previously the MLIR lowering rule signature was ``` def rule(ctx, avals_in, avals_out, *args, **jaxpr_params): ``` where `ctx` was a module-wide context. Change it to ``` def rule(ctx, *args, **jaxpr_params) ``` where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`. This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context. PiperOrigin-RevId: 416698663
This commit is contained in:
parent
46c2839258
commit
a87b21148c
@ -738,7 +738,7 @@ ad.deflinear2(device_put_p, lambda cotangent, _, **kwargs: [cotangent])
|
||||
masking.defvectorized(device_put_p)
|
||||
batching.defvectorized(device_put_p)
|
||||
|
||||
def _device_put_lowering(ctx, avals_in, avals_out, x, *, device):
|
||||
def _device_put_lowering(ctx, x, *, device):
|
||||
return [x]
|
||||
|
||||
|
||||
|
@ -640,7 +640,7 @@ def _pred_bcast_select_mhlo(
|
||||
return mhlo.SelectOp(bcast_pred, x, y).results
|
||||
|
||||
|
||||
def _while_lowering(ctx, avals_in, avals_out, *args, cond_jaxpr,
|
||||
def _while_lowering(ctx, *args, cond_jaxpr,
|
||||
body_jaxpr, cond_nconsts, body_nconsts):
|
||||
pred_aval = cond_jaxpr.out_avals[0]
|
||||
batched = bool(pred_aval.shape)
|
||||
@ -651,7 +651,7 @@ def _while_lowering(ctx, avals_in, avals_out, *args, cond_jaxpr,
|
||||
# computation), we build XLA computations that handle the tuple munging before
|
||||
# generating a Call into the computations formed from the jaxprs.
|
||||
|
||||
loop_carry_types = _map(mlir.aval_to_ir_types, avals_in)
|
||||
loop_carry_types = _map(mlir.aval_to_ir_types, ctx.avals_in)
|
||||
flat_loop_carry_types = util.flatten(loop_carry_types)
|
||||
loop_carry_tuple_type = ir.TupleType.get_tuple(flat_loop_carry_types)
|
||||
|
||||
@ -668,15 +668,17 @@ def _while_lowering(ctx, avals_in, avals_out, *args, cond_jaxpr,
|
||||
for i, input_type in enumerate(flat_loop_carry_types)]
|
||||
cond_args = util.unflatten(flat_cond_args, _map(len, loop_carry_types))
|
||||
x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts])
|
||||
cond_ctx = ctx.replace(
|
||||
name_stack=xla.extend_name_stack(ctx.name_stack, 'cond'))
|
||||
cond_ctx = ctx.module_context.replace(
|
||||
name_stack=xla.extend_name_stack(ctx.module_context.name_stack, 'cond'))
|
||||
(pred,), = mlir.jaxpr_subcomp(
|
||||
cond_ctx, cond_jaxpr.jaxpr, _map(mlir.ir_constants, cond_jaxpr.consts),
|
||||
*(x + z))
|
||||
if batched:
|
||||
pred_ctx = mlir.LoweringRuleContext(
|
||||
module_context=ctx.module_context, avals_in=[pred_aval],
|
||||
avals_out=[pred_aval.update(shape=())])
|
||||
pred, = lax._unary_reduce_lower(
|
||||
mhlo.OrOp, lambda dtype: np.array(False, dtype), ctx, [pred_aval],
|
||||
[pred_aval.update(shape=())], pred,
|
||||
mhlo.OrOp, lambda dtype: np.array(False, dtype), pred_ctx, pred,
|
||||
axes=tuple(range(len(pred_aval.shape))))
|
||||
mhlo.ReturnOp([pred])
|
||||
|
||||
@ -689,14 +691,15 @@ def _while_lowering(ctx, avals_in, avals_out, *args, cond_jaxpr,
|
||||
for i, input_type in enumerate(flat_loop_carry_types)]
|
||||
body_args = util.unflatten(flat_body_args, _map(len, loop_carry_types))
|
||||
x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts])
|
||||
body_ctx = ctx.replace(
|
||||
name_stack=xla.extend_name_stack(ctx.name_stack, 'body'))
|
||||
body_ctx = ctx.module_context.replace(
|
||||
name_stack=xla.extend_name_stack(ctx.module_context.name_stack, 'body'))
|
||||
new_z = mlir.jaxpr_subcomp(
|
||||
body_ctx, body_jaxpr.jaxpr, _map(mlir.ir_constants, body_jaxpr.consts),
|
||||
*(y + z))
|
||||
if batched:
|
||||
body_pred_ctx = ctx.replace(
|
||||
name_stack=xla.extend_name_stack(ctx.name_stack, 'body_pred'))
|
||||
body_pred_ctx = ctx.module_context.replace(
|
||||
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
|
||||
'body_pred'))
|
||||
(body_pred,), = mlir.jaxpr_subcomp(
|
||||
body_pred_ctx, cond_jaxpr.jaxpr,
|
||||
_map(mlir.ir_constants, cond_jaxpr.consts), *(x + z))
|
||||
@ -1318,11 +1321,11 @@ xla.register_translation(cond_p, _cond_translation_rule, initial_style=True)
|
||||
core.custom_typechecks[cond_p] = _cond_typecheck
|
||||
pe.partial_eval_jaxpr_custom_rules[cond_p] = pe.partial_eval_jaxpr_custom_rule_not_implemented
|
||||
|
||||
def _cond_lowering(ctx, avals_in, avals_out, index, *args, branches, linear):
|
||||
def _cond_lowering(ctx, index, *args, branches, linear):
|
||||
del linear # Unused.
|
||||
arg_avals = avals_in[1:]
|
||||
arg_avals = ctx.avals_in[1:]
|
||||
input_types = _map(mlir.aval_to_ir_types, arg_avals)
|
||||
output_types = _map(mlir.aval_to_ir_types, avals_out)
|
||||
output_types = _map(mlir.aval_to_ir_types, ctx.avals_out)
|
||||
flat_input_types = util.flatten(input_types)
|
||||
flat_output_types = util.flatten(output_types)
|
||||
input_tuple_type = ir.TupleType.get_tuple(flat_input_types)
|
||||
@ -1340,8 +1343,8 @@ def _cond_lowering(ctx, avals_in, avals_out, index, *args, branches, linear):
|
||||
mlir.i32_attr(i)).result
|
||||
for i, input_type in enumerate(flat_input_types)]
|
||||
unflattened_args = util.unflatten(args, _map(len, input_types))
|
||||
out_vals = mlir.jaxpr_subcomp(ctx, jaxpr.jaxpr, jaxpr.consts,
|
||||
*unflattened_args)
|
||||
out_vals = mlir.jaxpr_subcomp(ctx.module_context, jaxpr.jaxpr,
|
||||
jaxpr.consts, *unflattened_args)
|
||||
out = mhlo.TupleOp(output_tuple_type, util.flatten(out_vals)).results
|
||||
mhlo.ReturnOp(out)
|
||||
|
||||
|
@ -723,12 +723,12 @@ def _complex_mul(mul, x, y):
|
||||
_real_dtype = lambda dtype: np.finfo(dtype).dtype
|
||||
|
||||
def _conv_general_dilated_lower(
|
||||
ctx, avals_in, avals_out, lhs, rhs, *, window_strides, padding,
|
||||
ctx, lhs, rhs, *, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count,
|
||||
batch_group_count, precision, preferred_element_type,
|
||||
expand_complex_convolutions=False, **unused_kwargs):
|
||||
lhs_aval, rhs_aval = avals_in
|
||||
aval_out, = avals_out
|
||||
lhs_aval, rhs_aval = ctx.avals_in
|
||||
aval_out, = ctx.avals_out
|
||||
assert isinstance(dimension_numbers, ConvDimensionNumbers)
|
||||
dtype = lhs_aval.dtype
|
||||
if expand_complex_convolutions and np.issubdtype(dtype, np.complexfloating):
|
||||
@ -746,7 +746,7 @@ def _conv_general_dilated_lower(
|
||||
batch_group_count=batch_group_count, precision=precision,
|
||||
preferred_element_type=preferred_element_type)),
|
||||
multiple_results=False)
|
||||
return complex_conv(ctx, avals_in, avals_out, lhs, rhs)
|
||||
return complex_conv(ctx, lhs, rhs)
|
||||
|
||||
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
||||
dnums = mhlo.ConvDimensionNumbers.get(
|
||||
|
@ -1444,9 +1444,7 @@ def broadcast_mhlo(
|
||||
out.append(arg)
|
||||
return out
|
||||
|
||||
def _nary_lower_mhlo(op: Callable, ctx: mlir.LoweringContext,
|
||||
avals_in: Sequence[core.ShapedArray],
|
||||
avals_out: Sequence[core.ShapedArray],
|
||||
def _nary_lower_mhlo(op: Callable, ctx,
|
||||
*args: Union[ir.Value, Sequence[ir.Value]],
|
||||
explicit_type=False, **params):
|
||||
"""Lowers an elementwise operator to its MHLO/CHLO equivalent.
|
||||
@ -1456,8 +1454,8 @@ def _nary_lower_mhlo(op: Callable, ctx: mlir.LoweringContext,
|
||||
provided?
|
||||
"""
|
||||
del params
|
||||
aval_out, = avals_out
|
||||
broadcasted_args = broadcast_mhlo(aval_out, avals_in, args)
|
||||
aval_out, = ctx.avals_out
|
||||
broadcasted_args = broadcast_mhlo(aval_out, ctx.avals_in, args)
|
||||
if explicit_type:
|
||||
return op(mlir.aval_to_ir_type(aval_out), *broadcasted_args).results
|
||||
else:
|
||||
@ -1494,8 +1492,8 @@ def _sign_translation_rule(ctx, avals_in, avals_out, x):
|
||||
sign_p = standard_unop(_num, 'sign', translation_rule=_sign_translation_rule)
|
||||
ad.defjvp_zero(sign_p)
|
||||
|
||||
def _sign_lower_mhlo(ctx, avals_in, avals_out, x):
|
||||
x_aval, = avals_in
|
||||
def _sign_lower_mhlo(ctx, x):
|
||||
x_aval, = ctx.avals_in
|
||||
if dtypes.issubdtype(x_aval.dtype, np.unsignedinteger):
|
||||
return mhlo.SelectOp(
|
||||
mhlo.CompareOp(
|
||||
@ -1546,14 +1544,14 @@ round_p = standard_unop(_float, 'round')
|
||||
xla.register_translation(round_p, _round_translation_rule)
|
||||
ad.defjvp_zero(round_p)
|
||||
|
||||
def _round_lower(ctx, avals_in, avals_out, x, *, rounding_method):
|
||||
def _round_lower(ctx, x, *, rounding_method):
|
||||
if rounding_method is RoundingMethod.AWAY_FROM_ZERO:
|
||||
return mhlo.RoundOp(x).results
|
||||
else:
|
||||
assert rounding_method is RoundingMethod.TO_NEAREST_EVEN
|
||||
round_nearest = mlir.lower_fun(_round_to_nearest_even,
|
||||
multiple_results=False)
|
||||
return round_nearest(ctx, avals_in, avals_out, x)
|
||||
return round_nearest(ctx, x)
|
||||
mlir.register_lowering(round_p, _round_lower)
|
||||
|
||||
is_finite_p = unop(_fixed_dtype(np.bool_), _float, 'is_finite')
|
||||
@ -2113,10 +2111,11 @@ ad.defjvp_zero(shift_right_logical_p)
|
||||
mlir.register_lowering(shift_right_logical_p,
|
||||
partial(_nary_lower_mhlo, mhlo.ShiftRightLogicalOp))
|
||||
|
||||
def _compare_lower_mhlo(direction: str, ctx, avals_in, avals_out, x, y):
|
||||
x_aval, y_aval = avals_in
|
||||
aval_out, = avals_out
|
||||
x, y = broadcast_mhlo(aval_out.update(dtype=x_aval.dtype), avals_in, (x, y))
|
||||
def _compare_lower_mhlo(direction: str, ctx, x, y):
|
||||
x_aval, y_aval = ctx.avals_in
|
||||
aval_out, = ctx.avals_out
|
||||
x, y = broadcast_mhlo(aval_out.update(dtype=x_aval.dtype), ctx.avals_in,
|
||||
(x, y))
|
||||
if dtypes.issubdtype(x_aval.dtype, np.inexact):
|
||||
compare_type = "FLOAT"
|
||||
elif dtypes.issubdtype(x_aval.dtype, np.signedinteger):
|
||||
@ -2222,10 +2221,9 @@ pe.forwarding_rules[convert_element_type_p] = _convert_elt_type_fwd_rule
|
||||
|
||||
def _real_dtype(dtype): return np.finfo(dtype).dtype
|
||||
|
||||
def _convert_element_type_lower(ctx, avals_in, avals_out, operand, *,
|
||||
new_dtype, weak_type):
|
||||
aval_in, = avals_in
|
||||
aval_out, = avals_out
|
||||
def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type):
|
||||
aval_in, = ctx.avals_in
|
||||
aval_out, = ctx.avals_out
|
||||
if (dtypes.issubdtype(aval_in.dtype, np.complexfloating) and
|
||||
not dtypes.issubdtype(new_dtype, np.complexfloating)):
|
||||
operand = mhlo.RealOp(operand).result
|
||||
@ -2260,9 +2258,8 @@ ad.defjvp_zero(bitcast_convert_type_p)
|
||||
batching.defvectorized(bitcast_convert_type_p)
|
||||
masking.defvectorized(bitcast_convert_type_p)
|
||||
|
||||
def _bitcast_convert_type_lower(ctx, avals_in, avals_out, operand, *,
|
||||
new_dtype):
|
||||
aval_out, = avals_out
|
||||
def _bitcast_convert_type_lower(ctx, operand, *, new_dtype):
|
||||
aval_out, = ctx.avals_out
|
||||
return mhlo.BitcastConvertOp(mlir.aval_to_ir_type(aval_out), operand).results
|
||||
|
||||
mlir.register_lowering(bitcast_convert_type_p, _bitcast_convert_type_lower)
|
||||
@ -2522,15 +2519,15 @@ def precision_attr(precision: PrecisionType) -> ir.ArrayAttr:
|
||||
precision = (precision, precision)
|
||||
return ir.ArrayAttr.get([ir.StringAttr.get(str(p)) for p in precision])
|
||||
|
||||
def _dot_general_lower(ctx, avals_in, avals_out, lhs, rhs, *, dimension_numbers,
|
||||
def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
|
||||
precision, preferred_element_type: Optional[np.dtype]):
|
||||
del preferred_element_type # Implied by the output aval
|
||||
lhs_aval, rhs_aval = avals_in
|
||||
aval_out, = avals_out
|
||||
lhs_aval, rhs_aval = ctx.avals_in
|
||||
aval_out, = ctx.avals_out
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
|
||||
# TODO(b/195364460): Work around slow XLA/CPU implementation of float16 matmul
|
||||
if ctx.platform == "cpu":
|
||||
if ctx.module_context.platform == "cpu":
|
||||
if lhs_aval.dtype == np.float16:
|
||||
f32 = mlir.dtype_to_ir_type(np.dtype(np.float32))
|
||||
lhs = mhlo.ConvertOp(ir.RankedTensorType.get(lhs_aval.shape, f32),
|
||||
@ -2614,10 +2611,9 @@ ad.deflinear2(broadcast_in_dim_p, _broadcast_in_dim_transpose_rule)
|
||||
batching.primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule
|
||||
pe.forwarding_rules[broadcast_in_dim_p] = _broadcast_in_dim_fwd_rule
|
||||
|
||||
def _broadcast_in_dim_lower(ctx, avals_in, avals_out, x, *, shape,
|
||||
broadcast_dimensions):
|
||||
def _broadcast_in_dim_lower(ctx, x, *, shape, broadcast_dimensions):
|
||||
del shape
|
||||
aval_out, = avals_out
|
||||
aval_out, = ctx.avals_out
|
||||
return mhlo.BroadcastInDimOp(
|
||||
mlir.aval_to_ir_type(aval_out), x,
|
||||
mlir.dense_int_elements(broadcast_dimensions)
|
||||
@ -2760,7 +2756,7 @@ ad.deflinear2(concatenate_p, _concatenate_transpose_rule)
|
||||
ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule
|
||||
batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule
|
||||
|
||||
def _concatenate_lower(ctx, avals_in, avals_out, *xs, dimension):
|
||||
def _concatenate_lower(ctx, *xs, dimension):
|
||||
return mhlo.ConcatenateOp(xs, mlir.i64_attr(dimension)).results
|
||||
mlir.register_lowering(concatenate_p, _concatenate_lower)
|
||||
|
||||
@ -2852,8 +2848,8 @@ ad.deflinear2(pad_p, _pad_transpose)
|
||||
batching.primitive_batchers[pad_p] = _pad_batch_rule
|
||||
masking.masking_rules[pad_p] = _pad_masking_rule
|
||||
|
||||
def _pad_lower(ctx, avals_in, avals_out, x, padding_value, *, padding_config):
|
||||
aval_out, = avals_out
|
||||
def _pad_lower(ctx, x, padding_value, *, padding_config):
|
||||
aval_out, = ctx.avals_out
|
||||
low, high, interior = util.unzip3(padding_config)
|
||||
return mhlo.PadOp(mlir.aval_to_ir_type(aval_out), x, padding_value,
|
||||
mlir.dense_int_elements(low),
|
||||
@ -2910,9 +2906,9 @@ squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule,
|
||||
ad.deflinear2(squeeze_p, _squeeze_transpose_rule)
|
||||
batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule
|
||||
|
||||
def _squeeze_lower(ctx, avals_in, avals_out, operand, *, dimensions):
|
||||
def _squeeze_lower(ctx, operand, *, dimensions):
|
||||
del dimensions # Implied by the output aval.
|
||||
aval_out, = avals_out
|
||||
aval_out, = ctx.avals_out
|
||||
return mhlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), operand).results
|
||||
|
||||
mlir.register_lowering(squeeze_p, _squeeze_lower)
|
||||
@ -3014,9 +3010,9 @@ ad.deflinear2(reshape_p, _reshape_transpose_rule)
|
||||
batching.primitive_batchers[reshape_p] = _reshape_batch_rule
|
||||
masking.masking_rules[reshape_p] = _reshape_masking_rule
|
||||
|
||||
def _reshape_lower(ctx, avals_in, avals_out, x, *, new_sizes, dimensions):
|
||||
aval_in, = avals_in
|
||||
aval_out, = avals_out
|
||||
def _reshape_lower(ctx, x, *, new_sizes, dimensions):
|
||||
aval_in, = ctx.avals_in
|
||||
aval_out, = ctx.avals_out
|
||||
if dimensions is not None:
|
||||
aval = core.ShapedArray(np.take(aval_in.shape, dimensions), aval_in.dtype)
|
||||
x = mhlo.TransposeOp(mlir.aval_to_ir_type(aval), x,
|
||||
@ -3045,7 +3041,7 @@ rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev')
|
||||
ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)])
|
||||
batching.primitive_batchers[rev_p] = _rev_batch_rule
|
||||
|
||||
def _rev_lower(ctx, avals_in, avals_out, x, *, dimensions):
|
||||
def _rev_lower(ctx, x, *, dimensions):
|
||||
return mhlo.ReverseOp(x, mlir.dense_int_elements(dimensions)).results
|
||||
mlir.register_lowering(rev_p, _rev_lower)
|
||||
|
||||
@ -3076,8 +3072,8 @@ ad.deflinear2(transpose_p,
|
||||
batching.primitive_batchers[transpose_p] = _transpose_batch_rule
|
||||
masking.masking_rules[transpose_p] = _transpose_masking_rule
|
||||
|
||||
def _transpose_lower(ctx, avals_in, avals_out, x, *, permutation):
|
||||
aval_out, = avals_out
|
||||
def _transpose_lower(ctx, x, *, permutation):
|
||||
aval_out, = ctx.avals_out
|
||||
return mhlo.TransposeOp(mlir.aval_to_ir_type(aval_out), x,
|
||||
mlir.dense_int_elements(permutation)).results
|
||||
mlir.register_lowering(transpose_p, _transpose_lower)
|
||||
@ -3338,18 +3334,17 @@ xla.register_translation(reduce_p, _reduce_translation_rule)
|
||||
batching.primitive_batchers[reduce_p] = _reduce_batch_rule
|
||||
ad.primitive_jvps[reduce_p] = _reduce_jvp_rule
|
||||
|
||||
def _reduce_lower(ctx, avals_in, avals_out, *values, computation, jaxpr,
|
||||
consts, dimensions):
|
||||
assert all(isinstance(x, core.ShapedArray) for x in avals_in), avals_in
|
||||
def _reduce_lower(ctx, *values, computation, jaxpr, consts, dimensions):
|
||||
assert all(isinstance(x, core.ShapedArray) for x in ctx.avals_in), ctx.avals_in
|
||||
operands, init_values = util.split_list(values, [len(values) // 2])
|
||||
init_value_avals = avals_in[len(values) // 2:]
|
||||
op = mhlo.ReduceOp([mlir.aval_to_ir_type(aval) for aval in avals_out],
|
||||
init_value_avals = ctx.avals_in[len(values) // 2:]
|
||||
op = mhlo.ReduceOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
||||
operands, init_values, mlir.dense_int_elements(dimensions))
|
||||
ir_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
|
||||
reducer = op.regions[0].blocks.append(*(ir_types + ir_types))
|
||||
with ir.InsertionPoint(reducer):
|
||||
ctx = ctx.replace(name_stack='')
|
||||
out_nodes = mlir.jaxpr_subcomp(ctx, jaxpr, consts,
|
||||
reducer_ctx = ctx.module_context.replace(name_stack='')
|
||||
out_nodes = mlir.jaxpr_subcomp(reducer_ctx, jaxpr, consts,
|
||||
*([a] for a in reducer.arguments))
|
||||
mhlo.ReturnOp(util.flatten(out_nodes))
|
||||
return op.results
|
||||
@ -3566,9 +3561,8 @@ reduce_and_p = standard_primitive(_reduce_logical_shape_rule, _fixed_dtype(np.bo
|
||||
batching.defreducer(reduce_and_p)
|
||||
|
||||
|
||||
def _unary_reduce_lower(reducer, unit_factory, ctx, avals_in, avals_out, x, *,
|
||||
axes):
|
||||
aval_out, = avals_out
|
||||
def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes):
|
||||
aval_out, = ctx.avals_out
|
||||
dtype = aval_out.dtype
|
||||
op = mhlo.ReduceOp([mlir.aval_to_ir_type(aval_out)], [x],
|
||||
mlir.ir_constants(unit_factory(aval_out.dtype)),
|
||||
@ -3613,9 +3607,8 @@ reduce_precision_p = standard_primitive(
|
||||
batching.defvectorized(reduce_precision_p)
|
||||
masking.defvectorized(reduce_precision_p)
|
||||
|
||||
def _reduce_precision_lower(ctx, avals_in, avals_out, operand, *, exponent_bits,
|
||||
mantissa_bits):
|
||||
aval_out, = avals_out
|
||||
def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits):
|
||||
aval_out, = ctx.avals_out
|
||||
return mhlo.ReducePrecisionOp(mlir.aval_to_ir_type(aval_out), operand,
|
||||
mlir.i32_attr(exponent_bits),
|
||||
mlir.i32_attr(mantissa_bits)).results
|
||||
@ -3754,22 +3747,24 @@ ad.primitive_jvps[sort_p] = _sort_jvp
|
||||
batching.primitive_batchers[sort_p] = _sort_batch_rule
|
||||
|
||||
|
||||
def _sort_lower(ctx, avals_in, avals_out, *operands, dimension, is_stable,
|
||||
num_keys):
|
||||
assert all(isinstance(x, core.ShapedArray) for x in avals_in), avals_in
|
||||
sort = mhlo.SortOp([mlir.aval_to_ir_type(aval) for aval in avals_out],
|
||||
def _sort_lower(ctx, *operands, dimension, is_stable, num_keys):
|
||||
assert all(isinstance(x, core.ShapedArray) for x in ctx.avals_in), ctx.avals_in
|
||||
sort = mhlo.SortOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
||||
mlir.flatten_lowering_ir_args(operands),
|
||||
mlir.i64_attr(dimension), ir.BoolAttr.get(is_stable))
|
||||
scalar_avals = [aval.update(shape=()) for aval in avals_in]
|
||||
scalar_avals = [aval.update(shape=()) for aval in ctx.avals_in]
|
||||
scalar_types = safe_map(mlir.aval_to_ir_type, scalar_avals)
|
||||
comparator = sort.comparator.blocks.append(
|
||||
*util.flatten(zip(scalar_types, scalar_types)))
|
||||
with ir.InsertionPoint(comparator):
|
||||
lower_comparator = mlir.lower_fun(partial(_sort_lt_comparator),
|
||||
multiple_results=False)
|
||||
out = lower_comparator(ctx, util.flatten(zip(scalar_avals, scalar_avals)),
|
||||
[core.ShapedArray((), np.bool_)],
|
||||
*[[a] for a in comparator.arguments],
|
||||
sub_ctx = mlir.LoweringRuleContext(
|
||||
module_context = ctx.module_context,
|
||||
avals_in=util.flatten(zip(scalar_avals, scalar_avals)),
|
||||
avals_out=[core.ShapedArray((), np.bool_)])
|
||||
|
||||
out = lower_comparator(sub_ctx, *[[a] for a in comparator.arguments],
|
||||
num_keys=num_keys)
|
||||
mhlo.ReturnOp(util.flatten(out))
|
||||
return sort.results
|
||||
@ -3870,8 +3865,8 @@ create_token_p.def_abstract_eval(lambda *_: abstract_token)
|
||||
xla.register_translation(create_token_p,
|
||||
lambda ctx, *_: [xops.CreateToken(ctx.builder)])
|
||||
|
||||
def _create_token_lowering(ctx, avals_in, avals_out, *operands):
|
||||
aval_out, = avals_out
|
||||
def _create_token_lowering(ctx, *operands):
|
||||
aval_out, = ctx.avals_out
|
||||
return mhlo.CreateTokenOp(mlir.aval_to_ir_type(aval_out)).results
|
||||
|
||||
mlir.register_lowering(create_token_p, _create_token_lowering)
|
||||
@ -3897,8 +3892,8 @@ after_all_p.def_impl(partial(xla.apply_primitive, after_all_p))
|
||||
after_all_p.def_abstract_eval(_after_all_abstract_eval)
|
||||
xla.register_translation(after_all_p, _after_all_translation_rule)
|
||||
|
||||
def _after_all_lowering(ctx, avals_in, avals_out, *operands):
|
||||
aval_out, = avals_out
|
||||
def _after_all_lowering(ctx, *operands):
|
||||
aval_out, = ctx.avals_out
|
||||
return mhlo.AfterAllOp(mlir.aval_to_ir_type(aval_out), operands).results
|
||||
|
||||
mlir.register_lowering(after_all_p, _after_all_lowering)
|
||||
@ -3956,8 +3951,8 @@ infeed_p.def_abstract_eval(_infeed_abstract_eval)
|
||||
xla.register_translation(infeed_p, _infeed_translation_rule)
|
||||
|
||||
|
||||
def _infeed_lowering(ctx, avals_in, avals_out, token, *, shapes, partitions):
|
||||
output_types = safe_map(mlir.aval_to_ir_types, avals_out[:-1])
|
||||
def _infeed_lowering(ctx, token, *, shapes, partitions):
|
||||
output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out[:-1])
|
||||
flat_output_types = util.flatten(output_types)
|
||||
output_tuple_type = ir.TupleType.get_tuple(flat_output_types)
|
||||
# TODO(phawkins): verify `shapes` have a major-to-minor layout.
|
||||
@ -4020,9 +4015,9 @@ outfeed_p.def_abstract_eval(_outfeed_abstract_eval)
|
||||
xla.register_translation(outfeed_p, _outfeed_translation_rule)
|
||||
|
||||
|
||||
def _outfeed_lowering(ctx, avals_in, avals_out, token, *xs, partitions):
|
||||
token_aval = avals_in[0]
|
||||
xs_avals = avals_in[1:]
|
||||
def _outfeed_lowering(ctx, token, *xs, partitions):
|
||||
token_aval = ctx.avals_in[0]
|
||||
xs_avals = ctx.avals_in[1:]
|
||||
input_types = map(mlir.aval_to_ir_types, xs_avals)
|
||||
flat_input_types = util.flatten(input_types)
|
||||
input_tuple_type = ir.TupleType.get_tuple(flat_input_types)
|
||||
@ -4070,10 +4065,10 @@ rng_uniform_p.def_impl(partial(xla.apply_primitive, rng_uniform_p))
|
||||
rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval)
|
||||
xla.register_translation(rng_uniform_p, _rng_uniform_translation_rule)
|
||||
|
||||
def _rng_uniform_lowering(ctx, avals_in, avals_out, a, b, *, shape):
|
||||
aval_out, = avals_out
|
||||
def _rng_uniform_lowering(ctx, a, b, *, shape):
|
||||
aval_out, = ctx.avals_out
|
||||
shape, = mlir.ir_constants(np.array(aval_out.shape, np.int64),
|
||||
canonicalize_types=False)
|
||||
canonicalize_types=False)
|
||||
return mhlo.RngUniformOp(a, b, shape).results
|
||||
|
||||
mlir.register_lowering(rng_uniform_p, _rng_uniform_lowering)
|
||||
@ -4177,9 +4172,9 @@ iota_p.def_impl(partial(xla.apply_primitive, iota_p))
|
||||
iota_p.def_abstract_eval(_iota_abstract_eval)
|
||||
xla.register_translation(iota_p, _iota_translation_rule)
|
||||
|
||||
def _iota_lower(ctx, avals_in, avals_out, *, dtype, shape, dimension):
|
||||
def _iota_lower(ctx, *, dtype, shape, dimension):
|
||||
del dtype, shape
|
||||
aval_out, = avals_out
|
||||
aval_out, = ctx.avals_out
|
||||
return mhlo.IotaOp(mlir.aval_to_ir_type(aval_out),
|
||||
mlir.i64_attr(dimension)).results
|
||||
mlir.register_lowering(iota_p, _iota_lower)
|
||||
|
@ -747,9 +747,8 @@ ad.deflinear2(slice_p, _slice_transpose_rule)
|
||||
batching.primitive_batchers[slice_p] = _slice_batching_rule
|
||||
masking.masking_rules[slice_p] = _slice_masking_rule
|
||||
|
||||
def _slice_lower(ctx, avals_in, avals_out, x, *, start_indices,
|
||||
limit_indices, strides):
|
||||
aval_out, = avals_out
|
||||
def _slice_lower(ctx, x, *, start_indices, limit_indices, strides):
|
||||
aval_out, = ctx.avals_out
|
||||
strides = strides or [1] * len(start_indices)
|
||||
return mhlo.SliceOp(x,
|
||||
mlir.dense_int_elements(start_indices),
|
||||
@ -848,9 +847,8 @@ ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp # TODO
|
||||
ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule
|
||||
batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule
|
||||
|
||||
def _dynamic_slice_lower(ctx, avals_in, avals_out, x, *start_indices,
|
||||
slice_sizes):
|
||||
aval_out, = avals_out
|
||||
def _dynamic_slice_lower(ctx, x, *start_indices, slice_sizes):
|
||||
aval_out, = ctx.avals_out
|
||||
return mhlo.DynamicSliceOp(mlir.aval_to_ir_type(aval_out), x,
|
||||
start_indices,
|
||||
mlir.dense_int_elements(slice_sizes)).results
|
||||
@ -947,9 +945,8 @@ ad.primitive_transposes[dynamic_update_slice_p] = \
|
||||
batching.primitive_batchers[dynamic_update_slice_p] = \
|
||||
_dynamic_update_slice_batching_rule
|
||||
|
||||
def _dynamic_update_slice_lower(ctx, avals_in, avals_out, x, update,
|
||||
*start_indices):
|
||||
aval_out, = avals_out
|
||||
def _dynamic_update_slice_lower(ctx, x, update, *start_indices):
|
||||
aval_out, = ctx.avals_out
|
||||
return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update,
|
||||
start_indices).results
|
||||
|
||||
@ -1280,14 +1277,14 @@ batching.primitive_batchers[gather_p] = _gather_batching_rule
|
||||
|
||||
|
||||
|
||||
def _gather_lower(ctx, avals_in, avals_out, operand, indices, *,
|
||||
def _gather_lower(ctx, operand, indices, *,
|
||||
dimension_numbers, slice_sizes, unique_indices,
|
||||
indices_are_sorted, mode, fill_value):
|
||||
aval_out, = avals_out
|
||||
aval_out, = ctx.avals_out
|
||||
if mode == GatherScatterMode.FILL_OR_DROP:
|
||||
gather_fill_fn = mlir.lower_fun(_gather_fill, multiple_results=False)
|
||||
return gather_fill_fn(
|
||||
ctx, avals_in, avals_out, operand, indices,
|
||||
ctx, operand, indices,
|
||||
dimension_numbers=dimension_numbers, slice_sizes=slice_sizes,
|
||||
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
|
||||
fill_value=fill_value, output_shape=aval_out.shape)
|
||||
@ -1296,7 +1293,7 @@ def _gather_lower(ctx, avals_in, avals_out, operand, indices, *,
|
||||
GatherScatterMode.CLIP), mode
|
||||
dnums = mhlo.GatherDimensionNumbers.get(
|
||||
collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims),
|
||||
index_vector_dim=len(avals_in[1].shape) - 1,
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
|
||||
offset_dims=list(dimension_numbers.offset_dims),
|
||||
start_index_map=list(dimension_numbers.start_index_map))
|
||||
return mhlo.GatherOp(operand, indices, dnums,
|
||||
@ -1945,30 +1942,30 @@ batching.primitive_batchers[scatter_p] = (
|
||||
|
||||
|
||||
|
||||
def _scatter_lower(ctx, avals_in, avals_out, operand, indices, updates, *,
|
||||
def _scatter_lower(ctx, operand, indices, updates, *,
|
||||
update_jaxpr, update_consts, dimension_numbers,
|
||||
indices_are_sorted, unique_indices, mode):
|
||||
if mode == GatherScatterMode.CLIP:
|
||||
clip_fn = mlir.lower_fun(_clamp_scatter_indices, multiple_results=False)
|
||||
(indices,), = clip_fn(ctx, avals_in, None, operand, indices, updates,
|
||||
dnums=dimension_numbers)
|
||||
(indices,), = clip_fn(ctx.replace(avals_out=None), operand, indices,
|
||||
updates, dnums=dimension_numbers)
|
||||
|
||||
aval_out, = avals_out
|
||||
aval_out, = ctx.avals_out
|
||||
dnums = dimension_numbers
|
||||
scatter_dnums = mhlo.ScatterDimensionNumbers.get(
|
||||
update_window_dims=list(dnums.update_window_dims),
|
||||
inserted_window_dims=list(dnums.inserted_window_dims),
|
||||
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
|
||||
index_vector_dim=len(avals_in[1].shape) - 1)
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1)
|
||||
op = mhlo.ScatterOp(mlir.aval_to_ir_type(aval_out), operand, indices, updates,
|
||||
scatter_dnums, ir.BoolAttr.get(indices_are_sorted),
|
||||
ir.BoolAttr.get(unique_indices))
|
||||
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), aval_out.dtype))
|
||||
update = op.update_computation.blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(update):
|
||||
ctx = ctx.replace(name_stack='')
|
||||
update_ctx = ctx.module_context.replace(name_stack='')
|
||||
out_nodes = mlir.jaxpr_subcomp(
|
||||
ctx, update_jaxpr, update_consts,
|
||||
update_ctx, update_jaxpr, update_consts,
|
||||
(update.arguments[0],), (update.arguments[1],))
|
||||
mhlo.ReturnOp(util.flatten(out_nodes))
|
||||
return op.results
|
||||
@ -1982,12 +1979,12 @@ mlir.register_lowering(scatter_max_p, _scatter_lower)
|
||||
|
||||
def _real_dtype(dtype): return np.finfo(dtype).dtype
|
||||
|
||||
def _scatter_add_lower_gpu(ctx, avals_in, avals_out, operand, indices, updates,
|
||||
def _scatter_add_lower_gpu(ctx, operand, indices, updates,
|
||||
*, update_jaxpr, update_consts, dimension_numbers,
|
||||
indices_are_sorted, unique_indices, mode):
|
||||
operand_aval_in, _, updates_aval_in = avals_in
|
||||
operand_aval_in, _, updates_aval_in = ctx.avals_in
|
||||
if operand_aval_in.dtype != np.complex128:
|
||||
return _scatter_lower(ctx, avals_in, avals_out, operand, indices, updates,
|
||||
return _scatter_lower(ctx, operand, indices, updates,
|
||||
update_jaxpr=update_jaxpr,
|
||||
update_consts=update_consts,
|
||||
dimension_numbers=dimension_numbers,
|
||||
@ -1996,16 +1993,16 @@ def _scatter_add_lower_gpu(ctx, avals_in, avals_out, operand, indices, updates,
|
||||
|
||||
if mode == GatherScatterMode.CLIP:
|
||||
clip_fn = mlir.lower_fun(_clamp_scatter_indices, multiple_results=False)
|
||||
(indices,), = clip_fn(ctx, avals_in, None, operand, indices, updates,
|
||||
(indices,), = clip_fn(ctx, ctx.avals_in, None, operand, indices, updates,
|
||||
dnums=dimension_numbers)
|
||||
|
||||
aval_out, = avals_out
|
||||
aval_out, = ctx.avals_out
|
||||
dnums = dimension_numbers
|
||||
scatter_dnums = mhlo.ScatterDimensionNumbers.get(
|
||||
update_window_dims=list(dnums.update_window_dims),
|
||||
inserted_window_dims=list(dnums.inserted_window_dims),
|
||||
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
|
||||
index_vector_dim=len(avals_in[1].shape) - 1)
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1)
|
||||
real_dtype = _real_dtype(aval_out.dtype)
|
||||
operand_type_part = mlir.aval_to_ir_type(
|
||||
core.ShapedArray(aval_out.shape, real_dtype))
|
||||
|
@ -308,14 +308,14 @@ reduce_window_p.def_abstract_eval(_reduce_window_abstract_eval_rule)
|
||||
batching.primitive_batchers[reduce_window_p] = _generic_reduce_window_batch_rule
|
||||
xla.register_translation(reduce_window_p, _reduce_window_translation_rule)
|
||||
|
||||
def _generic_reduce_window_lower(ctx, avals_in, avals_out, *args, jaxpr, consts,
|
||||
def _generic_reduce_window_lower(ctx, *args, jaxpr, consts,
|
||||
window_dimensions, window_strides, padding,
|
||||
base_dilation, window_dilation):
|
||||
operands, init_values = util.split_list(args, [len(args) // 2])
|
||||
_, init_value_avals = util.split_list(avals_in, [len(operands)])
|
||||
_, init_value_avals = util.split_list(ctx.avals_in, [len(operands)])
|
||||
scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
|
||||
rw = mhlo.ReduceWindowOp(
|
||||
map(mlir.aval_to_ir_type, avals_out), operands, init_values,
|
||||
map(mlir.aval_to_ir_type, ctx.avals_out), operands, init_values,
|
||||
mlir.dense_int_elements(window_dimensions),
|
||||
mlir.dense_int_elements(window_strides),
|
||||
mlir.dense_int_elements(base_dilation),
|
||||
@ -323,7 +323,7 @@ def _generic_reduce_window_lower(ctx, avals_in, avals_out, *args, jaxpr, consts,
|
||||
ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
|
||||
reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
|
||||
with ir.InsertionPoint(reducer):
|
||||
out_nodes = mlir.jaxpr_subcomp(ctx, jaxpr, consts,
|
||||
out_nodes = mlir.jaxpr_subcomp(ctx.module_context, jaxpr, consts,
|
||||
*([a] for a in reducer.arguments))
|
||||
mhlo.ReturnOp(util.flatten(out_nodes))
|
||||
return rw.results
|
||||
@ -488,10 +488,10 @@ batching.primitive_batchers[reduce_window_min_p] = partial(
|
||||
|
||||
|
||||
def _reduce_window_lower(
|
||||
reduce_op, init_value, ctx, avals_in, avals_out, operand, *,
|
||||
reduce_op, init_value, ctx, operand, *,
|
||||
window_dimensions, window_strides, padding, base_dilation, window_dilation):
|
||||
aval_out, = avals_out
|
||||
operand_aval, = avals_in
|
||||
aval_out, = ctx.avals_out
|
||||
operand_aval, = ctx.avals_in
|
||||
scalar_aval = operand_aval.update(shape=())
|
||||
scalar_type = mlir.aval_to_ir_type(scalar_aval)
|
||||
rw = mhlo.ReduceWindowOp(
|
||||
@ -545,11 +545,11 @@ select_and_scatter_p = lax.standard_primitive(
|
||||
_select_and_scatter_translation)
|
||||
|
||||
def _select_and_scatter_lower(
|
||||
ctx, avals_in, avals_out, operand, source, init_value, *, select_jaxpr,
|
||||
ctx, operand, source, init_value, *, select_jaxpr,
|
||||
select_consts, scatter_jaxpr, scatter_consts, window_dimensions,
|
||||
window_strides, padding):
|
||||
operand_aval, source_aval, init_value_aval = avals_in
|
||||
aval_out, = avals_out
|
||||
operand_aval, source_aval, init_value_aval = ctx.avals_in
|
||||
aval_out, = ctx.avals_out
|
||||
scalar_aval = operand_aval.update(shape=())
|
||||
scalar_type = mlir.aval_to_ir_type(scalar_aval)
|
||||
op = mhlo.SelectAndScatterOp(
|
||||
@ -559,12 +559,14 @@ def _select_and_scatter_lower(
|
||||
ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
|
||||
select = op.select.blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(select):
|
||||
out_nodes = mlir.jaxpr_subcomp(ctx, select_jaxpr, select_consts,
|
||||
out_nodes = mlir.jaxpr_subcomp(ctx.module_context, select_jaxpr,
|
||||
select_consts,
|
||||
*([a] for a in select.arguments))
|
||||
mhlo.ReturnOp(util.flatten(out_nodes))
|
||||
scatter = op.scatter.blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(scatter):
|
||||
out_nodes = mlir.jaxpr_subcomp(ctx, scatter_jaxpr, scatter_consts,
|
||||
out_nodes = mlir.jaxpr_subcomp(ctx.module_context, scatter_jaxpr,
|
||||
scatter_consts,
|
||||
*([a] for a in scatter.arguments))
|
||||
mhlo.ReturnOp(util.flatten(out_nodes))
|
||||
return op.results
|
||||
@ -679,8 +681,9 @@ xla.register_translation(
|
||||
platform='gpu')
|
||||
|
||||
|
||||
def _select_and_scatter_add_impl(source, operand, *, select_prim, window_dimensions,
|
||||
window_strides, padding, expand_padding):
|
||||
def _select_and_scatter_add_impl(source, operand, *,
|
||||
select_prim, window_dimensions, window_strides,
|
||||
padding, expand_padding):
|
||||
dtype = source.dtype
|
||||
select = lambda x, y: select_prim.bind(x, y)
|
||||
scatter = lax.bitwise_or if dtype == np.bool_ else lax.add
|
||||
|
@ -1303,14 +1303,14 @@ def _xmap_lowering_rule(*args, **kwargs):
|
||||
return _xmap_lowering_rule_replica(*args, **kwargs)
|
||||
mlir.register_lowering(xmap_p, _xmap_lowering_rule)
|
||||
|
||||
def _xmap_lowering_rule_replica(ctx, avals_in, avals_out, *in_nodes,
|
||||
call_jaxpr, name,
|
||||
in_axes, out_axes, donated_invars,
|
||||
global_axis_sizes,
|
||||
spmd_in_axes, spmd_out_axes,
|
||||
positional_semantics,
|
||||
axis_resources, resource_env, backend):
|
||||
xla.check_backend_matches(backend, ctx.platform)
|
||||
def _xmap_lowering_rule_replica(ctx, *in_nodes,
|
||||
call_jaxpr, name,
|
||||
in_axes, out_axes, donated_invars,
|
||||
global_axis_sizes,
|
||||
spmd_in_axes, spmd_out_axes,
|
||||
positional_semantics,
|
||||
axis_resources, resource_env, backend):
|
||||
xla.check_backend_matches(backend, ctx.module_context.platform)
|
||||
# The only way for any of those two assertions to be violated is when xmap
|
||||
# is using the SPMD lowering, but then this rule shouldn't even trigger.
|
||||
assert positional_semantics == _PositionalSemantics.LOCAL
|
||||
@ -1346,30 +1346,34 @@ def _xmap_lowering_rule_replica(ctx, avals_in, avals_out, *in_nodes,
|
||||
assert not consts
|
||||
|
||||
tiled_ins = (
|
||||
mlir.lower_fun(
|
||||
partial(_tile, in_axes=arg_in_axes, axis_sizes=local_mesh_shape),
|
||||
multiple_results=False)(ctx, [aval], None, in_node)[0]
|
||||
if v.aval is not core.abstract_unit else in_node
|
||||
for v, aval, in_node, arg_in_axes
|
||||
in zip(call_jaxpr.invars, avals_in, in_nodes, mesh_in_axes))
|
||||
mlir.lower_fun(partial(_tile, in_axes=arg_in_axes,
|
||||
axis_sizes=local_mesh_shape),
|
||||
multiple_results=False)(
|
||||
mlir.LoweringRuleContext(module_context=ctx.module_context,
|
||||
avals_in=[aval], avals_out=None),
|
||||
in_node)[0]
|
||||
if v.aval is not core.abstract_unit else in_node
|
||||
for v, aval, in_node, arg_in_axes
|
||||
in zip(call_jaxpr.invars, ctx.avals_in, in_nodes, mesh_in_axes))
|
||||
|
||||
# NOTE: We don't extend the resource env with the mesh shape, because those
|
||||
# resources are already in scope! It's the outermost xmap that introduces
|
||||
# them!
|
||||
# We in-line here rather than generating a Call HLO as in the xla_call
|
||||
# translation rule just because the extra tuple stuff is a pain.
|
||||
sub_ctx = ctx.replace(
|
||||
name_stack=xla.extend_name_stack(ctx.name_stack,
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
|
||||
xla.wrap_name(name, 'xmap')))
|
||||
tiled_outs = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr, (), *tiled_ins)
|
||||
|
||||
outs = [
|
||||
mlir.lower_fun(
|
||||
partial(_untile, out_axes=ans_out_axes, axis_sizes=local_mesh_shape,
|
||||
platform=ctx.platform),
|
||||
platform=ctx.module_context.platform),
|
||||
multiple_results=False)(
|
||||
ctx, [vectorized_outvar.aval], None, tiled_out
|
||||
)[0]
|
||||
mlir.LoweringRuleContext(module_context=ctx.module_context,
|
||||
avals_in=[vectorized_outvar.aval],
|
||||
avals_out=None), tiled_out)[0]
|
||||
if v.aval is not core.abstract_unit else tiled_out
|
||||
for v, vectorized_outvar, tiled_out, ans_out_axes
|
||||
in zip(call_jaxpr.outvars, vectorized_jaxpr.outvars, tiled_outs,
|
||||
@ -1377,12 +1381,12 @@ def _xmap_lowering_rule_replica(ctx, avals_in, avals_out, *in_nodes,
|
||||
return outs
|
||||
|
||||
|
||||
def _xmap_lowering_rule_spmd(ctx, avals_in, avals_out, *global_in_nodes,
|
||||
def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
|
||||
call_jaxpr, name, in_axes, out_axes,
|
||||
donated_invars, global_axis_sizes, spmd_in_axes,
|
||||
spmd_out_axes, positional_semantics,
|
||||
axis_resources, resource_env, backend):
|
||||
xla.check_backend_matches(backend, ctx.platform)
|
||||
xla.check_backend_matches(backend, ctx.module_context.platform)
|
||||
plan = EvaluationPlan.from_axis_resources(axis_resources, resource_env, global_axis_sizes)
|
||||
|
||||
resource_call_jaxpr = plan.subst_axes_with_resources(call_jaxpr)
|
||||
@ -1409,7 +1413,7 @@ def _xmap_lowering_rule_spmd(ctx, avals_in, avals_out, *global_in_nodes,
|
||||
# NOTE: We don't extend the resource env with the mesh shape, because those
|
||||
# resources are already in scope! It's the outermost xmap that introduces
|
||||
# them!
|
||||
global_in_avals = avals_in
|
||||
global_in_avals = ctx.avals_in
|
||||
vectorized_jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(f, global_in_avals)
|
||||
assert not consts
|
||||
|
||||
@ -1422,8 +1426,8 @@ def _xmap_lowering_rule_spmd(ctx, avals_in, avals_out, *global_in_nodes,
|
||||
|
||||
# We in-line here rather than generating a Call HLO as in the xla_call
|
||||
# translation rule just because the extra tuple stuff is a pain.
|
||||
sub_ctx = ctx.replace(
|
||||
name_stack=xla.extend_name_stack(ctx.name_stack,
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
|
||||
xla.wrap_name(name, 'xmap')))
|
||||
global_out_nodes = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr, (),
|
||||
*sharded_global_in_nodes)
|
||||
|
@ -866,9 +866,9 @@ def _sharding_constraint_translation_rule(ctx, avals_in, avals_out, x_node, *,
|
||||
ctx.builder, x_node, get_aval_sharding_proto(aval, axis_resources, mesh))]
|
||||
xla.register_translation(sharding_constraint_p, _sharding_constraint_translation_rule)
|
||||
|
||||
def _sharding_constraint_mhlo_lowering(ctx, avals_in, avals_out, x_node, *,
|
||||
axis_resources, resource_env):
|
||||
aval, = avals_in
|
||||
def _sharding_constraint_mhlo_lowering(ctx, x_node, *, axis_resources,
|
||||
resource_env):
|
||||
aval, = ctx.avals_in
|
||||
mesh = resource_env.physical_mesh
|
||||
return [mlir.wrap_with_sharding_op(
|
||||
x_node, get_aval_sharding_proto(aval, axis_resources, mesh))]
|
||||
|
@ -280,7 +280,8 @@ def _source_info_to_location(
|
||||
# Translation rules
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LoweringContext:
|
||||
class ModuleContext:
|
||||
"""Module-wide context information for MLIR lowering."""
|
||||
context: ir.Context
|
||||
module: ir.Module
|
||||
ip: ir.InsertionPoint
|
||||
@ -309,11 +310,19 @@ class LoweringContext:
|
||||
def replace(self, **kw): return dataclasses.replace(self, **kw)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LoweringRuleContext:
|
||||
"""Per-rule context information for MLIR lowering."""
|
||||
module_context: ModuleContext
|
||||
avals_in: Sequence[core.AbstractValue]
|
||||
avals_out: Any
|
||||
|
||||
def replace(self, **kw): return dataclasses.replace(self, **kw)
|
||||
|
||||
|
||||
if not MYPY:
|
||||
class LoweringRule(Protocol):
|
||||
def __call__(self, ctx: LoweringContext,
|
||||
avals_in: Sequence[core.AbstractValue],
|
||||
avals_out: Sequence[core.AbstractValue],
|
||||
def __call__(self, ctx: LoweringRuleContext,
|
||||
*args: Union[ir.Value, Sequence[ir.Value]],
|
||||
**kw) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]:
|
||||
"""Converts a JAX primitive invocation into MLIR."""
|
||||
@ -367,7 +376,7 @@ def lower_jaxpr_to_module(
|
||||
warnings.warn("Some donated buffers were not usable: {}".format(
|
||||
", ".join(unused_donations)))
|
||||
|
||||
ctx = LoweringContext(platform, axis_env, name_stack)
|
||||
ctx = ModuleContext(platform, axis_env, name_stack)
|
||||
with ctx.context, ir.Location.unknown(ctx.context):
|
||||
# Some clients expect modules to have unique names, e.g., in trace data.
|
||||
# This may or may not be a reasonable assumption.
|
||||
@ -407,7 +416,7 @@ def _set_up_aliases(avals_in, avals_out, donated_args):
|
||||
return input_output_aliases, out_donated_args
|
||||
|
||||
def lower_jaxpr_to_fun(
|
||||
ctx: LoweringContext, name: str, jaxpr: core.ClosedJaxpr, *,
|
||||
ctx: ModuleContext, name: str, jaxpr: core.ClosedJaxpr, *,
|
||||
public: bool = False, replace_units_with_dummy: bool = False,
|
||||
replace_tokens_with_dummy: bool = False,
|
||||
replicated_args: Optional[Sequence[bool]] = None,
|
||||
@ -530,7 +539,7 @@ def lower_jaxpr_to_fun(
|
||||
return func_op
|
||||
|
||||
|
||||
def jaxpr_subcomp(ctx: LoweringContext, jaxpr: core.Jaxpr,
|
||||
def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
consts: Sequence[Sequence[ir.Value]],
|
||||
*args: Sequence[ir.Value]) -> Sequence[Sequence[ir.Value]]:
|
||||
"""Lowers a jaxpr into mHLO, inlined into an existing function.
|
||||
@ -578,8 +587,10 @@ def jaxpr_subcomp(ctx: LoweringContext, jaxpr: core.Jaxpr,
|
||||
f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
|
||||
f"found for platform {ctx.platform}")
|
||||
|
||||
ans = rule(ctx, map(aval, eqn.invars), map(aval, eqn.outvars),
|
||||
*map(_unwrap_singleton_ir_values, in_nodes),
|
||||
rule_ctx = LoweringRuleContext(
|
||||
module_context=ctx, avals_in=map(aval, eqn.invars),
|
||||
avals_out=map(aval, eqn.outvars))
|
||||
ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
|
||||
**eqn.params)
|
||||
|
||||
try:
|
||||
@ -605,15 +616,16 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
|
||||
|
||||
The returned function does not use `avals_out`, so callers may pass any value
|
||||
as `avals_out`."""
|
||||
def f_lowered(ctx, avals_in, avals_out, *args, **params):
|
||||
def f_lowered(ctx, *args, **params):
|
||||
if multiple_results:
|
||||
f = fun
|
||||
else:
|
||||
f = lambda *args, **kw: (fun(*args, **kw),)
|
||||
wrapped_fun = lu.wrap_init(f, params)
|
||||
with core.extend_axis_env_nd(zip(ctx.axis_env.names, ctx.axis_env.sizes)):
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals_in)
|
||||
return jaxpr_subcomp(ctx, jaxpr, _ir_consts(consts),
|
||||
axis_env = ctx.module_context.axis_env
|
||||
with core.extend_axis_env_nd(zip(axis_env.names, axis_env.sizes)):
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
||||
return jaxpr_subcomp(ctx.module_context, jaxpr, _ir_consts(consts),
|
||||
*map(wrap_singleton_ir_values, args))
|
||||
|
||||
return f_lowered
|
||||
@ -634,19 +646,20 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
|
||||
flatten_lowering_ir_args(args))
|
||||
return util.unflatten(call.results, map(len, output_types))
|
||||
|
||||
def _xla_call_lower(ctx, avals_in, avals_out, *args,
|
||||
def _xla_call_lower(ctx, *args,
|
||||
backend=None, name, call_jaxpr, donated_invars, inline=None,
|
||||
device=None):
|
||||
del device, donated_invars, inline # Ignored.
|
||||
return _call_lowering(f"jit_{name}", xla.wrap_name(name, "jit"), call_jaxpr,
|
||||
backend, ctx, avals_in, avals_out, *args)
|
||||
backend, ctx.module_context, ctx.avals_in, ctx.avals_out,
|
||||
*args)
|
||||
|
||||
register_lowering(xla.xla_call_p, _xla_call_lower)
|
||||
|
||||
def _named_call_lowering(ctx, avals_in, avals_out, *args, name, backend=None,
|
||||
def _named_call_lowering(ctx, *args, name, backend=None,
|
||||
call_jaxpr):
|
||||
return _call_lowering(name, name, call_jaxpr, backend, ctx, avals_in,
|
||||
avals_out, *args)
|
||||
return _call_lowering(name, name, call_jaxpr, backend, ctx.module_context,
|
||||
ctx.avals_in, ctx.avals_out, *args)
|
||||
|
||||
register_lowering(core.named_call_p, _named_call_lowering)
|
||||
register_lowering(core.call_p, partial(_named_call_lowering, name="core_call"))
|
||||
@ -658,18 +671,17 @@ def full_like_aval(value, aval: core.ShapedArray) -> ir.Value:
|
||||
return mhlo.BroadcastOp(aval_to_ir_type(aval), zero,
|
||||
dense_int_elements(aval.shape)).result
|
||||
|
||||
def zeros_like_lowering(ctx, avals_in, avals_out, x):
|
||||
aval, = avals_in
|
||||
def zeros_like_lowering(ctx, x):
|
||||
aval, = ctx.avals_in
|
||||
assert isinstance(aval, core.ShapedArray), aval
|
||||
return [full_like_aval(0, aval)]
|
||||
register_lowering(ad_util.zeros_like_p, zeros_like_lowering)
|
||||
|
||||
def add_jaxvals_lowering(ctx, avals_in, avals_out, x, y):
|
||||
def add_jaxvals_lowering(ctx, x, y):
|
||||
return mhlo.AddOp(x, y).results
|
||||
register_lowering(ad_util.add_jaxvals_p, add_jaxvals_lowering)
|
||||
|
||||
register_lowering(ad_util.stop_gradient_p,
|
||||
lambda ctx, avals_in, avals_out, x: [x])
|
||||
register_lowering(ad_util.stop_gradient_p, lambda ctx, x: [x])
|
||||
|
||||
|
||||
def _minmax_mhlo(op, cmp, x, y):
|
||||
@ -735,23 +747,24 @@ def set_sharding(op, sharding_proto: xc.OpSharding):
|
||||
|
||||
# MLIR lowerings for lax primitives
|
||||
|
||||
def xla_fallback_lowering(prim: core.Primitive, ctx: LoweringContext,
|
||||
avals_in, avals_out, *args, **params):
|
||||
def xla_fallback_lowering(prim: core.Primitive, ctx: LoweringRuleContext, *args,
|
||||
**params):
|
||||
module_ctx = ctx.module_context
|
||||
xla_computation = xla.primitive_subcomputation(
|
||||
ctx.platform, ctx.axis_env, prim, *avals_in, **params)
|
||||
module_ctx.platform, module_ctx.axis_env, prim, *ctx.avals_in, **params)
|
||||
submodule_str = xc._xla.mlir.xla_computation_to_mlir_module(xla_computation)
|
||||
submodule = ir.Module.parse(submodule_str)
|
||||
callee_name = None
|
||||
for op in submodule.body.operations:
|
||||
ctx.module.body.append(op)
|
||||
module_ctx.module.body.append(op)
|
||||
if op.name.value == "main":
|
||||
op.attributes["sym_name"] = ir.StringAttr.get(f"xla_fallback_{prim.name}")
|
||||
callee_name = ir.StringAttr(ctx.symbol_table.insert(op)).value
|
||||
callee_name = ir.StringAttr(module_ctx.symbol_table.insert(op)).value
|
||||
op.attributes["sym_visibility"] = ir.StringAttr.get("private")
|
||||
else:
|
||||
ctx.symbol_table.insert(op)
|
||||
module_ctx.symbol_table.insert(op)
|
||||
|
||||
output_types = map(aval_to_ir_types, avals_out)
|
||||
output_types = map(aval_to_ir_types, ctx.avals_out)
|
||||
flat_output_types = util.flatten(output_types)
|
||||
output_type = (ir.TupleType.get_tuple(flat_output_types)
|
||||
if prim.multiple_results else flat_output_types[0])
|
||||
|
@ -1718,19 +1718,20 @@ def _mhlo_unshard(aval, axis_env, out_axis, xs, platform):
|
||||
raise TypeError(aval)
|
||||
|
||||
|
||||
def _pmap_lowering(ctx, avals_in, avals_out, *in_nodes, axis_name,
|
||||
def _pmap_lowering(ctx, *in_nodes, axis_name,
|
||||
axis_size, global_axis_size, devices, name,
|
||||
call_jaxpr, backend=None, in_axes, out_axes,
|
||||
donated_invars, global_arg_shapes):
|
||||
del donated_invars # Unused.
|
||||
xla.check_backend_matches(backend, ctx.platform)
|
||||
xla.check_backend_matches(backend, ctx.module_context.platform)
|
||||
# We in-line here rather than generating a Call HLO as in the xla_call
|
||||
# translation rule just because the extra tuple stuff is a pain.
|
||||
if ctx.axis_env.names and devices is not None:
|
||||
if ctx.module_context.axis_env.names and devices is not None:
|
||||
raise ValueError("Nested pmap with explicit devices argument.")
|
||||
if global_axis_size is None:
|
||||
global_axis_size = axis_size
|
||||
new_env = xla.extend_axis_env(ctx.axis_env, axis_name, global_axis_size)
|
||||
new_env = xla.extend_axis_env(ctx.module_context.axis_env, axis_name,
|
||||
global_axis_size)
|
||||
# Shard the in_nodes that are mapped
|
||||
in_avals = [v.aval for v in call_jaxpr.invars]
|
||||
in_nodes_sharded = (
|
||||
@ -1739,14 +1740,15 @@ def _pmap_lowering(ctx, avals_in, avals_out, *in_nodes, axis_name,
|
||||
for aval, in_node, in_axis in zip(in_avals, in_nodes, in_axes))
|
||||
|
||||
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
||||
sub_ctx = ctx.replace(
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
axis_env=new_env,
|
||||
name_stack=xla.extend_name_stack(ctx.name_stack,
|
||||
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
|
||||
util.wrap_name(name, 'pmap')))
|
||||
sharded_outs = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, (),
|
||||
*in_nodes_sharded)
|
||||
out_avals = [v.aval for v in call_jaxpr.outvars]
|
||||
outs = [_mhlo_unshard(aval, new_env, out_axis, shard, platform=ctx.platform)
|
||||
outs = [_mhlo_unshard(aval, new_env, out_axis, shard,
|
||||
platform=ctx.module_context.platform)
|
||||
for aval, out_axis, shard in zip(out_avals, out_axes, sharded_outs)]
|
||||
return outs
|
||||
|
||||
|
@ -214,7 +214,7 @@ def _sharded_jit_translation_rule(ctx, avals_in, avals_out, *in_nodes,
|
||||
xops.Call(ctx.builder, subc, list(in_nodes)))
|
||||
|
||||
|
||||
def _sharded_jit_lowering(ctx, avals_in, avals_out, *in_nodes,
|
||||
def _sharded_jit_lowering(ctx, *in_nodes,
|
||||
in_parts, out_parts_thunk, nparts,
|
||||
name, call_jaxpr, local_in_parts,
|
||||
local_out_parts_thunk, local_nparts):
|
||||
@ -233,12 +233,12 @@ def _sharded_jit_lowering(ctx, avals_in, avals_out, *in_nodes,
|
||||
else:
|
||||
args.append(ns)
|
||||
|
||||
sub_ctx = ctx.replace(
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
name_stack=extend_name_stack(wrap_name(name, "sharded_jit")))
|
||||
fn = mlir.lower_jaxpr_to_fun(sub_ctx, f"sharded_jit_{name}",
|
||||
core.ClosedJaxpr(call_jaxpr, ()))
|
||||
|
||||
output_types = safe_map(mlir.aval_to_ir_types, avals_out)
|
||||
output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out)
|
||||
flat_output_types = util.flatten(output_types)
|
||||
call = std.CallOp(flat_output_types,
|
||||
ir.FlatSymbolRefAttr.get(fn.name.value),
|
||||
@ -472,8 +472,7 @@ ad.deflinear2(sharding_constraint_p,
|
||||
xla.register_translation(sharding_constraint_p,
|
||||
_sharding_constraint_translation_rule)
|
||||
|
||||
def _sharding_constraint_lowering(ctx, avals_in, avals_out, x_node,
|
||||
partitions):
|
||||
def _sharding_constraint_lowering(ctx, x_node, partitions):
|
||||
return [mlir.wrap_with_sharding_op(x_node, xla.sharding_to_proto(partitions))]
|
||||
|
||||
mlir.register_lowering(sharding_constraint_p, _sharding_constraint_lowering)
|
||||
|
@ -170,7 +170,7 @@ def _sp_indices_translation_rule(ctx, avals_in, avals_out, data, indices):
|
||||
# because it leads to infinite recursion.
|
||||
xla.register_translation(sp_indices_p, _sp_indices_translation_rule)
|
||||
|
||||
def _sp_indices_mhlo_lowering(ctx, avals_in, avals_out, data_and_indices):
|
||||
def _sp_indices_mhlo_lowering(ctx, data_and_indices):
|
||||
return [data_and_indices[1]]
|
||||
|
||||
mlir.register_lowering(sp_indices_p, _sp_indices_mhlo_lowering)
|
||||
@ -192,7 +192,7 @@ def _sp_data_translation_rule(ctx, avals_in, avals_out, data, indices):
|
||||
# because it leads to infinite recursion.
|
||||
xla.register_translation(sp_data_p, _sp_data_translation_rule)
|
||||
|
||||
def _sp_data_mhlo_lowering(ctx, avals_in, avals_out, data_and_indices):
|
||||
def _sp_data_mhlo_lowering(ctx, data_and_indices):
|
||||
return [data_and_indices[0]]
|
||||
|
||||
mlir.register_lowering(sp_data_p, _sp_data_mhlo_lowering)
|
||||
|
Loading…
x
Reference in New Issue
Block a user