[Mosaic:TPU] Add relayout for adding minor implicit dim and relax some offset restrictions on similar shape cast

This factors out some logic from the apply-vector-layout shape cast rule where we insert a minor dimension, relaxes some offset restrictions on it, and uses it for the relayout.

PiperOrigin-RevId: 702993092
This commit is contained in:
Tomás Longeri 2024-12-04 23:12:52 -08:00 committed by jax authors
parent 101168740e
commit 8163e74e45
2 changed files with 135 additions and 49 deletions

View File

@ -342,8 +342,13 @@ def TPU_RepeatOp : TPU_Op<"repeat", [Pure]> {
}
def TPU_BroadcastInSublanesOp : TPU_Op<"broadcast_in_sublanes", [Pure]> {
let description = [{
For each sublane `i`, broadcasts the value in lane `lane + i` along the entire
sublane. If `lane + i` is not in [0, lane_count), then the value in sublane `i`
is not defined (can be anything).
}];
let arguments = (ins
AnyVectorOfNonZeroRank:$source, // All sublanes should be equal.
TPU_Vreg:$source, // All sublanes should be equal.
I32Attr:$lane // Coordinates of the first element to take.
);
// Output shape should be the same, except for position dim which contains

View File

@ -655,6 +655,105 @@ FailureOr<SmallVector<Layout>> getInLayouts(
return in_layouts;
}
// Insert a minor dimension to the implicit shape. The original minor dimension
// becomes the new second minor dimension, laid out across sublanes.
//
// The returned vreg array uses the original tiling and the offsets specified in
// new_offsets to hold the value with the new implicit shape.
//
// Args:
// vregs: The vreg array with *implicit* array shape.
// ishape: The implicit shape of the represented value.
// layout: The layout used for the represented value. The implicit
// dimension is ignored, since this function operates directly at
// the level of the implicit shape.
// new_offsets: The offsets to use for the layout of the returned vreg array.
FailureOr<xla::Array<Value>> insertImplicitMinorDimension(
RewriteContext &ctx, OpBuilder &builder, const Location loc,
const xla::Array<Value> &vregs, const ArrayRef<int64_t> ishape,
const VectorLayout &layout, const LayoutOffsets new_offsets) {
if (layout.bitwidth() != 32 || !layout.hasNativeTiling(ctx.target_shape)) {
return emitError(loc, "Not implemented: Unsupported bitwidth or tiling");
}
if (layout.offsets()[1].has_value()) {
if (!new_offsets[0]) {
// TODO(tlongeri): This can only be valid if the dim size is 1.
return emitError(loc, "Not implemented: Replication mismatch");
}
if (*new_offsets[0] != *layout.offsets()[1] % ctx.target_shape[0] &&
*layout.offsets()[1] + *(ishape.end() - 1) > ctx.target_shape[1]) {
// This requires blending data from different vregs.
return emitError(loc,
"Not implemented: Misaligned offsets and shape does not "
"fit in one vreg");
}
}
// new_layout is only to get the new vreg array shape, the implicit dim is
// irrelevant (since we already have the implicit shape):
const VectorLayout new_layout(layout.bitwidth(), new_offsets, layout.tiling(),
VectorLayout::ImplicitDim::kNone);
SmallVector<int64_t> new_ishape(ishape);
new_ishape.push_back(1);
xla::Array<Value> new_vregs(new_layout.tileArrayShape(
/*src_is_implicit=*/true, /*res_is_implicit=*/true, std::move(new_ishape),
ctx.target_shape));
// Preallocate an indices vector to avoid repeated allocations:
SmallVector<int64_t> idxs;
new_vregs.Each([&](const absl::Span<const int64_t> dst_idx,
Value *const dst_vreg) {
// Indices of the new vreg in the new vreg array:
const int64_t new_2nd_minor_idx = *(dst_idx.end() - 2);
const int64_t new_3rd_minor_idx = *(dst_idx.end() - 3);
idxs.assign(dst_idx.begin(), dst_idx.end());
if (!layout.offsets()[0].has_value() && new_3rd_minor_idx != 0) {
// All vregs along that dimension are the same
*(idxs.end() - 3) = 0;
*dst_vreg = new_vregs(idxs);
} else if (!layout.offsets()[1].has_value() && new_2nd_minor_idx != 0) {
// All vregs along that dimension are the same
*(idxs.end() - 2) = 0;
*dst_vreg = new_vregs(idxs);
} else {
// dst_vreg will hold slice [row_idx, col_idx:(col_idx + target_shape[0])]
// of the after-offsets source shape
const int64_t row_idx =
layout.offsets()[0] ? new_3rd_minor_idx + *layout.offsets()[0] : 0;
const int64_t col_idx = layout.offsets()[1]
? new_2nd_minor_idx * ctx.target_shape[0] +
*layout.offsets()[1] - *new_offsets[0]
: 0;
idxs.pop_back();
*(idxs.end() - 2) = row_idx / ctx.target_shape[0];
*(idxs.end() - 1) = col_idx / ctx.target_shape[1];
Value src_vreg = vregs(idxs);
// TODO(tlongeri): We can sometimes skip operations when dst_vreg will
// hold a single non-padding element (first or last) and we don't need
// replication in the output.
if (layout.offsets()[0].has_value()) {
// [ . . . . . . . . ] [ . . . . a b c d ]
// [ . . . . a b c d ] => [ . . . . a b c d ]
// [ . . . . . . . . ] [ . . . . a b c d ]
// [ . . . . . . . . ] [ . . . . a b c d ]
src_vreg = broadcastSublane(
builder, src_vreg,
/*sublane_idx=*/row_idx % ctx.target_shape[0], ctx.target_shape);
}
if (layout.offsets()[1].has_value()) {
// [ . . . . a b c d ] [ a a a a a a a a ]
// [ . . . . a b c d ] => [ b b b b b b b b ]
// [ . . . . a b c d ] [ c c c c c c c c ]
// [ . . . . a b c d ] [ d d d d d d d d ]
src_vreg = builder.create<BroadcastInSublanesOp>(
loc, src_vreg.getType(), src_vreg,
/*lane=*/col_idx % ctx.target_shape[1]);
}
*dst_vreg = src_vreg;
}
});
return new_vregs;
}
LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
@ -4155,54 +4254,16 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op,
layout_in.bitwidth() == 32 &&
layout_in.hasNativeTiling(ctx.target_shape) &&
layout_in.tiling() == layout_out.tiling() &&
layout_in.offsets()[0].value_or(0) == 0 &&
layout_in.offsets()[1] == 0 && layout_out.offsets()[0] == 0
// layout_out.offsets[1] can be anything, as we produce a
// replicated result
) {
// First, insert the new singleton lane dimension.
SmallVector<int64_t> s = layout_in.implicitShape(src_shape);
s.push_back(1);
xla::Array<Value> dst_vregs_local(layout_out.tileArrayShape(
/*src_is_implicit=*/true, /*res_is_implicit=*/true, std::move(s),
ctx.target_shape));
TPU_ASSERT_EQ_OP(dst_vregs_local.dimensions().back(),
1); // We're inserting a singleton dimension
dst_vregs_local.Each(
[&](const absl::Span<const int64_t> dst_idx, Value *const dst_vreg) {
const int64_t col_idx = *(dst_idx.end() - 2);
const int64_t row_idx = *(dst_idx.end() - 3);
auto [sublanes_in_lane, rem] =
std::div(ctx.target_shape[1], ctx.target_shape[0]);
CHECK_EQ(rem, 0);
if (!layout_in.offsets()[0].has_value() && row_idx != 0) {
return; // All vregs along that dimension are the same.
}
SmallVector<int64_t> src_idx(toArrayRef(dst_idx));
src_idx.pop_back();
*(src_idx.end() - 2) /= ctx.target_shape[0];
*(src_idx.end() - 1) /= sublanes_in_lane;
Value col_vreg = src_vregs(src_idx);
// BroadcastInSublanesOp requires the sublanes to be replicated.
if (layout_in.offsets()[0].has_value()) {
const int32_t sublane = row_idx % ctx.target_shape[0];
col_vreg = broadcastSublane(builder, col_vreg, sublane,
ctx.target_shape);
}
*dst_vreg = builder.create<BroadcastInSublanesOp>(
col_vreg.getType(), col_vreg,
/*lane=*/(col_idx % sublanes_in_lane) * ctx.target_shape[0]);
});
if (!layout_in.offsets()[0].has_value()) {
// Broadcast the sublane vregs.
// TODO(tlongeri): This could be done more efficiently
dst_vregs_local.Each([&](const absl::Span<const int64_t> dst_idx,
Value *const dst_vreg) {
SmallVector<int64_t> first_row_idx(toArrayRef(dst_idx));
*(first_row_idx.end() - 3) = 0;
*dst_vreg = dst_vregs_local(first_row_idx);
});
}
(!layout_in.offsets()[1].has_value() ||
*layout_in.offsets()[1] % ctx.target_shape[0] ==
layout_out.offsets()[0] ||
*layout_in.offsets()[1] + src_tiled_dims[1] <=
ctx.target_shape[1])) {
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> dst_vregs_local,
insertImplicitMinorDimension(ctx, builder, op.getLoc(), src_vregs,
layout_in.implicitShape(src_shape),
layout_in, layout_out.offsets()));
// Now, reshape the major axes of the vreg array.
dst_vregs_local.Reshape(
layout_out.tileArrayImplicitShape(dst_shape, ctx.target_shape));
@ -6370,6 +6431,26 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeImplicitDim(
});
return std::make_pair(dst, new_vregs);
}
if (src.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
dst_implicit_dim == VectorLayout::ImplicitDim::kMinor &&
src.bitwidth() == 32 && src.hasNativeTiling(ctx.target_shape)) {
// TODO(tlongeri): Make insertImplicitMinorDimension more flexible about
// offsets, then we can pass dst_offset_hints directly.
const LayoutOffset dst_2nd_minor_offset =
!src.offsets()[1] || *src.offsets()[1] + *(vty.getShape().end() - 1) <=
ctx.target_shape[1]
? dst_offset_hints[0]
: LayoutOffset(*src.offsets()[1] % ctx.target_shape[0]);
VectorLayout dst(src.bitwidth(),
{dst_2nd_minor_offset, dst_offset_hints[1]}, src.tiling(),
VectorLayout::ImplicitDim::kMinor);
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> dst_vregs,
insertImplicitMinorDimension(ctx, builder, loc, vregs,
src.implicitShape(vty.getShape()), src,
dst.offsets()));
return std::make_pair(dst, std::move(dst_vregs));
}
return emitError(loc,
"Not implemented: Unsupported implicit dim change: from ")
<< src << " to " << dst_implicit_dim;