[Mosaic] Expand vector.shape_cast no-op detection for expanding/shrinking lane shape casts

- Remove restriction on sublane tiling being 1 or a multiple of 8 on the expanded shape.
- Support packed types.

PiperOrigin-RevId: 637777493
This commit is contained in:
Tomás Longeri 2024-05-27 22:30:54 -07:00 committed by jax authors
parent 3fb9acf01a
commit 97f9a5e80e
2 changed files with 48 additions and 57 deletions

View File

@ -3454,6 +3454,10 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op,
using Tiling = std::array<int64_t, 2>;
const VectorLayout &layout_in = *layouts_in.front();
const VectorLayout &layout_out = *layouts_out.front();
TPU_ASSERT_EQ_OP(
layout_in.bitwidth(),
layout_out.bitwidth()); // This should be guaranteed through MLIR
// verifier plus our layoutIsValidForValue check
ImplicitLocOpBuilder builder(op.getLoc(), &op);
auto shape_cast_op = cast<vector::ShapeCastOp>(op);
const VectorType src_ty = shape_cast_op.getSourceVectorType();
@ -3462,6 +3466,10 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<int64_t> dst_shape = dst_ty.getShape();
const int layout_rank = layout_in.layout_rank();
bool no_op = false;
const std::array<int64_t, 2> src_vreg_slice =
layout_in.vregSlice(ctx.target_shape);
const std::array<int64_t, 2> dst_vreg_slice =
layout_out.vregSlice(ctx.target_shape);
// TODO(tlongeri): It looks like this could probably be simplified by using
// VectorLayout::implicitShape()
if (layout_in == layout_out && src_ty.getShape().take_back(layout_rank) ==
@ -3508,33 +3516,23 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op,
layout_out.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
layout_in.offsets() == layout_out.offsets() &&
layout_in.offsets() == LayoutOffsets{0, 0} &&
layout_in.tiling() == Tiling{1, ctx.target_shape[1]} &&
layout_out.hasNaturalTopology(ctx.target_shape) &&
*(dst_shape.end() - 1) != *(src_shape.end() - 1) &&
*(dst_shape.end() - 1) == ctx.target_shape[1] &&
*(dst_shape.end() - 2) % ctx.target_shape[0] == 0 &&
*(src_shape.end() - 1) %
(ctx.target_shape[0] * ctx.target_shape[1]) ==
0 &&
(*(src_shape.end() - 2) == 1 ||
*(src_shape.end() - 2) % ctx.target_shape[0] == 0)) {
// Shapecast (..., m * 128) -> (..., 128).
layout_in.tiling()[0] == 1 &&
layout_out.hasNativeTiling(ctx.target_shape) &&
*(dst_shape.end() - 1) == dst_vreg_slice[1] &&
*(dst_shape.end() - 2) % dst_vreg_slice[0] == 0 &&
*(src_shape.end() - 1) % src_vreg_slice[1] == 0) {
// Shapecast (..., m * 128 * packing) -> (..., 128).
no_op = true;
} else if (layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
layout_out.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
layout_in.offsets() == LayoutOffsets{0, 0} &&
layout_out.offsets() == LayoutOffsets{0, 0} &&
layout_in.hasNaturalTopology(ctx.target_shape) &&
layout_out.tiling() == Tiling{1, ctx.target_shape[1]} &&
*(src_shape.end() - 1) != *(dst_shape.end() - 1) &&
*(src_shape.end() - 1) == ctx.target_shape[1] &&
*(src_shape.end() - 2) % ctx.target_shape[0] == 0 &&
*(dst_shape.end() - 1) %
(ctx.target_shape[0] * ctx.target_shape[1]) ==
0 &&
(*(dst_shape.end() - 2) == 1 ||
*(dst_shape.end() - 2) % ctx.target_shape[0] == 0)) {
// Shapecast (..., 128) -> (..., m * 128).
layout_in.hasNativeTiling(ctx.target_shape) &&
layout_out.tiling()[0] == 1 &&
*(src_shape.end() - 1) == src_vreg_slice[1] &&
*(src_shape.end() - 2) % src_vreg_slice[0] == 0 &&
*(dst_shape.end() - 1) % dst_vreg_slice[1] == 0) {
// Shapecast (..., 128) -> (..., m * 128 * packing).
no_op = true;
}
FAILUREOR_ASSIGN_OR_RETURN(

View File

@ -1236,52 +1236,47 @@ class VectorLayoutInferer {
setLayout(op, layout, layout);
return success();
}
const unsigned bitwidth = src_ty.getElementTypeBitWidth();
const auto native_tiling = nativeTiling(bitwidth);
// Lane (un)tiling.
if (layout.tiling()[1] == target_shape_[1] &&
src_ty.getDimSize(src_ty.getRank() - 1) !=
if (src_ty.getDimSize(src_ty.getRank() - 1) !=
res_shape[res_shape.size() - 1] &&
src_ty.getDimSize(src_ty.getRank() - 1) % layout.tiling()[1] == 0 &&
res_shape[res_shape.size() - 1] % layout.tiling()[1] == 0) {
// TODO(jevinjiang): support shapecast along lane with any bitwidth.
if (src_ty.getElementTypeBitWidth() != kNativeBitwidth) {
NYI("Shapecast along lane dimension when bitwidth is not 32");
}
// When we shapecast from input shape (..., m * target_shape_[1]) to
// output shape (..., target_shape_[1]), the reshape becomes no-op when
// input is densely packed with tiling (1, target_shape_[1]) and
// output has the native tiling.
const int packing = kNativeBitwidth / bitwidth;
const auto elements_per_vreg = native_tiling[0] * native_tiling[1];
// When we shapecast from input shape
// (..., m * target_shape_[1] * packing) to output shape
// (..., target_shape_[1]), the reshape becomes no-op when input is
// densely packed with tiling (1, target_shape_[1] * packing) and output
// has the native tiling.
if (*(res_shape.end() - 1) == target_shape_[1] &&
*(res_shape.end() - 2) % target_shape_[0] == 0 &&
*(src_shape.end() - 1) % (target_shape_[0] * target_shape_[1]) ==
0 &&
(*(src_shape.end() - 2) == 1 ||
*(src_shape.end() - 2) % target_shape_[0] == 0)) {
// Inferring in_layout to have tiling (1, 128) triggers any
*(res_shape.end() - 2) % native_tiling[0] == 0 &&
*(src_shape.end() - 1) % elements_per_vreg == 0) {
// Inferring in_layout to have tiling (1, 128 * packing) triggers any
// necessary relayout before shapecast.
setLayout(op,
VectorLayout(layout.bitwidth(), {0, 0},
{1, target_shape_[1]}, ImplicitDim::kNone),
VectorLayout(layout.bitwidth(), {0, 0}, default_tiling_,
ImplicitDim::kNone));
setLayout(
op,
VectorLayout(layout.bitwidth(), {0, 0},
{1, target_shape_[1] * packing}, ImplicitDim::kNone),
VectorLayout(layout.bitwidth(), {0, 0}, native_tiling,
ImplicitDim::kNone));
return success();
}
// When we shapecast from input shape (..., target_shape_[1]) to
// output shape (..., m * target_shape_[1]), the reshape becomes no-op
// when input has the native tiling and output is densely packed with
// tiling (1, target_shape_[1]).
// When we shapecast from input shape (..., target_shape_[1]) to output
// shape (..., m * target_shape_[1] * packing), the reshape becomes
// no-op when input has the native tiling and output is densely packed
// with tiling (1, target_shape_[1] * packing).
if (*(src_shape.end() - 1) == target_shape_[1] &&
*(src_shape.end() - 2) % target_shape_[0] == 0 &&
*(res_shape.end() - 1) % (target_shape_[0] * target_shape_[1]) ==
0 &&
(*(res_shape.end() - 2) == 1 ||
*(res_shape.end() - 2) % target_shape_[0] == 0)) {
*(src_shape.end() - 2) % native_tiling[0] == 0 &&
*(res_shape.end() - 1) % elements_per_vreg == 0) {
setLayout(op,
VectorLayout(layout.bitwidth(), {0, 0}, default_tiling_,
VectorLayout(layout.bitwidth(), {0, 0}, native_tiling,
ImplicitDim::kNone),
VectorLayout(layout.bitwidth(), {0, 0},
{1, target_shape_[1]}, ImplicitDim::kNone));
{1, target_shape_[1] * packing},
ImplicitDim::kNone));
return success();
}
@ -1289,8 +1284,6 @@ class VectorLayoutInferer {
op.emitOpError("unsupported shape cast");
return failure();
}
unsigned bitwidth = src_ty.getElementTypeBitWidth();
auto native_tiling = nativeTiling(bitwidth);
if (layout.tiling() != native_tiling) {
layout = VectorLayout(bitwidth, layout.offsets(), native_tiling,
layout.implicit_dim());