mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic] apply_vector_layout: Use shape in generalizes check (in Python)
- The addition to the check in the relayout loop in `apply_layout_op` should result in skipping some no-op relayouts - The assert in `disassemble` also needs to be updated because it won't hold now that relayout is skipped more (relayout guarantees the defining layout to be equal to the input layout) PiperOrigin-RevId: 571066259
This commit is contained in:
parent
633f68a398
commit
68c84a6c5c
@ -1554,7 +1554,7 @@ def apply_layout_op(ctx: RewriteContext, op: ir.OpView):
|
||||
lo = parse_layout(arr_attr[res_idx], vty)
|
||||
if lo is None:
|
||||
raise ValueError("vector result should have a defined layout")
|
||||
if lo.generalizes(li):
|
||||
if lo.generalizes(li, vty.shape):
|
||||
continue
|
||||
with ir.InsertionPoint(op), op.location:
|
||||
new_v = relayout(
|
||||
@ -3218,7 +3218,7 @@ def disassemble(layout: VectorLayout, val: ir.Value) -> np.ndarray:
|
||||
arr_attr = ir.ArrayAttr(op.attributes["out_layout"])
|
||||
def_layout = parse_layout(arr_attr[res_idx], vty)
|
||||
assert type(def_layout) is type(layout)
|
||||
assert def_layout.generalizes(layout)
|
||||
assert def_layout.generalizes(layout, vty.shape)
|
||||
def_layout_shape = def_layout.tile_array_shape(vty.shape)
|
||||
if isinstance(op.opview, tpu.RollVectorsOp):
|
||||
tile_vals = op.operands
|
||||
|
Loading…
x
Reference in New Issue
Block a user