[Pallas/Mosaic GPU] Enable progressive lowering for integer addition.

The helpers `_fragmented_array_to_ir` and `_fragmented_array_from_ir` in
`dialect_lowering.py` have been modified, such that a fragmented array's
signedness no longer appears in its IR representation.

This is because signedness is a reflection of how we make use of the value,
and not an inherent property of it. The appropriate signedness value to use
to reload a fragmented array from IR must be provided by the caller.

PiperOrigin-RevId: 726030853
This commit is contained in:
Benjamin Chetioui 2025-02-12 06:28:42 -08:00 committed by jax authors
parent 1e2a5770c9
commit c7199fe8a5
4 changed files with 35 additions and 14 deletions

View File

@ -1317,9 +1317,11 @@ def _add_lowering_rule_wg(ctx: LoweringRuleContext, x, y):
if np.issubdtype(ctx.avals_in[0].dtype, np.floating):
add_op = arith_dialect.addf
elif np.issubdtype(ctx.avals_in[0].dtype, np.integer):
add_op = arith_dialect.addi
else:
raise NotImplementedError(
"Lowering of non-float addition is not implemented"
f"Unsupported dtype {ctx.avals_in[0].dtype} in lowering of add_p"
)
x = _ensure_vector(x, x_aval.dtype)

View File

@ -59,6 +59,10 @@ _lowerings: dict[str, MlirLoweringRule] = {}
def _fragmented_array_to_ir(
fragmented_array: fa.FragmentedArray, ty: ir.Type
) -> ir.Value:
"""Converts a FragmentedArray to an IR value.
The fragmented array's signedness is omitted from the IR representation.
"""
conversion_cast = builtin.UnrealizedConversionCastOp(
[ty], fragmented_array.registers.flatten().tolist()
)
@ -72,16 +76,13 @@ def _fragmented_array_to_ir(
fragmented_array.layout
)
if fragmented_array.is_signed is not None:
conversion_cast.attributes["is_signed"] = ir.BoolAttr.get(
fragmented_array.is_signed
)
return conversion_cast.result
# TODO(bchetioui): add code that verifies the layout is as inferred.
def _fragmented_array_from_ir(
fragmented_array_as_ir: ir.Value,
is_signed: bool | None = None,
) -> fa.FragmentedArray:
conversion_cast = cast(
@ -98,7 +99,6 @@ def _fragmented_array_from_ir(
if not isinstance(converted_outputs, list):
converted_outputs = [converted_outputs]
reverse_conversion_cast = converted_outputs[0].owner.opview
for attribute in conversion_cast.attributes:
attribute = cast(ir.NamedAttribute, attribute)
@ -109,10 +109,8 @@ def _fragmented_array_from_ir(
)
layout = layouts.from_layout_attr(conversion_cast.attributes["layout"])
if ir.IntegerType.isinstance(conversion_cast.outputs[0].type):
is_signed = bool(conversion_cast.attributes["is_signed"])
else:
is_signed = None
if ir.IntegerType.isinstance(conversion_cast.outputs[0].type.element_type):
is_signed = False if is_signed is None else is_signed
return fa.FragmentedArray(
_registers=registers, _layout=layout, _is_signed=is_signed
@ -199,7 +197,12 @@ def _vector_load_op_lowering_rule(
f"for {vector_load_op}"
)
fragmented_array = fa.FragmentedArray.load_strided(vector_load_op.base)
element_type = vector_load_op.result.type.element_type
is_signed = False if ir.IntegerType.isinstance(element_type) else None
fragmented_array = fa.FragmentedArray.load_strided(
vector_load_op.base, is_signed=is_signed
)
return [_fragmented_array_to_ir(fragmented_array, vector_load_op.result.type)]
@ -299,8 +302,9 @@ def _mgpu_async_store_op_lowering_rule(
@_register_lowering(arith.AddFOp)
def _arith_addf_op_lowering_rule(
_: LoweringContext, add: arith.AddFOp
@_register_lowering(arith.AddIOp)
def _arith_add_op_lowering_rule(
_: LoweringContext, add: arith.AddFOp | arith.AddIOp
) -> Sequence[ir.Value]:
fragmented_array_lhs = _fragmented_array_from_ir(add.lhs)
@ -390,6 +394,15 @@ def single_thread_predicates(module: ir.Module) -> tuple[ir.Value, ir.Value]:
return block_predicate, warpgroup_predicate
def _should_lower(op: ir.OpView) -> bool:
"""Returns 'true' if the operation should be lowered."""
if isinstance(op.name, ir.StringAttr):
name = op.name.value
else:
name = op.name
return name.startswith("mosaic_gpu.") or layouts.should_have_layout(op)
def lower_mgpu_dialect(
module: ir.Module, launch_context: launch_context.LaunchContext | None
):
@ -413,8 +426,12 @@ def lower_mgpu_dialect(
ctx = LoweringContext(launch_context, block_predicate, warpgroup_predicate)
def _lower_op(op: ir.OpView):
if op.name not in _lowerings:
if not _should_lower(op):
return
if op.name not in _lowerings:
raise NotImplementedError(f"Missing lowering rule for {op.name}")
lowering_rule = _lowerings[op.name]
# TODO(bchetioui): make sure all layouts are set here.

View File

@ -240,6 +240,7 @@ def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts:
for op in (
arith.AddFOp,
arith.AddIOp,
arith.MulFOp,
vector.LoadOp,
vector.StoreOp,

View File

@ -2085,6 +2085,7 @@ class PallasCallWarpgroupSemanticsTest(PallasTest):
@parameterized.named_parameters(
("add_float", lambda x, y: x + y, np.float32),
("add_int", lambda x, y: x + y, np.int32),
)
def test_binary_op_wg_semantics(self, bop, dtype):
@functools.partial(