[Mosaic] Allow vector.shape_cast to (un)fold the sublane dim, for as long as it remains a multiple of sublane tiling

The old guards were overly restrictive, and we can actually treat a much larger class of reshapes as no-ops.

PiperOrigin-RevId: 570049016
This commit is contained in:
Adam Paszke 2023-10-02 06:34:54 -07:00 committed by jax authors
parent 77d11e4dfd
commit c9851ac7f3
2 changed files with 11 additions and 11 deletions

View File

@ -972,6 +972,16 @@ class VectorLayoutInferer {
setLayout(op, layout, layout);
return success();
}
// Sublane (un)tiling
if (res_rank >= 2 && layout.offsets() == LayoutOffsets{0, 0} &&
layout.tiling()[1] == target_shape_[1] &&
src_ty.getDimSize(src_ty.getRank() - 1) ==
res_shape[res_shape.size() - 1] &&
src_ty.getDimSize(src_ty.getRank() - 2) % layout.tiling()[0] == 0 &&
res_shape[res_shape.size() - 2] % layout.tiling()[0] == 0) {
setLayout(op, layout, layout);
return success();
}
unsigned bitwidth = src_ty.getElementTypeBitWidth();
auto native_tiling = nativeTiling(bitwidth);
if (layout.tiling() != native_tiling) {
@ -991,16 +1001,6 @@ class VectorLayoutInferer {
ImplicitDim::kSecondMinor));
return success();
}
// Sublane (un)tiling
if (src_ty.getElementTypeBitWidth() == kNativeBitwidth &&
src_ty.getDimSize(src_ty.getRank() - 1) ==
res_shape[res_shape.size() - 1] &&
layout.offsets() == LayoutOffsets{0, 0} &&
src_ty.getDimSize(src_ty.getRank() - 2) % target_shape_[0] == 0 &&
res_shape[res_shape.size() - 2] % target_shape_[0] == 0) {
setLayout(op, layout, layout);
return success();
}
// Insert a singleton lane dimension. The old lane dimension ends up
// in the sublane dimension. Other axes can be reshaped arbitrarily.
if (src_ty.getElementTypeBitWidth() == kNativeBitwidth &&

View File

@ -2653,8 +2653,8 @@ def _vector_shape_cast_rule(ctx: RewriteContext, op: vector.ShapeCastOp, # pyli
layout_in.implicit_dim is None
and layout_out.implicit_dim is None
and layout_in.offsets == layout_out.offsets == (0, 0)
and layout_in.has_native_tiling
and layout_in.tiling == layout_out.tiling
and layout_in.tiling[-1] == TARGET_SHAPE.lanes
and dst_ty.shape[-1] == src_ty.shape[-1]
and dst_ty.shape[-2] % layout_in.tiling[-2] == 0
and src_ty.shape[-2] % layout_in.tiling[-2] == 0