mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Mosaic:TPU] Add relayout for adding minor implicit dim and relax some offset restrictions on similar shape cast
This factors out some logic from the apply-vector-layout shape cast rule where we insert a minor dimension, relaxes some offset restrictions on it, and uses it for the relayout. PiperOrigin-RevId: 702993092
This commit is contained in:
parent
101168740e
commit
8163e74e45
@ -342,8 +342,13 @@ def TPU_RepeatOp : TPU_Op<"repeat", [Pure]> {
|
||||
}
|
||||
|
||||
def TPU_BroadcastInSublanesOp : TPU_Op<"broadcast_in_sublanes", [Pure]> {
|
||||
let description = [{
|
||||
For each sublane `i`, broadcasts the value in lane `lane + i` along the entire
|
||||
sublane. If `lane + i` is not in [0, lane_count), then the value in sublane `i`
|
||||
is not defined (can be anything).
|
||||
}];
|
||||
let arguments = (ins
|
||||
AnyVectorOfNonZeroRank:$source, // All sublanes should be equal.
|
||||
TPU_Vreg:$source, // All sublanes should be equal.
|
||||
I32Attr:$lane // Coordinates of the first element to take.
|
||||
);
|
||||
// Output shape should be the same, except for position dim which contains
|
||||
|
@ -655,6 +655,105 @@ FailureOr<SmallVector<Layout>> getInLayouts(
|
||||
return in_layouts;
|
||||
}
|
||||
|
||||
// Insert a minor dimension to the implicit shape. The original minor dimension
|
||||
// becomes the new second minor dimension, laid out across sublanes.
|
||||
//
|
||||
// The returned vreg array uses the original tiling and the offsets specified in
|
||||
// new_offsets to hold the value with the new implicit shape.
|
||||
//
|
||||
// Args:
|
||||
// vregs: The vreg array with *implicit* array shape.
|
||||
// ishape: The implicit shape of the represented value.
|
||||
// layout: The layout used for the represented value. The implicit
|
||||
// dimension is ignored, since this function operates directly at
|
||||
// the level of the implicit shape.
|
||||
// new_offsets: The offsets to use for the layout of the returned vreg array.
|
||||
FailureOr<xla::Array<Value>> insertImplicitMinorDimension(
|
||||
RewriteContext &ctx, OpBuilder &builder, const Location loc,
|
||||
const xla::Array<Value> &vregs, const ArrayRef<int64_t> ishape,
|
||||
const VectorLayout &layout, const LayoutOffsets new_offsets) {
|
||||
if (layout.bitwidth() != 32 || !layout.hasNativeTiling(ctx.target_shape)) {
|
||||
return emitError(loc, "Not implemented: Unsupported bitwidth or tiling");
|
||||
}
|
||||
if (layout.offsets()[1].has_value()) {
|
||||
if (!new_offsets[0]) {
|
||||
// TODO(tlongeri): This can only be valid if the dim size is 1.
|
||||
return emitError(loc, "Not implemented: Replication mismatch");
|
||||
}
|
||||
if (*new_offsets[0] != *layout.offsets()[1] % ctx.target_shape[0] &&
|
||||
*layout.offsets()[1] + *(ishape.end() - 1) > ctx.target_shape[1]) {
|
||||
// This requires blending data from different vregs.
|
||||
return emitError(loc,
|
||||
"Not implemented: Misaligned offsets and shape does not "
|
||||
"fit in one vreg");
|
||||
}
|
||||
}
|
||||
// new_layout is only to get the new vreg array shape, the implicit dim is
|
||||
// irrelevant (since we already have the implicit shape):
|
||||
const VectorLayout new_layout(layout.bitwidth(), new_offsets, layout.tiling(),
|
||||
VectorLayout::ImplicitDim::kNone);
|
||||
SmallVector<int64_t> new_ishape(ishape);
|
||||
new_ishape.push_back(1);
|
||||
xla::Array<Value> new_vregs(new_layout.tileArrayShape(
|
||||
/*src_is_implicit=*/true, /*res_is_implicit=*/true, std::move(new_ishape),
|
||||
ctx.target_shape));
|
||||
// Preallocate an indices vector to avoid repeated allocations:
|
||||
SmallVector<int64_t> idxs;
|
||||
new_vregs.Each([&](const absl::Span<const int64_t> dst_idx,
|
||||
Value *const dst_vreg) {
|
||||
// Indices of the new vreg in the new vreg array:
|
||||
const int64_t new_2nd_minor_idx = *(dst_idx.end() - 2);
|
||||
const int64_t new_3rd_minor_idx = *(dst_idx.end() - 3);
|
||||
idxs.assign(dst_idx.begin(), dst_idx.end());
|
||||
if (!layout.offsets()[0].has_value() && new_3rd_minor_idx != 0) {
|
||||
// All vregs along that dimension are the same
|
||||
*(idxs.end() - 3) = 0;
|
||||
*dst_vreg = new_vregs(idxs);
|
||||
} else if (!layout.offsets()[1].has_value() && new_2nd_minor_idx != 0) {
|
||||
// All vregs along that dimension are the same
|
||||
*(idxs.end() - 2) = 0;
|
||||
*dst_vreg = new_vregs(idxs);
|
||||
} else {
|
||||
// dst_vreg will hold slice [row_idx, col_idx:(col_idx + target_shape[0])]
|
||||
// of the after-offsets source shape
|
||||
const int64_t row_idx =
|
||||
layout.offsets()[0] ? new_3rd_minor_idx + *layout.offsets()[0] : 0;
|
||||
const int64_t col_idx = layout.offsets()[1]
|
||||
? new_2nd_minor_idx * ctx.target_shape[0] +
|
||||
*layout.offsets()[1] - *new_offsets[0]
|
||||
: 0;
|
||||
|
||||
idxs.pop_back();
|
||||
*(idxs.end() - 2) = row_idx / ctx.target_shape[0];
|
||||
*(idxs.end() - 1) = col_idx / ctx.target_shape[1];
|
||||
Value src_vreg = vregs(idxs);
|
||||
// TODO(tlongeri): We can sometimes skip operations when dst_vreg will
|
||||
// hold a single non-padding element (first or last) and we don't need
|
||||
// replication in the output.
|
||||
if (layout.offsets()[0].has_value()) {
|
||||
// [ . . . . . . . . ] [ . . . . a b c d ]
|
||||
// [ . . . . a b c d ] => [ . . . . a b c d ]
|
||||
// [ . . . . . . . . ] [ . . . . a b c d ]
|
||||
// [ . . . . . . . . ] [ . . . . a b c d ]
|
||||
src_vreg = broadcastSublane(
|
||||
builder, src_vreg,
|
||||
/*sublane_idx=*/row_idx % ctx.target_shape[0], ctx.target_shape);
|
||||
}
|
||||
if (layout.offsets()[1].has_value()) {
|
||||
// [ . . . . a b c d ] [ a a a a a a a a ]
|
||||
// [ . . . . a b c d ] => [ b b b b b b b b ]
|
||||
// [ . . . . a b c d ] [ c c c c c c c c ]
|
||||
// [ . . . . a b c d ] [ d d d d d d d d ]
|
||||
src_vreg = builder.create<BroadcastInSublanesOp>(
|
||||
loc, src_vreg.getType(), src_vreg,
|
||||
/*lane=*/col_idx % ctx.target_shape[1]);
|
||||
}
|
||||
*dst_vreg = src_vreg;
|
||||
}
|
||||
});
|
||||
return new_vregs;
|
||||
}
|
||||
|
||||
LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
@ -4155,54 +4254,16 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op,
|
||||
layout_in.bitwidth() == 32 &&
|
||||
layout_in.hasNativeTiling(ctx.target_shape) &&
|
||||
layout_in.tiling() == layout_out.tiling() &&
|
||||
layout_in.offsets()[0].value_or(0) == 0 &&
|
||||
layout_in.offsets()[1] == 0 && layout_out.offsets()[0] == 0
|
||||
// layout_out.offsets[1] can be anything, as we produce a
|
||||
// replicated result
|
||||
) {
|
||||
// First, insert the new singleton lane dimension.
|
||||
SmallVector<int64_t> s = layout_in.implicitShape(src_shape);
|
||||
s.push_back(1);
|
||||
xla::Array<Value> dst_vregs_local(layout_out.tileArrayShape(
|
||||
/*src_is_implicit=*/true, /*res_is_implicit=*/true, std::move(s),
|
||||
ctx.target_shape));
|
||||
TPU_ASSERT_EQ_OP(dst_vregs_local.dimensions().back(),
|
||||
1); // We're inserting a singleton dimension
|
||||
dst_vregs_local.Each(
|
||||
[&](const absl::Span<const int64_t> dst_idx, Value *const dst_vreg) {
|
||||
const int64_t col_idx = *(dst_idx.end() - 2);
|
||||
const int64_t row_idx = *(dst_idx.end() - 3);
|
||||
auto [sublanes_in_lane, rem] =
|
||||
std::div(ctx.target_shape[1], ctx.target_shape[0]);
|
||||
CHECK_EQ(rem, 0);
|
||||
if (!layout_in.offsets()[0].has_value() && row_idx != 0) {
|
||||
return; // All vregs along that dimension are the same.
|
||||
}
|
||||
SmallVector<int64_t> src_idx(toArrayRef(dst_idx));
|
||||
src_idx.pop_back();
|
||||
*(src_idx.end() - 2) /= ctx.target_shape[0];
|
||||
*(src_idx.end() - 1) /= sublanes_in_lane;
|
||||
Value col_vreg = src_vregs(src_idx);
|
||||
// BroadcastInSublanesOp requires the sublanes to be replicated.
|
||||
if (layout_in.offsets()[0].has_value()) {
|
||||
const int32_t sublane = row_idx % ctx.target_shape[0];
|
||||
col_vreg = broadcastSublane(builder, col_vreg, sublane,
|
||||
ctx.target_shape);
|
||||
}
|
||||
*dst_vreg = builder.create<BroadcastInSublanesOp>(
|
||||
col_vreg.getType(), col_vreg,
|
||||
/*lane=*/(col_idx % sublanes_in_lane) * ctx.target_shape[0]);
|
||||
});
|
||||
if (!layout_in.offsets()[0].has_value()) {
|
||||
// Broadcast the sublane vregs.
|
||||
// TODO(tlongeri): This could be done more efficiently
|
||||
dst_vregs_local.Each([&](const absl::Span<const int64_t> dst_idx,
|
||||
Value *const dst_vreg) {
|
||||
SmallVector<int64_t> first_row_idx(toArrayRef(dst_idx));
|
||||
*(first_row_idx.end() - 3) = 0;
|
||||
*dst_vreg = dst_vregs_local(first_row_idx);
|
||||
});
|
||||
}
|
||||
(!layout_in.offsets()[1].has_value() ||
|
||||
*layout_in.offsets()[1] % ctx.target_shape[0] ==
|
||||
layout_out.offsets()[0] ||
|
||||
*layout_in.offsets()[1] + src_tiled_dims[1] <=
|
||||
ctx.target_shape[1])) {
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
xla::Array<Value> dst_vregs_local,
|
||||
insertImplicitMinorDimension(ctx, builder, op.getLoc(), src_vregs,
|
||||
layout_in.implicitShape(src_shape),
|
||||
layout_in, layout_out.offsets()));
|
||||
// Now, reshape the major axes of the vreg array.
|
||||
dst_vregs_local.Reshape(
|
||||
layout_out.tileArrayImplicitShape(dst_shape, ctx.target_shape));
|
||||
@ -6370,6 +6431,26 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeImplicitDim(
|
||||
});
|
||||
return std::make_pair(dst, new_vregs);
|
||||
}
|
||||
if (src.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
|
||||
dst_implicit_dim == VectorLayout::ImplicitDim::kMinor &&
|
||||
src.bitwidth() == 32 && src.hasNativeTiling(ctx.target_shape)) {
|
||||
// TODO(tlongeri): Make insertImplicitMinorDimension more flexible about
|
||||
// offsets, then we can pass dst_offset_hints directly.
|
||||
const LayoutOffset dst_2nd_minor_offset =
|
||||
!src.offsets()[1] || *src.offsets()[1] + *(vty.getShape().end() - 1) <=
|
||||
ctx.target_shape[1]
|
||||
? dst_offset_hints[0]
|
||||
: LayoutOffset(*src.offsets()[1] % ctx.target_shape[0]);
|
||||
VectorLayout dst(src.bitwidth(),
|
||||
{dst_2nd_minor_offset, dst_offset_hints[1]}, src.tiling(),
|
||||
VectorLayout::ImplicitDim::kMinor);
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
xla::Array<Value> dst_vregs,
|
||||
insertImplicitMinorDimension(ctx, builder, loc, vregs,
|
||||
src.implicitShape(vty.getShape()), src,
|
||||
dst.offsets()));
|
||||
return std::make_pair(dst, std::move(dst_vregs));
|
||||
}
|
||||
return emitError(loc,
|
||||
"Not implemented: Unsupported implicit dim change: from ")
|
||||
<< src << " to " << dst_implicit_dim;
|
||||
|
Loading…
x
Reference in New Issue
Block a user