[Mosaic:TPU] Vreg-slice-aligned offset changes with scratch retiling

PiperOrigin-RevId: 709133729
This commit is contained in:
Tomás Longeri 2024-12-23 13:04:33 -08:00 committed by jax authors
parent c57b49c606
commit 3c79b98cd9
2 changed files with 165 additions and 69 deletions

View File

@ -5823,7 +5823,8 @@ LogicalResult retileToLargeTileWithScratch(
RewriteContext &ctx, OpBuilder &builder, const Location loc,
xla::Array<Value> &dst_tiles, const std::array<int64_t, 2> &dst_tile,
const xla::Array<Value> &src_tiles, const std::array<int64_t, 2> &src_tile,
TypedValue<MemRefType> scratch_ref) {
TypedValue<MemRefType> scratch_ref, const int64_t store_vreg_delay,
const int64_t load_vreg_skips) {
if (dst_tile[0] % src_tile[0] != 0) {
return failure();
}
@ -5927,8 +5928,8 @@ LogicalResult retileToLargeTileWithScratch(
SmallVector<int64_t, 4> src_idx(rank);
dst_tiles.Each([&](absl::Span<const int64_t> dst_idx, Value *dst_vreg) {
int64_t dst_row_idx = *(dst_idx.end() - 2);
int64_t dst_col_idx = *(dst_idx.end() - 1);
int64_t vreg_idx_in_group = dst_col_idx % vregs_per_group;
int64_t dst_col_idx_with_skips = *(dst_idx.end() - 1) + load_vreg_skips;
int64_t vreg_idx_in_group = dst_col_idx_with_skips % vregs_per_group;
int64_t load_offset = sublanes_per_group * stored_group_cnt +
vreg_idx_in_group * sl_per_vreg * stride;
delayed_loads.push_back(
@ -5938,16 +5939,20 @@ LogicalResult retileToLargeTileWithScratch(
// the vregs from current group and now we need to store corresponding
// group of src vregs before actually emitting the loads.
if (vreg_idx_in_group == vregs_per_group - 1 ||
dst_col_idx == dst_tiles.dimensions().back() - 1) {
auto src_row_idx = dst_row_idx * vregs_per_group;
auto src_col_idx = dst_col_idx / vregs_per_group;
dst_idx.back() == dst_tiles.dimensions().back() - 1) {
auto base_src_row_idx = dst_row_idx * vregs_per_group - store_vreg_delay;
auto src_col_idx = dst_col_idx_with_skips / vregs_per_group;
std::copy(dst_idx.begin(), dst_idx.end(), src_idx.begin());
for (int vi = 0; vi < vregs_per_group; ++vi) {
if (src_row_idx + vi >= src_tiles.dim(rank - 2) ||
const int64_t src_row_idx = base_src_row_idx + vi;
if (src_row_idx < 0) {
continue;
}
if (src_row_idx >= src_tiles.dim(rank - 2) ||
src_col_idx >= src_tiles.dim(rank - 1)) {
break;
}
*(src_idx.end() - 2) = src_row_idx + vi;
*(src_idx.end() - 2) = src_row_idx;
*(src_idx.end() - 1) = src_col_idx;
Value src_vreg = src_tiles(src_idx);
src_vreg =
@ -5976,7 +5981,8 @@ LogicalResult retileToSmallTileWithScratch(
RewriteContext &ctx, OpBuilder &builder, const Location loc,
xla::Array<Value> &dst_tiles, const std::array<int64_t, 2> &dst_tile,
const xla::Array<Value> &src_tiles, const std::array<int64_t, 2> &src_tile,
TypedValue<MemRefType> scratch_ref) {
TypedValue<MemRefType> scratch_ref, const int64_t store_vreg_delay,
const int64_t load_vreg_skips) {
if (src_tile[0] % dst_tile[0] != 0) {
return failure();
}
@ -6103,8 +6109,8 @@ LogicalResult retileToSmallTileWithScratch(
SmallVector<int64_t, 4> dst_idx(rank);
src_tiles.Each([&](absl::Span<const int64_t> src_idx, Value src_vreg) {
int64_t src_row_idx = *(src_idx.end() - 2);
int64_t src_col_idx = *(src_idx.end() - 1);
int64_t vreg_idx_in_group = src_col_idx % vregs_per_group;
int64_t src_col_idx_with_delays = *(src_idx.end() - 1) + store_vreg_delay;
int64_t vreg_idx_in_group = src_col_idx_with_delays % vregs_per_group;
src_vreg = builder.create<tpu::BitcastVregOp>(loc, temp_vreg_ty, src_vreg);
if (use_shuffled_load) {
Value store_offset = mlirIndexConst(
@ -6126,16 +6132,20 @@ LogicalResult retileToSmallTileWithScratch(
// vregs' row, this indicates we have stored all the vregs needed to
// assemble a new group of dst vreg.
if (vreg_idx_in_group == vregs_per_group - 1 ||
src_col_idx == src_tiles.dimensions().back() - 1) {
auto dst_row_idx = src_row_idx * vregs_per_group;
auto dst_col_idx = src_col_idx / vregs_per_group;
src_idx.back() == src_tiles.dimensions().back() - 1) {
auto base_dst_row_idx = src_row_idx * vregs_per_group - load_vreg_skips;
auto dst_col_idx = src_col_idx_with_delays / vregs_per_group;
std::copy(src_idx.begin(), src_idx.end(), dst_idx.begin());
for (int vi = 0; vi < vregs_per_group; ++vi) {
if (dst_row_idx + vi >= dst_tiles.dim(rank - 2) ||
const int64_t dst_row_idx = base_dst_row_idx + vi;
if (dst_row_idx < 0) {
continue;
}
if (dst_row_idx >= dst_tiles.dim(rank - 2) ||
dst_col_idx >= dst_tiles.dim(rank - 1)) {
break;
}
*(dst_idx.end() - 2) = dst_row_idx + vi;
*(dst_idx.end() - 2) = dst_row_idx;
*(dst_idx.end() - 1) = dst_col_idx;
Value *dst_vreg = &dst_tiles(dst_idx);
int64_t load_offset =
@ -6160,18 +6170,70 @@ LogicalResult retileToSmallTileWithScratch(
// go/mosaic-retiling-in-scratch is the full internal documentation that
// includes more details about the TPU generations.
LogicalResult retileWithScratch(RewriteContext &ctx, OpBuilder &builder,
const Location loc,
xla::Array<Value> &dst_tiles,
const std::array<int64_t, 2> &dst_tiling,
const xla::Array<Value> &src_tiles,
const std::array<int64_t, 2> &src_tiling,
int packing) {
// Arguments:
// - shape: The non-implicit shape of the operand
// - dst_tiling: The desired result tiling
// - dst_offsets_hint: Hints for the result offsets. They may be used or
// ignored. See comments in the body of the function for
// more details.
// - src_vregs: The source vregs to retile.
// - src: The source layout
// Returns a pair holding the result layout (potentially using the hints) and
// the retiled vregs.
// TODO(tlongeri): Clean up the function parameters/signatures. We are passing
// in more information than strictly needed.
FailureOr<std::pair<VectorLayout, xla::Array<Value>>> retileWithScratch(
RewriteContext &ctx, OpBuilder &builder, const Location loc,
const ArrayRef<int64_t> shape, const std::array<int64_t, 2> dst_tiling,
const LayoutOffsets dst_offsets_hint, const xla::Array<Value> &src_vregs,
const VectorLayout &src) {
const int bitwidth = src.bitwidth();
const int packing = src.packing();
const std::array<int64_t, 2> src_tiling = src.tiling();
if (!(src_tiling[1] == ctx.target_shape[1] &&
dst_tiling[1] == ctx.target_shape[1] && src_tiling[0] % packing == 0 &&
dst_tiling[0] % packing == 0)) {
return failure();
}
const std::array<int64_t, 2> src_vreg_slice =
VectorLayout::vregSlice(ctx.target_shape, bitwidth, src_tiling);
const std::array<int64_t, 2> dst_vreg_slice =
VectorLayout::vregSlice(ctx.target_shape, bitwidth, dst_tiling);
// TODO(b/368088671): When sublane tiling changes, we should be able to
// preserve some replications from the source layout. But we need to
// make sure they are implemented efficiently and well-tested. For now, we
// just simply use 0 for the replicated offset after retiling.
const LayoutOffsets src_offsets = {src.offsets()[0].value_or(0),
src.offsets()[1].value_or(0)};
// The provided offset hints are used only if they align with the source
// offsets, else we default to the smallest possible aligned offsets.
LayoutOffsets dst_offsets = {*src_offsets[0] % dst_vreg_slice[0],
*src_offsets[1] % dst_vreg_slice[1]};
// On a given dimension, either the source vreg slice size divides the dest
// vreg slice size, or vice versa (depending on the dimension and whether it's
// small-to-large or large-to-small retiling). Offset changes are supported
// as long as they are aligned modulo the smaller of the two sizes.
const std::array<int64_t, 2> alignment = {
std::min(src_vreg_slice[0], dst_vreg_slice[0]),
std::min(src_vreg_slice[1], dst_vreg_slice[1])};
if (dst_offsets_hint[0].has_value() &&
(*dst_offsets_hint[0] - *src_offsets[0]) % alignment[0] == 0) {
CHECK_LT(*dst_offsets_hint[0], dst_vreg_slice[0]);
dst_offsets[0] = *dst_offsets_hint[0];
}
if (dst_offsets_hint[1].has_value() &&
(*dst_offsets_hint[1] - *src_offsets[1]) % alignment[1] == 0) {
CHECK_LT(*dst_offsets_hint[1], dst_vreg_slice[1]);
dst_offsets[1] = *dst_offsets_hint[1];
}
// The offsets of the source in units of the destination vreg slice:
const std::array<int64_t, 2> src_offsets_in_dst_vreg_slices = {
*src_offsets[0] / dst_vreg_slice[0], *src_offsets[1] / dst_vreg_slice[1]};
// The offsets of the destination in units of the source vreg slice:
const std::array<int64_t, 2> dst_offsets_in_src_vreg_slices = {
*dst_offsets[0] / src_vreg_slice[0], *dst_offsets[1] / src_vreg_slice[1]};
// Try to get i32 vector scratch space. Because we will bitcast vregs to
// i32 vregs before using scratch for retiling. Through this way we can
// handle packed types as well.
@ -6186,24 +6248,57 @@ LogicalResult retileWithScratch(RewriteContext &ctx, OpBuilder &builder,
dst_tiling[1]};
std::array<int64_t, 2> vi32_src_tiling = {src_tiling[0] / packing,
src_tiling[1]};
const VectorLayout dst(bitwidth, dst_offsets, dst_tiling, src.implicit_dim());
TPU_ASSERT_LOC(loc, dst.isValid(ctx.target_shape));
xla::Array<Value> dst_vregs(
dst.tileArrayImplicitShape(shape, ctx.target_shape));
// When differences in offsets exist, the source vregs may stored at an offset
// position in their group. For example, the 1st vreg in a row/column may be
// stored as if it was the 3rd, so that the parts corresponding to the 1st and
// 2nd in the destination are filled with padding. Likewise, loads to
// destination vregs may be skipped, when they would load only padding.
// store_vreg_delay is the position offset for stores, and load_vreg_skips is
// the position offset for loads.
//
// For example, suppose we are going from 32-bit {0, 128}(2, 128) to
// {4, 0}(8, 128). We form groups of 4 vregs that represent an (8, 512) slice
// of the padded implicit shape. For the given offsets, for the first group,
// the data is in (4:8, 128:512). But the first and second sources (stored
// vregs) of the group form the slices of data (0:2, 0:512) and (2:4, 0:512),
// which should be all padding. Likewise, the first dest vreg slice (which we
// load from) holds the data from slice (0:8, 0:128), which is all padding.
// We never load or store to slices that should contain only padding.
if (src_tiling[0] > dst_tiling[0]) {
return retileToSmallTileWithScratch(ctx, builder, loc, dst_tiles,
vi32_dst_tiling, src_tiles,
vi32_src_tiling, ref);
DCHECK_EQ(src_offsets_in_dst_vreg_slices[1], 0);
DCHECK_EQ(dst_offsets_in_src_vreg_slices[0], 0);
const int64_t store_vreg_delay = dst_offsets_in_src_vreg_slices[1];
const int64_t load_vreg_skips = src_offsets_in_dst_vreg_slices[0];
if (failed(retileToSmallTileWithScratch(
ctx, builder, loc, dst_vregs, vi32_dst_tiling, src_vregs,
vi32_src_tiling, ref, store_vreg_delay, load_vreg_skips))) {
return failure();
}
}
if (src_tiling[0] < dst_tiling[0]) {
return retileToLargeTileWithScratch(ctx, builder, loc, dst_tiles,
vi32_dst_tiling, src_tiles,
vi32_src_tiling, ref);
DCHECK_EQ(src_offsets_in_dst_vreg_slices[0], 0);
DCHECK_EQ(dst_offsets_in_src_vreg_slices[1], 0);
const int64_t store_vreg_delay = dst_offsets_in_src_vreg_slices[0];
const int64_t load_vreg_skips = src_offsets_in_dst_vreg_slices[1];
if (failed(retileToLargeTileWithScratch(
ctx, builder, loc, dst_vregs, vi32_dst_tiling, src_vregs,
vi32_src_tiling, ref, store_vreg_delay, load_vreg_skips))) {
return failure();
}
}
dst_tiles = std::move(src_tiles);
return success();
return std::make_pair(dst, dst_vregs);
}
FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
RewriteContext &ctx, OpBuilder &builder, const Location loc, VectorType vty,
const VectorLayout src, xla::Array<Value> vregs,
const std::array<int64_t, 2> dst_tiling, bool try_replicate_rows) {
const std::array<int64_t, 2> dst_tiling,
const LayoutOffsets dst_offsets_hint) {
bool has_enough_scratch = ctx.max_sublanes_in_scratch >=
ctx.target_shape[0] * (ctx.target_shape[0] + 1);
const auto &target_shape = ctx.target_shape;
@ -6219,6 +6314,12 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
const int8_t bitwidth = src.bitwidth();
const std::array<int64_t, 2> dst_vreg_slice =
VectorLayout::vregSlice(ctx.target_shape, bitwidth, dst_tiling);
// TODO(tlongeri): Using canonical vs non-canonical offsets can change the
// value of try_replicate rows, and it breaks some tests. It doesn't make
// sense that we have different behavior for equivalent layouts, though. We
// need better logic for picking the relayout strategy.
const bool try_replicate_rows =
src.offsets()[0].has_value() && !dst_offsets_hint[0].has_value();
// Fully replicated offsets are handled efficiently elsewhere (in relayout)
CHECK(src.offsets()[0].has_value() || src.offsets()[1].has_value());
@ -6290,15 +6391,10 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
});
return std::pair(dst, std::move(retiled));
}
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
src.implicit_dim());
if (!dst.isValid(target_shape)) {
return emitError(loc, "Not implemented: invalid offsets in tiling target");
}
auto dst_tiles_shape =
dst.tileArrayImplicitShape(vty.getShape(), target_shape);
// (8,128) -> (8 * packing,128) tiling change for packed type.
if (bitwidth < 32 && 32 % bitwidth == 0 && src_tiling == ctx.target_shape &&
if (src_offsets[0].value_or(0) < dst_vreg_slice[0] &&
src_offsets[1].value_or(0) < dst_vreg_slice[1] && bitwidth < 32 &&
32 % bitwidth == 0 && src_tiling == ctx.target_shape &&
dst_tiling == std::array<int64_t, 2>{ctx.target_shape[0] * packing,
ctx.target_shape[1]}) {
// TODO(tlongeri): This relayout is just ext + trunc. Refactor.
@ -6308,8 +6404,10 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
// not, since it relies on the src vreg array shape to know how many tiles
// to pack in dst, and vreg array shapes with materialized offsets are
// unfortunately not equal to vreg array shapes with replicated offsets.
CHECK(dst.offsets() == src_offsets);
xla::Array<Value> retiled(dst_tiles_shape);
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
src.implicit_dim());
xla::Array<Value> retiled(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
VectorType vreg_x32 =
vty.getElementType().isSignlessInteger()
? VectorType::get(target_shape, builder.getI32Type())
@ -6357,7 +6455,9 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
// interesting if the next step is a retile, since we can also
// match corresponding elements without shifting. It's just that
// the tiles are not adjacent (no contiguous vreg slice).
if (bitwidth < 32 && 32 % bitwidth == 0 &&
if (src_offsets[0].value_or(0) < dst_vreg_slice[0] &&
src_offsets[1].value_or(0) < dst_vreg_slice[1] && bitwidth < 32 &&
32 % bitwidth == 0 &&
src_tiling == std::array<int64_t, 2>{1, ctx.target_shape[1] * packing} &&
dst_tiling == std::array<int64_t, 2>{packing, ctx.target_shape[1]}) {
// TODO(tlongeri): This relayout is just ext + trunc. Refactor.
@ -6406,8 +6506,10 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
// not, since it relies on the src vreg array shape to know how many tiles
// to pack in dst, and vreg array shapes with materialized offsets are
// unfortunately not equal to vreg array shapes with replicated offsets.
CHECK(dst.offsets() == src.offsets());
xla::Array<Value> retiled(dst_tiles_shape);
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
src.implicit_dim());
xla::Array<Value> retiled(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
const VectorType vreg_x32 =
vty.getElementType().isSignlessInteger()
? VectorType::get(target_shape, builder.getI32Type())
@ -6444,24 +6546,25 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
return std::pair(dst, std::move(retiled));
}
if (src_tiling[1] == target_shape[1] && dst_tiling[1] == target_shape[1]) {
// TODO(b/368088671): When sublane tiling changes, we should be able to
// preserve some replications from the source layout. But we need to
// make sure they are implemented efficiently and well-tested. For now, we
// just simply use 0 for the replicated offset after retiling.
dst = VectorLayout(
bitwidth, {src.offsets()[0].value_or(0), src.offsets()[1].value_or(0)},
dst_tiling, dst.implicit_dim());
// All clauses in the and expression are based on performance benchmarking.
bool use_alu = !has_enough_scratch ||
(ctx.hardware_generation >= 5 && src_tiling[0] != packing &&
dst_tiling[0] != packing);
if (use_alu) {
if (src_tiling[0] > dst_tiling[0]) {
return std::pair(
dst, retileToReducedSublanes(builder, vty.getShape(), src, vregs,
dst, target_shape));
if (src_tiling[0] > dst_tiling[0] &&
// retileToReducedSublanes does not support offset changes
src.offsets()[0].value_or(0) < dst_vreg_slice[0] &&
src.offsets()[1].value_or(0) < dst_vreg_slice[1]) {
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
src.implicit_dim());
return std::pair(dst, retileToReducedSublanes(
builder, vty.getShape(), src, vregs,
VectorLayout(bitwidth,
{src.offsets()[0].value_or(0),
src.offsets()[1].value_or(0)},
dst_tiling, dst.implicit_dim()),
target_shape));
} else if (!has_enough_scratch) {
// TODO(b/357538782): Implement retileToIncreasedSublanes with ALU ops.
return emitError(
@ -6469,15 +6572,12 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
"Not implemented: retiling to increase sublane tiling with ALU");
}
}
xla::Array<Value> retiled(dst_tiles_shape);
if (failed(retileWithScratch(ctx, builder, loc, retiled, dst_tiling, vregs,
src_tiling, packing))) {
return failure();
}
return std::pair(dst, std::move(retiled));
return retileWithScratch(ctx, builder, loc, vty.getShape(), dst_tiling,
dst_offsets_hint, vregs, src);
}
return emitError(loc, "Not implemented: Unsupported tiling change for ")
<< vty << ": from " << src << " to " << dst;
<< vty << ": from " << src << " to (" << dst_tiling[0] << ", "
<< dst_tiling[1] << ") tiling";
}
FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeImplicitDim(
@ -6737,9 +6837,7 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
FAILUREOR_ASSIGN_OR_RETURN(
std::tie(src, src_tiles),
changeTiling(ctx, builder, v.getLoc(), vty, src, std::move(src_tiles),
dst.tiling(),
dst.offsets()[0] == std::nullopt &&
src.offsets()[0] != std::nullopt));
dst.tiling(), dst.offsets()));
FAILUREOR_ASSIGN_OR_RETURN(
std::tie(src, src_tiles),

View File

@ -2555,9 +2555,7 @@ class MiscellaneousTest(PallasBaseTest):
np.testing.assert_array_equal(out, np.reshape(x, (8, 128)))
@only_passes_in_interpret()
def test_retiling2(self):
"""b/348040767"""
x = np.arange(1 * 8 * 1024, dtype=jnp.bfloat16).reshape(1, 8, 1024)
def kernel(x_ref, out_ref):