[Mosaic TPU][NFC] Refactor tpu matmul rule.

* Separate MXU size to MXU contracting size and MXU non-contracting size.
* Rename tile to group for MXU shaped tiling since tile is overused in Mosaic.

PiperOrigin-RevId: 684116306
This commit is contained in:
Jevin Jiang 2024-10-09 11:44:33 -07:00 committed by jax authors
parent 53668b88eb
commit f96c5661ac

View File

@ -1714,12 +1714,12 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
llvm::all_of(layouts_in, [&](const Layout &l) { return l.has_value(); }));
TPU_ASSERT_OP(layouts_out.front().has_value());
auto matmul_op = cast<tpu::MatmulOp>(op);
auto transpose_lhs = matmul_op.getTransposeLhs();
auto transpose_rhs = matmul_op.getTransposeRhs();
auto &layout_lhs = *layouts_in[0];
auto &layout_rhs = *layouts_in[1];
auto &layout_acc = *layouts_in[2];
auto layout_out = *layouts_out[0];
const auto transpose_lhs = matmul_op.getTransposeLhs();
const auto transpose_rhs = matmul_op.getTransposeRhs();
const auto &layout_lhs = *layouts_in[0];
const auto &layout_rhs = *layouts_in[1];
const auto &layout_acc = *layouts_in[2];
const auto &layout_out = *layouts_out[0];
if (transpose_lhs) {
return op.emitOpError("Not implemented: Transposed LHS");
}
@ -1740,7 +1740,7 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
acc = tpu_matmul_op.getAcc();
res = tpu_matmul_op.getResult();
} else {
LOG(FATAL) << "Unexpected op type";
return op.emitOpError("Expected a tpu::MatmulOp");
}
for (const std::optional<VectorLayout> &layout_opt : layouts_in) {
@ -1755,7 +1755,7 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
}
}
if (acc.getType().getElementType().getIntOrFloatBitWidth() != 32) {
return op.emitOpError("Not implemented: Non-32-bit matmul result");
return op.emitOpError("Not implemented: Non-32-bit matmul acc");
}
const ArrayRef<int64_t> lhs_shape = lhs.getType().getShape();
const ArrayRef<int64_t> rhs_shape = rhs.getType().getShape();
@ -1882,20 +1882,32 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
// At this point, all paddings on vregs are masked out. For now, we
// append zero vregs to make LHS's second dim, both RHS's dims and ACC's
// second dim to be a multiple of mxu_size.
if (ctx.mxu_shape[0] != ctx.mxu_shape[1]) {
return op.emitOpError(
"Not implemented: MXU contracting size and noncontracting size are "
"different");
auto mxu_contracting_size = ctx.mxu_shape[0];
auto mxu_noncontracting_size = ctx.mxu_shape[1];
auto rhs_row_size = mxu_contracting_size;
auto rhs_col_size = mxu_noncontracting_size;
if (transpose_rhs) {
rhs_row_size = mxu_noncontracting_size;
rhs_col_size = mxu_contracting_size;
}
int64_t mxu_size = ctx.mxu_shape[0];
CHECK_EQ(mxu_size % ctx.target_shape[0], 0);
CHECK_EQ(mxu_size % ctx.target_shape[1], 0);
auto mxu_row_vregs = mxu_size / (ctx.target_shape[0] * layout_rhs.packing());
auto mxu_col_vregs = mxu_size / ctx.target_shape[1];
int64_t target_lhs_col_vregs = llvm::alignTo(lhs_vregs.dim(1), mxu_col_vregs);
int64_t target_rhs_row_vregs = llvm::alignTo(rhs_vregs.dim(0), mxu_row_vregs);
int64_t target_rhs_col_vregs = llvm::alignTo(rhs_vregs.dim(1), mxu_col_vregs);
int64_t target_acc_col_vregs = llvm::alignTo(acc_vregs.dim(1), mxu_col_vregs);
CHECK_EQ(rhs_row_size % ctx.target_shape[1], 0);
CHECK_EQ(rhs_col_size % ctx.target_shape[1], 0);
// Here, a single group corresponds to a single matmul invocation in unrolled
// code. The RHS group matches the MXU shape.
auto lhs_col_vregs_per_group = mxu_contracting_size / ctx.target_shape[1];
auto rhs_row_vregs_per_group =
rhs_row_size / (ctx.target_shape[0] * layout_rhs.packing());
auto rhs_col_vregs_per_group = rhs_col_size / ctx.target_shape[1];
auto acc_col_vregs_per_group = mxu_noncontracting_size / ctx.target_shape[1];
int64_t target_lhs_col_vregs =
llvm::alignTo(lhs_vregs.dim(1), lhs_col_vregs_per_group);
int64_t target_rhs_row_vregs =
llvm::alignTo(rhs_vregs.dim(0), rhs_row_vregs_per_group);
int64_t target_rhs_col_vregs =
llvm::alignTo(rhs_vregs.dim(1), rhs_col_vregs_per_group);
int64_t target_acc_col_vregs =
llvm::alignTo(acc_vregs.dim(1), acc_col_vregs_per_group);
xla::Array<Value> target_lhs_vregs({lhs_vregs.dim(0), target_lhs_col_vregs},
lhs_zeros_vreg);
@ -1908,10 +1920,11 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
target_acc_vregs.UpdateSlice(acc_vregs, {0, 0});
// Now we can regroup vregs from target vregs.
const auto lhs_col_ty = VectorType::get({padded_lhs_rows, mxu_size},
lhs.getType().getElementType());
const auto acc_col_ty = VectorType::get({padded_lhs_rows, mxu_size},
acc.getType().getElementType());
const auto lhs_col_ty = VectorType::get(
{padded_lhs_rows, mxu_contracting_size}, lhs.getType().getElementType());
const auto acc_col_ty =
VectorType::get({padded_lhs_rows, mxu_noncontracting_size},
acc.getType().getElementType());
const ArrayAttr lhs_layout_attr =
builder.getArrayAttr({builder.getAttr<VectorLayoutAttr>(layout_lhs)});
const ArrayAttr rhs_layout_attr =
@ -1919,40 +1932,39 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
const ArrayAttr acc_layout_attr =
builder.getArrayAttr({builder.getAttr<VectorLayoutAttr>(layout_acc)});
int64_t nk = llvm::divideCeil(lhs_shape[1], mxu_size);
CHECK_EQ(nk, target_lhs_vregs.dim(1) / mxu_col_vregs);
int64_t nk = llvm::divideCeil(lhs_shape[1], mxu_contracting_size);
CHECK_EQ(nk, target_lhs_vregs.dim(1) / lhs_col_vregs_per_group);
SmallVector<tpu::RollVectorsOp> lhs_cols(nk);
for (int64_t i = 0; i < nk; ++i) {
const xla::Array<Value> col_vregs = target_lhs_vregs.Slice(
{0, i * mxu_col_vregs},
{target_lhs_vregs.dim(0), (i + 1) * mxu_col_vregs});
{0, i * lhs_col_vregs_per_group},
{target_lhs_vregs.dim(0), (i + 1) * lhs_col_vregs_per_group});
lhs_cols[i] = builder.create<tpu::RollVectorsOp>(
op.getLoc(), lhs_col_ty, XlaArrayToFlatArrayRef(col_vregs));
lhs_cols[i]->setAttr("out_layout", lhs_layout_attr);
}
// Here, "tile" is used as in the context of the MXU shape (NOT as in the
// context of tiled layouts).
const auto rhs_tile_ty =
VectorType::get({mxu_size, mxu_size}, rhs.getType().getElementType());
const int64_t rhs_vregs_per_tile = mxu_row_vregs * mxu_col_vregs;
const auto rhs_group_ty = VectorType::get({rhs_row_size, rhs_col_size},
rhs.getType().getElementType());
const int64_t rhs_vregs_per_group =
rhs_row_vregs_per_group * rhs_col_vregs_per_group;
int64_t nj;
if (transpose_rhs) {
nj = llvm::divideCeil(rhs_shape[0], mxu_size);
CHECK_EQ(nk, llvm::divideCeil(rhs_shape[1], mxu_size));
CHECK_EQ(nk, target_rhs_vregs.dim(1) / mxu_col_vregs);
target_rhs_vregs.Reshape(
{nj, rhs_vregs_per_tile / mxu_col_vregs, nk, mxu_col_vregs});
nj = llvm::divideCeil(rhs_shape[0], rhs_row_size);
CHECK_EQ(nk, llvm::divideCeil(rhs_shape[1], rhs_col_size));
CHECK_EQ(nk, target_rhs_vregs.dim(1) / rhs_col_vregs_per_group);
target_rhs_vregs.Reshape({nj, rhs_vregs_per_group / rhs_col_vregs_per_group,
nk, rhs_col_vregs_per_group});
target_rhs_vregs.TransposeDimensions({2, 0, 1, 3});
target_rhs_vregs.Reshape({nk, nj, rhs_vregs_per_tile});
target_rhs_vregs.Reshape({nk, nj, rhs_vregs_per_group});
} else {
nj = llvm::divideCeil(rhs_shape[1], mxu_size);
CHECK_EQ(nk, llvm::divideCeil(rhs_shape[0], mxu_size));
CHECK_EQ(nk, target_rhs_vregs.dim(0) / mxu_row_vregs);
target_rhs_vregs.Reshape(
{nk, rhs_vregs_per_tile / mxu_col_vregs, nj, mxu_col_vregs});
nj = llvm::divideCeil(rhs_shape[1], rhs_col_size);
CHECK_EQ(nk, llvm::divideCeil(rhs_shape[0], rhs_row_size));
CHECK_EQ(nk, target_rhs_vregs.dim(0) / rhs_row_vregs_per_group);
target_rhs_vregs.Reshape({nk, rhs_vregs_per_group / rhs_col_vregs_per_group,
nj, rhs_col_vregs_per_group});
target_rhs_vregs.TransposeDimensions({0, 2, 1, 3});
target_rhs_vregs.Reshape({nk, nj, rhs_vregs_per_tile});
target_rhs_vregs.Reshape({nk, nj, rhs_vregs_per_group});
}
const tpu::ContractPrecisionAttr precision_attr = // May be null
@ -1960,28 +1972,29 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
for (int64_t j = 0; j < nj; ++j) {
for (int64_t k = 0; k < nk; ++k) {
// TODO(tlongeri): there should be a way to slice without copying
xla::Array<Value> rhs_tile =
target_rhs_vregs.Slice({k, j, 0}, {k + 1, j + 1, rhs_vregs_per_tile});
auto rhs_rolled_tile = builder.create<tpu::RollVectorsOp>(
op.getLoc(), rhs_tile_ty, XlaArrayToFlatArrayRef(rhs_tile));
rhs_rolled_tile->setAttr("out_layout", rhs_layout_attr);
xla::Array<Value> rhs_group = target_rhs_vregs.Slice(
{k, j, 0}, {k + 1, j + 1, rhs_vregs_per_group});
auto rhs_rolled_group = builder.create<tpu::RollVectorsOp>(
op.getLoc(), rhs_group_ty, XlaArrayToFlatArrayRef(rhs_group));
rhs_rolled_group->setAttr("out_layout", rhs_layout_attr);
const xla::Array<Value> acc_col_vregs = target_acc_vregs.Slice(
{0, j * mxu_col_vregs},
{target_acc_vregs.dim(0), (j + 1) * mxu_col_vregs});
{0, j * acc_col_vregs_per_group},
{target_acc_vregs.dim(0), (j + 1) * acc_col_vregs_per_group});
auto acc_col = builder.create<tpu::RollVectorsOp>(
op.getLoc(), acc_col_ty, XlaArrayToFlatArrayRef(acc_col_vregs));
acc_col->setAttr("out_layout", acc_layout_attr);
auto new_acc_col = builder.create<tpu::MatmulOp>(
op.getLoc(), acc_col_ty, lhs_cols[k], rhs_rolled_tile, acc_col,
op.getLoc(), acc_col_ty, lhs_cols[k], rhs_rolled_group, acc_col,
transpose_lhs, transpose_rhs, precision_attr);
auto new_acc_vregs = builder.create<tpu::UnrollVectorsOp>(
op.getLoc(),
TypeRange(ValueRange(XlaArrayToFlatArrayRef(acc_col_vregs))),
new_acc_col);
new_acc_vregs->setAttr("in_layout", acc_layout_attr);
updateSliceFromRange(target_acc_vregs, new_acc_vregs->getResults(),
{0, j * mxu_col_vregs},
{target_acc_vregs.dim(0), (j + 1) * mxu_col_vregs});
updateSliceFromRange(
target_acc_vregs, new_acc_vregs->getResults(),
{0, j * acc_col_vregs_per_group},
{target_acc_vregs.dim(0), (j + 1) * acc_col_vregs_per_group});
}
}
op.replaceAllUsesWith(