mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic] Expand vector.shape_cast no-op detection for expanding/shrinking lane shape casts
- Remove restriction on sublane tiling being 1 or a multiple of 8 on the expanded shape. - Support packed types. PiperOrigin-RevId: 637777493
This commit is contained in:
parent
3fb9acf01a
commit
97f9a5e80e
@ -3454,6 +3454,10 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op,
|
||||
using Tiling = std::array<int64_t, 2>;
|
||||
const VectorLayout &layout_in = *layouts_in.front();
|
||||
const VectorLayout &layout_out = *layouts_out.front();
|
||||
TPU_ASSERT_EQ_OP(
|
||||
layout_in.bitwidth(),
|
||||
layout_out.bitwidth()); // This should be guaranteed through MLIR
|
||||
// verifier plus our layoutIsValidForValue check
|
||||
ImplicitLocOpBuilder builder(op.getLoc(), &op);
|
||||
auto shape_cast_op = cast<vector::ShapeCastOp>(op);
|
||||
const VectorType src_ty = shape_cast_op.getSourceVectorType();
|
||||
@ -3462,6 +3466,10 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<int64_t> dst_shape = dst_ty.getShape();
|
||||
const int layout_rank = layout_in.layout_rank();
|
||||
bool no_op = false;
|
||||
const std::array<int64_t, 2> src_vreg_slice =
|
||||
layout_in.vregSlice(ctx.target_shape);
|
||||
const std::array<int64_t, 2> dst_vreg_slice =
|
||||
layout_out.vregSlice(ctx.target_shape);
|
||||
// TODO(tlongeri): It looks like this could probably be simplified by using
|
||||
// VectorLayout::implicitShape()
|
||||
if (layout_in == layout_out && src_ty.getShape().take_back(layout_rank) ==
|
||||
@ -3508,33 +3516,23 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op,
|
||||
layout_out.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
|
||||
layout_in.offsets() == layout_out.offsets() &&
|
||||
layout_in.offsets() == LayoutOffsets{0, 0} &&
|
||||
layout_in.tiling() == Tiling{1, ctx.target_shape[1]} &&
|
||||
layout_out.hasNaturalTopology(ctx.target_shape) &&
|
||||
*(dst_shape.end() - 1) != *(src_shape.end() - 1) &&
|
||||
*(dst_shape.end() - 1) == ctx.target_shape[1] &&
|
||||
*(dst_shape.end() - 2) % ctx.target_shape[0] == 0 &&
|
||||
*(src_shape.end() - 1) %
|
||||
(ctx.target_shape[0] * ctx.target_shape[1]) ==
|
||||
0 &&
|
||||
(*(src_shape.end() - 2) == 1 ||
|
||||
*(src_shape.end() - 2) % ctx.target_shape[0] == 0)) {
|
||||
// Shapecast (..., m * 128) -> (..., 128).
|
||||
layout_in.tiling()[0] == 1 &&
|
||||
layout_out.hasNativeTiling(ctx.target_shape) &&
|
||||
*(dst_shape.end() - 1) == dst_vreg_slice[1] &&
|
||||
*(dst_shape.end() - 2) % dst_vreg_slice[0] == 0 &&
|
||||
*(src_shape.end() - 1) % src_vreg_slice[1] == 0) {
|
||||
// Shapecast (..., m * 128 * packing) -> (..., 128).
|
||||
no_op = true;
|
||||
} else if (layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
|
||||
layout_out.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
|
||||
layout_in.offsets() == LayoutOffsets{0, 0} &&
|
||||
layout_out.offsets() == LayoutOffsets{0, 0} &&
|
||||
layout_in.hasNaturalTopology(ctx.target_shape) &&
|
||||
layout_out.tiling() == Tiling{1, ctx.target_shape[1]} &&
|
||||
*(src_shape.end() - 1) != *(dst_shape.end() - 1) &&
|
||||
*(src_shape.end() - 1) == ctx.target_shape[1] &&
|
||||
*(src_shape.end() - 2) % ctx.target_shape[0] == 0 &&
|
||||
*(dst_shape.end() - 1) %
|
||||
(ctx.target_shape[0] * ctx.target_shape[1]) ==
|
||||
0 &&
|
||||
(*(dst_shape.end() - 2) == 1 ||
|
||||
*(dst_shape.end() - 2) % ctx.target_shape[0] == 0)) {
|
||||
// Shapecast (..., 128) -> (..., m * 128).
|
||||
layout_in.hasNativeTiling(ctx.target_shape) &&
|
||||
layout_out.tiling()[0] == 1 &&
|
||||
*(src_shape.end() - 1) == src_vreg_slice[1] &&
|
||||
*(src_shape.end() - 2) % src_vreg_slice[0] == 0 &&
|
||||
*(dst_shape.end() - 1) % dst_vreg_slice[1] == 0) {
|
||||
// Shapecast (..., 128) -> (..., m * 128 * packing).
|
||||
no_op = true;
|
||||
}
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
|
@ -1236,52 +1236,47 @@ class VectorLayoutInferer {
|
||||
setLayout(op, layout, layout);
|
||||
return success();
|
||||
}
|
||||
const unsigned bitwidth = src_ty.getElementTypeBitWidth();
|
||||
const auto native_tiling = nativeTiling(bitwidth);
|
||||
// Lane (un)tiling.
|
||||
if (layout.tiling()[1] == target_shape_[1] &&
|
||||
src_ty.getDimSize(src_ty.getRank() - 1) !=
|
||||
if (src_ty.getDimSize(src_ty.getRank() - 1) !=
|
||||
res_shape[res_shape.size() - 1] &&
|
||||
src_ty.getDimSize(src_ty.getRank() - 1) % layout.tiling()[1] == 0 &&
|
||||
res_shape[res_shape.size() - 1] % layout.tiling()[1] == 0) {
|
||||
// TODO(jevinjiang): support shapecast along lane with any bitwidth.
|
||||
if (src_ty.getElementTypeBitWidth() != kNativeBitwidth) {
|
||||
NYI("Shapecast along lane dimension when bitwidth is not 32");
|
||||
}
|
||||
|
||||
// When we shapecast from input shape (..., m * target_shape_[1]) to
|
||||
// output shape (..., target_shape_[1]), the reshape becomes no-op when
|
||||
// input is densely packed with tiling (1, target_shape_[1]) and
|
||||
// output has the native tiling.
|
||||
const int packing = kNativeBitwidth / bitwidth;
|
||||
const auto elements_per_vreg = native_tiling[0] * native_tiling[1];
|
||||
// When we shapecast from input shape
|
||||
// (..., m * target_shape_[1] * packing) to output shape
|
||||
// (..., target_shape_[1]), the reshape becomes no-op when input is
|
||||
// densely packed with tiling (1, target_shape_[1] * packing) and output
|
||||
// has the native tiling.
|
||||
if (*(res_shape.end() - 1) == target_shape_[1] &&
|
||||
*(res_shape.end() - 2) % target_shape_[0] == 0 &&
|
||||
*(src_shape.end() - 1) % (target_shape_[0] * target_shape_[1]) ==
|
||||
0 &&
|
||||
(*(src_shape.end() - 2) == 1 ||
|
||||
*(src_shape.end() - 2) % target_shape_[0] == 0)) {
|
||||
// Inferring in_layout to have tiling (1, 128) triggers any
|
||||
*(res_shape.end() - 2) % native_tiling[0] == 0 &&
|
||||
*(src_shape.end() - 1) % elements_per_vreg == 0) {
|
||||
// Inferring in_layout to have tiling (1, 128 * packing) triggers any
|
||||
// necessary relayout before shapecast.
|
||||
setLayout(op,
|
||||
VectorLayout(layout.bitwidth(), {0, 0},
|
||||
{1, target_shape_[1]}, ImplicitDim::kNone),
|
||||
VectorLayout(layout.bitwidth(), {0, 0}, default_tiling_,
|
||||
ImplicitDim::kNone));
|
||||
setLayout(
|
||||
op,
|
||||
VectorLayout(layout.bitwidth(), {0, 0},
|
||||
{1, target_shape_[1] * packing}, ImplicitDim::kNone),
|
||||
VectorLayout(layout.bitwidth(), {0, 0}, native_tiling,
|
||||
ImplicitDim::kNone));
|
||||
return success();
|
||||
}
|
||||
|
||||
// When we shapecast from input shape (..., target_shape_[1]) to
|
||||
// output shape (..., m * target_shape_[1]), the reshape becomes no-op
|
||||
// when input has the native tiling and output is densely packed with
|
||||
// tiling (1, target_shape_[1]).
|
||||
// When we shapecast from input shape (..., target_shape_[1]) to output
|
||||
// shape (..., m * target_shape_[1] * packing), the reshape becomes
|
||||
// no-op when input has the native tiling and output is densely packed
|
||||
// with tiling (1, target_shape_[1] * packing).
|
||||
if (*(src_shape.end() - 1) == target_shape_[1] &&
|
||||
*(src_shape.end() - 2) % target_shape_[0] == 0 &&
|
||||
*(res_shape.end() - 1) % (target_shape_[0] * target_shape_[1]) ==
|
||||
0 &&
|
||||
(*(res_shape.end() - 2) == 1 ||
|
||||
*(res_shape.end() - 2) % target_shape_[0] == 0)) {
|
||||
*(src_shape.end() - 2) % native_tiling[0] == 0 &&
|
||||
*(res_shape.end() - 1) % elements_per_vreg == 0) {
|
||||
setLayout(op,
|
||||
VectorLayout(layout.bitwidth(), {0, 0}, default_tiling_,
|
||||
VectorLayout(layout.bitwidth(), {0, 0}, native_tiling,
|
||||
ImplicitDim::kNone),
|
||||
VectorLayout(layout.bitwidth(), {0, 0},
|
||||
{1, target_shape_[1]}, ImplicitDim::kNone));
|
||||
{1, target_shape_[1] * packing},
|
||||
ImplicitDim::kNone));
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -1289,8 +1284,6 @@ class VectorLayoutInferer {
|
||||
op.emitOpError("unsupported shape cast");
|
||||
return failure();
|
||||
}
|
||||
unsigned bitwidth = src_ty.getElementTypeBitWidth();
|
||||
auto native_tiling = nativeTiling(bitwidth);
|
||||
if (layout.tiling() != native_tiling) {
|
||||
layout = VectorLayout(bitwidth, layout.offsets(), native_tiling,
|
||||
layout.implicit_dim());
|
||||
|
Loading…
x
Reference in New Issue
Block a user