[Mosaic:TPU] For trunc, expand supported tilings, offsets and bitwidths

infer-vector-layout won't use the full generality anytime soon, but we could reuse this logic for relayouts

PiperOrigin-RevId: 708011538
This commit is contained in:
Tomás Longeri 2024-12-19 13:30:12 -08:00 committed by jax authors
parent d129438548
commit 307c8d3af8
2 changed files with 163 additions and 72 deletions

View File

@ -1012,73 +1012,162 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
disassemble(builder, layout_in, source, ctx.target_shape,
/*use_implicit_shape=*/true));
xla::Array<Value> output_vregs(output_vregs_shape);
if (layout_in.bitwidth() != 32) {
return op.emitOpError("Not implemented: Only 32-bit truncation supported");
}
if (layout_in.offsets() != layout_out.offsets()) {
return op.emitOpError(
"Not implemented: Change of offsets during the truncation");
}
const LayoutOffsets input_offsets = layout_in.offsets();
const LayoutOffsets output_offsets = layout_out.offsets();
const std::array<int64_t, 2> input_vreg_slice =
layout_in.vregSlice(ctx.target_shape);
const std::array<int64_t, 2> output_vreg_slice =
layout_out.vregSlice(ctx.target_shape);
const int input_sublanes_per_tile =
layout_in.sublanesPerTile(ctx.target_shape);
if (layout_in.implicit_dim() != layout_out.implicit_dim()) {
return op.emitOpError("Not implemented: Change of layout during the cast");
return op.emitOpError(
"Not implemented: Truncation changes implicit dimension");
}
if (layout_in.tiling() != ctx.target_shape) {
return op.emitOpError("Not implemented: Only (8,128) tiling supported");
for (const auto &[input_offset, output_offset, input_slice_size] :
llvm::zip_equal(input_offsets, output_offsets, input_vreg_slice)) {
if (!input_offset.has_value() && !output_offset.has_value()) {
// Replicated to replicated is okay
} else if (!input_offset.has_value() && output_offset.has_value()) {
// Replicated to non-replicated could be handled, but we don't leverage
// replication, so we don't expect a replicated input offset to be
// assigned. The materialization of replicated vregs in the vreg
// array should be handled by relayout.
return op.emitOpError(
"Not implemented: Replicated to non-replicated offset");
} else if (input_offset.has_value() && !output_offset.has_value()) {
return op.emitOpError(
"Not implemented: Truncation introduces replication");
} else {
DCHECK(input_offset.has_value() && output_offset.has_value());
if (*input_offset != *output_offset % input_slice_size) {
return op.emitOpError("Not implemented: Misaligned offsets");
}
}
}
VectorType res_vreg_ty =
if (output_vreg_slice[0] % input_vreg_slice[0] != 0 ||
output_vreg_slice[1] % input_vreg_slice[1] != 0) {
// The output vreg slice should be a union of whole input vreg slices
return op.emitOpError("Not implemented: Unsupported tiling change");
}
// How many rows and columns of input vregs we are packing into one output
// vreg:
const int64_t vreg_rows = output_vreg_slice[0] / input_vreg_slice[0];
const int64_t vreg_cols = output_vreg_slice[1] / input_vreg_slice[1];
// Currently, we always pack across rows first, and then across columns.
// Note: Even though we combine it into a single tpu.pack_subelements op, the
// order of the operands is such that it is equivalent to packing across
// rows and then across columns.
// TODO(b/384274392): For some cases we want to pack across columns first, but
// we also need mixed compressed/interleaved packing.
// The format for packing *across* multiple rows in the vreg array (different
// 2nd minor index):
PackFormat row_pack_format = PackFormat::kCompressed;
if (vreg_rows != 1) {
// When going from (a, b) to (a * n, b) tiling, each output tile is the
// union of n input tiles from different vregs. The ith tile of the output
// vreg is formed by packing the ith tiles of the input vregs together.
// This can only be done when tiles are one sublane (by packing interleaved)
// or when they occupy the full vreg (by packing compressed).
// Note: Currently, we always pack across rows before packing across
// columns, so we just check the source tiling.
if (input_sublanes_per_tile == 1) {
row_pack_format = PackFormat::kInterleaved;
} else if (input_sublanes_per_tile == ctx.target_shape[0]) {
row_pack_format = PackFormat::kCompressed;
} else {
return op.emitOpError(
"Not implemented: Tiling change requires interleaving tiles that are "
"not one sublane or one full vreg");
}
}
// The tiling after packing across rows:
const std::array<int64_t, 2> intermediate_tiling = {
layout_in.tiling()[0] * vreg_rows, layout_in.tiling()[1]};
DCHECK_EQ(intermediate_tiling[0], layout_out.tiling()[0]);
// We only support compressed packing across vreg columns, which doesn't
// change the tiling. Logically, it just stacks tiles horizontally.
if (intermediate_tiling[1] != layout_out.tiling()[1] &&
// For (1, x) tiling all minor dimension tilings are equivalent, although
// some are illegal in VectorLayout. So, even though compressed packing in
// general does not change the tiling, for (1, x) we can still change to
// other minor dimension tilings (they are equivalent).
intermediate_tiling[0] != 1) {
// This could be handled, in some cases, by using interleaved packing across
// vreg columns, but we never use tilings like this. An example where we
// could use interleaved packing is (8, 128) f32 -> (8, 256) bf16.
return op.emitOpError(
"Not implemented: Truncating to increasing minor tile size");
}
// The format for packing *across* multiple columns in the vreg array
// (different minor index):
constexpr PackFormat col_pack_format = PackFormat::kCompressed;
if (vreg_rows != 1 && vreg_cols != 1 && row_pack_format != col_pack_format) {
// TODO(b/384274392): We can alternate interleaved and compressed packing
// but how should we expose it in tpu.pack_subelements?
return op.emitOpError(
"Not implemented: Tiling change requires mixed compressed and "
"interleaved packing");
}
const PackFormat pack_format =
vreg_rows != 1 ? row_pack_format : col_pack_format;
const VectorType res_vreg_ty =
getNativeVregType(result_ty.getElementType(), ctx.target_shape);
if (layout_out.tiling() == ctx.target_shape) {
const int packing = layout_out.packing();
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
SmallVector<Value> parts;
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
if (!layout_out.offsets()[1].has_value()) {
idxs_local.back() = 0;
// Make sure we set all parts of the output vreg to make it replicated
parts.append(packing, input_vregs(idxs_local));
} else {
idxs_local.back() *= packing;
for (int64_t i = 0; i < packing; ++i) {
if (idxs_local.back() < input_vregs.dimensions().back()) {
parts.push_back(input_vregs(idxs_local));
++idxs_local.back();
} else {
parts.push_back(nullptr);
}
}
}
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts,
tpu::PackFormat::kCompressed);
});
} else if (layout_out.hasNativeTiling(ctx.target_shape)) {
int packing = layout_out.packing();
SmallVector<int64_t> input_idx;
output_vregs.Each([&](absl::Span<const int64_t> output_idx, Value *v) {
SmallVector<Value> parts;
parts.reserve(packing);
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
CHECK_GE(idxs.size(), 2);
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
if (!layout_out.offsets()[0].has_value()) {
*(idxs_local.end() - 2) = 0;
// Make sure we set all parts of the output vreg to make it replicated
parts.append(packing, input_vregs(idxs_local));
input_idx.assign(output_idx.begin(), output_idx.end());
auto push_col = [&]() {
if (!output_offsets[0].has_value()) {
*(input_idx.end() - 2) = 0;
// Make sure we set all rows of the column to make it replicated
parts.append(vreg_rows, input_vregs(input_idx));
} else {
*(idxs_local.end() - 2) *= packing;
for (int64_t i = 0; i < packing; ++i) {
if (*(idxs_local.end() - 2) < *(input_vregs.dimensions().end() - 2)) {
parts.push_back(input_vregs(idxs_local));
++*(idxs_local.end() - 2);
const int64_t row_offset = *output_offsets[0] / input_vreg_slice[0];
const int64_t base_src_row =
*(output_idx.end() - 2) * vreg_rows - row_offset;
for (int64_t row = base_src_row; row < base_src_row + vreg_rows;
++row) {
if (0 <= row && row < *(input_vregs.dimensions().end() - 2)) {
*(input_idx.end() - 2) = row;
parts.push_back(input_vregs(input_idx));
} else {
parts.push_back(nullptr);
}
}
}
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts,
tpu::PackFormat::kCompressed);
parts.clear();
});
} else {
return op.emitOpError("Not implemented: unsupported output tiling");
}
};
if (!output_offsets[1].has_value()) {
*(input_idx.end() - 1) = 0;
// Make sure we set all column parts of the vreg to make it replicated
push_col();
for (int64_t col = 1; col < vreg_cols; ++col) {
for (int64_t row = 0; row < vreg_rows; ++row) {
parts.push_back(parts[row]);
}
}
} else {
const int64_t col_offset = *output_offsets[1] / input_vreg_slice[1];
const int64_t base_src_col =
*(output_idx.end() - 1) * vreg_cols - col_offset;
for (int64_t col = base_src_col; col < base_src_col + vreg_cols; ++col) {
if (0 <= col && col < *(input_vregs.dimensions().end() - 1)) {
*(input_idx.end() - 1) = col;
push_col();
} else {
parts.append(vreg_rows, nullptr);
}
}
}
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts, pack_format);
});
op.replaceAllUsesWith(assemble(builder, result_ty, layout_out,
std::move(output_vregs), ctx.target_shape,
/*use_implicit_shape=*/true)
@ -6321,6 +6410,7 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
if (bitwidth < 32 && 32 % bitwidth == 0 && src_tiling == ctx.target_shape &&
dst_tiling == std::array<int64_t, 2>{ctx.target_shape[0] * packing,
ctx.target_shape[1]}) {
// TODO(tlongeri): This relayout is just ext + trunc. Refactor.
// Note: for int4, retiling with scratch is always faster.
if (bitwidth != 4 || !has_enough_scratch) {
// Note: The code below does not work when src is replicated and dst is
@ -6379,6 +6469,7 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
if (bitwidth < 32 && 32 % bitwidth == 0 &&
src_tiling == std::array<int64_t, 2>{1, ctx.target_shape[1] * packing} &&
dst_tiling == std::array<int64_t, 2>{packing, ctx.target_shape[1]}) {
// TODO(tlongeri): This relayout is just ext + trunc. Refactor.
// To illustrate, consider a 2 x 16 16-bit shape laid out in vregs of
// 4 sublanes and 2 lanes (this is convenient for to keep the example small
// yet non-trivial) with (1, 4) tiling. We will relayout to (2, 2) tiling.
@ -6826,8 +6917,8 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
// TODO: b/342235360 - This check is temporary while we increase and test
// support for offsets outside of the first tile. When support is more broad,
// any op without support should check it within their own rule.
if (!isa<vector::BroadcastOp, vector::ExtractStridedSliceOp,
vector::ShapeCastOp>(op)) {
if (!isa<arith::TruncFOp, arith::TruncIOp, vector::BroadcastOp,
vector::ExtractStridedSliceOp, vector::ShapeCastOp>(op)) {
for (const Layout &layout : layouts_in) {
if (layout && layout->offsets()[1].has_value() &&
layout->offsets()[1].value() >= layout->tiling()[1]) {

View File

@ -142,7 +142,8 @@ class VectorLayoutInferer {
// TODO: b/342235360 - This check is temporary while we increase and test
// support for offsets outside of the first tile. When support is more
// broad, any op without support should check it within their own rule.
if (!isa<vector::BroadcastOp, vector::ExtractStridedSliceOp>(any_op)) {
if (!isa<arith::TruncIOp, arith::TruncFOp, vector::BroadcastOp,
vector::ExtractStridedSliceOp>(any_op)) {
const SmallVector<Layout> layouts_in = getLayoutFromOperands(&any_op);
for (const Layout &layout : layouts_in) {
if (layout &&
@ -1699,23 +1700,22 @@ class VectorLayoutInferer {
auto dst_ty = cast<VectorType>(op->getResult(0).getType());
auto some_layout = getLayout(op->getOperand(0));
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
if (dyn_cast<arith::TruncFOp>(op)) {
TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32 &&
(dst_ty.getElementTypeBitWidth() == 16 ||
dst_ty.getElementTypeBitWidth() == 8),
"Only 32-bit to 8-bit or 16-bit truncation supported");
} else {
TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32,
"Only 32-bit truncation supported");
const unsigned src_bitwidth = src_ty.getElementTypeBitWidth();
const unsigned dst_bitwidth = dst_ty.getElementTypeBitWidth();
if (isa<arith::TruncFOp>(op)) {
TPU_CHECK_OP(
src_bitwidth == 32 && (dst_bitwidth == 16 || dst_bitwidth == 8),
"Only 32-bit to 16-bit or 8-bit float truncation supported");
}
auto &layout = *some_layout;
bool select_native = allUsersRequireNativeTiling(op->getResult(0));
auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_,
layout.implicit_dim());
auto src_layout = VectorLayout(
src_bitwidth, layout.offsets(),
select_native ? nativeTiling(src_bitwidth) : layout.tiling(),
layout.implicit_dim());
auto dst_layout = VectorLayout(
dst_ty.getElementTypeBitWidth(), layout.offsets(),
select_native ? nativeTiling(dst_ty.getElementTypeBitWidth())
: default_tiling_,
dst_bitwidth, layout.offsets(),
select_native ? nativeTiling(dst_bitwidth) : layout.tiling(),
layout.implicit_dim());
setLayout(op, src_layout, dst_layout);
return success();