diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 52603b3ce..0d2f65fde 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -738,7 +738,7 @@ ad.deflinear2(device_put_p, lambda cotangent, _, **kwargs: [cotangent]) masking.defvectorized(device_put_p) batching.defvectorized(device_put_p) -def _device_put_lowering(ctx, avals_in, avals_out, x, *, device): +def _device_put_lowering(ctx, x, *, device): return [x] diff --git a/jax/_src/lax/control_flow.py b/jax/_src/lax/control_flow.py index b0fcb35ec..4a6013869 100644 --- a/jax/_src/lax/control_flow.py +++ b/jax/_src/lax/control_flow.py @@ -640,7 +640,7 @@ def _pred_bcast_select_mhlo( return mhlo.SelectOp(bcast_pred, x, y).results -def _while_lowering(ctx, avals_in, avals_out, *args, cond_jaxpr, +def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): pred_aval = cond_jaxpr.out_avals[0] batched = bool(pred_aval.shape) @@ -651,7 +651,7 @@ def _while_lowering(ctx, avals_in, avals_out, *args, cond_jaxpr, # computation), we build XLA computations that handle the tuple munging before # generating a Call into the computations formed from the jaxprs. - loop_carry_types = _map(mlir.aval_to_ir_types, avals_in) + loop_carry_types = _map(mlir.aval_to_ir_types, ctx.avals_in) flat_loop_carry_types = util.flatten(loop_carry_types) loop_carry_tuple_type = ir.TupleType.get_tuple(flat_loop_carry_types) @@ -668,15 +668,17 @@ def _while_lowering(ctx, avals_in, avals_out, *args, cond_jaxpr, for i, input_type in enumerate(flat_loop_carry_types)] cond_args = util.unflatten(flat_cond_args, _map(len, loop_carry_types)) x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts]) - cond_ctx = ctx.replace( - name_stack=xla.extend_name_stack(ctx.name_stack, 'cond')) + cond_ctx = ctx.module_context.replace( + name_stack=xla.extend_name_stack(ctx.module_context.name_stack, 'cond')) (pred,), = mlir.jaxpr_subcomp( cond_ctx, cond_jaxpr.jaxpr, _map(mlir.ir_constants, cond_jaxpr.consts), *(x + z)) if batched: + pred_ctx = mlir.LoweringRuleContext( + module_context=ctx.module_context, avals_in=[pred_aval], + avals_out=[pred_aval.update(shape=())]) pred, = lax._unary_reduce_lower( - mhlo.OrOp, lambda dtype: np.array(False, dtype), ctx, [pred_aval], - [pred_aval.update(shape=())], pred, + mhlo.OrOp, lambda dtype: np.array(False, dtype), pred_ctx, pred, axes=tuple(range(len(pred_aval.shape)))) mhlo.ReturnOp([pred]) @@ -689,14 +691,15 @@ def _while_lowering(ctx, avals_in, avals_out, *args, cond_jaxpr, for i, input_type in enumerate(flat_loop_carry_types)] body_args = util.unflatten(flat_body_args, _map(len, loop_carry_types)) x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts]) - body_ctx = ctx.replace( - name_stack=xla.extend_name_stack(ctx.name_stack, 'body')) + body_ctx = ctx.module_context.replace( + name_stack=xla.extend_name_stack(ctx.module_context.name_stack, 'body')) new_z = mlir.jaxpr_subcomp( body_ctx, body_jaxpr.jaxpr, _map(mlir.ir_constants, body_jaxpr.consts), *(y + z)) if batched: - body_pred_ctx = ctx.replace( - name_stack=xla.extend_name_stack(ctx.name_stack, 'body_pred')) + body_pred_ctx = ctx.module_context.replace( + name_stack=xla.extend_name_stack(ctx.module_context.name_stack, + 'body_pred')) (body_pred,), = mlir.jaxpr_subcomp( body_pred_ctx, cond_jaxpr.jaxpr, _map(mlir.ir_constants, cond_jaxpr.consts), *(x + z)) @@ -1318,11 +1321,11 @@ xla.register_translation(cond_p, _cond_translation_rule, initial_style=True) core.custom_typechecks[cond_p] = _cond_typecheck pe.partial_eval_jaxpr_custom_rules[cond_p] = pe.partial_eval_jaxpr_custom_rule_not_implemented -def _cond_lowering(ctx, avals_in, avals_out, index, *args, branches, linear): +def _cond_lowering(ctx, index, *args, branches, linear): del linear # Unused. - arg_avals = avals_in[1:] + arg_avals = ctx.avals_in[1:] input_types = _map(mlir.aval_to_ir_types, arg_avals) - output_types = _map(mlir.aval_to_ir_types, avals_out) + output_types = _map(mlir.aval_to_ir_types, ctx.avals_out) flat_input_types = util.flatten(input_types) flat_output_types = util.flatten(output_types) input_tuple_type = ir.TupleType.get_tuple(flat_input_types) @@ -1340,8 +1343,8 @@ def _cond_lowering(ctx, avals_in, avals_out, index, *args, branches, linear): mlir.i32_attr(i)).result for i, input_type in enumerate(flat_input_types)] unflattened_args = util.unflatten(args, _map(len, input_types)) - out_vals = mlir.jaxpr_subcomp(ctx, jaxpr.jaxpr, jaxpr.consts, - *unflattened_args) + out_vals = mlir.jaxpr_subcomp(ctx.module_context, jaxpr.jaxpr, + jaxpr.consts, *unflattened_args) out = mhlo.TupleOp(output_tuple_type, util.flatten(out_vals)).results mhlo.ReturnOp(out) diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index a9fa58d6e..3f4332acc 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -723,12 +723,12 @@ def _complex_mul(mul, x, y): _real_dtype = lambda dtype: np.finfo(dtype).dtype def _conv_general_dilated_lower( - ctx, avals_in, avals_out, lhs, rhs, *, window_strides, padding, + ctx, lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, precision, preferred_element_type, expand_complex_convolutions=False, **unused_kwargs): - lhs_aval, rhs_aval = avals_in - aval_out, = avals_out + lhs_aval, rhs_aval = ctx.avals_in + aval_out, = ctx.avals_out assert isinstance(dimension_numbers, ConvDimensionNumbers) dtype = lhs_aval.dtype if expand_complex_convolutions and np.issubdtype(dtype, np.complexfloating): @@ -746,7 +746,7 @@ def _conv_general_dilated_lower( batch_group_count=batch_group_count, precision=precision, preferred_element_type=preferred_element_type)), multiple_results=False) - return complex_conv(ctx, avals_in, avals_out, lhs, rhs) + return complex_conv(ctx, lhs, rhs) lhs_spec, rhs_spec, out_spec = dimension_numbers dnums = mhlo.ConvDimensionNumbers.get( diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index d567cd747..93b1fd9c2 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1444,9 +1444,7 @@ def broadcast_mhlo( out.append(arg) return out -def _nary_lower_mhlo(op: Callable, ctx: mlir.LoweringContext, - avals_in: Sequence[core.ShapedArray], - avals_out: Sequence[core.ShapedArray], +def _nary_lower_mhlo(op: Callable, ctx, *args: Union[ir.Value, Sequence[ir.Value]], explicit_type=False, **params): """Lowers an elementwise operator to its MHLO/CHLO equivalent. @@ -1456,8 +1454,8 @@ def _nary_lower_mhlo(op: Callable, ctx: mlir.LoweringContext, provided? """ del params - aval_out, = avals_out - broadcasted_args = broadcast_mhlo(aval_out, avals_in, args) + aval_out, = ctx.avals_out + broadcasted_args = broadcast_mhlo(aval_out, ctx.avals_in, args) if explicit_type: return op(mlir.aval_to_ir_type(aval_out), *broadcasted_args).results else: @@ -1494,8 +1492,8 @@ def _sign_translation_rule(ctx, avals_in, avals_out, x): sign_p = standard_unop(_num, 'sign', translation_rule=_sign_translation_rule) ad.defjvp_zero(sign_p) -def _sign_lower_mhlo(ctx, avals_in, avals_out, x): - x_aval, = avals_in +def _sign_lower_mhlo(ctx, x): + x_aval, = ctx.avals_in if dtypes.issubdtype(x_aval.dtype, np.unsignedinteger): return mhlo.SelectOp( mhlo.CompareOp( @@ -1546,14 +1544,14 @@ round_p = standard_unop(_float, 'round') xla.register_translation(round_p, _round_translation_rule) ad.defjvp_zero(round_p) -def _round_lower(ctx, avals_in, avals_out, x, *, rounding_method): +def _round_lower(ctx, x, *, rounding_method): if rounding_method is RoundingMethod.AWAY_FROM_ZERO: return mhlo.RoundOp(x).results else: assert rounding_method is RoundingMethod.TO_NEAREST_EVEN round_nearest = mlir.lower_fun(_round_to_nearest_even, multiple_results=False) - return round_nearest(ctx, avals_in, avals_out, x) + return round_nearest(ctx, x) mlir.register_lowering(round_p, _round_lower) is_finite_p = unop(_fixed_dtype(np.bool_), _float, 'is_finite') @@ -2113,10 +2111,11 @@ ad.defjvp_zero(shift_right_logical_p) mlir.register_lowering(shift_right_logical_p, partial(_nary_lower_mhlo, mhlo.ShiftRightLogicalOp)) -def _compare_lower_mhlo(direction: str, ctx, avals_in, avals_out, x, y): - x_aval, y_aval = avals_in - aval_out, = avals_out - x, y = broadcast_mhlo(aval_out.update(dtype=x_aval.dtype), avals_in, (x, y)) +def _compare_lower_mhlo(direction: str, ctx, x, y): + x_aval, y_aval = ctx.avals_in + aval_out, = ctx.avals_out + x, y = broadcast_mhlo(aval_out.update(dtype=x_aval.dtype), ctx.avals_in, + (x, y)) if dtypes.issubdtype(x_aval.dtype, np.inexact): compare_type = "FLOAT" elif dtypes.issubdtype(x_aval.dtype, np.signedinteger): @@ -2222,10 +2221,9 @@ pe.forwarding_rules[convert_element_type_p] = _convert_elt_type_fwd_rule def _real_dtype(dtype): return np.finfo(dtype).dtype -def _convert_element_type_lower(ctx, avals_in, avals_out, operand, *, - new_dtype, weak_type): - aval_in, = avals_in - aval_out, = avals_out +def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type): + aval_in, = ctx.avals_in + aval_out, = ctx.avals_out if (dtypes.issubdtype(aval_in.dtype, np.complexfloating) and not dtypes.issubdtype(new_dtype, np.complexfloating)): operand = mhlo.RealOp(operand).result @@ -2260,9 +2258,8 @@ ad.defjvp_zero(bitcast_convert_type_p) batching.defvectorized(bitcast_convert_type_p) masking.defvectorized(bitcast_convert_type_p) -def _bitcast_convert_type_lower(ctx, avals_in, avals_out, operand, *, - new_dtype): - aval_out, = avals_out +def _bitcast_convert_type_lower(ctx, operand, *, new_dtype): + aval_out, = ctx.avals_out return mhlo.BitcastConvertOp(mlir.aval_to_ir_type(aval_out), operand).results mlir.register_lowering(bitcast_convert_type_p, _bitcast_convert_type_lower) @@ -2522,15 +2519,15 @@ def precision_attr(precision: PrecisionType) -> ir.ArrayAttr: precision = (precision, precision) return ir.ArrayAttr.get([ir.StringAttr.get(str(p)) for p in precision]) -def _dot_general_lower(ctx, avals_in, avals_out, lhs, rhs, *, dimension_numbers, +def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, precision, preferred_element_type: Optional[np.dtype]): del preferred_element_type # Implied by the output aval - lhs_aval, rhs_aval = avals_in - aval_out, = avals_out + lhs_aval, rhs_aval = ctx.avals_in + aval_out, = ctx.avals_out (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers # TODO(b/195364460): Work around slow XLA/CPU implementation of float16 matmul - if ctx.platform == "cpu": + if ctx.module_context.platform == "cpu": if lhs_aval.dtype == np.float16: f32 = mlir.dtype_to_ir_type(np.dtype(np.float32)) lhs = mhlo.ConvertOp(ir.RankedTensorType.get(lhs_aval.shape, f32), @@ -2614,10 +2611,9 @@ ad.deflinear2(broadcast_in_dim_p, _broadcast_in_dim_transpose_rule) batching.primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule pe.forwarding_rules[broadcast_in_dim_p] = _broadcast_in_dim_fwd_rule -def _broadcast_in_dim_lower(ctx, avals_in, avals_out, x, *, shape, - broadcast_dimensions): +def _broadcast_in_dim_lower(ctx, x, *, shape, broadcast_dimensions): del shape - aval_out, = avals_out + aval_out, = ctx.avals_out return mhlo.BroadcastInDimOp( mlir.aval_to_ir_type(aval_out), x, mlir.dense_int_elements(broadcast_dimensions) @@ -2760,7 +2756,7 @@ ad.deflinear2(concatenate_p, _concatenate_transpose_rule) ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule -def _concatenate_lower(ctx, avals_in, avals_out, *xs, dimension): +def _concatenate_lower(ctx, *xs, dimension): return mhlo.ConcatenateOp(xs, mlir.i64_attr(dimension)).results mlir.register_lowering(concatenate_p, _concatenate_lower) @@ -2852,8 +2848,8 @@ ad.deflinear2(pad_p, _pad_transpose) batching.primitive_batchers[pad_p] = _pad_batch_rule masking.masking_rules[pad_p] = _pad_masking_rule -def _pad_lower(ctx, avals_in, avals_out, x, padding_value, *, padding_config): - aval_out, = avals_out +def _pad_lower(ctx, x, padding_value, *, padding_config): + aval_out, = ctx.avals_out low, high, interior = util.unzip3(padding_config) return mhlo.PadOp(mlir.aval_to_ir_type(aval_out), x, padding_value, mlir.dense_int_elements(low), @@ -2910,9 +2906,9 @@ squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule, ad.deflinear2(squeeze_p, _squeeze_transpose_rule) batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule -def _squeeze_lower(ctx, avals_in, avals_out, operand, *, dimensions): +def _squeeze_lower(ctx, operand, *, dimensions): del dimensions # Implied by the output aval. - aval_out, = avals_out + aval_out, = ctx.avals_out return mhlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), operand).results mlir.register_lowering(squeeze_p, _squeeze_lower) @@ -3014,9 +3010,9 @@ ad.deflinear2(reshape_p, _reshape_transpose_rule) batching.primitive_batchers[reshape_p] = _reshape_batch_rule masking.masking_rules[reshape_p] = _reshape_masking_rule -def _reshape_lower(ctx, avals_in, avals_out, x, *, new_sizes, dimensions): - aval_in, = avals_in - aval_out, = avals_out +def _reshape_lower(ctx, x, *, new_sizes, dimensions): + aval_in, = ctx.avals_in + aval_out, = ctx.avals_out if dimensions is not None: aval = core.ShapedArray(np.take(aval_in.shape, dimensions), aval_in.dtype) x = mhlo.TransposeOp(mlir.aval_to_ir_type(aval), x, @@ -3045,7 +3041,7 @@ rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev') ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)]) batching.primitive_batchers[rev_p] = _rev_batch_rule -def _rev_lower(ctx, avals_in, avals_out, x, *, dimensions): +def _rev_lower(ctx, x, *, dimensions): return mhlo.ReverseOp(x, mlir.dense_int_elements(dimensions)).results mlir.register_lowering(rev_p, _rev_lower) @@ -3076,8 +3072,8 @@ ad.deflinear2(transpose_p, batching.primitive_batchers[transpose_p] = _transpose_batch_rule masking.masking_rules[transpose_p] = _transpose_masking_rule -def _transpose_lower(ctx, avals_in, avals_out, x, *, permutation): - aval_out, = avals_out +def _transpose_lower(ctx, x, *, permutation): + aval_out, = ctx.avals_out return mhlo.TransposeOp(mlir.aval_to_ir_type(aval_out), x, mlir.dense_int_elements(permutation)).results mlir.register_lowering(transpose_p, _transpose_lower) @@ -3338,18 +3334,17 @@ xla.register_translation(reduce_p, _reduce_translation_rule) batching.primitive_batchers[reduce_p] = _reduce_batch_rule ad.primitive_jvps[reduce_p] = _reduce_jvp_rule -def _reduce_lower(ctx, avals_in, avals_out, *values, computation, jaxpr, - consts, dimensions): - assert all(isinstance(x, core.ShapedArray) for x in avals_in), avals_in +def _reduce_lower(ctx, *values, computation, jaxpr, consts, dimensions): + assert all(isinstance(x, core.ShapedArray) for x in ctx.avals_in), ctx.avals_in operands, init_values = util.split_list(values, [len(values) // 2]) - init_value_avals = avals_in[len(values) // 2:] - op = mhlo.ReduceOp([mlir.aval_to_ir_type(aval) for aval in avals_out], + init_value_avals = ctx.avals_in[len(values) // 2:] + op = mhlo.ReduceOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], operands, init_values, mlir.dense_int_elements(dimensions)) ir_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals] reducer = op.regions[0].blocks.append(*(ir_types + ir_types)) with ir.InsertionPoint(reducer): - ctx = ctx.replace(name_stack='') - out_nodes = mlir.jaxpr_subcomp(ctx, jaxpr, consts, + reducer_ctx = ctx.module_context.replace(name_stack='') + out_nodes = mlir.jaxpr_subcomp(reducer_ctx, jaxpr, consts, *([a] for a in reducer.arguments)) mhlo.ReturnOp(util.flatten(out_nodes)) return op.results @@ -3566,9 +3561,8 @@ reduce_and_p = standard_primitive(_reduce_logical_shape_rule, _fixed_dtype(np.bo batching.defreducer(reduce_and_p) -def _unary_reduce_lower(reducer, unit_factory, ctx, avals_in, avals_out, x, *, - axes): - aval_out, = avals_out +def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes): + aval_out, = ctx.avals_out dtype = aval_out.dtype op = mhlo.ReduceOp([mlir.aval_to_ir_type(aval_out)], [x], mlir.ir_constants(unit_factory(aval_out.dtype)), @@ -3613,9 +3607,8 @@ reduce_precision_p = standard_primitive( batching.defvectorized(reduce_precision_p) masking.defvectorized(reduce_precision_p) -def _reduce_precision_lower(ctx, avals_in, avals_out, operand, *, exponent_bits, - mantissa_bits): - aval_out, = avals_out +def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits): + aval_out, = ctx.avals_out return mhlo.ReducePrecisionOp(mlir.aval_to_ir_type(aval_out), operand, mlir.i32_attr(exponent_bits), mlir.i32_attr(mantissa_bits)).results @@ -3754,22 +3747,24 @@ ad.primitive_jvps[sort_p] = _sort_jvp batching.primitive_batchers[sort_p] = _sort_batch_rule -def _sort_lower(ctx, avals_in, avals_out, *operands, dimension, is_stable, - num_keys): - assert all(isinstance(x, core.ShapedArray) for x in avals_in), avals_in - sort = mhlo.SortOp([mlir.aval_to_ir_type(aval) for aval in avals_out], +def _sort_lower(ctx, *operands, dimension, is_stable, num_keys): + assert all(isinstance(x, core.ShapedArray) for x in ctx.avals_in), ctx.avals_in + sort = mhlo.SortOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], mlir.flatten_lowering_ir_args(operands), mlir.i64_attr(dimension), ir.BoolAttr.get(is_stable)) - scalar_avals = [aval.update(shape=()) for aval in avals_in] + scalar_avals = [aval.update(shape=()) for aval in ctx.avals_in] scalar_types = safe_map(mlir.aval_to_ir_type, scalar_avals) comparator = sort.comparator.blocks.append( *util.flatten(zip(scalar_types, scalar_types))) with ir.InsertionPoint(comparator): lower_comparator = mlir.lower_fun(partial(_sort_lt_comparator), multiple_results=False) - out = lower_comparator(ctx, util.flatten(zip(scalar_avals, scalar_avals)), - [core.ShapedArray((), np.bool_)], - *[[a] for a in comparator.arguments], + sub_ctx = mlir.LoweringRuleContext( + module_context = ctx.module_context, + avals_in=util.flatten(zip(scalar_avals, scalar_avals)), + avals_out=[core.ShapedArray((), np.bool_)]) + + out = lower_comparator(sub_ctx, *[[a] for a in comparator.arguments], num_keys=num_keys) mhlo.ReturnOp(util.flatten(out)) return sort.results @@ -3870,8 +3865,8 @@ create_token_p.def_abstract_eval(lambda *_: abstract_token) xla.register_translation(create_token_p, lambda ctx, *_: [xops.CreateToken(ctx.builder)]) -def _create_token_lowering(ctx, avals_in, avals_out, *operands): - aval_out, = avals_out +def _create_token_lowering(ctx, *operands): + aval_out, = ctx.avals_out return mhlo.CreateTokenOp(mlir.aval_to_ir_type(aval_out)).results mlir.register_lowering(create_token_p, _create_token_lowering) @@ -3897,8 +3892,8 @@ after_all_p.def_impl(partial(xla.apply_primitive, after_all_p)) after_all_p.def_abstract_eval(_after_all_abstract_eval) xla.register_translation(after_all_p, _after_all_translation_rule) -def _after_all_lowering(ctx, avals_in, avals_out, *operands): - aval_out, = avals_out +def _after_all_lowering(ctx, *operands): + aval_out, = ctx.avals_out return mhlo.AfterAllOp(mlir.aval_to_ir_type(aval_out), operands).results mlir.register_lowering(after_all_p, _after_all_lowering) @@ -3956,8 +3951,8 @@ infeed_p.def_abstract_eval(_infeed_abstract_eval) xla.register_translation(infeed_p, _infeed_translation_rule) -def _infeed_lowering(ctx, avals_in, avals_out, token, *, shapes, partitions): - output_types = safe_map(mlir.aval_to_ir_types, avals_out[:-1]) +def _infeed_lowering(ctx, token, *, shapes, partitions): + output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out[:-1]) flat_output_types = util.flatten(output_types) output_tuple_type = ir.TupleType.get_tuple(flat_output_types) # TODO(phawkins): verify `shapes` have a major-to-minor layout. @@ -4020,9 +4015,9 @@ outfeed_p.def_abstract_eval(_outfeed_abstract_eval) xla.register_translation(outfeed_p, _outfeed_translation_rule) -def _outfeed_lowering(ctx, avals_in, avals_out, token, *xs, partitions): - token_aval = avals_in[0] - xs_avals = avals_in[1:] +def _outfeed_lowering(ctx, token, *xs, partitions): + token_aval = ctx.avals_in[0] + xs_avals = ctx.avals_in[1:] input_types = map(mlir.aval_to_ir_types, xs_avals) flat_input_types = util.flatten(input_types) input_tuple_type = ir.TupleType.get_tuple(flat_input_types) @@ -4070,10 +4065,10 @@ rng_uniform_p.def_impl(partial(xla.apply_primitive, rng_uniform_p)) rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval) xla.register_translation(rng_uniform_p, _rng_uniform_translation_rule) -def _rng_uniform_lowering(ctx, avals_in, avals_out, a, b, *, shape): - aval_out, = avals_out +def _rng_uniform_lowering(ctx, a, b, *, shape): + aval_out, = ctx.avals_out shape, = mlir.ir_constants(np.array(aval_out.shape, np.int64), - canonicalize_types=False) + canonicalize_types=False) return mhlo.RngUniformOp(a, b, shape).results mlir.register_lowering(rng_uniform_p, _rng_uniform_lowering) @@ -4177,9 +4172,9 @@ iota_p.def_impl(partial(xla.apply_primitive, iota_p)) iota_p.def_abstract_eval(_iota_abstract_eval) xla.register_translation(iota_p, _iota_translation_rule) -def _iota_lower(ctx, avals_in, avals_out, *, dtype, shape, dimension): +def _iota_lower(ctx, *, dtype, shape, dimension): del dtype, shape - aval_out, = avals_out + aval_out, = ctx.avals_out return mhlo.IotaOp(mlir.aval_to_ir_type(aval_out), mlir.i64_attr(dimension)).results mlir.register_lowering(iota_p, _iota_lower) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index c53987566..9a5873739 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -747,9 +747,8 @@ ad.deflinear2(slice_p, _slice_transpose_rule) batching.primitive_batchers[slice_p] = _slice_batching_rule masking.masking_rules[slice_p] = _slice_masking_rule -def _slice_lower(ctx, avals_in, avals_out, x, *, start_indices, - limit_indices, strides): - aval_out, = avals_out +def _slice_lower(ctx, x, *, start_indices, limit_indices, strides): + aval_out, = ctx.avals_out strides = strides or [1] * len(start_indices) return mhlo.SliceOp(x, mlir.dense_int_elements(start_indices), @@ -848,9 +847,8 @@ ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp # TODO ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule -def _dynamic_slice_lower(ctx, avals_in, avals_out, x, *start_indices, - slice_sizes): - aval_out, = avals_out +def _dynamic_slice_lower(ctx, x, *start_indices, slice_sizes): + aval_out, = ctx.avals_out return mhlo.DynamicSliceOp(mlir.aval_to_ir_type(aval_out), x, start_indices, mlir.dense_int_elements(slice_sizes)).results @@ -947,9 +945,8 @@ ad.primitive_transposes[dynamic_update_slice_p] = \ batching.primitive_batchers[dynamic_update_slice_p] = \ _dynamic_update_slice_batching_rule -def _dynamic_update_slice_lower(ctx, avals_in, avals_out, x, update, - *start_indices): - aval_out, = avals_out +def _dynamic_update_slice_lower(ctx, x, update, *start_indices): + aval_out, = ctx.avals_out return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update, start_indices).results @@ -1280,14 +1277,14 @@ batching.primitive_batchers[gather_p] = _gather_batching_rule -def _gather_lower(ctx, avals_in, avals_out, operand, indices, *, +def _gather_lower(ctx, operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): - aval_out, = avals_out + aval_out, = ctx.avals_out if mode == GatherScatterMode.FILL_OR_DROP: gather_fill_fn = mlir.lower_fun(_gather_fill, multiple_results=False) return gather_fill_fn( - ctx, avals_in, avals_out, operand, indices, + ctx, operand, indices, dimension_numbers=dimension_numbers, slice_sizes=slice_sizes, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, fill_value=fill_value, output_shape=aval_out.shape) @@ -1296,7 +1293,7 @@ def _gather_lower(ctx, avals_in, avals_out, operand, indices, *, GatherScatterMode.CLIP), mode dnums = mhlo.GatherDimensionNumbers.get( collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims), - index_vector_dim=len(avals_in[1].shape) - 1, + index_vector_dim=len(ctx.avals_in[1].shape) - 1, offset_dims=list(dimension_numbers.offset_dims), start_index_map=list(dimension_numbers.start_index_map)) return mhlo.GatherOp(operand, indices, dnums, @@ -1945,30 +1942,30 @@ batching.primitive_batchers[scatter_p] = ( -def _scatter_lower(ctx, avals_in, avals_out, operand, indices, updates, *, +def _scatter_lower(ctx, operand, indices, updates, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode): if mode == GatherScatterMode.CLIP: clip_fn = mlir.lower_fun(_clamp_scatter_indices, multiple_results=False) - (indices,), = clip_fn(ctx, avals_in, None, operand, indices, updates, - dnums=dimension_numbers) + (indices,), = clip_fn(ctx.replace(avals_out=None), operand, indices, + updates, dnums=dimension_numbers) - aval_out, = avals_out + aval_out, = ctx.avals_out dnums = dimension_numbers scatter_dnums = mhlo.ScatterDimensionNumbers.get( update_window_dims=list(dnums.update_window_dims), inserted_window_dims=list(dnums.inserted_window_dims), scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims), - index_vector_dim=len(avals_in[1].shape) - 1) + index_vector_dim=len(ctx.avals_in[1].shape) - 1) op = mhlo.ScatterOp(mlir.aval_to_ir_type(aval_out), operand, indices, updates, scatter_dnums, ir.BoolAttr.get(indices_are_sorted), ir.BoolAttr.get(unique_indices)) scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), aval_out.dtype)) update = op.update_computation.blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(update): - ctx = ctx.replace(name_stack='') + update_ctx = ctx.module_context.replace(name_stack='') out_nodes = mlir.jaxpr_subcomp( - ctx, update_jaxpr, update_consts, + update_ctx, update_jaxpr, update_consts, (update.arguments[0],), (update.arguments[1],)) mhlo.ReturnOp(util.flatten(out_nodes)) return op.results @@ -1982,12 +1979,12 @@ mlir.register_lowering(scatter_max_p, _scatter_lower) def _real_dtype(dtype): return np.finfo(dtype).dtype -def _scatter_add_lower_gpu(ctx, avals_in, avals_out, operand, indices, updates, +def _scatter_add_lower_gpu(ctx, operand, indices, updates, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode): - operand_aval_in, _, updates_aval_in = avals_in + operand_aval_in, _, updates_aval_in = ctx.avals_in if operand_aval_in.dtype != np.complex128: - return _scatter_lower(ctx, avals_in, avals_out, operand, indices, updates, + return _scatter_lower(ctx, operand, indices, updates, update_jaxpr=update_jaxpr, update_consts=update_consts, dimension_numbers=dimension_numbers, @@ -1996,16 +1993,16 @@ def _scatter_add_lower_gpu(ctx, avals_in, avals_out, operand, indices, updates, if mode == GatherScatterMode.CLIP: clip_fn = mlir.lower_fun(_clamp_scatter_indices, multiple_results=False) - (indices,), = clip_fn(ctx, avals_in, None, operand, indices, updates, + (indices,), = clip_fn(ctx, ctx.avals_in, None, operand, indices, updates, dnums=dimension_numbers) - aval_out, = avals_out + aval_out, = ctx.avals_out dnums = dimension_numbers scatter_dnums = mhlo.ScatterDimensionNumbers.get( update_window_dims=list(dnums.update_window_dims), inserted_window_dims=list(dnums.inserted_window_dims), scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims), - index_vector_dim=len(avals_in[1].shape) - 1) + index_vector_dim=len(ctx.avals_in[1].shape) - 1) real_dtype = _real_dtype(aval_out.dtype) operand_type_part = mlir.aval_to_ir_type( core.ShapedArray(aval_out.shape, real_dtype)) diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index e01d31db5..7fc991397 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -308,14 +308,14 @@ reduce_window_p.def_abstract_eval(_reduce_window_abstract_eval_rule) batching.primitive_batchers[reduce_window_p] = _generic_reduce_window_batch_rule xla.register_translation(reduce_window_p, _reduce_window_translation_rule) -def _generic_reduce_window_lower(ctx, avals_in, avals_out, *args, jaxpr, consts, +def _generic_reduce_window_lower(ctx, *args, jaxpr, consts, window_dimensions, window_strides, padding, base_dilation, window_dilation): operands, init_values = util.split_list(args, [len(args) // 2]) - _, init_value_avals = util.split_list(avals_in, [len(operands)]) + _, init_value_avals = util.split_list(ctx.avals_in, [len(operands)]) scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals] rw = mhlo.ReduceWindowOp( - map(mlir.aval_to_ir_type, avals_out), operands, init_values, + map(mlir.aval_to_ir_type, ctx.avals_out), operands, init_values, mlir.dense_int_elements(window_dimensions), mlir.dense_int_elements(window_strides), mlir.dense_int_elements(base_dilation), @@ -323,7 +323,7 @@ def _generic_reduce_window_lower(ctx, avals_in, avals_out, *args, jaxpr, consts, ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64))) reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types)) with ir.InsertionPoint(reducer): - out_nodes = mlir.jaxpr_subcomp(ctx, jaxpr, consts, + out_nodes = mlir.jaxpr_subcomp(ctx.module_context, jaxpr, consts, *([a] for a in reducer.arguments)) mhlo.ReturnOp(util.flatten(out_nodes)) return rw.results @@ -488,10 +488,10 @@ batching.primitive_batchers[reduce_window_min_p] = partial( def _reduce_window_lower( - reduce_op, init_value, ctx, avals_in, avals_out, operand, *, + reduce_op, init_value, ctx, operand, *, window_dimensions, window_strides, padding, base_dilation, window_dilation): - aval_out, = avals_out - operand_aval, = avals_in + aval_out, = ctx.avals_out + operand_aval, = ctx.avals_in scalar_aval = operand_aval.update(shape=()) scalar_type = mlir.aval_to_ir_type(scalar_aval) rw = mhlo.ReduceWindowOp( @@ -545,11 +545,11 @@ select_and_scatter_p = lax.standard_primitive( _select_and_scatter_translation) def _select_and_scatter_lower( - ctx, avals_in, avals_out, operand, source, init_value, *, select_jaxpr, + ctx, operand, source, init_value, *, select_jaxpr, select_consts, scatter_jaxpr, scatter_consts, window_dimensions, window_strides, padding): - operand_aval, source_aval, init_value_aval = avals_in - aval_out, = avals_out + operand_aval, source_aval, init_value_aval = ctx.avals_in + aval_out, = ctx.avals_out scalar_aval = operand_aval.update(shape=()) scalar_type = mlir.aval_to_ir_type(scalar_aval) op = mhlo.SelectAndScatterOp( @@ -559,12 +559,14 @@ def _select_and_scatter_lower( ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64))) select = op.select.blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(select): - out_nodes = mlir.jaxpr_subcomp(ctx, select_jaxpr, select_consts, + out_nodes = mlir.jaxpr_subcomp(ctx.module_context, select_jaxpr, + select_consts, *([a] for a in select.arguments)) mhlo.ReturnOp(util.flatten(out_nodes)) scatter = op.scatter.blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(scatter): - out_nodes = mlir.jaxpr_subcomp(ctx, scatter_jaxpr, scatter_consts, + out_nodes = mlir.jaxpr_subcomp(ctx.module_context, scatter_jaxpr, + scatter_consts, *([a] for a in scatter.arguments)) mhlo.ReturnOp(util.flatten(out_nodes)) return op.results @@ -679,8 +681,9 @@ xla.register_translation( platform='gpu') -def _select_and_scatter_add_impl(source, operand, *, select_prim, window_dimensions, - window_strides, padding, expand_padding): +def _select_and_scatter_add_impl(source, operand, *, + select_prim, window_dimensions, window_strides, + padding, expand_padding): dtype = source.dtype select = lambda x, y: select_prim.bind(x, y) scatter = lax.bitwise_or if dtype == np.bool_ else lax.add diff --git a/jax/experimental/maps.py b/jax/experimental/maps.py index 6e933e781..2b323e699 100644 --- a/jax/experimental/maps.py +++ b/jax/experimental/maps.py @@ -1303,14 +1303,14 @@ def _xmap_lowering_rule(*args, **kwargs): return _xmap_lowering_rule_replica(*args, **kwargs) mlir.register_lowering(xmap_p, _xmap_lowering_rule) -def _xmap_lowering_rule_replica(ctx, avals_in, avals_out, *in_nodes, - call_jaxpr, name, - in_axes, out_axes, donated_invars, - global_axis_sizes, - spmd_in_axes, spmd_out_axes, - positional_semantics, - axis_resources, resource_env, backend): - xla.check_backend_matches(backend, ctx.platform) +def _xmap_lowering_rule_replica(ctx, *in_nodes, + call_jaxpr, name, + in_axes, out_axes, donated_invars, + global_axis_sizes, + spmd_in_axes, spmd_out_axes, + positional_semantics, + axis_resources, resource_env, backend): + xla.check_backend_matches(backend, ctx.module_context.platform) # The only way for any of those two assertions to be violated is when xmap # is using the SPMD lowering, but then this rule shouldn't even trigger. assert positional_semantics == _PositionalSemantics.LOCAL @@ -1346,30 +1346,34 @@ def _xmap_lowering_rule_replica(ctx, avals_in, avals_out, *in_nodes, assert not consts tiled_ins = ( - mlir.lower_fun( - partial(_tile, in_axes=arg_in_axes, axis_sizes=local_mesh_shape), - multiple_results=False)(ctx, [aval], None, in_node)[0] - if v.aval is not core.abstract_unit else in_node - for v, aval, in_node, arg_in_axes - in zip(call_jaxpr.invars, avals_in, in_nodes, mesh_in_axes)) + mlir.lower_fun(partial(_tile, in_axes=arg_in_axes, + axis_sizes=local_mesh_shape), + multiple_results=False)( + mlir.LoweringRuleContext(module_context=ctx.module_context, + avals_in=[aval], avals_out=None), + in_node)[0] + if v.aval is not core.abstract_unit else in_node + for v, aval, in_node, arg_in_axes + in zip(call_jaxpr.invars, ctx.avals_in, in_nodes, mesh_in_axes)) # NOTE: We don't extend the resource env with the mesh shape, because those # resources are already in scope! It's the outermost xmap that introduces # them! # We in-line here rather than generating a Call HLO as in the xla_call # translation rule just because the extra tuple stuff is a pain. - sub_ctx = ctx.replace( - name_stack=xla.extend_name_stack(ctx.name_stack, + sub_ctx = ctx.module_context.replace( + name_stack=xla.extend_name_stack(ctx.module_context.name_stack, xla.wrap_name(name, 'xmap'))) tiled_outs = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr, (), *tiled_ins) outs = [ mlir.lower_fun( partial(_untile, out_axes=ans_out_axes, axis_sizes=local_mesh_shape, - platform=ctx.platform), + platform=ctx.module_context.platform), multiple_results=False)( - ctx, [vectorized_outvar.aval], None, tiled_out - )[0] + mlir.LoweringRuleContext(module_context=ctx.module_context, + avals_in=[vectorized_outvar.aval], + avals_out=None), tiled_out)[0] if v.aval is not core.abstract_unit else tiled_out for v, vectorized_outvar, tiled_out, ans_out_axes in zip(call_jaxpr.outvars, vectorized_jaxpr.outvars, tiled_outs, @@ -1377,12 +1381,12 @@ def _xmap_lowering_rule_replica(ctx, avals_in, avals_out, *in_nodes, return outs -def _xmap_lowering_rule_spmd(ctx, avals_in, avals_out, *global_in_nodes, +def _xmap_lowering_rule_spmd(ctx, *global_in_nodes, call_jaxpr, name, in_axes, out_axes, donated_invars, global_axis_sizes, spmd_in_axes, spmd_out_axes, positional_semantics, axis_resources, resource_env, backend): - xla.check_backend_matches(backend, ctx.platform) + xla.check_backend_matches(backend, ctx.module_context.platform) plan = EvaluationPlan.from_axis_resources(axis_resources, resource_env, global_axis_sizes) resource_call_jaxpr = plan.subst_axes_with_resources(call_jaxpr) @@ -1409,7 +1413,7 @@ def _xmap_lowering_rule_spmd(ctx, avals_in, avals_out, *global_in_nodes, # NOTE: We don't extend the resource env with the mesh shape, because those # resources are already in scope! It's the outermost xmap that introduces # them! - global_in_avals = avals_in + global_in_avals = ctx.avals_in vectorized_jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(f, global_in_avals) assert not consts @@ -1422,8 +1426,8 @@ def _xmap_lowering_rule_spmd(ctx, avals_in, avals_out, *global_in_nodes, # We in-line here rather than generating a Call HLO as in the xla_call # translation rule just because the extra tuple stuff is a pain. - sub_ctx = ctx.replace( - name_stack=xla.extend_name_stack(ctx.name_stack, + sub_ctx = ctx.module_context.replace( + name_stack=xla.extend_name_stack(ctx.module_context.name_stack, xla.wrap_name(name, 'xmap'))) global_out_nodes = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr, (), *sharded_global_in_nodes) diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index d89344097..e9aa5d6e9 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -866,9 +866,9 @@ def _sharding_constraint_translation_rule(ctx, avals_in, avals_out, x_node, *, ctx.builder, x_node, get_aval_sharding_proto(aval, axis_resources, mesh))] xla.register_translation(sharding_constraint_p, _sharding_constraint_translation_rule) -def _sharding_constraint_mhlo_lowering(ctx, avals_in, avals_out, x_node, *, - axis_resources, resource_env): - aval, = avals_in +def _sharding_constraint_mhlo_lowering(ctx, x_node, *, axis_resources, + resource_env): + aval, = ctx.avals_in mesh = resource_env.physical_mesh return [mlir.wrap_with_sharding_op( x_node, get_aval_sharding_proto(aval, axis_resources, mesh))] diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 8cfabaffd..6f8af801a 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -280,7 +280,8 @@ def _source_info_to_location( # Translation rules @dataclasses.dataclass -class LoweringContext: +class ModuleContext: + """Module-wide context information for MLIR lowering.""" context: ir.Context module: ir.Module ip: ir.InsertionPoint @@ -309,11 +310,19 @@ class LoweringContext: def replace(self, **kw): return dataclasses.replace(self, **kw) +@dataclasses.dataclass +class LoweringRuleContext: + """Per-rule context information for MLIR lowering.""" + module_context: ModuleContext + avals_in: Sequence[core.AbstractValue] + avals_out: Any + + def replace(self, **kw): return dataclasses.replace(self, **kw) + + if not MYPY: class LoweringRule(Protocol): - def __call__(self, ctx: LoweringContext, - avals_in: Sequence[core.AbstractValue], - avals_out: Sequence[core.AbstractValue], + def __call__(self, ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]], **kw) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]: """Converts a JAX primitive invocation into MLIR.""" @@ -367,7 +376,7 @@ def lower_jaxpr_to_module( warnings.warn("Some donated buffers were not usable: {}".format( ", ".join(unused_donations))) - ctx = LoweringContext(platform, axis_env, name_stack) + ctx = ModuleContext(platform, axis_env, name_stack) with ctx.context, ir.Location.unknown(ctx.context): # Some clients expect modules to have unique names, e.g., in trace data. # This may or may not be a reasonable assumption. @@ -407,7 +416,7 @@ def _set_up_aliases(avals_in, avals_out, donated_args): return input_output_aliases, out_donated_args def lower_jaxpr_to_fun( - ctx: LoweringContext, name: str, jaxpr: core.ClosedJaxpr, *, + ctx: ModuleContext, name: str, jaxpr: core.ClosedJaxpr, *, public: bool = False, replace_units_with_dummy: bool = False, replace_tokens_with_dummy: bool = False, replicated_args: Optional[Sequence[bool]] = None, @@ -530,7 +539,7 @@ def lower_jaxpr_to_fun( return func_op -def jaxpr_subcomp(ctx: LoweringContext, jaxpr: core.Jaxpr, +def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, consts: Sequence[Sequence[ir.Value]], *args: Sequence[ir.Value]) -> Sequence[Sequence[ir.Value]]: """Lowers a jaxpr into mHLO, inlined into an existing function. @@ -578,8 +587,10 @@ def jaxpr_subcomp(ctx: LoweringContext, jaxpr: core.Jaxpr, f"MLIR translation rule for primitive '{eqn.primitive.name}' not " f"found for platform {ctx.platform}") - ans = rule(ctx, map(aval, eqn.invars), map(aval, eqn.outvars), - *map(_unwrap_singleton_ir_values, in_nodes), + rule_ctx = LoweringRuleContext( + module_context=ctx, avals_in=map(aval, eqn.invars), + avals_out=map(aval, eqn.outvars)) + ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes), **eqn.params) try: @@ -605,15 +616,16 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable: The returned function does not use `avals_out`, so callers may pass any value as `avals_out`.""" - def f_lowered(ctx, avals_in, avals_out, *args, **params): + def f_lowered(ctx, *args, **params): if multiple_results: f = fun else: f = lambda *args, **kw: (fun(*args, **kw),) wrapped_fun = lu.wrap_init(f, params) - with core.extend_axis_env_nd(zip(ctx.axis_env.names, ctx.axis_env.sizes)): - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals_in) - return jaxpr_subcomp(ctx, jaxpr, _ir_consts(consts), + axis_env = ctx.module_context.axis_env + with core.extend_axis_env_nd(zip(axis_env.names, axis_env.sizes)): + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) + return jaxpr_subcomp(ctx.module_context, jaxpr, _ir_consts(consts), *map(wrap_singleton_ir_values, args)) return f_lowered @@ -634,19 +646,20 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in, flatten_lowering_ir_args(args)) return util.unflatten(call.results, map(len, output_types)) -def _xla_call_lower(ctx, avals_in, avals_out, *args, +def _xla_call_lower(ctx, *args, backend=None, name, call_jaxpr, donated_invars, inline=None, device=None): del device, donated_invars, inline # Ignored. return _call_lowering(f"jit_{name}", xla.wrap_name(name, "jit"), call_jaxpr, - backend, ctx, avals_in, avals_out, *args) + backend, ctx.module_context, ctx.avals_in, ctx.avals_out, + *args) register_lowering(xla.xla_call_p, _xla_call_lower) -def _named_call_lowering(ctx, avals_in, avals_out, *args, name, backend=None, +def _named_call_lowering(ctx, *args, name, backend=None, call_jaxpr): - return _call_lowering(name, name, call_jaxpr, backend, ctx, avals_in, - avals_out, *args) + return _call_lowering(name, name, call_jaxpr, backend, ctx.module_context, + ctx.avals_in, ctx.avals_out, *args) register_lowering(core.named_call_p, _named_call_lowering) register_lowering(core.call_p, partial(_named_call_lowering, name="core_call")) @@ -658,18 +671,17 @@ def full_like_aval(value, aval: core.ShapedArray) -> ir.Value: return mhlo.BroadcastOp(aval_to_ir_type(aval), zero, dense_int_elements(aval.shape)).result -def zeros_like_lowering(ctx, avals_in, avals_out, x): - aval, = avals_in +def zeros_like_lowering(ctx, x): + aval, = ctx.avals_in assert isinstance(aval, core.ShapedArray), aval return [full_like_aval(0, aval)] register_lowering(ad_util.zeros_like_p, zeros_like_lowering) -def add_jaxvals_lowering(ctx, avals_in, avals_out, x, y): +def add_jaxvals_lowering(ctx, x, y): return mhlo.AddOp(x, y).results register_lowering(ad_util.add_jaxvals_p, add_jaxvals_lowering) -register_lowering(ad_util.stop_gradient_p, - lambda ctx, avals_in, avals_out, x: [x]) +register_lowering(ad_util.stop_gradient_p, lambda ctx, x: [x]) def _minmax_mhlo(op, cmp, x, y): @@ -735,23 +747,24 @@ def set_sharding(op, sharding_proto: xc.OpSharding): # MLIR lowerings for lax primitives -def xla_fallback_lowering(prim: core.Primitive, ctx: LoweringContext, - avals_in, avals_out, *args, **params): +def xla_fallback_lowering(prim: core.Primitive, ctx: LoweringRuleContext, *args, + **params): + module_ctx = ctx.module_context xla_computation = xla.primitive_subcomputation( - ctx.platform, ctx.axis_env, prim, *avals_in, **params) + module_ctx.platform, module_ctx.axis_env, prim, *ctx.avals_in, **params) submodule_str = xc._xla.mlir.xla_computation_to_mlir_module(xla_computation) submodule = ir.Module.parse(submodule_str) callee_name = None for op in submodule.body.operations: - ctx.module.body.append(op) + module_ctx.module.body.append(op) if op.name.value == "main": op.attributes["sym_name"] = ir.StringAttr.get(f"xla_fallback_{prim.name}") - callee_name = ir.StringAttr(ctx.symbol_table.insert(op)).value + callee_name = ir.StringAttr(module_ctx.symbol_table.insert(op)).value op.attributes["sym_visibility"] = ir.StringAttr.get("private") else: - ctx.symbol_table.insert(op) + module_ctx.symbol_table.insert(op) - output_types = map(aval_to_ir_types, avals_out) + output_types = map(aval_to_ir_types, ctx.avals_out) flat_output_types = util.flatten(output_types) output_type = (ir.TupleType.get_tuple(flat_output_types) if prim.multiple_results else flat_output_types[0]) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 16278daca..4a959aa65 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -1718,19 +1718,20 @@ def _mhlo_unshard(aval, axis_env, out_axis, xs, platform): raise TypeError(aval) -def _pmap_lowering(ctx, avals_in, avals_out, *in_nodes, axis_name, +def _pmap_lowering(ctx, *in_nodes, axis_name, axis_size, global_axis_size, devices, name, call_jaxpr, backend=None, in_axes, out_axes, donated_invars, global_arg_shapes): del donated_invars # Unused. - xla.check_backend_matches(backend, ctx.platform) + xla.check_backend_matches(backend, ctx.module_context.platform) # We in-line here rather than generating a Call HLO as in the xla_call # translation rule just because the extra tuple stuff is a pain. - if ctx.axis_env.names and devices is not None: + if ctx.module_context.axis_env.names and devices is not None: raise ValueError("Nested pmap with explicit devices argument.") if global_axis_size is None: global_axis_size = axis_size - new_env = xla.extend_axis_env(ctx.axis_env, axis_name, global_axis_size) + new_env = xla.extend_axis_env(ctx.module_context.axis_env, axis_name, + global_axis_size) # Shard the in_nodes that are mapped in_avals = [v.aval for v in call_jaxpr.invars] in_nodes_sharded = ( @@ -1739,14 +1740,15 @@ def _pmap_lowering(ctx, avals_in, avals_out, *in_nodes, axis_name, for aval, in_node, in_axis in zip(in_avals, in_nodes, in_axes)) with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore - sub_ctx = ctx.replace( + sub_ctx = ctx.module_context.replace( axis_env=new_env, - name_stack=xla.extend_name_stack(ctx.name_stack, + name_stack=xla.extend_name_stack(ctx.module_context.name_stack, util.wrap_name(name, 'pmap'))) sharded_outs = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, (), *in_nodes_sharded) out_avals = [v.aval for v in call_jaxpr.outvars] - outs = [_mhlo_unshard(aval, new_env, out_axis, shard, platform=ctx.platform) + outs = [_mhlo_unshard(aval, new_env, out_axis, shard, + platform=ctx.module_context.platform) for aval, out_axis, shard in zip(out_avals, out_axes, sharded_outs)] return outs diff --git a/jax/interpreters/sharded_jit.py b/jax/interpreters/sharded_jit.py index 2ee1f87c3..9b9b7e56a 100644 --- a/jax/interpreters/sharded_jit.py +++ b/jax/interpreters/sharded_jit.py @@ -214,7 +214,7 @@ def _sharded_jit_translation_rule(ctx, avals_in, avals_out, *in_nodes, xops.Call(ctx.builder, subc, list(in_nodes))) -def _sharded_jit_lowering(ctx, avals_in, avals_out, *in_nodes, +def _sharded_jit_lowering(ctx, *in_nodes, in_parts, out_parts_thunk, nparts, name, call_jaxpr, local_in_parts, local_out_parts_thunk, local_nparts): @@ -233,12 +233,12 @@ def _sharded_jit_lowering(ctx, avals_in, avals_out, *in_nodes, else: args.append(ns) - sub_ctx = ctx.replace( + sub_ctx = ctx.module_context.replace( name_stack=extend_name_stack(wrap_name(name, "sharded_jit"))) fn = mlir.lower_jaxpr_to_fun(sub_ctx, f"sharded_jit_{name}", core.ClosedJaxpr(call_jaxpr, ())) - output_types = safe_map(mlir.aval_to_ir_types, avals_out) + output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out) flat_output_types = util.flatten(output_types) call = std.CallOp(flat_output_types, ir.FlatSymbolRefAttr.get(fn.name.value), @@ -472,8 +472,7 @@ ad.deflinear2(sharding_constraint_p, xla.register_translation(sharding_constraint_p, _sharding_constraint_translation_rule) -def _sharding_constraint_lowering(ctx, avals_in, avals_out, x_node, - partitions): +def _sharding_constraint_lowering(ctx, x_node, partitions): return [mlir.wrap_with_sharding_op(x_node, xla.sharding_to_proto(partitions))] mlir.register_lowering(sharding_constraint_p, _sharding_constraint_lowering) diff --git a/tests/custom_object_test.py b/tests/custom_object_test.py index af340dc2a..2c716b912 100644 --- a/tests/custom_object_test.py +++ b/tests/custom_object_test.py @@ -170,7 +170,7 @@ def _sp_indices_translation_rule(ctx, avals_in, avals_out, data, indices): # because it leads to infinite recursion. xla.register_translation(sp_indices_p, _sp_indices_translation_rule) -def _sp_indices_mhlo_lowering(ctx, avals_in, avals_out, data_and_indices): +def _sp_indices_mhlo_lowering(ctx, data_and_indices): return [data_and_indices[1]] mlir.register_lowering(sp_indices_p, _sp_indices_mhlo_lowering) @@ -192,7 +192,7 @@ def _sp_data_translation_rule(ctx, avals_in, avals_out, data, indices): # because it leads to infinite recursion. xla.register_translation(sp_data_p, _sp_data_translation_rule) -def _sp_data_mhlo_lowering(ctx, avals_in, avals_out, data_and_indices): +def _sp_data_mhlo_lowering(ctx, data_and_indices): return [data_and_indices[0]] mlir.register_lowering(sp_data_p, _sp_data_mhlo_lowering)