mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic GPU] Fix layout inference traversal to traverse ops recursively.
PiperOrigin-RevId: 706136221
This commit is contained in:
parent
38592981c7
commit
2386838315
@ -89,9 +89,12 @@ def _infer_pointwise_op_layouts(
|
||||
if layout is None:
|
||||
# Still no layout set. We iterate on producers.
|
||||
for operand in op.operands:
|
||||
layout = _extract_any_layout_from_op(operand.owner)
|
||||
if layout:
|
||||
break
|
||||
if isinstance(operand.owner, ir.Operation) or isinstance(
|
||||
operand.owner, ir.OpView
|
||||
):
|
||||
layout = _extract_any_layout_from_op(operand.owner)
|
||||
if layout:
|
||||
break
|
||||
|
||||
if layout is None:
|
||||
return None
|
||||
@ -141,7 +144,7 @@ def traverse_op(
|
||||
else:
|
||||
ops_to_traverse = reversed(list(block))
|
||||
for block_op in ops_to_traverse:
|
||||
callback(block_op)
|
||||
traverse_op(block_op, callback, traversal_order)
|
||||
callback(op)
|
||||
|
||||
|
||||
|
@ -20,6 +20,8 @@ from jax._src import test_util as jtu
|
||||
from jax._src.interpreters import mlir as mlir_interpreter
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith
|
||||
from jax._src.lib.mlir.dialects import func
|
||||
from jax._src.lib.mlir.dialects import scf
|
||||
from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member
|
||||
from jax.experimental.mosaic.gpu import infer_layout # pylint: disable=g-importing-member
|
||||
from jax.experimental.mosaic.gpu import splat_fragmented_layout # pylint: disable=g-importing-member
|
||||
@ -111,6 +113,29 @@ class LayoutInferenceTest(parameterized.TestCase):
|
||||
op.attributes["out_layouts"], [layout] * len(op.results)
|
||||
)
|
||||
|
||||
def test_infer_layout_traverses_ops_correctly(self):
|
||||
shape = (4, 8)
|
||||
elt_type = ir.BF16Type.get()
|
||||
add_op = None
|
||||
|
||||
def body(a, b):
|
||||
bool_type = ir.IntegerType.get_signless(1)
|
||||
cst_true = arith.constant(bool_type, ir.IntegerAttr.get(bool_type, 1))
|
||||
if_op = scf.IfOp(cst_true)
|
||||
with ir.InsertionPoint(if_op.then_block):
|
||||
nonlocal add_op
|
||||
add_op = arith.addf(a, b)
|
||||
scf.yield_([])
|
||||
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
ab_type = ir.VectorType.get(shape, elt_type)
|
||||
func.FuncOp.from_py_func(ab_type, ab_type)(body)
|
||||
|
||||
infer_layout(self.module)
|
||||
|
||||
self.assertIn("in_layouts", add_op.owner.attributes)
|
||||
self.assertIn("out_layouts", add_op.owner.attributes)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user