[NFC][Mosaic TPU] Separate out retiling from relayout

PiperOrigin-RevId: 658335679
This commit is contained in:
Adam Paszke 2024-08-01 03:08:44 -07:00 committed by jax authors
parent 0734345279
commit 0307438c3d
3 changed files with 227 additions and 199 deletions

View File

@ -713,7 +713,8 @@ llvm::hash_code hash_value(const VectorLayout& layout) {
return llvm::hash_value(layout.as_tuple());
}
std::ostream &operator<<(std::ostream &os, VectorLayout::ImplicitDim dim) {
template <typename Stream>
Stream& printImplicitDim(Stream& os, VectorLayout::ImplicitDim dim) {
switch (dim) {
case VectorLayout::ImplicitDim::kNone:
os << "none";
@ -728,6 +729,15 @@ std::ostream &operator<<(std::ostream &os, VectorLayout::ImplicitDim dim) {
return os;
}
std::ostream& operator<<(std::ostream& os, VectorLayout::ImplicitDim dim) {
return printImplicitDim(os, dim);
}
mlir::Diagnostic& operator<<(mlir::Diagnostic& diag,
VectorLayout::ImplicitDim dim) {
return printImplicitDim(diag, dim);
}
std::optional<Layout> parseLayout(mlir::AsmParser& parser) {
std::string layout_str;
if (failed(parser.parseString(&layout_str))) {

View File

@ -513,6 +513,8 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Layout &v);
llvm::hash_code hash_value(const VectorLayout &layout);
mlir::Diagnostic &operator<<(mlir::Diagnostic &diag, const Layout &v);
std::ostream &operator<<(std::ostream &os, VectorLayout::ImplicitDim dim);
mlir::Diagnostic &operator<<(mlir::Diagnostic &diag,
VectorLayout::ImplicitDim dim);
std::optional<Layout> parseLayout(mlir::AsmParser &parser);

View File

@ -5166,6 +5166,212 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeOffsets(
return std::make_pair(dst, std::move(vregs));
}
// TODO(b/265133506): Generalize retiling.
FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
OpBuilder &builder, const std::array<int64_t, 2> target_shape,
const Location loc, VectorType vty, const VectorLayout src,
xla::Array<Value> vregs, const std::array<int64_t, 2> dst_tiling,
bool try_replicate_rows) {
if (src.tiling() == dst_tiling) {
return std::pair(src, std::move(vregs));
}
const int packing = src.packing();
const int8_t bitwidth = src.bitwidth();
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
src.implicit_dim());
if (!dst.isValid(target_shape)) {
return emitError(loc, "Not implemented: invalid offsets in tiling target");
}
// Handle retiling from (packing, 128) to (8 * packing, 128).
if (src.offsets() == LayoutOffsets{0, 0} &&
src.tiling() == std::array<int64_t, 2>{packing, 128} &&
dst_tiling == std::array<int64_t, 2>{8 * packing, 128}) {
bool replicate_sublanes = try_replicate_rows && packing == 1 &&
*(vregs.dimensions().end() - 2) == 1;
xla::Array<Value> retiled(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
*(src_idx.end() - 2) *= target_shape[0];
*(src_idx.end() - 1) /= target_shape[0];
const int64_t src_sl_idx = *(idx.end() - 1) % target_shape[0];
if (replicate_sublanes) {
CHECK_EQ(src.getImplicitTiledDims(vty.getShape(), 1)[0], 1);
*tile =
broadcastSublane(builder, vregs(src_idx), src_sl_idx, target_shape);
} else {
for (int dst_sl_idx = 0;
dst_sl_idx < target_shape[0] &&
*(src_idx.end() - 2) < *(vregs.dimensions().end() - 2);
++dst_sl_idx, ++*(src_idx.end() - 2)) {
*tile = copy_one_sublane(builder, vregs(src_idx), src_sl_idx, *tile,
dst_sl_idx, target_shape);
}
}
});
// We have successfully replicated sublanes.
if (replicate_sublanes) {
dst = VectorLayout(bitwidth, {std::nullopt, dst.offsets()[1]}, dst_tiling,
dst.implicit_dim());
}
return std::pair(dst, std::move(retiled));
}
// Handle retiling from (m, 128) to (8, 128) for 32-bit data
// where m < 8 and m is a power of 2.
// TODO(b/306692696): Handle any vregs.dimensions().
if (bitwidth == 32 && src.offsets() == LayoutOffsets{0, 0} &&
target_shape[0] % src.tiling()[0] == 0 &&
src.tiling()[1] == target_shape[1] && dst.tiling() == target_shape &&
*(vregs.dimensions().end() - 2) == 1) {
xla::Array<Value> retiled(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
retiled.Each([&](const absl::Span<const int64_t> idx,
Value *const new_src_tile) {
const int64_t tiles_per_vreg = src.tilesPerVreg(target_shape);
const int64_t dst_col = idx.back();
const int64_t src_col = dst_col / tiles_per_vreg;
const int64_t start_slane_idx =
src.tiling()[0] * (dst_col % tiles_per_vreg);
SmallVector<int64_t> src_idx(toArrayRef(idx));
src_idx.back() = src_col;
Value src_tile = vregs(src_idx);
if (start_slane_idx) {
SmallVector<int32_t> slane_idxs;
slane_idxs.reserve(target_shape[0]);
for (int i = 0; i < target_shape[0]; ++i) {
slane_idxs.push_back(start_slane_idx + (i % src.tiling()[0]));
}
const DenseI32ArrayAttr gather_indices =
builder.getDenseI32ArrayAttr(slane_idxs);
*new_src_tile = builder.create<tpu::GatherOp>(loc, src_tile.getType(),
src_tile, gather_indices,
/*dimension=*/0);
} else {
*new_src_tile = src_tile;
}
});
return std::pair(dst, std::move(retiled));
}
// (8,128) -> (8 * packing,128) tiling change for packed type.
if (bitwidth < 32 && 32 % bitwidth == 0 &&
src.tiling() == std::array<int64_t, 2>{8, 128} &&
dst.tiling() == std::array<int64_t, 2>{8 * dst.packing(), 128}) {
xla::Array<Value> retiled(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
int vty_packing = dst.packing();
VectorType vreg_x32 =
vty.getElementType().isSignlessInteger()
? VectorType::get(target_shape, builder.getI32Type())
: VectorType::get(target_shape, builder.getF32Type());
retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
const int vreg_part = idx.back() % vty_packing;
SmallVector<Value, 8> parts;
parts.reserve(vty_packing);
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
src_idx[src_idx.size() - 2] *= vty_packing;
src_idx[src_idx.size() - 1] /= vty_packing;
for (int i = 0; i < vty_packing; ++i) {
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_x32, vregs(src_idx), vreg_part));
if (src_idx[src_idx.size() - 2] <
vregs.dim(vregs.num_dimensions() - 2) - 1) {
++src_idx[src_idx.size() - 2];
}
}
*tile = builder.create<tpu::PackSubelementsOp>(
loc, vregs.begin()->getType(), parts, tpu::PackFormat::kCompressed);
});
return std::pair(dst, std::move(retiled));
}
// Handle retiling from (1, 128 * packing) to (packing, 128) for
// packed data.
// We do compressed unpacking followed by interleaved packing.
// TODO(tlongeri): This can be used as a first step before using
// a generalized retiling where we only move sublanes around
// (without packing/unpacking).
// TODO(tlongeri): Interleaved unpacking followed by interleaved
// packing (but with different pairings) might also be
// interesting if the next step is a retile, since we can also
// match corresponding elements without shifting. It's just that
// the tiles are not adjacent (no contiguous vreg slice).
if (bitwidth < 32 && 32 % bitwidth == 0 &&
src.tiling() == std::array<int64_t, 2>{1, 128 * packing} &&
dst.tiling() == std::array<int64_t, 2>{packing, 128}) {
// To illustrate, consider a 2 x 16 16-bit shape laid out in vregs of
// 4 sublanes and 2 lanes (this is convenient for to keep the example small
// yet non-trivial) with (1, 4) tiling. We will relayout to (2, 2) tiling.
//
// The vreg slice is 1 x 16, that is, the vreg contains the data for a
// 1 x 16 window of the logical shape.
//
// [a b c d e f g h i j k l m n o p] -> vreg 1
// [A B C D E F G H I J K L M N O P] -> vreg 2
//
// Note: we support multiple vregs per row of the logical shape, but we use
// one here just to keep the example small.
//
// When we do a compressed unpack, the resulting vregs effectively have a
// tiling of (1, 2) and cover a vreg slice of 1 x 8 logical elements.
//
// [a b c d e f g h] -> vreg 1, part 1 [i j k l m n o p] -> vreg 1, part 2
// [A B C D E F G H] -> vreg 2, part 1 [I J K L M N O P] -> vreg 2, part 2
//
// It is clear that if combine vreg 1, part 1 and vreg 2, part 1 we get data
// that covers a 2 x 8 vreg slice. Note, however, that we will have to mind
// the internal ordering of the vreg.
//
// [a b c d e f g h [i j k l m n o p
// A B C D E F G H] -> new vreg 1 I J K L M N O P] -> new vreg 2
//
// To see if we can get the right internal ordering that we need for (2, 2)
// tiling, let's break new vreg 1 into (1, 2) rows, which correspond to
// sublanes when unpacked and half-sublanes when packed.
//
// [(a b) (c d) (e f) (g h)
// (A B) (C D) (E F) (G H)]
//
// The sublane order for the vreg parts is [(a b) (c d) ...] for vreg 1,
// part 1 and [(A B) (C D) ...] for vreg 2, part 1.
//
// The desired half-sublane order, for packed (2, 2) tiling, is
// [(a b) (A B) (c d) (C D) ...]. That is, traverse down each column before
// moving to the next one. This is exactly an interleaving of the sublanes
// of the vreg parts.
xla::Array<Value> retiled(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
const VectorType vreg_x32 =
vty.getElementType().isSignlessInteger()
? VectorType::get(target_shape, builder.getI32Type())
: VectorType::get(target_shape, builder.getF32Type());
retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
SmallVector<Value> parts;
parts.reserve(packing);
SmallVector<int64_t> src_idx(toArrayRef(idx));
*(src_idx.end() - 2) *= packing;
const int64_t vreg_part = *(src_idx.end() - 1) % packing;
*(src_idx.end() - 1) /= packing;
for (int i = 0; i < packing; ++i) {
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_x32, vregs(src_idx), vreg_part));
if (*(src_idx.end() - 2) < *(vregs.dimensions().end() - 2) - 1) {
++*(src_idx.end() - 2);
} // The rest is padding, so just pick any of the input parts (but not
// an arbitrary vreg so we don't add an extra dependency).
}
*tile = builder.create<tpu::PackSubelementsOp>(
loc, vregs.begin()->getType(), parts, tpu::PackFormat::kInterleaved);
});
return std::pair(dst, std::move(retiled));
}
if (isSupportedReducedSublanesRetile(src, dst, target_shape)) {
return std::pair(dst, retileToReducedSublanes(builder, vty.getShape(), src,
vregs, dst, target_shape));
}
return emitError(loc, "Not implemented: Unsupported tiling change for ")
<< vty << ": from " << src << " to tiling (" << dst_tiling[0] << ", "
<< dst_tiling[1] << ")";
}
// TODO(apaszke): Test this function properly
FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
OpBuilder &builder,
@ -5174,7 +5380,6 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
VectorLayout dst) {
const auto target_shape = ctx.target_shape;
const int8_t bitwidth = src.bitwidth();
const int packing = src.packing();
if (bitwidth != dst.bitwidth()) {
return emitError(v.getLoc(), "Can't change bitwidth during a relayout");
}
@ -5237,6 +5442,7 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> src_tiles,
disassemble(builder, src, v, target_shape, /*use_implicit_shape=*/true));
// Two easy cases: layouts are equivalent, or the source is replicated.
if (src.generalizes(dst, vty.getShape(), target_shape)) {
return assemble(builder, vty, dst, std::move(src_tiles), target_shape,
/*use_implicit_shape=*/true)
@ -5255,200 +5461,12 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
.getResult();
}
// Handle retiling from (packing, 128) to (8 * packing, 128).
if (src.implicit_dim() == dst.implicit_dim() &&
src.offsets() == LayoutOffsets{0, 0} &&
(dst.offsets()[0] == 0 ||
(packing == 1 && dst.offsets()[0] == std::nullopt &&
*(src_tiles.dimensions().end() - 2) == 1)) &&
dst.offsets()[1] == 0 &&
src.tiling() == std::array<int64_t, 2>{packing, 128} &&
dst.tiling() == std::array<int64_t, 2>{8 * packing, 128}) {
xla::Array<Value> src_tiles_retiled(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
src_tiles_retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
*(src_idx.end() - 2) *= target_shape[0];
*(src_idx.end() - 1) /= target_shape[0];
const int64_t src_sl_idx = *(idx.end() - 1) % target_shape[0];
if (!dst.offsets()[0]) {
CHECK_EQ(src.getImplicitTiledDims(vty.getShape(), 1)[0], 1);
*tile = broadcastSublane(builder, src_tiles(src_idx), src_sl_idx,
target_shape);
} else {
for (int dst_sl_idx = 0;
dst_sl_idx < target_shape[0] &&
*(src_idx.end() - 2) < *(src_tiles.dimensions().end() - 2);
++dst_sl_idx, ++*(src_idx.end() - 2)) {
*tile = copy_one_sublane(builder, src_tiles(src_idx), src_sl_idx,
*tile, dst_sl_idx, target_shape);
}
}
});
src = dst;
src_tiles = std::move(src_tiles_retiled);
} else if ( // Handle retiling from (m, 128) to (8, 128) for 32-bit data
// where m < 8 and m is a power of 2.
// TODO(b/306692696) Generalize relayout from tiling (m, 128) to
// (8, 128) for any src_tiles.dimensions().
src.implicit_dim() == dst.implicit_dim() && src.bitwidth() == 32 &&
src.offsets() == LayoutOffsets{0, 0} &&
dst.offsets() == LayoutOffsets{0, 0} &&
target_shape[0] % src.tiling()[0] == 0 &&
src.tiling()[1] == target_shape[1] && dst.tiling() == target_shape &&
*(src_tiles.dimensions().end() - 2) == 1) {
xla::Array<Value> src_tiles_retiled(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
src_tiles_retiled.Each(
[&](const absl::Span<const int64_t> idx, Value *const new_src_tile) {
const int64_t tiles_per_vreg = src.tilesPerVreg(target_shape);
const int64_t dst_col = idx.back();
const int64_t src_col = dst_col / tiles_per_vreg;
const int64_t start_slane_idx =
src.tiling()[0] * (dst_col % tiles_per_vreg);
SmallVector<int64_t> src_idx(toArrayRef(idx));
src_idx.back() = src_col;
Value src_tile = src_tiles(src_idx);
if (start_slane_idx) {
SmallVector<int32_t> slane_idxs;
slane_idxs.reserve(target_shape[0]);
for (int i = 0; i < target_shape[0]; ++i) {
slane_idxs.push_back(start_slane_idx + (i % src.tiling()[0]));
}
const DenseI32ArrayAttr gather_indices =
builder.getDenseI32ArrayAttr(slane_idxs);
*new_src_tile = builder.create<tpu::GatherOp>(
v.getLoc(), src_tile.getType(), src_tile, gather_indices,
/*dimension=*/0);
} else {
*new_src_tile = src_tile;
}
});
src = dst;
src_tiles = std::move(src_tiles_retiled);
} else if ( // TODO(b/265133506): Generalize retiling.
// (8,128) -> (8 * packing,128) tiling change for packed type.
src.implicit_dim() == dst.implicit_dim() && bitwidth < 32 &&
32 % bitwidth == 0 && src.offsets() == dst.offsets() &&
src.tiling() == std::array<int64_t, 2>{8, 128} &&
dst.tiling() == std::array<int64_t, 2>{8 * dst.packing(), 128}) {
xla::Array<Value> src_tiles_retiled(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
int vty_packing = dst.packing();
VectorType vreg_x32 =
vty.getElementType().isSignlessInteger()
? VectorType::get(target_shape, builder.getI32Type())
: VectorType::get(target_shape, builder.getF32Type());
src_tiles_retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
const int vreg_part = idx.back() % vty_packing;
SmallVector<Value, 8> parts;
parts.reserve(vty_packing);
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
src_idx[src_idx.size() - 2] *= vty_packing;
src_idx[src_idx.size() - 1] /= vty_packing;
for (int i = 0; i < vty_packing; ++i) {
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
v.getLoc(), vreg_x32, src_tiles(src_idx), vreg_part));
if (src_idx[src_idx.size() - 2] <
src_tiles.dim(src_tiles.num_dimensions() - 2) - 1) {
++src_idx[src_idx.size() - 2];
}
}
*tile = builder.create<tpu::PackSubelementsOp>(
v.getLoc(), src_tiles.begin()->getType(), parts,
tpu::PackFormat::kCompressed);
});
src = dst;
src_tiles = std::move(src_tiles_retiled);
} else if ( // Handle retiling from (1, 128 * packing) to (packing, 128) for
// packed data.
// We do compressed unpacking followed by interleaved packing.
// TODO(tlongeri): This can be used as a first step before using
// a generalized retiling where we only move sublanes around
// (without packing/unpacking).
// TODO(tlongeri): Interleaved unpacking followed by interleaved
// packing (but with different pairings) might also be
// interesting if the next step is a retile, since we can also
// match corresponding elements without shifting. It's just that
// the tiles are not adjacent (no contiguous vreg slice).
src.implicit_dim() == dst.implicit_dim() && bitwidth < 32 &&
32 % bitwidth == 0 && src.offsets() == dst.offsets() &&
src.tiling() == std::array<int64_t, 2>{1, 128 * packing} &&
dst.tiling() == std::array<int64_t, 2>{packing, 128}) {
// To illustrate, consider a 2 x 16 16-bit shape laid out in vregs of
// 4 sublanes and 2 lanes (this is convenient for to keep the example small
// yet non-trivial) with (1, 4) tiling. We will relayout to (2, 2) tiling.
//
// The vreg slice is 1 x 16, that is, the vreg contains the data for a
// 1 x 16 window of the logical shape.
//
// [a b c d e f g h i j k l m n o p] -> vreg 1
// [A B C D E F G H I J K L M N O P] -> vreg 2
//
// Note: we support multiple vregs per row of the logical shape, but we use
// one here just to keep the example small.
//
// When we do a compressed unpack, the resulting vregs effectively have a
// tiling of (1, 2) and cover a vreg slice of 1 x 8 logical elements.
//
// [a b c d e f g h] -> vreg 1, part 1 [i j k l m n o p] -> vreg 1, part 2
// [A B C D E F G H] -> vreg 2, part 1 [I J K L M N O P] -> vreg 2, part 2
//
// It is clear that if combine vreg 1, part 1 and vreg 2, part 1 we get data
// that covers a 2 x 8 vreg slice. Note, however, that we will have to mind
// the internal ordering of the vreg.
//
// [a b c d e f g h [i j k l m n o p
// A B C D E F G H] -> new vreg 1 I J K L M N O P] -> new vreg 2
//
// To see if we can get the right internal ordering that we need for (2, 2)
// tiling, let's break new vreg 1 into (1, 2) rows, which correspond to
// sublanes when unpacked and half-sublanes when packed.
//
// [(a b) (c d) (e f) (g h)
// (A B) (C D) (E F) (G H)]
//
// The sublane order for the vreg parts is [(a b) (c d) ...] for vreg 1,
// part 1 and [(A B) (C D) ...] for vreg 2, part 1.
//
// The desired half-sublane order, for packed (2, 2) tiling, is
// [(a b) (A B) (c d) (C D) ...]. That is, traverse down each column before
// moving to the next one. This is exactly an interleaving of the sublanes
// of the vreg parts.
xla::Array<Value> src_tiles_retiled(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
const VectorType vreg_x32 =
vty.getElementType().isSignlessInteger()
? VectorType::get(target_shape, builder.getI32Type())
: VectorType::get(target_shape, builder.getF32Type());
src_tiles_retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
SmallVector<Value> parts;
parts.reserve(packing);
SmallVector<int64_t> src_idx(toArrayRef(idx));
*(src_idx.end() - 2) *= packing;
const int64_t vreg_part = *(src_idx.end() - 1) % packing;
*(src_idx.end() - 1) /= packing;
for (int i = 0; i < packing; ++i) {
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
v.getLoc(), vreg_x32, src_tiles(src_idx), vreg_part));
if (*(src_idx.end() - 2) < *(src_tiles.dimensions().end() - 2) - 1) {
++*(src_idx.end() - 2);
} // The rest is padding, so just pick any of the input parts (but not
// an arbitrary vreg so we don't add an extra dependency).
}
*tile = builder.create<tpu::PackSubelementsOp>(
v.getLoc(), src_tiles.begin()->getType(), parts,
tpu::PackFormat::kInterleaved);
});
src = dst;
src_tiles = std::move(src_tiles_retiled);
} else if (isSupportedReducedSublanesRetile(src, dst, target_shape)) {
src_tiles = retileToReducedSublanes(builder, vty.getShape(), src, src_tiles,
dst, target_shape);
src = dst;
} else if (src.tiling() != dst.tiling()) {
return not_implemented();
}
FAILUREOR_ASSIGN_OR_RETURN(
std::tie(src, src_tiles),
changeTiling(builder, ctx.target_shape, v.getLoc(), vty, src,
std::move(src_tiles), dst.tiling(),
dst.offsets()[0] == std::nullopt &&
src.offsets()[0] != std::nullopt));
// Remove second minor implicit dim, for values that have (8, 128) tiling.
if (src.implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor &&
@ -5478,9 +5496,7 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
});
src = dst;
src_tiles = std::move(src_tiles_retiled);
}
if (src.implicit_dim() != dst.implicit_dim()) {
} else if (src.implicit_dim() != dst.implicit_dim()) {
return not_implemented();
}