mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Rewrite vector.contraction
with bf16 accumulator and output into a
contraction with f32 accumulator and output, where the accumulator is extended and the output truncated. For targets that do not support bf16 matmul, the lhs and rhs are extended to f32. PiperOrigin-RevId: 642051952
This commit is contained in:
parent
9d9dd36219
commit
71c19b779d
@ -708,7 +708,10 @@ def LinalgVectorizationPass : Pass<"linalg-vectorization", "::mlir::func::FuncOp
|
||||
"::mlir::tpu::TPUDialect",
|
||||
];
|
||||
let constructor = "::mlir::tpu::createLinalgVectorizationPass(false)";
|
||||
let options = [Option<"supports_bf16_alu_instructions", "supports-bf16-alu-instructions", "bool", "", "">];
|
||||
let options = [
|
||||
Option<"supports_bf16_alu_instructions", "supports-bf16-alu-instructions", "bool", "", "">,
|
||||
Option<"supports_bf16_matmul", "supports-bf16-matmul", "bool", "", "">,
|
||||
];
|
||||
}
|
||||
|
||||
#endif // TPU_ATTRS
|
||||
|
@ -64,7 +64,8 @@ std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
createLogicalToPhysicalDeviceIdPass(int64_t total_devices);
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createLinalgVectorizationPass(
|
||||
bool supports_bf16_alu_instructions = false);
|
||||
bool supports_bf16_alu_instructions = false,
|
||||
bool supports_bf16_matmul = false);
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createDebugAssertInsertionPass();
|
||||
|
||||
|
@ -181,17 +181,15 @@ struct TransferReadOfConstant
|
||||
|
||||
// Rewrite `vector.transfer_read(arith.select)` as `arith.select` with
|
||||
// `transfer_read` applied to its operands.
|
||||
struct TransferReadOfSelect
|
||||
: public ::mlir::OpRewritePattern<::mlir::vector::TransferReadOp> {
|
||||
using OpRewritePattern<::mlir::vector::TransferReadOp>::OpRewritePattern;
|
||||
struct TransferReadOfSelect : public OpRewritePattern<vector::TransferReadOp> {
|
||||
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
|
||||
|
||||
::mlir::LogicalResult matchAndRewrite(
|
||||
::mlir::vector::TransferReadOp op,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(vector::TransferReadOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (failed(checkPreconditions(op, rewriter))) {
|
||||
return failure();
|
||||
}
|
||||
auto select = op.getSource().getDefiningOp<::mlir::arith::SelectOp>();
|
||||
auto select = op.getSource().getDefiningOp<arith::SelectOp>();
|
||||
if (!select) {
|
||||
return rewriter.notifyMatchFailure(op, "source not an arith.select");
|
||||
}
|
||||
@ -214,27 +212,25 @@ struct TransferReadOfSelect
|
||||
auto transfer_read = [&](Value value, RankedTensorType type) {
|
||||
return createTransferReadOp(op, value, type, rewriter);
|
||||
};
|
||||
rewriter.replaceOpWithNewOp<::mlir::arith::SelectOp>(
|
||||
rewriter.replaceOpWithNewOp<arith::SelectOp>(
|
||||
op, transfer_read(select.getCondition(), condition_type),
|
||||
transfer_read(select.getTrueValue(), true_value_ty),
|
||||
transfer_read(select.getFalseValue(), false_value_ty));
|
||||
return ::mlir::success();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Rewrite `vector.transfer_read(arith.cmpi)` as `arith.cmpi` with
|
||||
// `transfer_read` applied to its operands.
|
||||
struct TransferReadOfCmpI
|
||||
: public ::mlir::OpRewritePattern<::mlir::vector::TransferReadOp> {
|
||||
using OpRewritePattern<::mlir::vector::TransferReadOp>::OpRewritePattern;
|
||||
struct TransferReadOfCmpI : public OpRewritePattern<vector::TransferReadOp> {
|
||||
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
|
||||
|
||||
::mlir::LogicalResult matchAndRewrite(
|
||||
::mlir::vector::TransferReadOp op,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(vector::TransferReadOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (failed(checkPreconditions(op, rewriter))) {
|
||||
return failure();
|
||||
}
|
||||
auto cmp = op.getSource().getDefiningOp<::mlir::arith::CmpIOp>();
|
||||
auto cmp = op.getSource().getDefiningOp<arith::CmpIOp>();
|
||||
if (!cmp) {
|
||||
return rewriter.notifyMatchFailure(op, "source not an arith.cmpi");
|
||||
}
|
||||
@ -249,25 +245,23 @@ struct TransferReadOfCmpI
|
||||
auto transfer_read = [&](Value value, RankedTensorType type) {
|
||||
return createTransferReadOp(op, value, type, rewriter);
|
||||
};
|
||||
rewriter.replaceOpWithNewOp<::mlir::arith::CmpIOp>(
|
||||
rewriter.replaceOpWithNewOp<arith::CmpIOp>(
|
||||
op, cmp.getPredicate(), transfer_read(cmp.getLhs(), lhs_type),
|
||||
transfer_read(cmp.getRhs(), rhs_type));
|
||||
return ::mlir::success();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Rewrite `vector.transfer_read(tensor.splat)` as `vector.broadcast`.
|
||||
struct TransferReadOfSplat
|
||||
: public ::mlir::OpRewritePattern<::mlir::vector::TransferReadOp> {
|
||||
using OpRewritePattern<::mlir::vector::TransferReadOp>::OpRewritePattern;
|
||||
struct TransferReadOfSplat : public OpRewritePattern<vector::TransferReadOp> {
|
||||
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
|
||||
|
||||
::mlir::LogicalResult matchAndRewrite(
|
||||
::mlir::vector::TransferReadOp op,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(vector::TransferReadOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (failed(checkPreconditions(op, rewriter))) {
|
||||
return failure();
|
||||
}
|
||||
auto splat = op.getSource().getDefiningOp<::mlir::tensor::SplatOp>();
|
||||
auto splat = op.getSource().getDefiningOp<tensor::SplatOp>();
|
||||
if (!splat) {
|
||||
return rewriter.notifyMatchFailure(op, "source not a tensor.splat");
|
||||
}
|
||||
@ -276,7 +270,7 @@ struct TransferReadOfSplat
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, op.getVectorType(),
|
||||
splat.getInput());
|
||||
return ::mlir::success();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@ -354,6 +348,77 @@ class GenericBitwidthConvert : public RewritePattern {
|
||||
const bool supports_bf16_alu_instructions_;
|
||||
};
|
||||
|
||||
// Rewrite `vector.contraction` with bf16 accumulator and output into a
|
||||
// contraction with f32 accumulator and output, where the accumulator is
|
||||
// extended and the output truncated. For targets that do not support bf16
|
||||
// matmul, the lhs and rhs are extended to f32.
|
||||
struct ContractionBitwidthConvert
|
||||
: public OpRewritePattern<vector::ContractionOp> {
|
||||
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
|
||||
|
||||
ContractionBitwidthConvert(bool supports_bf16_matmul, MLIRContext *ctx)
|
||||
: OpRewritePattern(ctx), supports_bf16_matmul_(supports_bf16_matmul) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// The ContractionOp contract is that (1) lhs and rhs have same element
|
||||
// type, and (2) the accumulator and result have the same element type.
|
||||
|
||||
// If the target does not support bf16 matmul and we have bf16 operands, we
|
||||
// need to extend the lhs and rhs to f32.
|
||||
const bool extend_operands =
|
||||
op.getLhsType().getElementType().isBF16() && !supports_bf16_matmul_;
|
||||
// Determine if the accumulator is bf16 and hence needs to be extended to
|
||||
// f32.
|
||||
ShapedType acc_ty = dyn_cast<ShapedType>(op.getAccType());
|
||||
if (acc_ty == nullptr) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"accumulator is not a shaped type");
|
||||
}
|
||||
const bool extend_acc = acc_ty.getElementType().isBF16();
|
||||
|
||||
if (!extend_operands && !extend_acc) {
|
||||
return rewriter.notifyMatchFailure(op, "no bf16 operands or accumulator");
|
||||
}
|
||||
|
||||
Value lhs = op.getLhs();
|
||||
Value rhs = op.getRhs();
|
||||
if (extend_operands) {
|
||||
lhs = rewriter.create<arith::ExtFOp>(
|
||||
op.getLoc(),
|
||||
VectorType::get(op.getLhsType().getShape(), rewriter.getF32Type()),
|
||||
lhs);
|
||||
rhs = rewriter.create<arith::ExtFOp>(
|
||||
op.getLoc(),
|
||||
VectorType::get(op.getRhsType().getShape(), rewriter.getF32Type()),
|
||||
rhs);
|
||||
}
|
||||
|
||||
Value acc = op.getAcc();
|
||||
if (extend_acc) {
|
||||
acc = rewriter.create<arith::ExtFOp>(
|
||||
op.getLoc(),
|
||||
VectorType::get(acc_ty.getShape(), rewriter.getF32Type()),
|
||||
op.getAcc());
|
||||
}
|
||||
|
||||
vector::ContractionOp contraction = rewriter.create<vector::ContractionOp>(
|
||||
op.getLoc(), lhs, rhs, acc, op.getIndexingMaps(), op.getIteratorTypes(),
|
||||
op.getKind());
|
||||
|
||||
if (extend_acc) {
|
||||
rewriter.replaceOpWithNewOp<arith::TruncFOp>(
|
||||
op, dyn_cast<ShapedType>(op.getResultType()), contraction);
|
||||
} else {
|
||||
rewriter.replaceOp(op, contraction);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
const bool supports_bf16_matmul_;
|
||||
};
|
||||
|
||||
struct LinalgVectorizationPass
|
||||
: public impl::LinalgVectorizationPassBase<LinalgVectorizationPass> {
|
||||
explicit LinalgVectorizationPass(
|
||||
@ -406,6 +471,7 @@ struct LinalgVectorizationPass
|
||||
patterns.add<GenericBitwidthConvert>(ternary_op_name, ctx,
|
||||
supports_bf16_alu_instructions);
|
||||
}
|
||||
patterns.add<ContractionBitwidthConvert>(supports_bf16_matmul, ctx);
|
||||
|
||||
// We do not want to apply the vector patterns above to the ops that are
|
||||
// unrelated to the original linalg op.
|
||||
@ -413,7 +479,8 @@ struct LinalgVectorizationPass
|
||||
func.walk([&](Operation *op) {
|
||||
if (dyn_cast<linalg::LinalgOp>(op) ||
|
||||
dyn_cast<vector::TransferReadOp>(op) ||
|
||||
dyn_cast<vector::TransferWriteOp>(op)) {
|
||||
dyn_cast<vector::TransferWriteOp>(op) ||
|
||||
dyn_cast<vector::ContractionOp>(op)) {
|
||||
linalgOps.push_back(op);
|
||||
}
|
||||
});
|
||||
@ -426,9 +493,10 @@ struct LinalgVectorizationPass
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createLinalgVectorizationPass(
|
||||
bool supports_bf16_alu_instructions) {
|
||||
bool supports_bf16_alu_instructions, bool supports_bf16_matmul) {
|
||||
LinalgVectorizationPassOptions options;
|
||||
options.supports_bf16_alu_instructions = supports_bf16_alu_instructions;
|
||||
options.supports_bf16_matmul = supports_bf16_matmul;
|
||||
return std::make_unique<LinalgVectorizationPass>(options);
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user