mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-27 16:56:06 +00:00
[MLIR] Canonicalize sub/add of a constant and another sub/add of a constant
Differential Revision: https://reviews.llvm.org/D101705
This commit is contained in:
parent
3ed6a6f6cd
commit
039bdcc0a8
@ -277,6 +277,7 @@ def AddIOp : IntBinaryOp<"addi", [Commutative]> {
|
||||
```
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1792,6 +1793,7 @@ def SubFOp : FloatBinaryOp<"subf"> {
|
||||
def SubIOp : IntBinaryOp<"subi"> {
|
||||
let summary = "integer subtraction operation";
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -283,6 +283,62 @@ static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
|
||||
}));
|
||||
}
|
||||
|
||||
/// Canonicalize a sum of a constant and (constant - something) to simply be
|
||||
/// a sum of constants minus something. This transformation does similar
|
||||
/// transformations for additions of a constant with a subtract/add of
|
||||
/// a constant. This may result in some operations being reordered (but should
|
||||
/// remain equivalent).
|
||||
struct AddConstantReorder : public OpRewritePattern<AddIOp> {
|
||||
using OpRewritePattern<AddIOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(AddIOp addop,
|
||||
PatternRewriter &rewriter) const override {
|
||||
for (int i = 0; i < 2; i++) {
|
||||
APInt origConst;
|
||||
APInt midConst;
|
||||
if (matchPattern(addop.getOperand(i), m_ConstantInt(&origConst))) {
|
||||
if (auto midAddOp = addop.getOperand(1 - i).getDefiningOp<AddIOp>()) {
|
||||
for (int j = 0; j < 2; j++) {
|
||||
if (matchPattern(midAddOp.getOperand(j),
|
||||
m_ConstantInt(&midConst))) {
|
||||
auto nextConstant = rewriter.create<ConstantOp>(
|
||||
addop.getLoc(), rewriter.getIntegerAttr(
|
||||
addop.getType(), origConst + midConst));
|
||||
rewriter.replaceOpWithNewOp<AddIOp>(addop, nextConstant,
|
||||
midAddOp.getOperand(1 - j));
|
||||
return success();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (auto midSubOp = addop.getOperand(1 - i).getDefiningOp<SubIOp>()) {
|
||||
if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
|
||||
auto nextConstant = rewriter.create<ConstantOp>(
|
||||
addop.getLoc(),
|
||||
rewriter.getIntegerAttr(addop.getType(), origConst + midConst));
|
||||
rewriter.replaceOpWithNewOp<SubIOp>(addop, nextConstant,
|
||||
midSubOp.getOperand(1));
|
||||
return success();
|
||||
}
|
||||
if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
|
||||
auto nextConstant = rewriter.create<ConstantOp>(
|
||||
addop.getLoc(),
|
||||
rewriter.getIntegerAttr(addop.getType(), origConst - midConst));
|
||||
rewriter.replaceOpWithNewOp<AddIOp>(addop, nextConstant,
|
||||
midSubOp.getOperand(0));
|
||||
return success();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
void AddIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<AddConstantReorder>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AndOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1706,6 +1762,153 @@ OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
|
||||
[](APInt a, APInt b) { return a - b; });
|
||||
}
|
||||
|
||||
/// Canonicalize a sub of a constant and (constant +/- something) to simply be
|
||||
/// a single operation that merges the two constants.
|
||||
struct SubConstantReorder : public OpRewritePattern<SubIOp> {
|
||||
using OpRewritePattern<SubIOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(SubIOp subOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
APInt origConst;
|
||||
APInt midConst;
|
||||
|
||||
if (matchPattern(subOp.getOperand(0), m_ConstantInt(&origConst))) {
|
||||
if (auto midAddOp = subOp.getOperand(1).getDefiningOp<AddIOp>()) {
|
||||
// origConst - (midConst + something) == (origConst - midConst) -
|
||||
// something
|
||||
for (int j = 0; j < 2; j++) {
|
||||
if (matchPattern(midAddOp.getOperand(j), m_ConstantInt(&midConst))) {
|
||||
auto nextConstant = rewriter.create<ConstantOp>(
|
||||
subOp.getLoc(),
|
||||
rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
|
||||
rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
|
||||
midAddOp.getOperand(1 - j));
|
||||
return success();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (auto midSubOp = subOp.getOperand(0).getDefiningOp<SubIOp>()) {
|
||||
if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
|
||||
// (midConst - something) - origConst == (midConst - origConst) -
|
||||
// something
|
||||
auto nextConstant = rewriter.create<ConstantOp>(
|
||||
subOp.getLoc(),
|
||||
rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
|
||||
rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
|
||||
midSubOp.getOperand(1));
|
||||
return success();
|
||||
}
|
||||
|
||||
if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
|
||||
// (something - midConst) - origConst == something - (origConst +
|
||||
// midConst)
|
||||
auto nextConstant = rewriter.create<ConstantOp>(
|
||||
subOp.getLoc(),
|
||||
rewriter.getIntegerAttr(subOp.getType(), origConst + midConst));
|
||||
rewriter.replaceOpWithNewOp<SubIOp>(subOp, midSubOp.getOperand(0),
|
||||
nextConstant);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
if (auto midSubOp = subOp.getOperand(1).getDefiningOp<SubIOp>()) {
|
||||
if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
|
||||
// origConst - (midConst - something) == (origConst - midConst) +
|
||||
// something
|
||||
auto nextConstant = rewriter.create<ConstantOp>(
|
||||
subOp.getLoc(),
|
||||
rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
|
||||
rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
|
||||
midSubOp.getOperand(1));
|
||||
return success();
|
||||
}
|
||||
|
||||
if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
|
||||
// origConst - (something - midConst) == (origConst + midConst) -
|
||||
// something
|
||||
auto nextConstant = rewriter.create<ConstantOp>(
|
||||
subOp.getLoc(),
|
||||
rewriter.getIntegerAttr(subOp.getType(), origConst + midConst));
|
||||
rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
|
||||
midSubOp.getOperand(0));
|
||||
return success();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (matchPattern(subOp.getOperand(1), m_ConstantInt(&origConst))) {
|
||||
if (auto midAddOp = subOp.getOperand(0).getDefiningOp<AddIOp>()) {
|
||||
// (midConst + something) - origConst == (midConst - origConst) +
|
||||
// something
|
||||
for (int j = 0; j < 2; j++) {
|
||||
if (matchPattern(midAddOp.getOperand(j), m_ConstantInt(&midConst))) {
|
||||
auto nextConstant = rewriter.create<ConstantOp>(
|
||||
subOp.getLoc(),
|
||||
rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
|
||||
rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
|
||||
midAddOp.getOperand(1 - j));
|
||||
return success();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (auto midSubOp = subOp.getOperand(0).getDefiningOp<SubIOp>()) {
|
||||
if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
|
||||
// (midConst - something) - origConst == (midConst - origConst) -
|
||||
// something
|
||||
auto nextConstant = rewriter.create<ConstantOp>(
|
||||
subOp.getLoc(),
|
||||
rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
|
||||
rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
|
||||
midSubOp.getOperand(1));
|
||||
return success();
|
||||
}
|
||||
|
||||
if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
|
||||
// (something - midConst) - origConst == something - (midConst +
|
||||
// origConst)
|
||||
auto nextConstant = rewriter.create<ConstantOp>(
|
||||
subOp.getLoc(),
|
||||
rewriter.getIntegerAttr(subOp.getType(), midConst + origConst));
|
||||
rewriter.replaceOpWithNewOp<SubIOp>(subOp, midSubOp.getOperand(0),
|
||||
nextConstant);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
if (auto midSubOp = subOp.getOperand(1).getDefiningOp<SubIOp>()) {
|
||||
if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
|
||||
// origConst - (midConst - something) == (origConst - midConst) +
|
||||
// something
|
||||
auto nextConstant = rewriter.create<ConstantOp>(
|
||||
subOp.getLoc(),
|
||||
rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
|
||||
rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
|
||||
midSubOp.getOperand(1));
|
||||
return success();
|
||||
}
|
||||
if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
|
||||
// origConst - (something - midConst) == (origConst - midConst) -
|
||||
// something
|
||||
auto nextConstant = rewriter.create<ConstantOp>(
|
||||
subOp.getLoc(),
|
||||
rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
|
||||
rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
|
||||
midSubOp.getOperand(0));
|
||||
return success();
|
||||
}
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
void SubIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<SubConstantReorder>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// UIToFPOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -428,3 +428,113 @@ func @truncConstant(%arg0: i8) -> i16 {
|
||||
%tr = trunci %c-2 : i32 to i16
|
||||
return %tr : i16
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @tripleAddAdd
|
||||
// CHECK: %[[cres:.+]] = constant 59 : index
|
||||
// CHECK: %[[add:.+]] = addi %arg0, %[[cres]] : index
|
||||
// CHECK: return %[[add]]
|
||||
func @tripleAddAdd(%arg0: index) -> index {
|
||||
%c17 = constant 17 : index
|
||||
%c42 = constant 42 : index
|
||||
%add1 = addi %c17, %arg0 : index
|
||||
%add2 = addi %c42, %add1 : index
|
||||
return %add2 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @tripleAddSub0
|
||||
// CHECK: %[[cres:.+]] = constant 59 : index
|
||||
// CHECK: %[[add:.+]] = subi %[[cres]], %arg0 : index
|
||||
// CHECK: return %[[add]]
|
||||
func @tripleAddSub0(%arg0: index) -> index {
|
||||
%c17 = constant 17 : index
|
||||
%c42 = constant 42 : index
|
||||
%add1 = subi %c17, %arg0 : index
|
||||
%add2 = addi %c42, %add1 : index
|
||||
return %add2 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @tripleAddSub1
|
||||
// CHECK: %[[cres:.+]] = constant 25 : index
|
||||
// CHECK: %[[add:.+]] = addi %arg0, %[[cres]] : index
|
||||
// CHECK: return %[[add]]
|
||||
func @tripleAddSub1(%arg0: index) -> index {
|
||||
%c17 = constant 17 : index
|
||||
%c42 = constant 42 : index
|
||||
%add1 = subi %arg0, %c17 : index
|
||||
%add2 = addi %c42, %add1 : index
|
||||
return %add2 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @tripleSubAdd0
|
||||
// CHECK: %[[cres:.+]] = constant 25 : index
|
||||
// CHECK: %[[add:.+]] = subi %[[cres]], %arg0 : index
|
||||
// CHECK: return %[[add]]
|
||||
func @tripleSubAdd0(%arg0: index) -> index {
|
||||
%c17 = constant 17 : index
|
||||
%c42 = constant 42 : index
|
||||
%add1 = addi %c17, %arg0 : index
|
||||
%add2 = subi %c42, %add1 : index
|
||||
return %add2 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @tripleSubAdd1
|
||||
// CHECK: %[[cres:.+]] = constant -25 : index
|
||||
// CHECK: %[[add:.+]] = addi %arg0, %[[cres]] : index
|
||||
// CHECK: return %[[add]]
|
||||
func @tripleSubAdd1(%arg0: index) -> index {
|
||||
%c17 = constant 17 : index
|
||||
%c42 = constant 42 : index
|
||||
%add1 = addi %c17, %arg0 : index
|
||||
%add2 = subi %add1, %c42 : index
|
||||
return %add2 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @tripleSubSub0
|
||||
// CHECK: %[[cres:.+]] = constant 25 : index
|
||||
// CHECK: %[[add:.+]] = addi %arg0, %[[cres]] : index
|
||||
// CHECK: return %[[add]]
|
||||
func @tripleSubSub0(%arg0: index) -> index {
|
||||
%c17 = constant 17 : index
|
||||
%c42 = constant 42 : index
|
||||
%add1 = subi %c17, %arg0 : index
|
||||
%add2 = subi %c42, %add1 : index
|
||||
return %add2 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @tripleSubSub1
|
||||
// CHECK: %[[cres:.+]] = constant -25 : index
|
||||
// CHECK: %[[add:.+]] = subi %[[cres]], %arg0 : index
|
||||
// CHECK: return %[[add]]
|
||||
func @tripleSubSub1(%arg0: index) -> index {
|
||||
%c17 = constant 17 : index
|
||||
%c42 = constant 42 : index
|
||||
%add1 = subi %c17, %arg0 : index
|
||||
%add2 = subi %add1, %c42 : index
|
||||
return %add2 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @tripleSubSub2
|
||||
// CHECK: %[[cres:.+]] = constant 59 : index
|
||||
// CHECK: %[[add:.+]] = subi %[[cres]], %arg0 : index
|
||||
// CHECK: return %[[add]]
|
||||
func @tripleSubSub2(%arg0: index) -> index {
|
||||
%c17 = constant 17 : index
|
||||
%c42 = constant 42 : index
|
||||
%add1 = subi %arg0, %c17 : index
|
||||
%add2 = subi %c42, %add1 : index
|
||||
return %add2 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @tripleSubSub3
|
||||
// CHECK: %[[cres:.+]] = constant 59 : index
|
||||
// CHECK: %[[add:.+]] = subi %arg0, %[[cres]] : index
|
||||
// CHECK: return %[[add]]
|
||||
func @tripleSubSub3(%arg0: index) -> index {
|
||||
%c17 = constant 17 : index
|
||||
%c42 = constant 42 : index
|
||||
%add1 = subi %arg0, %c17 : index
|
||||
%add2 = subi %add1, %c42 : index
|
||||
return %add2 : index
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user