[XLA:Mosaic] Support shrinking lane dim to 128 in shapecast.

PiperOrigin-RevId: 574965883
This commit is contained in:
Jevin Jiang 2023-10-19 12:25:12 -07:00 committed by jax authors
parent 741b71fe85
commit bb30d3ee9f
3 changed files with 61 additions and 1 deletions

View File

@ -2261,6 +2261,7 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op,
if (!layouts_out.front().has_value()) {
return op.emitOpError("Expected non-null output layout");
}
using Tiling = std::array<int64_t, 2>;
const VectorLayout &layout_in = *layouts_in.front();
const VectorLayout &layout_out = *layouts_out.front();
ImplicitLocOpBuilder builder(op.getLoc(), &op);
@ -2313,6 +2314,17 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op,
*(dst_shape.end() - 2) % layout_in.tiling()[0] == 0 &&
*(src_shape.end() - 2) % layout_in.tiling()[0] == 0) {
no_op = true;
} else if (layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
layout_out.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
layout_in.offsets() == layout_out.offsets() &&
layout_in.offsets() == LayoutOffsets{0, 0} &&
layout_in.tiling() == Tiling{1, ctx.target_shape[1]} &&
layout_out.hasNaturalTopology(ctx.target_shape) &&
*(dst_shape.end() - 1) != *(src_shape.end() - 1) &&
*(dst_shape.end() - 1) == ctx.target_shape[1] &&
*(dst_shape.end() - 2) % layout_out.tiling()[0] == 0 &&
*(src_shape.end() - 1) % layout_in.tiling()[1] == 0) {
no_op = true;
}
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> src_vregs,

View File

@ -995,7 +995,7 @@ class VectorLayoutInferer {
setLayout(op, layout, layout);
return success();
}
// Sublane (un)tiling
// Sublane (un)tiling.
if (res_rank >= 2 && layout.offsets() == LayoutOffsets{0, 0} &&
layout.tiling()[1] == target_shape_[1] &&
src_ty.getDimSize(src_ty.getRank() - 1) ==
@ -1005,6 +1005,42 @@ class VectorLayoutInferer {
setLayout(op, layout, layout);
return success();
}
// Lane (un)tiling.
if (res_rank >= 2 && layout.offsets() == LayoutOffsets{0, 0} &&
layout.tiling()[1] == target_shape_[1] &&
src_ty.getDimSize(src_ty.getRank() - 1) !=
res_shape[res_shape.size() - 1] &&
src_ty.getDimSize(src_ty.getRank() - 1) % layout.tiling()[1] == 0 &&
res_shape[res_shape.size() - 1] % layout.tiling()[1] == 0) {
// TODO(jevinjiang): support shapecast along lane with any bitwidth.
if (src_ty.getElementTypeBitWidth() != kNativeBitwidth) {
NYI("Shapecast along lane dimension when bitwidth is not 32");
}
// Inferring in_layout to have tiling (1, 128) triggers any necessary
// relayout before shapecast.
setInLayout(op,
{VectorLayout(layout.bitwidth(), layout.offsets(),
{1, target_shape_[1]}, ImplicitDim::kNone)});
// If the input has tiling (1, target_shape_[1]) and the last two dims
// of result shape are [n * target_shape_[0], target_shape_[1]], the
// reshape becomes a no-op if only we change the tiling to match
// target_shape_. For example, reshaping to layouts like the following
// will not require any data movement:
// 8x8x128 (1, 128)
// 4x16x128 (1, 128)
// 8x8x128 (8, 128)
// ....
if (res_shape[res_shape.size() - 1] == target_shape_[1] &&
res_shape[res_shape.size() - 2] % target_shape_[0] == 0) {
setOutLayout(op, VectorLayout(layout.bitwidth(), layout.offsets(),
default_tiling_, ImplicitDim::kNone));
return success();
}
// TODO(b/299253805): support shapecast along lane for other cases.
op.emitOpError("unsupported shape cast");
return failure();
}
unsigned bitwidth = src_ty.getElementTypeBitWidth();
auto native_tiling = nativeTiling(bitwidth);
if (layout.tiling() != native_tiling) {

View File

@ -2748,6 +2748,18 @@ def _vector_shape_cast_rule(ctx: RewriteContext, op: vector.ShapeCastOp, # pyli
and src_ty.shape[-2] % layout_in.tiling[-2] == 0
):
no_op = True
elif (
layout_in.implicit_dim is None
and layout_out.implicit_dim is None
and layout_out.offsets == layout_in.offsets == (0, 0)
and layout_in.tiling == (1, TARGET_SHAPE.lanes)
and layout_out.has_natural_topology
and dst_ty.shape[-1] != src_ty.shape[-1]
and dst_ty.shape[-1] == TARGET_SHAPE.lanes
and dst_ty.shape[-2] % layout_out.tiling[-2] == 0
and src_ty.shape[-1] % layout_in.tiling[-1] == 0
): # n x (m * 128) -> n x m x 128.
no_op = True
src_vregs = disassemble(layout_in, op.source)
if no_op: