[Mosaic TPU] Use large to compact 2nd minor retiling for conversions going both ways

This specific retiling is its own inverse and it faster than alternatives.

PiperOrigin-RevId: 716360070
This commit is contained in:
Adam Paszke 2025-01-16 13:34:39 -08:00 committed by jax authors
parent aa9cea0a55
commit bd22bfef71

View File

@ -6256,12 +6256,17 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
});
return std::pair(dst, std::move(retiled));
}
// (8,128) -> (8 * packing,128) tiling change for packed type.
if (src_offsets[0].value_or(0) < dst_vreg_slice[0] &&
// (8,128) <-> (8 * packing,128) tiling change for packed type.
if (ctx.hardware_generation >= 4 &&
src_offsets[0].value_or(0) < dst_vreg_slice[0] &&
src_offsets[1].value_or(0) < dst_vreg_slice[1] && bitwidth < 32 &&
32 % bitwidth == 0 && src_tiling == ctx.target_shape &&
dst_tiling == std::array<int64_t, 2>{ctx.target_shape[0] * packing,
ctx.target_shape[1]}) {
32 % bitwidth == 0 &&
((src_tiling == ctx.target_shape &&
dst_tiling == std::array<int64_t, 2>{ctx.target_shape[0] * packing,
ctx.target_shape[1]}) ||
(dst_tiling == ctx.target_shape &&
src_tiling == std::array<int64_t, 2>{ctx.target_shape[0] * packing,
ctx.target_shape[1]}))) {
// TODO(tlongeri): This relayout is just ext + trunc. Refactor.
// Note: for int4, retiling with scratch is always faster.
if (bitwidth != 4 || !has_enough_scratch) {
@ -6277,26 +6282,32 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
vty.getElementType().isSignlessInteger()
? VectorType::get(target_shape, builder.getI32Type())
: VectorType::get(target_shape, builder.getF32Type());
// For each output vreg we collect `packing` registers from the moving dim
// (sublanes or lanes), while using the other vreg dim to determine which
// part of each register to use (the parts dim).
const int parts_dim = src_tiling[0] < dst_tiling[0] ? 1 : 2;
const int moving_dim = src_tiling[0] < dst_tiling[0] ? 2 : 1;
retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
const int vreg_part = idx.back() % packing;
const int vreg_part = *(idx.end() - parts_dim) % packing;
SmallVector<Value, 8> parts;
parts.reserve(packing);
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
*(src_idx.end() - 1) /= packing;
if (!dst.offsets()[0].has_value()) {
*(src_idx.end() - 2) = 0;
*(src_idx.end() - parts_dim) /= packing;
if (!dst.offsets()[2 - moving_dim].has_value()) {
*(src_idx.end() - moving_dim) = 0;
// Make sure we set all parts of the output vreg to make it replicated
parts.append(packing, builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_x32, vregs(src_idx), vreg_part,
tpu::PackFormat::kCompressed));
} else {
*(src_idx.end() - 2) *= packing;
*(src_idx.end() - moving_dim) *= packing;
for (int i = 0; i < packing; ++i) {
if (*(src_idx.end() - 2) < *(vregs.dimensions().end() - 2)) {
if (*(src_idx.end() - moving_dim) <
*(vregs.dimensions().end() - moving_dim)) {
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_x32, vregs(src_idx), vreg_part,
tpu::PackFormat::kCompressed));
++*(src_idx.end() - 2);
++*(src_idx.end() - moving_dim);
} else {
parts.push_back(nullptr);
}