mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
[Mosaic:TPU] For trunc, expand supported tilings, offsets and bitwidths
infer-vector-layout won't use the full generality anytime soon, but we could reuse this logic for relayouts PiperOrigin-RevId: 708011538
This commit is contained in:
parent
d129438548
commit
307c8d3af8
@ -1012,73 +1012,162 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
|
||||
disassemble(builder, layout_in, source, ctx.target_shape,
|
||||
/*use_implicit_shape=*/true));
|
||||
xla::Array<Value> output_vregs(output_vregs_shape);
|
||||
if (layout_in.bitwidth() != 32) {
|
||||
return op.emitOpError("Not implemented: Only 32-bit truncation supported");
|
||||
}
|
||||
if (layout_in.offsets() != layout_out.offsets()) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: Change of offsets during the truncation");
|
||||
}
|
||||
const LayoutOffsets input_offsets = layout_in.offsets();
|
||||
const LayoutOffsets output_offsets = layout_out.offsets();
|
||||
const std::array<int64_t, 2> input_vreg_slice =
|
||||
layout_in.vregSlice(ctx.target_shape);
|
||||
const std::array<int64_t, 2> output_vreg_slice =
|
||||
layout_out.vregSlice(ctx.target_shape);
|
||||
const int input_sublanes_per_tile =
|
||||
layout_in.sublanesPerTile(ctx.target_shape);
|
||||
|
||||
if (layout_in.implicit_dim() != layout_out.implicit_dim()) {
|
||||
return op.emitOpError("Not implemented: Change of layout during the cast");
|
||||
return op.emitOpError(
|
||||
"Not implemented: Truncation changes implicit dimension");
|
||||
}
|
||||
if (layout_in.tiling() != ctx.target_shape) {
|
||||
return op.emitOpError("Not implemented: Only (8,128) tiling supported");
|
||||
for (const auto &[input_offset, output_offset, input_slice_size] :
|
||||
llvm::zip_equal(input_offsets, output_offsets, input_vreg_slice)) {
|
||||
if (!input_offset.has_value() && !output_offset.has_value()) {
|
||||
// Replicated to replicated is okay
|
||||
} else if (!input_offset.has_value() && output_offset.has_value()) {
|
||||
// Replicated to non-replicated could be handled, but we don't leverage
|
||||
// replication, so we don't expect a replicated input offset to be
|
||||
// assigned. The materialization of replicated vregs in the vreg
|
||||
// array should be handled by relayout.
|
||||
return op.emitOpError(
|
||||
"Not implemented: Replicated to non-replicated offset");
|
||||
} else if (input_offset.has_value() && !output_offset.has_value()) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: Truncation introduces replication");
|
||||
} else {
|
||||
DCHECK(input_offset.has_value() && output_offset.has_value());
|
||||
if (*input_offset != *output_offset % input_slice_size) {
|
||||
return op.emitOpError("Not implemented: Misaligned offsets");
|
||||
}
|
||||
}
|
||||
}
|
||||
VectorType res_vreg_ty =
|
||||
if (output_vreg_slice[0] % input_vreg_slice[0] != 0 ||
|
||||
output_vreg_slice[1] % input_vreg_slice[1] != 0) {
|
||||
// The output vreg slice should be a union of whole input vreg slices
|
||||
return op.emitOpError("Not implemented: Unsupported tiling change");
|
||||
}
|
||||
// How many rows and columns of input vregs we are packing into one output
|
||||
// vreg:
|
||||
const int64_t vreg_rows = output_vreg_slice[0] / input_vreg_slice[0];
|
||||
const int64_t vreg_cols = output_vreg_slice[1] / input_vreg_slice[1];
|
||||
|
||||
// Currently, we always pack across rows first, and then across columns.
|
||||
// Note: Even though we combine it into a single tpu.pack_subelements op, the
|
||||
// order of the operands is such that it is equivalent to packing across
|
||||
// rows and then across columns.
|
||||
// TODO(b/384274392): For some cases we want to pack across columns first, but
|
||||
// we also need mixed compressed/interleaved packing.
|
||||
|
||||
// The format for packing *across* multiple rows in the vreg array (different
|
||||
// 2nd minor index):
|
||||
PackFormat row_pack_format = PackFormat::kCompressed;
|
||||
if (vreg_rows != 1) {
|
||||
// When going from (a, b) to (a * n, b) tiling, each output tile is the
|
||||
// union of n input tiles from different vregs. The ith tile of the output
|
||||
// vreg is formed by packing the ith tiles of the input vregs together.
|
||||
// This can only be done when tiles are one sublane (by packing interleaved)
|
||||
// or when they occupy the full vreg (by packing compressed).
|
||||
// Note: Currently, we always pack across rows before packing across
|
||||
// columns, so we just check the source tiling.
|
||||
if (input_sublanes_per_tile == 1) {
|
||||
row_pack_format = PackFormat::kInterleaved;
|
||||
} else if (input_sublanes_per_tile == ctx.target_shape[0]) {
|
||||
row_pack_format = PackFormat::kCompressed;
|
||||
} else {
|
||||
return op.emitOpError(
|
||||
"Not implemented: Tiling change requires interleaving tiles that are "
|
||||
"not one sublane or one full vreg");
|
||||
}
|
||||
}
|
||||
// The tiling after packing across rows:
|
||||
const std::array<int64_t, 2> intermediate_tiling = {
|
||||
layout_in.tiling()[0] * vreg_rows, layout_in.tiling()[1]};
|
||||
DCHECK_EQ(intermediate_tiling[0], layout_out.tiling()[0]);
|
||||
|
||||
// We only support compressed packing across vreg columns, which doesn't
|
||||
// change the tiling. Logically, it just stacks tiles horizontally.
|
||||
if (intermediate_tiling[1] != layout_out.tiling()[1] &&
|
||||
// For (1, x) tiling all minor dimension tilings are equivalent, although
|
||||
// some are illegal in VectorLayout. So, even though compressed packing in
|
||||
// general does not change the tiling, for (1, x) we can still change to
|
||||
// other minor dimension tilings (they are equivalent).
|
||||
intermediate_tiling[0] != 1) {
|
||||
// This could be handled, in some cases, by using interleaved packing across
|
||||
// vreg columns, but we never use tilings like this. An example where we
|
||||
// could use interleaved packing is (8, 128) f32 -> (8, 256) bf16.
|
||||
return op.emitOpError(
|
||||
"Not implemented: Truncating to increasing minor tile size");
|
||||
}
|
||||
// The format for packing *across* multiple columns in the vreg array
|
||||
// (different minor index):
|
||||
constexpr PackFormat col_pack_format = PackFormat::kCompressed;
|
||||
|
||||
if (vreg_rows != 1 && vreg_cols != 1 && row_pack_format != col_pack_format) {
|
||||
// TODO(b/384274392): We can alternate interleaved and compressed packing
|
||||
// but how should we expose it in tpu.pack_subelements?
|
||||
return op.emitOpError(
|
||||
"Not implemented: Tiling change requires mixed compressed and "
|
||||
"interleaved packing");
|
||||
}
|
||||
const PackFormat pack_format =
|
||||
vreg_rows != 1 ? row_pack_format : col_pack_format;
|
||||
|
||||
const VectorType res_vreg_ty =
|
||||
getNativeVregType(result_ty.getElementType(), ctx.target_shape);
|
||||
if (layout_out.tiling() == ctx.target_shape) {
|
||||
const int packing = layout_out.packing();
|
||||
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
|
||||
SmallVector<Value> parts;
|
||||
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
|
||||
if (!layout_out.offsets()[1].has_value()) {
|
||||
idxs_local.back() = 0;
|
||||
// Make sure we set all parts of the output vreg to make it replicated
|
||||
parts.append(packing, input_vregs(idxs_local));
|
||||
} else {
|
||||
idxs_local.back() *= packing;
|
||||
for (int64_t i = 0; i < packing; ++i) {
|
||||
if (idxs_local.back() < input_vregs.dimensions().back()) {
|
||||
parts.push_back(input_vregs(idxs_local));
|
||||
++idxs_local.back();
|
||||
} else {
|
||||
parts.push_back(nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts,
|
||||
tpu::PackFormat::kCompressed);
|
||||
});
|
||||
} else if (layout_out.hasNativeTiling(ctx.target_shape)) {
|
||||
int packing = layout_out.packing();
|
||||
|
||||
SmallVector<int64_t> input_idx;
|
||||
output_vregs.Each([&](absl::Span<const int64_t> output_idx, Value *v) {
|
||||
SmallVector<Value> parts;
|
||||
parts.reserve(packing);
|
||||
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
|
||||
CHECK_GE(idxs.size(), 2);
|
||||
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
|
||||
if (!layout_out.offsets()[0].has_value()) {
|
||||
*(idxs_local.end() - 2) = 0;
|
||||
// Make sure we set all parts of the output vreg to make it replicated
|
||||
parts.append(packing, input_vregs(idxs_local));
|
||||
input_idx.assign(output_idx.begin(), output_idx.end());
|
||||
auto push_col = [&]() {
|
||||
if (!output_offsets[0].has_value()) {
|
||||
*(input_idx.end() - 2) = 0;
|
||||
// Make sure we set all rows of the column to make it replicated
|
||||
parts.append(vreg_rows, input_vregs(input_idx));
|
||||
} else {
|
||||
*(idxs_local.end() - 2) *= packing;
|
||||
for (int64_t i = 0; i < packing; ++i) {
|
||||
if (*(idxs_local.end() - 2) < *(input_vregs.dimensions().end() - 2)) {
|
||||
parts.push_back(input_vregs(idxs_local));
|
||||
++*(idxs_local.end() - 2);
|
||||
const int64_t row_offset = *output_offsets[0] / input_vreg_slice[0];
|
||||
const int64_t base_src_row =
|
||||
*(output_idx.end() - 2) * vreg_rows - row_offset;
|
||||
for (int64_t row = base_src_row; row < base_src_row + vreg_rows;
|
||||
++row) {
|
||||
if (0 <= row && row < *(input_vregs.dimensions().end() - 2)) {
|
||||
*(input_idx.end() - 2) = row;
|
||||
parts.push_back(input_vregs(input_idx));
|
||||
} else {
|
||||
parts.push_back(nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts,
|
||||
tpu::PackFormat::kCompressed);
|
||||
parts.clear();
|
||||
});
|
||||
} else {
|
||||
return op.emitOpError("Not implemented: unsupported output tiling");
|
||||
}
|
||||
};
|
||||
if (!output_offsets[1].has_value()) {
|
||||
*(input_idx.end() - 1) = 0;
|
||||
// Make sure we set all column parts of the vreg to make it replicated
|
||||
push_col();
|
||||
for (int64_t col = 1; col < vreg_cols; ++col) {
|
||||
for (int64_t row = 0; row < vreg_rows; ++row) {
|
||||
parts.push_back(parts[row]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int64_t col_offset = *output_offsets[1] / input_vreg_slice[1];
|
||||
const int64_t base_src_col =
|
||||
*(output_idx.end() - 1) * vreg_cols - col_offset;
|
||||
for (int64_t col = base_src_col; col < base_src_col + vreg_cols; ++col) {
|
||||
if (0 <= col && col < *(input_vregs.dimensions().end() - 1)) {
|
||||
*(input_idx.end() - 1) = col;
|
||||
push_col();
|
||||
} else {
|
||||
parts.append(vreg_rows, nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts, pack_format);
|
||||
});
|
||||
op.replaceAllUsesWith(assemble(builder, result_ty, layout_out,
|
||||
std::move(output_vregs), ctx.target_shape,
|
||||
/*use_implicit_shape=*/true)
|
||||
@ -6321,6 +6410,7 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
if (bitwidth < 32 && 32 % bitwidth == 0 && src_tiling == ctx.target_shape &&
|
||||
dst_tiling == std::array<int64_t, 2>{ctx.target_shape[0] * packing,
|
||||
ctx.target_shape[1]}) {
|
||||
// TODO(tlongeri): This relayout is just ext + trunc. Refactor.
|
||||
// Note: for int4, retiling with scratch is always faster.
|
||||
if (bitwidth != 4 || !has_enough_scratch) {
|
||||
// Note: The code below does not work when src is replicated and dst is
|
||||
@ -6379,6 +6469,7 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
if (bitwidth < 32 && 32 % bitwidth == 0 &&
|
||||
src_tiling == std::array<int64_t, 2>{1, ctx.target_shape[1] * packing} &&
|
||||
dst_tiling == std::array<int64_t, 2>{packing, ctx.target_shape[1]}) {
|
||||
// TODO(tlongeri): This relayout is just ext + trunc. Refactor.
|
||||
// To illustrate, consider a 2 x 16 16-bit shape laid out in vregs of
|
||||
// 4 sublanes and 2 lanes (this is convenient for to keep the example small
|
||||
// yet non-trivial) with (1, 4) tiling. We will relayout to (2, 2) tiling.
|
||||
@ -6826,8 +6917,8 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
|
||||
// TODO: b/342235360 - This check is temporary while we increase and test
|
||||
// support for offsets outside of the first tile. When support is more broad,
|
||||
// any op without support should check it within their own rule.
|
||||
if (!isa<vector::BroadcastOp, vector::ExtractStridedSliceOp,
|
||||
vector::ShapeCastOp>(op)) {
|
||||
if (!isa<arith::TruncFOp, arith::TruncIOp, vector::BroadcastOp,
|
||||
vector::ExtractStridedSliceOp, vector::ShapeCastOp>(op)) {
|
||||
for (const Layout &layout : layouts_in) {
|
||||
if (layout && layout->offsets()[1].has_value() &&
|
||||
layout->offsets()[1].value() >= layout->tiling()[1]) {
|
||||
|
@ -142,7 +142,8 @@ class VectorLayoutInferer {
|
||||
// TODO: b/342235360 - This check is temporary while we increase and test
|
||||
// support for offsets outside of the first tile. When support is more
|
||||
// broad, any op without support should check it within their own rule.
|
||||
if (!isa<vector::BroadcastOp, vector::ExtractStridedSliceOp>(any_op)) {
|
||||
if (!isa<arith::TruncIOp, arith::TruncFOp, vector::BroadcastOp,
|
||||
vector::ExtractStridedSliceOp>(any_op)) {
|
||||
const SmallVector<Layout> layouts_in = getLayoutFromOperands(&any_op);
|
||||
for (const Layout &layout : layouts_in) {
|
||||
if (layout &&
|
||||
@ -1699,23 +1700,22 @@ class VectorLayoutInferer {
|
||||
auto dst_ty = cast<VectorType>(op->getResult(0).getType());
|
||||
auto some_layout = getLayout(op->getOperand(0));
|
||||
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
|
||||
if (dyn_cast<arith::TruncFOp>(op)) {
|
||||
TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32 &&
|
||||
(dst_ty.getElementTypeBitWidth() == 16 ||
|
||||
dst_ty.getElementTypeBitWidth() == 8),
|
||||
"Only 32-bit to 8-bit or 16-bit truncation supported");
|
||||
} else {
|
||||
TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32,
|
||||
"Only 32-bit truncation supported");
|
||||
const unsigned src_bitwidth = src_ty.getElementTypeBitWidth();
|
||||
const unsigned dst_bitwidth = dst_ty.getElementTypeBitWidth();
|
||||
if (isa<arith::TruncFOp>(op)) {
|
||||
TPU_CHECK_OP(
|
||||
src_bitwidth == 32 && (dst_bitwidth == 16 || dst_bitwidth == 8),
|
||||
"Only 32-bit to 16-bit or 8-bit float truncation supported");
|
||||
}
|
||||
auto &layout = *some_layout;
|
||||
bool select_native = allUsersRequireNativeTiling(op->getResult(0));
|
||||
auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_,
|
||||
layout.implicit_dim());
|
||||
auto src_layout = VectorLayout(
|
||||
src_bitwidth, layout.offsets(),
|
||||
select_native ? nativeTiling(src_bitwidth) : layout.tiling(),
|
||||
layout.implicit_dim());
|
||||
auto dst_layout = VectorLayout(
|
||||
dst_ty.getElementTypeBitWidth(), layout.offsets(),
|
||||
select_native ? nativeTiling(dst_ty.getElementTypeBitWidth())
|
||||
: default_tiling_,
|
||||
dst_bitwidth, layout.offsets(),
|
||||
select_native ? nativeTiling(dst_bitwidth) : layout.tiling(),
|
||||
layout.implicit_dim());
|
||||
setLayout(op, src_layout, dst_layout);
|
||||
return success();
|
||||
|
Loading…
x
Reference in New Issue
Block a user