[Mosaic GPU] Fix layout inference traversal to traverse ops recursively.

PiperOrigin-RevId: 706136221
This commit is contained in:
Benjamin Chetioui 2024-12-13 23:50:48 -08:00 committed by jax authors
parent 38592981c7
commit 2386838315
2 changed files with 32 additions and 4 deletions

View File

@ -89,9 +89,12 @@ def _infer_pointwise_op_layouts(
if layout is None:
# Still no layout set. We iterate on producers.
for operand in op.operands:
layout = _extract_any_layout_from_op(operand.owner)
if layout:
break
if isinstance(operand.owner, ir.Operation) or isinstance(
operand.owner, ir.OpView
):
layout = _extract_any_layout_from_op(operand.owner)
if layout:
break
if layout is None:
return None
@ -141,7 +144,7 @@ def traverse_op(
else:
ops_to_traverse = reversed(list(block))
for block_op in ops_to_traverse:
callback(block_op)
traverse_op(block_op, callback, traversal_order)
callback(op)

View File

@ -20,6 +20,8 @@ from jax._src import test_util as jtu
from jax._src.interpreters import mlir as mlir_interpreter
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import func
from jax._src.lib.mlir.dialects import scf
from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member
from jax.experimental.mosaic.gpu import infer_layout # pylint: disable=g-importing-member
from jax.experimental.mosaic.gpu import splat_fragmented_layout # pylint: disable=g-importing-member
@ -111,6 +113,29 @@ class LayoutInferenceTest(parameterized.TestCase):
op.attributes["out_layouts"], [layout] * len(op.results)
)
def test_infer_layout_traverses_ops_correctly(self):
shape = (4, 8)
elt_type = ir.BF16Type.get()
add_op = None
def body(a, b):
bool_type = ir.IntegerType.get_signless(1)
cst_true = arith.constant(bool_type, ir.IntegerAttr.get(bool_type, 1))
if_op = scf.IfOp(cst_true)
with ir.InsertionPoint(if_op.then_block):
nonlocal add_op
add_op = arith.addf(a, b)
scf.yield_([])
with ir.InsertionPoint(self.module.body):
ab_type = ir.VectorType.get(shape, elt_type)
func.FuncOp.from_py_func(ab_type, ab_type)(body)
infer_layout(self.module)
self.assertIn("in_layouts", add_op.owner.attributes)
self.assertIn("out_layouts", add_op.owner.attributes)
if __name__ == "__main__":
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())