[Mosaic] apply_vector_layout: Use shape in generalizes check (in Python)

- The addition to the check in the relayout loop in `apply_layout_op` should result in skipping some no-op relayouts
- The assert in `disassemble` also needs to be updated because it won't hold now that relayout is skipped more (relayout guarantees the defining layout to be equal to the input layout)

PiperOrigin-RevId: 571066259
This commit is contained in:
Tomás Longeri 2023-10-05 10:43:42 -07:00 committed by jax authors
parent 633f68a398
commit 68c84a6c5c

View File

@ -1554,7 +1554,7 @@ def apply_layout_op(ctx: RewriteContext, op: ir.OpView):
lo = parse_layout(arr_attr[res_idx], vty)
if lo is None:
raise ValueError("vector result should have a defined layout")
if lo.generalizes(li):
if lo.generalizes(li, vty.shape):
continue
with ir.InsertionPoint(op), op.location:
new_v = relayout(
@ -3218,7 +3218,7 @@ def disassemble(layout: VectorLayout, val: ir.Value) -> np.ndarray:
arr_attr = ir.ArrayAttr(op.attributes["out_layout"])
def_layout = parse_layout(arr_attr[res_idx], vty)
assert type(def_layout) is type(layout)
assert def_layout.generalizes(layout)
assert def_layout.generalizes(layout, vty.shape)
def_layout_shape = def_layout.tile_array_shape(vty.shape)
if isinstance(op.opview, tpu.RollVectorsOp):
tile_vals = op.operands