Rewrite vector.contraction with bf16 accumulator and output into a

contraction with f32 accumulator and output, where the accumulator is
extended and the output truncated. For targets that do not support bf16
matmul, the lhs and rhs are extended to f32.

PiperOrigin-RevId: 642051952
This commit is contained in:
jax authors 2024-06-10 16:02:00 -07:00 committed by jax authors
parent 9d9dd36219
commit 71c19b779d
3 changed files with 102 additions and 30 deletions

View File

@ -708,7 +708,10 @@ def LinalgVectorizationPass : Pass<"linalg-vectorization", "::mlir::func::FuncOp
"::mlir::tpu::TPUDialect",
];
let constructor = "::mlir::tpu::createLinalgVectorizationPass(false)";
let options = [Option<"supports_bf16_alu_instructions", "supports-bf16-alu-instructions", "bool", "", "">];
let options = [
Option<"supports_bf16_alu_instructions", "supports-bf16-alu-instructions", "bool", "", "">,
Option<"supports_bf16_matmul", "supports-bf16-matmul", "bool", "", "">,
];
}
#endif // TPU_ATTRS

View File

@ -64,7 +64,8 @@ std::unique_ptr<OperationPass<func::FuncOp>>
createLogicalToPhysicalDeviceIdPass(int64_t total_devices);
std::unique_ptr<OperationPass<func::FuncOp>> createLinalgVectorizationPass(
bool supports_bf16_alu_instructions = false);
bool supports_bf16_alu_instructions = false,
bool supports_bf16_matmul = false);
std::unique_ptr<OperationPass<func::FuncOp>> createDebugAssertInsertionPass();

View File

