mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Pallas/Mosaic GPU] Enable progressive lowering for integer addition.
The helpers `_fragmented_array_to_ir` and `_fragmented_array_from_ir` in `dialect_lowering.py` have been modified, such that a fragmented array's signedness no longer appears in its IR representation. This is because signedness is a reflection of how we make use of the value, and not an inherent property of it. The appropriate signedness value to use to reload a fragmented array from IR must be provided by the caller. PiperOrigin-RevId: 726030853
This commit is contained in:
parent
1e2a5770c9
commit
c7199fe8a5
@ -1317,9 +1317,11 @@ def _add_lowering_rule_wg(ctx: LoweringRuleContext, x, y):
|
||||
|
||||
if np.issubdtype(ctx.avals_in[0].dtype, np.floating):
|
||||
add_op = arith_dialect.addf
|
||||
elif np.issubdtype(ctx.avals_in[0].dtype, np.integer):
|
||||
add_op = arith_dialect.addi
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Lowering of non-float addition is not implemented"
|
||||
f"Unsupported dtype {ctx.avals_in[0].dtype} in lowering of add_p"
|
||||
)
|
||||
|
||||
x = _ensure_vector(x, x_aval.dtype)
|
||||
|
@ -59,6 +59,10 @@ _lowerings: dict[str, MlirLoweringRule] = {}
|
||||
def _fragmented_array_to_ir(
|
||||
fragmented_array: fa.FragmentedArray, ty: ir.Type
|
||||
) -> ir.Value:
|
||||
"""Converts a FragmentedArray to an IR value.
|
||||
|
||||
The fragmented array's signedness is omitted from the IR representation.
|
||||
"""
|
||||
conversion_cast = builtin.UnrealizedConversionCastOp(
|
||||
[ty], fragmented_array.registers.flatten().tolist()
|
||||
)
|
||||
@ -72,16 +76,13 @@ def _fragmented_array_to_ir(
|
||||
fragmented_array.layout
|
||||
)
|
||||
|
||||
if fragmented_array.is_signed is not None:
|
||||
conversion_cast.attributes["is_signed"] = ir.BoolAttr.get(
|
||||
fragmented_array.is_signed
|
||||
)
|
||||
return conversion_cast.result
|
||||
|
||||
|
||||
# TODO(bchetioui): add code that verifies the layout is as inferred.
|
||||
def _fragmented_array_from_ir(
|
||||
fragmented_array_as_ir: ir.Value,
|
||||
is_signed: bool | None = None,
|
||||
) -> fa.FragmentedArray:
|
||||
|
||||
conversion_cast = cast(
|
||||
@ -98,7 +99,6 @@ def _fragmented_array_from_ir(
|
||||
if not isinstance(converted_outputs, list):
|
||||
converted_outputs = [converted_outputs]
|
||||
|
||||
|
||||
reverse_conversion_cast = converted_outputs[0].owner.opview
|
||||
for attribute in conversion_cast.attributes:
|
||||
attribute = cast(ir.NamedAttribute, attribute)
|
||||
@ -109,10 +109,8 @@ def _fragmented_array_from_ir(
|
||||
)
|
||||
layout = layouts.from_layout_attr(conversion_cast.attributes["layout"])
|
||||
|
||||
if ir.IntegerType.isinstance(conversion_cast.outputs[0].type):
|
||||
is_signed = bool(conversion_cast.attributes["is_signed"])
|
||||
else:
|
||||
is_signed = None
|
||||
if ir.IntegerType.isinstance(conversion_cast.outputs[0].type.element_type):
|
||||
is_signed = False if is_signed is None else is_signed
|
||||
|
||||
return fa.FragmentedArray(
|
||||
_registers=registers, _layout=layout, _is_signed=is_signed
|
||||
@ -199,7 +197,12 @@ def _vector_load_op_lowering_rule(
|
||||
f"for {vector_load_op}"
|
||||
)
|
||||
|
||||
fragmented_array = fa.FragmentedArray.load_strided(vector_load_op.base)
|
||||
element_type = vector_load_op.result.type.element_type
|
||||
is_signed = False if ir.IntegerType.isinstance(element_type) else None
|
||||
|
||||
fragmented_array = fa.FragmentedArray.load_strided(
|
||||
vector_load_op.base, is_signed=is_signed
|
||||
)
|
||||
return [_fragmented_array_to_ir(fragmented_array, vector_load_op.result.type)]
|
||||
|
||||
|
||||
@ -299,8 +302,9 @@ def _mgpu_async_store_op_lowering_rule(
|
||||
|
||||
|
||||
@_register_lowering(arith.AddFOp)
|
||||
def _arith_addf_op_lowering_rule(
|
||||
_: LoweringContext, add: arith.AddFOp
|
||||
@_register_lowering(arith.AddIOp)
|
||||
def _arith_add_op_lowering_rule(
|
||||
_: LoweringContext, add: arith.AddFOp | arith.AddIOp
|
||||
) -> Sequence[ir.Value]:
|
||||
|
||||
fragmented_array_lhs = _fragmented_array_from_ir(add.lhs)
|
||||
@ -390,6 +394,15 @@ def single_thread_predicates(module: ir.Module) -> tuple[ir.Value, ir.Value]:
|
||||
return block_predicate, warpgroup_predicate
|
||||
|
||||
|
||||
def _should_lower(op: ir.OpView) -> bool:
|
||||
"""Returns 'true' if the operation should be lowered."""
|
||||
if isinstance(op.name, ir.StringAttr):
|
||||
name = op.name.value
|
||||
else:
|
||||
name = op.name
|
||||
return name.startswith("mosaic_gpu.") or layouts.should_have_layout(op)
|
||||
|
||||
|
||||
def lower_mgpu_dialect(
|
||||
module: ir.Module, launch_context: launch_context.LaunchContext | None
|
||||
):
|
||||
@ -413,8 +426,12 @@ def lower_mgpu_dialect(
|
||||
ctx = LoweringContext(launch_context, block_predicate, warpgroup_predicate)
|
||||
|
||||
def _lower_op(op: ir.OpView):
|
||||
if op.name not in _lowerings:
|
||||
if not _should_lower(op):
|
||||
return
|
||||
|
||||
if op.name not in _lowerings:
|
||||
raise NotImplementedError(f"Missing lowering rule for {op.name}")
|
||||
|
||||
lowering_rule = _lowerings[op.name]
|
||||
|
||||
# TODO(bchetioui): make sure all layouts are set here.
|
||||
|
@ -240,6 +240,7 @@ def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts:
|
||||
|
||||
for op in (
|
||||
arith.AddFOp,
|
||||
arith.AddIOp,
|
||||
arith.MulFOp,
|
||||
vector.LoadOp,
|
||||
vector.StoreOp,
|
||||
|
@ -2085,6 +2085,7 @@ class PallasCallWarpgroupSemanticsTest(PallasTest):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("add_float", lambda x, y: x + y, np.float32),
|
||||
("add_int", lambda x, y: x + y, np.int32),
|
||||
)
|
||||
def test_binary_op_wg_semantics(self, bop, dtype):
|
||||
@functools.partial(
|
||||
|
Loading…
x
Reference in New Issue
Block a user