mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic TPU][NFC] Refactor tpu matmul rule.
* Separate MXU size to MXU contracting size and MXU non-contracting size. * Rename tile to group for MXU shaped tiling since tile is overused in Mosaic. PiperOrigin-RevId: 684116306
This commit is contained in:
parent
53668b88eb
commit
f96c5661ac
@ -1714,12 +1714,12 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
|
||||
llvm::all_of(layouts_in, [&](const Layout &l) { return l.has_value(); }));
|
||||
TPU_ASSERT_OP(layouts_out.front().has_value());
|
||||
auto matmul_op = cast<tpu::MatmulOp>(op);
|
||||
auto transpose_lhs = matmul_op.getTransposeLhs();
|
||||
auto transpose_rhs = matmul_op.getTransposeRhs();
|
||||
auto &layout_lhs = *layouts_in[0];
|
||||
auto &layout_rhs = *layouts_in[1];
|
||||
auto &layout_acc = *layouts_in[2];
|
||||
auto layout_out = *layouts_out[0];
|
||||
const auto transpose_lhs = matmul_op.getTransposeLhs();
|
||||
const auto transpose_rhs = matmul_op.getTransposeRhs();
|
||||
const auto &layout_lhs = *layouts_in[0];
|
||||
const auto &layout_rhs = *layouts_in[1];
|
||||
const auto &layout_acc = *layouts_in[2];
|
||||
const auto &layout_out = *layouts_out[0];
|
||||
if (transpose_lhs) {
|
||||
return op.emitOpError("Not implemented: Transposed LHS");
|
||||
}
|
||||
@ -1740,7 +1740,7 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
|
||||
acc = tpu_matmul_op.getAcc();
|
||||
res = tpu_matmul_op.getResult();
|
||||
} else {
|
||||
LOG(FATAL) << "Unexpected op type";
|
||||
return op.emitOpError("Expected a tpu::MatmulOp");
|
||||
}
|
||||
|
||||
for (const std::optional<VectorLayout> &layout_opt : layouts_in) {
|
||||
@ -1755,7 +1755,7 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
|
||||
}
|
||||
}
|
||||
if (acc.getType().getElementType().getIntOrFloatBitWidth() != 32) {
|
||||
return op.emitOpError("Not implemented: Non-32-bit matmul result");
|
||||
return op.emitOpError("Not implemented: Non-32-bit matmul acc");
|
||||
}
|
||||
const ArrayRef<int64_t> lhs_shape = lhs.getType().getShape();
|
||||
const ArrayRef<int64_t> rhs_shape = rhs.getType().getShape();
|
||||
@ -1882,20 +1882,32 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
|
||||
// At this point, all paddings on vregs are masked out. For now, we
|
||||
// append zero vregs to make LHS's second dim, both RHS's dims and ACC's
|
||||
// second dim to be a multiple of mxu_size.
|
||||
if (ctx.mxu_shape[0] != ctx.mxu_shape[1]) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: MXU contracting size and noncontracting size are "
|
||||
"different");
|
||||
auto mxu_contracting_size = ctx.mxu_shape[0];
|
||||
auto mxu_noncontracting_size = ctx.mxu_shape[1];
|
||||
auto rhs_row_size = mxu_contracting_size;
|
||||
auto rhs_col_size = mxu_noncontracting_size;
|
||||
if (transpose_rhs) {
|
||||
rhs_row_size = mxu_noncontracting_size;
|
||||
rhs_col_size = mxu_contracting_size;
|
||||
}
|
||||
int64_t mxu_size = ctx.mxu_shape[0];
|
||||
CHECK_EQ(mxu_size % ctx.target_shape[0], 0);
|
||||
CHECK_EQ(mxu_size % ctx.target_shape[1], 0);
|
||||
auto mxu_row_vregs = mxu_size / (ctx.target_shape[0] * layout_rhs.packing());
|
||||
auto mxu_col_vregs = mxu_size / ctx.target_shape[1];
|
||||
int64_t target_lhs_col_vregs = llvm::alignTo(lhs_vregs.dim(1), mxu_col_vregs);
|
||||
int64_t target_rhs_row_vregs = llvm::alignTo(rhs_vregs.dim(0), mxu_row_vregs);
|
||||
int64_t target_rhs_col_vregs = llvm::alignTo(rhs_vregs.dim(1), mxu_col_vregs);
|
||||
int64_t target_acc_col_vregs = llvm::alignTo(acc_vregs.dim(1), mxu_col_vregs);
|
||||
CHECK_EQ(rhs_row_size % ctx.target_shape[1], 0);
|
||||
CHECK_EQ(rhs_col_size % ctx.target_shape[1], 0);
|
||||
|
||||
// Here, a single group corresponds to a single matmul invocation in unrolled
|
||||
// code. The RHS group matches the MXU shape.
|
||||
auto lhs_col_vregs_per_group = mxu_contracting_size / ctx.target_shape[1];
|
||||
auto rhs_row_vregs_per_group =
|
||||
rhs_row_size / (ctx.target_shape[0] * layout_rhs.packing());
|
||||
auto rhs_col_vregs_per_group = rhs_col_size / ctx.target_shape[1];
|
||||
auto acc_col_vregs_per_group = mxu_noncontracting_size / ctx.target_shape[1];
|
||||
int64_t target_lhs_col_vregs =
|
||||
llvm::alignTo(lhs_vregs.dim(1), lhs_col_vregs_per_group);
|
||||
int64_t target_rhs_row_vregs =
|
||||
llvm::alignTo(rhs_vregs.dim(0), rhs_row_vregs_per_group);
|
||||
int64_t target_rhs_col_vregs =
|
||||
llvm::alignTo(rhs_vregs.dim(1), rhs_col_vregs_per_group);
|
||||
int64_t target_acc_col_vregs =
|
||||
llvm::alignTo(acc_vregs.dim(1), acc_col_vregs_per_group);
|
||||
|
||||
xla::Array<Value> target_lhs_vregs({lhs_vregs.dim(0), target_lhs_col_vregs},
|
||||
lhs_zeros_vreg);
|
||||
@ -1908,10 +1920,11 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
|
||||
target_acc_vregs.UpdateSlice(acc_vregs, {0, 0});
|
||||
|
||||
// Now we can regroup vregs from target vregs.
|
||||
const auto lhs_col_ty = VectorType::get({padded_lhs_rows, mxu_size},
|
||||
lhs.getType().getElementType());
|
||||
const auto acc_col_ty = VectorType::get({padded_lhs_rows, mxu_size},
|
||||
acc.getType().getElementType());
|
||||
const auto lhs_col_ty = VectorType::get(
|
||||
{padded_lhs_rows, mxu_contracting_size}, lhs.getType().getElementType());
|
||||
const auto acc_col_ty =
|
||||
VectorType::get({padded_lhs_rows, mxu_noncontracting_size},
|
||||
acc.getType().getElementType());
|
||||
const ArrayAttr lhs_layout_attr =
|
||||
builder.getArrayAttr({builder.getAttr<VectorLayoutAttr>(layout_lhs)});
|
||||
const ArrayAttr rhs_layout_attr =
|
||||
@ -1919,40 +1932,39 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayAttr acc_layout_attr =
|
||||
builder.getArrayAttr({builder.getAttr<VectorLayoutAttr>(layout_acc)});
|
||||
|
||||
int64_t nk = llvm::divideCeil(lhs_shape[1], mxu_size);
|
||||
CHECK_EQ(nk, target_lhs_vregs.dim(1) / mxu_col_vregs);
|
||||
int64_t nk = llvm::divideCeil(lhs_shape[1], mxu_contracting_size);
|
||||
CHECK_EQ(nk, target_lhs_vregs.dim(1) / lhs_col_vregs_per_group);
|
||||
SmallVector<tpu::RollVectorsOp> lhs_cols(nk);
|
||||
for (int64_t i = 0; i < nk; ++i) {
|
||||
const xla::Array<Value> col_vregs = target_lhs_vregs.Slice(
|
||||
{0, i * mxu_col_vregs},
|
||||
{target_lhs_vregs.dim(0), (i + 1) * mxu_col_vregs});
|
||||
{0, i * lhs_col_vregs_per_group},
|
||||
{target_lhs_vregs.dim(0), (i + 1) * lhs_col_vregs_per_group});
|
||||
lhs_cols[i] = builder.create<tpu::RollVectorsOp>(
|
||||
op.getLoc(), lhs_col_ty, XlaArrayToFlatArrayRef(col_vregs));
|
||||
lhs_cols[i]->setAttr("out_layout", lhs_layout_attr);
|
||||
}
|
||||
// Here, "tile" is used as in the context of the MXU shape (NOT as in the
|
||||
// context of tiled layouts).
|
||||
const auto rhs_tile_ty =
|
||||
VectorType::get({mxu_size, mxu_size}, rhs.getType().getElementType());
|
||||
const int64_t rhs_vregs_per_tile = mxu_row_vregs * mxu_col_vregs;
|
||||
const auto rhs_group_ty = VectorType::get({rhs_row_size, rhs_col_size},
|
||||
rhs.getType().getElementType());
|
||||
const int64_t rhs_vregs_per_group =
|
||||
rhs_row_vregs_per_group * rhs_col_vregs_per_group;
|
||||
|
||||
int64_t nj;
|
||||
if (transpose_rhs) {
|
||||
nj = llvm::divideCeil(rhs_shape[0], mxu_size);
|
||||
CHECK_EQ(nk, llvm::divideCeil(rhs_shape[1], mxu_size));
|
||||
CHECK_EQ(nk, target_rhs_vregs.dim(1) / mxu_col_vregs);
|
||||
target_rhs_vregs.Reshape(
|
||||
{nj, rhs_vregs_per_tile / mxu_col_vregs, nk, mxu_col_vregs});
|
||||
nj = llvm::divideCeil(rhs_shape[0], rhs_row_size);
|
||||
CHECK_EQ(nk, llvm::divideCeil(rhs_shape[1], rhs_col_size));
|
||||
CHECK_EQ(nk, target_rhs_vregs.dim(1) / rhs_col_vregs_per_group);
|
||||
target_rhs_vregs.Reshape({nj, rhs_vregs_per_group / rhs_col_vregs_per_group,
|
||||
nk, rhs_col_vregs_per_group});
|
||||
target_rhs_vregs.TransposeDimensions({2, 0, 1, 3});
|
||||
target_rhs_vregs.Reshape({nk, nj, rhs_vregs_per_tile});
|
||||
target_rhs_vregs.Reshape({nk, nj, rhs_vregs_per_group});
|
||||
} else {
|
||||
nj = llvm::divideCeil(rhs_shape[1], mxu_size);
|
||||
CHECK_EQ(nk, llvm::divideCeil(rhs_shape[0], mxu_size));
|
||||
CHECK_EQ(nk, target_rhs_vregs.dim(0) / mxu_row_vregs);
|
||||
target_rhs_vregs.Reshape(
|
||||
{nk, rhs_vregs_per_tile / mxu_col_vregs, nj, mxu_col_vregs});
|
||||
nj = llvm::divideCeil(rhs_shape[1], rhs_col_size);
|
||||
CHECK_EQ(nk, llvm::divideCeil(rhs_shape[0], rhs_row_size));
|
||||
CHECK_EQ(nk, target_rhs_vregs.dim(0) / rhs_row_vregs_per_group);
|
||||
target_rhs_vregs.Reshape({nk, rhs_vregs_per_group / rhs_col_vregs_per_group,
|
||||
nj, rhs_col_vregs_per_group});
|
||||
target_rhs_vregs.TransposeDimensions({0, 2, 1, 3});
|
||||
target_rhs_vregs.Reshape({nk, nj, rhs_vregs_per_tile});
|
||||
target_rhs_vregs.Reshape({nk, nj, rhs_vregs_per_group});
|
||||
}
|
||||
|
||||
const tpu::ContractPrecisionAttr precision_attr = // May be null
|
||||
@ -1960,28 +1972,29 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
|
||||
for (int64_t j = 0; j < nj; ++j) {
|
||||
for (int64_t k = 0; k < nk; ++k) {
|
||||
// TODO(tlongeri): there should be a way to slice without copying
|
||||
xla::Array<Value> rhs_tile =
|
||||
target_rhs_vregs.Slice({k, j, 0}, {k + 1, j + 1, rhs_vregs_per_tile});
|
||||
auto rhs_rolled_tile = builder.create<tpu::RollVectorsOp>(
|
||||
op.getLoc(), rhs_tile_ty, XlaArrayToFlatArrayRef(rhs_tile));
|
||||
rhs_rolled_tile->setAttr("out_layout", rhs_layout_attr);
|
||||
xla::Array<Value> rhs_group = target_rhs_vregs.Slice(
|
||||
{k, j, 0}, {k + 1, j + 1, rhs_vregs_per_group});
|
||||
auto rhs_rolled_group = builder.create<tpu::RollVectorsOp>(
|
||||
op.getLoc(), rhs_group_ty, XlaArrayToFlatArrayRef(rhs_group));
|
||||
rhs_rolled_group->setAttr("out_layout", rhs_layout_attr);
|
||||
const xla::Array<Value> acc_col_vregs = target_acc_vregs.Slice(
|
||||
{0, j * mxu_col_vregs},
|
||||
{target_acc_vregs.dim(0), (j + 1) * mxu_col_vregs});
|
||||
{0, j * acc_col_vregs_per_group},
|
||||
{target_acc_vregs.dim(0), (j + 1) * acc_col_vregs_per_group});
|
||||
auto acc_col = builder.create<tpu::RollVectorsOp>(
|
||||
op.getLoc(), acc_col_ty, XlaArrayToFlatArrayRef(acc_col_vregs));
|
||||
acc_col->setAttr("out_layout", acc_layout_attr);
|
||||
auto new_acc_col = builder.create<tpu::MatmulOp>(
|
||||
op.getLoc(), acc_col_ty, lhs_cols[k], rhs_rolled_tile, acc_col,
|
||||
op.getLoc(), acc_col_ty, lhs_cols[k], rhs_rolled_group, acc_col,
|
||||
transpose_lhs, transpose_rhs, precision_attr);
|
||||
auto new_acc_vregs = builder.create<tpu::UnrollVectorsOp>(
|
||||
op.getLoc(),
|
||||
TypeRange(ValueRange(XlaArrayToFlatArrayRef(acc_col_vregs))),
|
||||
new_acc_col);
|
||||
new_acc_vregs->setAttr("in_layout", acc_layout_attr);
|
||||
updateSliceFromRange(target_acc_vregs, new_acc_vregs->getResults(),
|
||||
{0, j * mxu_col_vregs},
|
||||
{target_acc_vregs.dim(0), (j + 1) * mxu_col_vregs});
|
||||
updateSliceFromRange(
|
||||
target_acc_vregs, new_acc_vregs->getResults(),
|
||||
{0, j * acc_col_vregs_per_group},
|
||||
{target_acc_vregs.dim(0), (j + 1) * acc_col_vregs_per_group});
|
||||
}
|
||||
}
|
||||
op.replaceAllUsesWith(
|
||||
|
Loading…
x
Reference in New Issue
Block a user