@ -181,17 +181,15 @@ struct TransferReadOfConstant
// Rewrite `vector.transfer_read(arith.select)` as `arith.select` with
// `transfer_read` applied to its operands.
struct TransferReadOfSelect
: public ::mlir::OpRewritePattern<::mlir::vector::TransferReadOp> {
using OpRewritePattern<::mlir::vector::TransferReadOp>::OpRewritePattern;
struct TransferReadOfSelect : public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
::mlir::LogicalResult matchAndRewrite(
::mlir::vector::TransferReadOp op,
::mlir::PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(vector::TransferReadOp op,
PatternRewriter &rewriter) const override {
if (failed(checkPreconditions(op, rewriter))) {
return failure();
}
auto select = op.getSource().getDefiningOp<::mlir::arith::SelectOp>();
auto select = op.getSource().getDefiningOp<arith::SelectOp>();
if (!select) {
return rewriter.notifyMatchFailure(op, "source not an arith.select");
}
@ -214,27 +212,25 @@ struct TransferReadOfSelect
auto transfer_read = [&](Value value, RankedTensorType type) {
return createTransferReadOp(op, value, type, rewriter);
};
rewriter.replaceOpWithNewOp<::mlir::arith::SelectOp>(
rewriter.replaceOpWithNewOp<arith::SelectOp>(
op, transfer_read(select.getCondition(), condition_type),
transfer_read(select.getTrueValue(), true_value_ty),
transfer_read(select.getFalseValue(), false_value_ty));
return ::mlir::success();
return success();
}
};
// Rewrite `vector.transfer_read(arith.cmpi)` as `arith.cmpi` with
// `transfer_read` applied to its operands.
struct TransferReadOfCmpI
: public ::mlir::OpRewritePattern<::mlir::vector::TransferReadOp> {
using OpRewritePattern<::mlir::vector::TransferReadOp>::OpRewritePattern;
struct TransferReadOfCmpI : public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
::mlir::LogicalResult matchAndRewrite(
::mlir::vector::TransferReadOp op,
::mlir::PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(vector::TransferReadOp op,
PatternRewriter &rewriter) const override {
if (failed(checkPreconditions(op, rewriter))) {
return failure();
}
auto cmp = op.getSource().getDefiningOp<::mlir::arith::CmpIOp>();
auto cmp = op.getSource().getDefiningOp<arith::CmpIOp>();
if (!cmp) {
return rewriter.notifyMatchFailure(op, "source not an arith.cmpi");
}
@ -249,25 +245,23 @@ struct TransferReadOfCmpI
auto transfer_read = [&](Value value, RankedTensorType type) {
return createTransferReadOp(op, value, type, rewriter);
};
rewriter.replaceOpWithNewOp<::mlir::arith::CmpIOp>(
rewriter.replaceOpWithNewOp<arith::CmpIOp>(
op, cmp.getPredicate(), transfer_read(cmp.getLhs(), lhs_type),
transfer_read(cmp.getRhs(), rhs_type));
return ::mlir::success();
return success();
}
};
// Rewrite `vector.transfer_read(tensor.splat)` as `vector.broadcast`.
struct TransferReadOfSplat
: public ::mlir::OpRewritePattern<::mlir::vector::TransferReadOp> {
using OpRewritePattern<::mlir::vector::TransferReadOp>::OpRewritePattern;
struct TransferReadOfSplat : public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
::mlir::LogicalResult matchAndRewrite(
::mlir::vector::TransferReadOp op,
::mlir::PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(vector::TransferReadOp op,
PatternRewriter &rewriter) const override {
if (failed(checkPreconditions(op, rewriter))) {
return failure();
}
auto splat = op.getSource().getDefiningOp<::mlir::tensor::SplatOp>();
auto splat = op.getSource().getDefiningOp<tensor::SplatOp>();
if (!splat) {
return rewriter.notifyMatchFailure(op, "source not a tensor.splat");
}
@ -276,7 +270,7 @@ struct TransferReadOfSplat
}
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, op.getVectorType(),
splat.getInput());
return ::mlir::success();
return success();
}
};
@ -354,6 +348,77 @@ class GenericBitwidthConvert : public RewritePattern {
const bool supports_bf16_alu_instructions_;
};
// Rewrite `vector.contraction` with bf16 accumulator and output into a
// contraction with f32 accumulator and output, where the accumulator is
// extended and the output truncated. For targets that do not support bf16
// matmul, the lhs and rhs are extended to f32.
struct ContractionBitwidthConvert
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
ContractionBitwidthConvert(bool supports_bf16_matmul, MLIRContext *ctx)
: OpRewritePattern(ctx), supports_bf16_matmul_(supports_bf16_matmul) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {
// The ContractionOp contract is that (1) lhs and rhs have same element
// type, and (2) the accumulator and result have the same element type.
// If the target does not support bf16 matmul and we have bf16 operands, we
// need to extend the lhs and rhs to f32.
const bool extend_operands =
op.getLhsType().getElementType().isBF16() && !supports_bf16_matmul_;
// Determine if the accumulator is bf16 and hence needs to be extended to
// f32.
ShapedType acc_ty = dyn_cast<ShapedType>(op.getAccType());
if (acc_ty == nullptr) {
return rewriter.notifyMatchFailure(op,
"accumulator is not a shaped type");
}
const bool extend_acc = acc_ty.getElementType().isBF16();
if (!extend_operands && !extend_acc) {
return rewriter.notifyMatchFailure(op, "no bf16 operands or accumulator");
}
Value lhs = op.getLhs();
Value rhs = op.getRhs();
if (extend_operands) {
lhs = rewriter.create<arith::ExtFOp>(
op.getLoc(),
VectorType::get(op.getLhsType().getShape(), rewriter.getF32Type()),
lhs);
rhs = rewriter.create<arith::ExtFOp>(
op.getLoc(),
VectorType::get(op.getRhsType().getShape(), rewriter.getF32Type()),
rhs);
}
Value acc = op.getAcc();
if (extend_acc) {
acc = rewriter.create<arith::ExtFOp>(
op.getLoc(),
VectorType::get(acc_ty.getShape(), rewriter.getF32Type()),
op.getAcc());
}
vector::ContractionOp contraction = rewriter.create<vector::ContractionOp>(
op.getLoc(), lhs, rhs, acc, op.getIndexingMaps(), op.getIteratorTypes(),
op.getKind());
if (extend_acc) {
rewriter.replaceOpWithNewOp<arith::TruncFOp>(
op, dyn_cast<ShapedType>(op.getResultType()), contraction);
} else {
rewriter.replaceOp(op, contraction);
}
return success();
}
private:
const bool supports_bf16_matmul_;
};
struct LinalgVectorizationPass
: public impl::LinalgVectorizationPassBase<LinalgVectorizationPass> {
explicit LinalgVectorizationPass(
@ -406,6 +471,7 @@ struct LinalgVectorizationPass
patterns.add<GenericBitwidthConvert>(ternary_op_name, ctx,
supports_bf16_alu_instructions);
}
patterns.add<ContractionBitwidthConvert>(supports_bf16_matmul, ctx);
// We do not want to apply the vector patterns above to the ops that are
// unrelated to the original linalg op.
@ -413,7 +479,8 @@ struct LinalgVectorizationPass
func.walk([&](Operation *op) {
if (dyn_cast<linalg::LinalgOp>(op) ||
dyn_cast<vector::TransferReadOp>(op) ||
dyn_cast<vector::TransferWriteOp>(op)) {
dyn_cast<vector::TransferWriteOp>(op) ||
dyn_cast<vector::ContractionOp>(op)) {
linalgOps.push_back(op);
}
});
@ -426,9 +493,10 @@ struct LinalgVectorizationPass
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>> createLinalgVectorizationPass(
bool supports_bf16_alu_instructions) {
bool supports_bf16_alu_instructions, bool supports_bf16_matmul) {
LinalgVectorizationPassOptions options;
options.supports_bf16_alu_instructions = supports_bf16_alu_instructions;
options.supports_bf16_matmul = supports_bf16_matmul;
return std::make_unique<LinalgVectorizationPass>(options);
}