[MLIR] Canonicalize sub/add of a constant and another sub/add of a constant

Differential Revision: https://reviews.llvm.org/D101705
This commit is contained in:
William S. Moses 2021-05-01 22:54:23 -04:00
parent 3ed6a6f6cd
commit 039bdcc0a8
3 changed files with 315 additions and 0 deletions

View File

@ -277,6 +277,7 @@ def AddIOp : IntBinaryOp<"addi", [Commutative]> {
```
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
@ -1792,6 +1793,7 @@ def SubFOp : FloatBinaryOp<"subf"> {
def SubIOp : IntBinaryOp<"subi"> {
let summary = "integer subtraction operation";
let hasFolder = 1;
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//

View File

@ -283,6 +283,62 @@ static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
}));
}
/// Canonicalize a sum of a constant and (constant - something) to simply be
/// a sum of constants minus something. This transformation does similar
/// transformations for additions of a constant with a subtract/add of
/// a constant. This may result in some operations being reordered (but should
/// remain equivalent).
struct AddConstantReorder : public OpRewritePattern<AddIOp> {
using OpRewritePattern<AddIOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AddIOp addop,
PatternRewriter &rewriter) const override {
for (int i = 0; i < 2; i++) {
APInt origConst;
APInt midConst;
if (matchPattern(addop.getOperand(i), m_ConstantInt(&origConst))) {
if (auto midAddOp = addop.getOperand(1 - i).getDefiningOp<AddIOp>()) {
for (int j = 0; j < 2; j++) {
if (matchPattern(midAddOp.getOperand(j),
m_ConstantInt(&midConst))) {
auto nextConstant = rewriter.create<ConstantOp>(
addop.getLoc(), rewriter.getIntegerAttr(
addop.getType(), origConst + midConst));
rewriter.replaceOpWithNewOp<AddIOp>(addop, nextConstant,
midAddOp.getOperand(1 - j));
return success();
}
}
}
if (auto midSubOp = addop.getOperand(1 - i).getDefiningOp<SubIOp>()) {
if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
auto nextConstant = rewriter.create<ConstantOp>(
addop.getLoc(),
rewriter.getIntegerAttr(addop.getType(), origConst + midConst));
rewriter.replaceOpWithNewOp<SubIOp>(addop, nextConstant,
midSubOp.getOperand(1));
return success();
}
if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
auto nextConstant = rewriter.create<ConstantOp>(
addop.getLoc(),
rewriter.getIntegerAttr(addop.getType(), origConst - midConst));
rewriter.replaceOpWithNewOp<AddIOp>(addop, nextConstant,
midSubOp.getOperand(0));
return success();
}
}
}
}
return failure();
}
};
void AddIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<AddConstantReorder>(context);
}
//===----------------------------------------------------------------------===//
// AndOp
//===----------------------------------------------------------------------===//
@ -1706,6 +1762,153 @@ OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
[](APInt a, APInt b) { return a - b; });
}
/// Canonicalize a sub of a constant and (constant +/- something) to simply be
/// a single operation that merges the two constants.
struct SubConstantReorder : public OpRewritePattern<SubIOp> {
using OpRewritePattern<SubIOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SubIOp subOp,
PatternRewriter &rewriter) const override {
APInt origConst;
APInt midConst;
if (matchPattern(subOp.getOperand(0), m_ConstantInt(&origConst))) {
if (auto midAddOp = subOp.getOperand(1).getDefiningOp<AddIOp>()) {
// origConst - (midConst + something) == (origConst - midConst) -
// something
for (int j = 0; j < 2; j++) {
if (matchPattern(midAddOp.getOperand(j), m_ConstantInt(&midConst))) {
auto nextConstant = rewriter.create<ConstantOp>(
subOp.getLoc(),
rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
midAddOp.getOperand(1 - j));
return success();
}
}
}
if (auto midSubOp = subOp.getOperand(0).getDefiningOp<SubIOp>()) {
if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
// (midConst - something) - origConst == (midConst - origConst) -
// something
auto nextConstant = rewriter.create<ConstantOp>(
subOp.getLoc(),
rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
midSubOp.getOperand(1));
return success();
}
if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
// (something - midConst) - origConst == something - (origConst +
// midConst)
auto nextConstant = rewriter.create<ConstantOp>(
subOp.getLoc(),
rewriter.getIntegerAttr(subOp.getType(), origConst + midConst));
rewriter.replaceOpWithNewOp<SubIOp>(subOp, midSubOp.getOperand(0),
nextConstant);
return success();
}
}
if (auto midSubOp = subOp.getOperand(1).getDefiningOp<SubIOp>()) {
if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
// origConst - (midConst - something) == (origConst - midConst) +
// something
auto nextConstant = rewriter.create<ConstantOp>(
subOp.getLoc(),
rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
midSubOp.getOperand(1));
return success();
}
if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
// origConst - (something - midConst) == (origConst + midConst) -
// something
auto nextConstant = rewriter.create<ConstantOp>(
subOp.getLoc(),
rewriter.getIntegerAttr(subOp.getType(), origConst + midConst));
rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
midSubOp.getOperand(0));
return success();
}
}
}
if (matchPattern(subOp.getOperand(1), m_ConstantInt(&origConst))) {
if (auto midAddOp = subOp.getOperand(0).getDefiningOp<AddIOp>()) {
// (midConst + something) - origConst == (midConst - origConst) +
// something
for (int j = 0; j < 2; j++) {
if (matchPattern(midAddOp.getOperand(j), m_ConstantInt(&midConst))) {
auto nextConstant = rewriter.create<ConstantOp>(
subOp.getLoc(),
rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
midAddOp.getOperand(1 - j));
return success();
}
}
}
if (auto midSubOp = subOp.getOperand(0).getDefiningOp<SubIOp>()) {
if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
// (midConst - something) - origConst == (midConst - origConst) -
// something
auto nextConstant = rewriter.create<ConstantOp>(
subOp.getLoc(),
rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
midSubOp.getOperand(1));
return success();
}
if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
// (something - midConst) - origConst == something - (midConst +
// origConst)
auto nextConstant = rewriter.create<ConstantOp>(
subOp.getLoc(),
rewriter.getIntegerAttr(subOp.getType(), midConst + origConst));
rewriter.replaceOpWithNewOp<SubIOp>(subOp, midSubOp.getOperand(0),
nextConstant);
return success();
}
}
if (auto midSubOp = subOp.getOperand(1).getDefiningOp<SubIOp>()) {
if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
// origConst - (midConst - something) == (origConst - midConst) +
// something
auto nextConstant = rewriter.create<ConstantOp>(
subOp.getLoc(),
rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
midSubOp.getOperand(1));
return success();
}
if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
// origConst - (something - midConst) == (origConst - midConst) -
// something
auto nextConstant = rewriter.create<ConstantOp>(
subOp.getLoc(),
rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
midSubOp.getOperand(0));
return success();
}
}
}
return failure();
}
};
void SubIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<SubConstantReorder>(context);
}
//===----------------------------------------------------------------------===//
// UIToFPOp
//===----------------------------------------------------------------------===//

View File

@ -428,3 +428,113 @@ func @truncConstant(%arg0: i8) -> i16 {
%tr = trunci %c-2 : i32 to i16
return %tr : i16
}
// -----
// CHECK-LABEL: @tripleAddAdd
// CHECK: %[[cres:.+]] = constant 59 : index
// CHECK: %[[add:.+]] = addi %arg0, %[[cres]] : index
// CHECK: return %[[add]]
func @tripleAddAdd(%arg0: index) -> index {
%c17 = constant 17 : index
%c42 = constant 42 : index
%add1 = addi %c17, %arg0 : index
%add2 = addi %c42, %add1 : index
return %add2 : index
}
// CHECK-LABEL: @tripleAddSub0
// CHECK: %[[cres:.+]] = constant 59 : index
// CHECK: %[[add:.+]] = subi %[[cres]], %arg0 : index
// CHECK: return %[[add]]
func @tripleAddSub0(%arg0: index) -> index {
%c17 = constant 17 : index
%c42 = constant 42 : index
%add1 = subi %c17, %arg0 : index
%add2 = addi %c42, %add1 : index
return %add2 : index
}
// CHECK-LABEL: @tripleAddSub1
// CHECK: %[[cres:.+]] = constant 25 : index
// CHECK: %[[add:.+]] = addi %arg0, %[[cres]] : index
// CHECK: return %[[add]]
func @tripleAddSub1(%arg0: index) -> index {
%c17 = constant 17 : index
%c42 = constant 42 : index
%add1 = subi %arg0, %c17 : index
%add2 = addi %c42, %add1 : index
return %add2 : index
}
// CHECK-LABEL: @tripleSubAdd0
// CHECK: %[[cres:.+]] = constant 25 : index
// CHECK: %[[add:.+]] = subi %[[cres]], %arg0 : index
// CHECK: return %[[add]]
func @tripleSubAdd0(%arg0: index) -> index {
%c17 = constant 17 : index
%c42 = constant 42 : index
%add1 = addi %c17, %arg0 : index
%add2 = subi %c42, %add1 : index
return %add2 : index
}
// CHECK-LABEL: @tripleSubAdd1
// CHECK: %[[cres:.+]] = constant -25 : index
// CHECK: %[[add:.+]] = addi %arg0, %[[cres]] : index
// CHECK: return %[[add]]
func @tripleSubAdd1(%arg0: index) -> index {
%c17 = constant 17 : index
%c42 = constant 42 : index
%add1 = addi %c17, %arg0 : index
%add2 = subi %add1, %c42 : index
return %add2 : index
}
// CHECK-LABEL: @tripleSubSub0
// CHECK: %[[cres:.+]] = constant 25 : index
// CHECK: %[[add:.+]] = addi %arg0, %[[cres]] : index
// CHECK: return %[[add]]
func @tripleSubSub0(%arg0: index) -> index {
%c17 = constant 17 : index
%c42 = constant 42 : index
%add1 = subi %c17, %arg0 : index
%add2 = subi %c42, %add1 : index
return %add2 : index
}
// CHECK-LABEL: @tripleSubSub1
// CHECK: %[[cres:.+]] = constant -25 : index
// CHECK: %[[add:.+]] = subi %[[cres]], %arg0 : index
// CHECK: return %[[add]]
func @tripleSubSub1(%arg0: index) -> index {
%c17 = constant 17 : index
%c42 = constant 42 : index
%add1 = subi %c17, %arg0 : index
%add2 = subi %add1, %c42 : index
return %add2 : index
}
// CHECK-LABEL: @tripleSubSub2
// CHECK: %[[cres:.+]] = constant 59 : index
// CHECK: %[[add:.+]] = subi %[[cres]], %arg0 : index
// CHECK: return %[[add]]
func @tripleSubSub2(%arg0: index) -> index {
%c17 = constant 17 : index
%c42 = constant 42 : index
%add1 = subi %arg0, %c17 : index
%add2 = subi %c42, %add1 : index
return %add2 : index
}
// CHECK-LABEL: @tripleSubSub3
// CHECK: %[[cres:.+]] = constant 59 : index
// CHECK: %[[add:.+]] = subi %arg0, %[[cres]] : index
// CHECK: return %[[add]]
func @tripleSubSub3(%arg0: index) -> index {
%c17 = constant 17 : index
%c42 = constant 42 : index
%add1 = subi %arg0, %c17 : index
%add2 = subi %add1, %c42 : index
return %add2 : index
}