mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[NFC][Mosaic TPU] Separate out retiling from relayout
PiperOrigin-RevId: 658335679
This commit is contained in:
parent
0734345279
commit
0307438c3d
@ -713,7 +713,8 @@ llvm::hash_code hash_value(const VectorLayout& layout) {
|
||||
return llvm::hash_value(layout.as_tuple());
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, VectorLayout::ImplicitDim dim) {
|
||||
template <typename Stream>
|
||||
Stream& printImplicitDim(Stream& os, VectorLayout::ImplicitDim dim) {
|
||||
switch (dim) {
|
||||
case VectorLayout::ImplicitDim::kNone:
|
||||
os << "none";
|
||||
@ -728,6 +729,15 @@ std::ostream &operator<<(std::ostream &os, VectorLayout::ImplicitDim dim) {
|
||||
return os;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, VectorLayout::ImplicitDim dim) {
|
||||
return printImplicitDim(os, dim);
|
||||
}
|
||||
|
||||
mlir::Diagnostic& operator<<(mlir::Diagnostic& diag,
|
||||
VectorLayout::ImplicitDim dim) {
|
||||
return printImplicitDim(diag, dim);
|
||||
}
|
||||
|
||||
std::optional<Layout> parseLayout(mlir::AsmParser& parser) {
|
||||
std::string layout_str;
|
||||
if (failed(parser.parseString(&layout_str))) {
|
||||
|
@ -513,6 +513,8 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Layout &v);
|
||||
llvm::hash_code hash_value(const VectorLayout &layout);
|
||||
mlir::Diagnostic &operator<<(mlir::Diagnostic &diag, const Layout &v);
|
||||
std::ostream &operator<<(std::ostream &os, VectorLayout::ImplicitDim dim);
|
||||
mlir::Diagnostic &operator<<(mlir::Diagnostic &diag,
|
||||
VectorLayout::ImplicitDim dim);
|
||||
|
||||
std::optional<Layout> parseLayout(mlir::AsmParser &parser);
|
||||
|
||||
|
@ -5166,6 +5166,212 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeOffsets(
|
||||
return std::make_pair(dst, std::move(vregs));
|
||||
}
|
||||
|
||||
// TODO(b/265133506): Generalize retiling.
|
||||
FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
OpBuilder &builder, const std::array<int64_t, 2> target_shape,
|
||||
const Location loc, VectorType vty, const VectorLayout src,
|
||||
xla::Array<Value> vregs, const std::array<int64_t, 2> dst_tiling,
|
||||
bool try_replicate_rows) {
|
||||
if (src.tiling() == dst_tiling) {
|
||||
return std::pair(src, std::move(vregs));
|
||||
}
|
||||
const int packing = src.packing();
|
||||
const int8_t bitwidth = src.bitwidth();
|
||||
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
|
||||
src.implicit_dim());
|
||||
if (!dst.isValid(target_shape)) {
|
||||
return emitError(loc, "Not implemented: invalid offsets in tiling target");
|
||||
}
|
||||
// Handle retiling from (packing, 128) to (8 * packing, 128).
|
||||
if (src.offsets() == LayoutOffsets{0, 0} &&
|
||||
src.tiling() == std::array<int64_t, 2>{packing, 128} &&
|
||||
dst_tiling == std::array<int64_t, 2>{8 * packing, 128}) {
|
||||
bool replicate_sublanes = try_replicate_rows && packing == 1 &&
|
||||
*(vregs.dimensions().end() - 2) == 1;
|
||||
xla::Array<Value> retiled(
|
||||
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
|
||||
retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
|
||||
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
|
||||
*(src_idx.end() - 2) *= target_shape[0];
|
||||
*(src_idx.end() - 1) /= target_shape[0];
|
||||
const int64_t src_sl_idx = *(idx.end() - 1) % target_shape[0];
|
||||
if (replicate_sublanes) {
|
||||
CHECK_EQ(src.getImplicitTiledDims(vty.getShape(), 1)[0], 1);
|
||||
*tile =
|
||||
broadcastSublane(builder, vregs(src_idx), src_sl_idx, target_shape);
|
||||
} else {
|
||||
for (int dst_sl_idx = 0;
|
||||
dst_sl_idx < target_shape[0] &&
|
||||
*(src_idx.end() - 2) < *(vregs.dimensions().end() - 2);
|
||||
++dst_sl_idx, ++*(src_idx.end() - 2)) {
|
||||
*tile = copy_one_sublane(builder, vregs(src_idx), src_sl_idx, *tile,
|
||||
dst_sl_idx, target_shape);
|
||||
}
|
||||
}
|
||||
});
|
||||
// We have successfully replicated sublanes.
|
||||
if (replicate_sublanes) {
|
||||
dst = VectorLayout(bitwidth, {std::nullopt, dst.offsets()[1]}, dst_tiling,
|
||||
dst.implicit_dim());
|
||||
}
|
||||
return std::pair(dst, std::move(retiled));
|
||||
}
|
||||
// Handle retiling from (m, 128) to (8, 128) for 32-bit data
|
||||
// where m < 8 and m is a power of 2.
|
||||
// TODO(b/306692696): Handle any vregs.dimensions().
|
||||
if (bitwidth == 32 && src.offsets() == LayoutOffsets{0, 0} &&
|
||||
target_shape[0] % src.tiling()[0] == 0 &&
|
||||
src.tiling()[1] == target_shape[1] && dst.tiling() == target_shape &&
|
||||
*(vregs.dimensions().end() - 2) == 1) {
|
||||
xla::Array<Value> retiled(
|
||||
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
|
||||
retiled.Each([&](const absl::Span<const int64_t> idx,
|
||||
Value *const new_src_tile) {
|
||||
const int64_t tiles_per_vreg = src.tilesPerVreg(target_shape);
|
||||
const int64_t dst_col = idx.back();
|
||||
const int64_t src_col = dst_col / tiles_per_vreg;
|
||||
const int64_t start_slane_idx =
|
||||
src.tiling()[0] * (dst_col % tiles_per_vreg);
|
||||
SmallVector<int64_t> src_idx(toArrayRef(idx));
|
||||
src_idx.back() = src_col;
|
||||
Value src_tile = vregs(src_idx);
|
||||
if (start_slane_idx) {
|
||||
SmallVector<int32_t> slane_idxs;
|
||||
slane_idxs.reserve(target_shape[0]);
|
||||
for (int i = 0; i < target_shape[0]; ++i) {
|
||||
slane_idxs.push_back(start_slane_idx + (i % src.tiling()[0]));
|
||||
}
|
||||
const DenseI32ArrayAttr gather_indices =
|
||||
builder.getDenseI32ArrayAttr(slane_idxs);
|
||||
*new_src_tile = builder.create<tpu::GatherOp>(loc, src_tile.getType(),
|
||||
src_tile, gather_indices,
|
||||
/*dimension=*/0);
|
||||
} else {
|
||||
*new_src_tile = src_tile;
|
||||
}
|
||||
});
|
||||
return std::pair(dst, std::move(retiled));
|
||||
}
|
||||
// (8,128) -> (8 * packing,128) tiling change for packed type.
|
||||
if (bitwidth < 32 && 32 % bitwidth == 0 &&
|
||||
src.tiling() == std::array<int64_t, 2>{8, 128} &&
|
||||
dst.tiling() == std::array<int64_t, 2>{8 * dst.packing(), 128}) {
|
||||
xla::Array<Value> retiled(
|
||||
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
|
||||
int vty_packing = dst.packing();
|
||||
VectorType vreg_x32 =
|
||||
vty.getElementType().isSignlessInteger()
|
||||
? VectorType::get(target_shape, builder.getI32Type())
|
||||
: VectorType::get(target_shape, builder.getF32Type());
|
||||
retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
|
||||
const int vreg_part = idx.back() % vty_packing;
|
||||
SmallVector<Value, 8> parts;
|
||||
parts.reserve(vty_packing);
|
||||
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
|
||||
src_idx[src_idx.size() - 2] *= vty_packing;
|
||||
src_idx[src_idx.size() - 1] /= vty_packing;
|
||||
for (int i = 0; i < vty_packing; ++i) {
|
||||
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
|
||||
loc, vreg_x32, vregs(src_idx), vreg_part));
|
||||
if (src_idx[src_idx.size() - 2] <
|
||||
vregs.dim(vregs.num_dimensions() - 2) - 1) {
|
||||
++src_idx[src_idx.size() - 2];
|
||||
}
|
||||
}
|
||||
*tile = builder.create<tpu::PackSubelementsOp>(
|
||||
loc, vregs.begin()->getType(), parts, tpu::PackFormat::kCompressed);
|
||||
});
|
||||
return std::pair(dst, std::move(retiled));
|
||||
}
|
||||
// Handle retiling from (1, 128 * packing) to (packing, 128) for
|
||||
// packed data.
|
||||
// We do compressed unpacking followed by interleaved packing.
|
||||
// TODO(tlongeri): This can be used as a first step before using
|
||||
// a generalized retiling where we only move sublanes around
|
||||
// (without packing/unpacking).
|
||||
// TODO(tlongeri): Interleaved unpacking followed by interleaved
|
||||
// packing (but with different pairings) might also be
|
||||
// interesting if the next step is a retile, since we can also
|
||||
// match corresponding elements without shifting. It's just that
|
||||
// the tiles are not adjacent (no contiguous vreg slice).
|
||||
if (bitwidth < 32 && 32 % bitwidth == 0 &&
|
||||
src.tiling() == std::array<int64_t, 2>{1, 128 * packing} &&
|
||||
dst.tiling() == std::array<int64_t, 2>{packing, 128}) {
|
||||
// To illustrate, consider a 2 x 16 16-bit shape laid out in vregs of
|
||||
// 4 sublanes and 2 lanes (this is convenient for to keep the example small
|
||||
// yet non-trivial) with (1, 4) tiling. We will relayout to (2, 2) tiling.
|
||||
//
|
||||
// The vreg slice is 1 x 16, that is, the vreg contains the data for a
|
||||
// 1 x 16 window of the logical shape.
|
||||
//
|
||||
// [a b c d e f g h i j k l m n o p] -> vreg 1
|
||||
// [A B C D E F G H I J K L M N O P] -> vreg 2
|
||||
//
|
||||
// Note: we support multiple vregs per row of the logical shape, but we use
|
||||
// one here just to keep the example small.
|
||||
//
|
||||
// When we do a compressed unpack, the resulting vregs effectively have a
|
||||
// tiling of (1, 2) and cover a vreg slice of 1 x 8 logical elements.
|
||||
//
|
||||
// [a b c d e f g h] -> vreg 1, part 1 [i j k l m n o p] -> vreg 1, part 2
|
||||
// [A B C D E F G H] -> vreg 2, part 1 [I J K L M N O P] -> vreg 2, part 2
|
||||
//
|
||||
// It is clear that if combine vreg 1, part 1 and vreg 2, part 1 we get data
|
||||
// that covers a 2 x 8 vreg slice. Note, however, that we will have to mind
|
||||
// the internal ordering of the vreg.
|
||||
//
|
||||
// [a b c d e f g h [i j k l m n o p
|
||||
// A B C D E F G H] -> new vreg 1 I J K L M N O P] -> new vreg 2
|
||||
//
|
||||
// To see if we can get the right internal ordering that we need for (2, 2)
|
||||
// tiling, let's break new vreg 1 into (1, 2) rows, which correspond to
|
||||
// sublanes when unpacked and half-sublanes when packed.
|
||||
//
|
||||
// [(a b) (c d) (e f) (g h)
|
||||
// (A B) (C D) (E F) (G H)]
|
||||
//
|
||||
// The sublane order for the vreg parts is [(a b) (c d) ...] for vreg 1,
|
||||
// part 1 and [(A B) (C D) ...] for vreg 2, part 1.
|
||||
//
|
||||
// The desired half-sublane order, for packed (2, 2) tiling, is
|
||||
// [(a b) (A B) (c d) (C D) ...]. That is, traverse down each column before
|
||||
// moving to the next one. This is exactly an interleaving of the sublanes
|
||||
// of the vreg parts.
|
||||
xla::Array<Value> retiled(
|
||||
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
|
||||
const VectorType vreg_x32 =
|
||||
vty.getElementType().isSignlessInteger()
|
||||
? VectorType::get(target_shape, builder.getI32Type())
|
||||
: VectorType::get(target_shape, builder.getF32Type());
|
||||
retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
|
||||
SmallVector<Value> parts;
|
||||
parts.reserve(packing);
|
||||
SmallVector<int64_t> src_idx(toArrayRef(idx));
|
||||
*(src_idx.end() - 2) *= packing;
|
||||
const int64_t vreg_part = *(src_idx.end() - 1) % packing;
|
||||
*(src_idx.end() - 1) /= packing;
|
||||
for (int i = 0; i < packing; ++i) {
|
||||
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
|
||||
loc, vreg_x32, vregs(src_idx), vreg_part));
|
||||
if (*(src_idx.end() - 2) < *(vregs.dimensions().end() - 2) - 1) {
|
||||
++*(src_idx.end() - 2);
|
||||
} // The rest is padding, so just pick any of the input parts (but not
|
||||
// an arbitrary vreg so we don't add an extra dependency).
|
||||
}
|
||||
*tile = builder.create<tpu::PackSubelementsOp>(
|
||||
loc, vregs.begin()->getType(), parts, tpu::PackFormat::kInterleaved);
|
||||
});
|
||||
return std::pair(dst, std::move(retiled));
|
||||
}
|
||||
if (isSupportedReducedSublanesRetile(src, dst, target_shape)) {
|
||||
return std::pair(dst, retileToReducedSublanes(builder, vty.getShape(), src,
|
||||
vregs, dst, target_shape));
|
||||
}
|
||||
return emitError(loc, "Not implemented: Unsupported tiling change for ")
|
||||
<< vty << ": from " << src << " to tiling (" << dst_tiling[0] << ", "
|
||||
<< dst_tiling[1] << ")";
|
||||
}
|
||||
|
||||
// TODO(apaszke): Test this function properly
|
||||
FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
|
||||
OpBuilder &builder,
|
||||
@ -5174,7 +5380,6 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
|
||||
VectorLayout dst) {
|
||||
const auto target_shape = ctx.target_shape;
|
||||
const int8_t bitwidth = src.bitwidth();
|
||||
const int packing = src.packing();
|
||||
if (bitwidth != dst.bitwidth()) {
|
||||
return emitError(v.getLoc(), "Can't change bitwidth during a relayout");
|
||||
}
|
||||
@ -5237,6 +5442,7 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
xla::Array<Value> src_tiles,
|
||||
disassemble(builder, src, v, target_shape, /*use_implicit_shape=*/true));
|
||||
// Two easy cases: layouts are equivalent, or the source is replicated.
|
||||
if (src.generalizes(dst, vty.getShape(), target_shape)) {
|
||||
return assemble(builder, vty, dst, std::move(src_tiles), target_shape,
|
||||
/*use_implicit_shape=*/true)
|
||||
@ -5255,200 +5461,12 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
|
||||
.getResult();
|
||||
}
|
||||
|
||||
// Handle retiling from (packing, 128) to (8 * packing, 128).
|
||||
if (src.implicit_dim() == dst.implicit_dim() &&
|
||||
src.offsets() == LayoutOffsets{0, 0} &&
|
||||
(dst.offsets()[0] == 0 ||
|
||||
(packing == 1 && dst.offsets()[0] == std::nullopt &&
|
||||
*(src_tiles.dimensions().end() - 2) == 1)) &&
|
||||
dst.offsets()[1] == 0 &&
|
||||
src.tiling() == std::array<int64_t, 2>{packing, 128} &&
|
||||
dst.tiling() == std::array<int64_t, 2>{8 * packing, 128}) {
|
||||
xla::Array<Value> src_tiles_retiled(
|
||||
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
|
||||
src_tiles_retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
|
||||
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
|
||||
*(src_idx.end() - 2) *= target_shape[0];
|
||||
*(src_idx.end() - 1) /= target_shape[0];
|
||||
const int64_t src_sl_idx = *(idx.end() - 1) % target_shape[0];
|
||||
if (!dst.offsets()[0]) {
|
||||
CHECK_EQ(src.getImplicitTiledDims(vty.getShape(), 1)[0], 1);
|
||||
*tile = broadcastSublane(builder, src_tiles(src_idx), src_sl_idx,
|
||||
target_shape);
|
||||
} else {
|
||||
for (int dst_sl_idx = 0;
|
||||
dst_sl_idx < target_shape[0] &&
|
||||
*(src_idx.end() - 2) < *(src_tiles.dimensions().end() - 2);
|
||||
++dst_sl_idx, ++*(src_idx.end() - 2)) {
|
||||
*tile = copy_one_sublane(builder, src_tiles(src_idx), src_sl_idx,
|
||||
*tile, dst_sl_idx, target_shape);
|
||||
}
|
||||
}
|
||||
});
|
||||
src = dst;
|
||||
src_tiles = std::move(src_tiles_retiled);
|
||||
} else if ( // Handle retiling from (m, 128) to (8, 128) for 32-bit data
|
||||
// where m < 8 and m is a power of 2.
|
||||
// TODO(b/306692696) Generalize relayout from tiling (m, 128) to
|
||||
// (8, 128) for any src_tiles.dimensions().
|
||||
src.implicit_dim() == dst.implicit_dim() && src.bitwidth() == 32 &&
|
||||
src.offsets() == LayoutOffsets{0, 0} &&
|
||||
dst.offsets() == LayoutOffsets{0, 0} &&
|
||||
target_shape[0] % src.tiling()[0] == 0 &&
|
||||
src.tiling()[1] == target_shape[1] && dst.tiling() == target_shape &&
|
||||
*(src_tiles.dimensions().end() - 2) == 1) {
|
||||
xla::Array<Value> src_tiles_retiled(
|
||||
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
|
||||
src_tiles_retiled.Each(
|
||||
[&](const absl::Span<const int64_t> idx, Value *const new_src_tile) {
|
||||
const int64_t tiles_per_vreg = src.tilesPerVreg(target_shape);
|
||||
const int64_t dst_col = idx.back();
|
||||
const int64_t src_col = dst_col / tiles_per_vreg;
|
||||
const int64_t start_slane_idx =
|
||||
src.tiling()[0] * (dst_col % tiles_per_vreg);
|
||||
SmallVector<int64_t> src_idx(toArrayRef(idx));
|
||||
src_idx.back() = src_col;
|
||||
Value src_tile = src_tiles(src_idx);
|
||||
if (start_slane_idx) {
|
||||
SmallVector<int32_t> slane_idxs;
|
||||
slane_idxs.reserve(target_shape[0]);
|
||||
for (int i = 0; i < target_shape[0]; ++i) {
|
||||
slane_idxs.push_back(start_slane_idx + (i % src.tiling()[0]));
|
||||
}
|
||||
const DenseI32ArrayAttr gather_indices =
|
||||
builder.getDenseI32ArrayAttr(slane_idxs);
|
||||
*new_src_tile = builder.create<tpu::GatherOp>(
|
||||
v.getLoc(), src_tile.getType(), src_tile, gather_indices,
|
||||
/*dimension=*/0);
|
||||
} else {
|
||||
*new_src_tile = src_tile;
|
||||
}
|
||||
});
|
||||
src = dst;
|
||||
src_tiles = std::move(src_tiles_retiled);
|
||||
} else if ( // TODO(b/265133506): Generalize retiling.
|
||||
// (8,128) -> (8 * packing,128) tiling change for packed type.
|
||||
src.implicit_dim() == dst.implicit_dim() && bitwidth < 32 &&
|
||||
32 % bitwidth == 0 && src.offsets() == dst.offsets() &&
|
||||
src.tiling() == std::array<int64_t, 2>{8, 128} &&
|
||||
dst.tiling() == std::array<int64_t, 2>{8 * dst.packing(), 128}) {
|
||||
xla::Array<Value> src_tiles_retiled(
|
||||
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
|
||||
int vty_packing = dst.packing();
|
||||
VectorType vreg_x32 =
|
||||
vty.getElementType().isSignlessInteger()
|
||||
? VectorType::get(target_shape, builder.getI32Type())
|
||||
: VectorType::get(target_shape, builder.getF32Type());
|
||||
src_tiles_retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
|
||||
const int vreg_part = idx.back() % vty_packing;
|
||||
SmallVector<Value, 8> parts;
|
||||
parts.reserve(vty_packing);
|
||||
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
|
||||
src_idx[src_idx.size() - 2] *= vty_packing;
|
||||
src_idx[src_idx.size() - 1] /= vty_packing;
|
||||
for (int i = 0; i < vty_packing; ++i) {
|
||||
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
|
||||
v.getLoc(), vreg_x32, src_tiles(src_idx), vreg_part));
|
||||
if (src_idx[src_idx.size() - 2] <
|
||||
src_tiles.dim(src_tiles.num_dimensions() - 2) - 1) {
|
||||
++src_idx[src_idx.size() - 2];
|
||||
}
|
||||
}
|
||||
*tile = builder.create<tpu::PackSubelementsOp>(
|
||||
v.getLoc(), src_tiles.begin()->getType(), parts,
|
||||
tpu::PackFormat::kCompressed);
|
||||
});
|
||||
src = dst;
|
||||
src_tiles = std::move(src_tiles_retiled);
|
||||
} else if ( // Handle retiling from (1, 128 * packing) to (packing, 128) for
|
||||
// packed data.
|
||||
// We do compressed unpacking followed by interleaved packing.
|
||||
// TODO(tlongeri): This can be used as a first step before using
|
||||
// a generalized retiling where we only move sublanes around
|
||||
// (without packing/unpacking).
|
||||
// TODO(tlongeri): Interleaved unpacking followed by interleaved
|
||||
// packing (but with different pairings) might also be
|
||||
// interesting if the next step is a retile, since we can also
|
||||
// match corresponding elements without shifting. It's just that
|
||||
// the tiles are not adjacent (no contiguous vreg slice).
|
||||
src.implicit_dim() == dst.implicit_dim() && bitwidth < 32 &&
|
||||
32 % bitwidth == 0 && src.offsets() == dst.offsets() &&
|
||||
src.tiling() == std::array<int64_t, 2>{1, 128 * packing} &&
|
||||
dst.tiling() == std::array<int64_t, 2>{packing, 128}) {
|
||||
// To illustrate, consider a 2 x 16 16-bit shape laid out in vregs of
|
||||
// 4 sublanes and 2 lanes (this is convenient for to keep the example small
|
||||
// yet non-trivial) with (1, 4) tiling. We will relayout to (2, 2) tiling.
|
||||
//
|
||||
// The vreg slice is 1 x 16, that is, the vreg contains the data for a
|
||||
// 1 x 16 window of the logical shape.
|
||||
//
|
||||
// [a b c d e f g h i j k l m n o p] -> vreg 1
|
||||
// [A B C D E F G H I J K L M N O P] -> vreg 2
|
||||
//
|
||||
// Note: we support multiple vregs per row of the logical shape, but we use
|
||||
// one here just to keep the example small.
|
||||
//
|
||||
// When we do a compressed unpack, the resulting vregs effectively have a
|
||||
// tiling of (1, 2) and cover a vreg slice of 1 x 8 logical elements.
|
||||
//
|
||||
// [a b c d e f g h] -> vreg 1, part 1 [i j k l m n o p] -> vreg 1, part 2
|
||||
// [A B C D E F G H] -> vreg 2, part 1 [I J K L M N O P] -> vreg 2, part 2
|
||||
//
|
||||
// It is clear that if combine vreg 1, part 1 and vreg 2, part 1 we get data
|
||||
// that covers a 2 x 8 vreg slice. Note, however, that we will have to mind
|
||||
// the internal ordering of the vreg.
|
||||
//
|
||||
// [a b c d e f g h [i j k l m n o p
|
||||
// A B C D E F G H] -> new vreg 1 I J K L M N O P] -> new vreg 2
|
||||
//
|
||||
// To see if we can get the right internal ordering that we need for (2, 2)
|
||||
// tiling, let's break new vreg 1 into (1, 2) rows, which correspond to
|
||||
// sublanes when unpacked and half-sublanes when packed.
|
||||
//
|
||||
// [(a b) (c d) (e f) (g h)
|
||||
// (A B) (C D) (E F) (G H)]
|
||||
//
|
||||
// The sublane order for the vreg parts is [(a b) (c d) ...] for vreg 1,
|
||||
// part 1 and [(A B) (C D) ...] for vreg 2, part 1.
|
||||
//
|
||||
// The desired half-sublane order, for packed (2, 2) tiling, is
|
||||
// [(a b) (A B) (c d) (C D) ...]. That is, traverse down each column before
|
||||
// moving to the next one. This is exactly an interleaving of the sublanes
|
||||
// of the vreg parts.
|
||||
xla::Array<Value> src_tiles_retiled(
|
||||
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
|
||||
const VectorType vreg_x32 =
|
||||
vty.getElementType().isSignlessInteger()
|
||||
? VectorType::get(target_shape, builder.getI32Type())
|
||||
: VectorType::get(target_shape, builder.getF32Type());
|
||||
src_tiles_retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
|
||||
SmallVector<Value> parts;
|
||||
parts.reserve(packing);
|
||||
SmallVector<int64_t> src_idx(toArrayRef(idx));
|
||||
*(src_idx.end() - 2) *= packing;
|
||||
const int64_t vreg_part = *(src_idx.end() - 1) % packing;
|
||||
*(src_idx.end() - 1) /= packing;
|
||||
for (int i = 0; i < packing; ++i) {
|
||||
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
|
||||
v.getLoc(), vreg_x32, src_tiles(src_idx), vreg_part));
|
||||
if (*(src_idx.end() - 2) < *(src_tiles.dimensions().end() - 2) - 1) {
|
||||
++*(src_idx.end() - 2);
|
||||
} // The rest is padding, so just pick any of the input parts (but not
|
||||
// an arbitrary vreg so we don't add an extra dependency).
|
||||
}
|
||||
*tile = builder.create<tpu::PackSubelementsOp>(
|
||||
v.getLoc(), src_tiles.begin()->getType(), parts,
|
||||
tpu::PackFormat::kInterleaved);
|
||||
});
|
||||
src = dst;
|
||||
src_tiles = std::move(src_tiles_retiled);
|
||||
} else if (isSupportedReducedSublanesRetile(src, dst, target_shape)) {
|
||||
src_tiles = retileToReducedSublanes(builder, vty.getShape(), src, src_tiles,
|
||||
dst, target_shape);
|
||||
src = dst;
|
||||
} else if (src.tiling() != dst.tiling()) {
|
||||
return not_implemented();
|
||||
}
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
std::tie(src, src_tiles),
|
||||
changeTiling(builder, ctx.target_shape, v.getLoc(), vty, src,
|
||||
std::move(src_tiles), dst.tiling(),
|
||||
dst.offsets()[0] == std::nullopt &&
|
||||
src.offsets()[0] != std::nullopt));
|
||||
|
||||
// Remove second minor implicit dim, for values that have (8, 128) tiling.
|
||||
if (src.implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor &&
|
||||
@ -5478,9 +5496,7 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
|
||||
});
|
||||
src = dst;
|
||||
src_tiles = std::move(src_tiles_retiled);
|
||||
}
|
||||
|
||||
if (src.implicit_dim() != dst.implicit_dim()) {
|
||||
} else if (src.implicit_dim() != dst.implicit_dim()) {
|
||||
return not_implemented();
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user