More simplification for affine binary op expr's.

- simplify operations with identity elements (multiply by 1, add with 0).
- simplify successive add/mul: fold constants, propagate constants to the
  right.
- simplify floordiv and ceildiv when divisors are constants, and the LHS is a
  multiply expression with RHS constant.
- fix an affine expression printing bug on paren emission.

- while on this, fix affine-map test cases file (memref's using layout maps
  that were duplicates of existing ones should be emitted pointing to the
  unique'd one).

PiperOrigin-RevId: 207046738
This commit is contained in:
Uday Bondhugula 2018-08-01 22:02:00 -07:00 committed by jpienaar
parent 1e793eb8dc
commit b92378e8fa
3 changed files with 189 additions and 48 deletions

View File

@ -27,32 +27,65 @@ AffineMap::AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
: numDims(numDims), numSymbols(numSymbols), numResults(numResults),
results(results), rangeSizes(rangeSizes) {}
/// Fold to a constant when possible. Canonicalize so that only the RHS is a
/// constant. (4 + d0 becomes d0 + 4). If only one of them is a symbolic
/// expressions, make it the RHS. Return nullptr if it can't be simplified.
/// Simplify add expression. Return nullptr if it can't be simplified.
AffineExpr *AffineBinaryOpExpr::simplifyAdd(AffineExpr *lhs, AffineExpr *rhs,
MLIRContext *context) {
if (auto *l = dyn_cast<AffineConstantExpr>(lhs))
if (auto *r = dyn_cast<AffineConstantExpr>(rhs))
return AffineConstantExpr::get(l->getValue() + r->getValue(), context);
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
// Fold if both LHS, RHS are a constant.
if (lhsConst && rhsConst)
return AffineConstantExpr::get(lhsConst->getValue() + rhsConst->getValue(),
context);
// Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
// If only one of them is a symbolic expressions, make it the RHS.
if (isa<AffineConstantExpr>(lhs) ||
(lhs->isSymbolicOrConstant() && !rhs->isSymbolicOrConstant()))
(lhs->isSymbolicOrConstant() && !rhs->isSymbolicOrConstant())) {
return AffineBinaryOpExpr::get(Kind::Add, rhs, lhs, context);
}
// At this point, if there was a constant, it would be on the right.
// Addition with a zero is a noop, return the other input.
if (rhsConst) {
if (rhsConst->getValue() == 0)
return lhs;
}
// Fold successive additions like (d0 + 2) + 3 into d0 + 5.
auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
if (lBin && rhsConst && lBin->getKind() == Kind::Add) {
if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS()))
return AffineBinaryOpExpr::get(
Kind::Add, lBin->getLHS(),
AffineConstantExpr::get(lrhs->getValue() + rhsConst->getValue(),
context),
context);
}
// When doing successive additions, bring constant to the right: turn (d0 + 2)
// + d1 into (d0 + d1) + 2.
if (lBin && lBin->getKind() == Kind::Add) {
if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
return AffineBinaryOpExpr::get(
Kind::Add,
AffineBinaryOpExpr::get(Kind::Add, lBin->getLHS(), rhs, context),
lrhs, context);
}
}
return nullptr;
// TODO(someone): implement more simplification like x + 0 -> x; (x + 2) + 4
// -> (x + 6). Do this in a systematic way in conjunction with other
// simplifications as opposed to incremental hacks.
}
/// Simplify a multiply expression. Fold it to a constant when possible, and
/// make the symbolic/constant operand the RHS.
/// Simplify a multiply expression. Return nullptr if it can't be simplified.
AffineExpr *AffineBinaryOpExpr::simplifyMul(AffineExpr *lhs, AffineExpr *rhs,
MLIRContext *context) {
if (auto *l = dyn_cast<AffineConstantExpr>(lhs))
if (auto *r = dyn_cast<AffineConstantExpr>(rhs))
return AffineConstantExpr::get(l->getValue() * r->getValue(), context);
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
if (lhsConst && rhsConst)
return AffineConstantExpr::get(lhsConst->getValue() * rhsConst->getValue(),
context);
assert(lhs->isSymbolicOrConstant() || rhs->isSymbolicOrConstant());
@ -64,33 +97,100 @@ AffineExpr *AffineBinaryOpExpr::simplifyMul(AffineExpr *lhs, AffineExpr *rhs,
return AffineBinaryOpExpr::get(Kind::Mul, rhs, lhs, context);
}
// At this point, if there was a constant, it would be on the right.
// Multiplication with a one is a noop, return the other input.
if (rhsConst) {
if (rhsConst->getValue() == 1)
return lhs;
// Multiplication with zero.
if (rhsConst->getValue() == 0)
return rhsConst;
}
// Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
if (lBin && rhsConst && lBin->getKind() == Kind::Mul) {
if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS()))
return AffineBinaryOpExpr::get(
Kind::Mul, lBin->getLHS(),
AffineConstantExpr::get(lrhs->getValue() * rhsConst->getValue(),
context),
context);
}
// When doing successive multiplication, bring constant to the right: turn (d0
// * 2) * d1 into (d0 * d1) * 2.
if (lBin && lBin->getKind() == Kind::Mul) {
if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
return AffineBinaryOpExpr::get(
Kind::Mul,
AffineBinaryOpExpr::get(Kind::Mul, lBin->getLHS(), rhs, context),
lrhs, context);
}
}
return nullptr;
// TODO(someone): implement some more simplification/canonicalization such as
// 1*x is same as x, and in general, move it in the form d_i*expr where d_i is
// a dimensional identifier. So, 2*(d0 + 4) + s0*d0 becomes (2 + s0)*d0 + 8.
}
AffineExpr *AffineBinaryOpExpr::simplifyFloorDiv(AffineExpr *lhs,
AffineExpr *rhs,
MLIRContext *context) {
if (auto *l = dyn_cast<AffineConstantExpr>(lhs))
if (auto *r = dyn_cast<AffineConstantExpr>(rhs))
return AffineConstantExpr::get(l->getValue() / r->getValue(), context);
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
if (lhsConst && rhsConst)
return AffineConstantExpr::get(lhsConst->getValue() / rhsConst->getValue(),
context);
// Fold floordiv of a multiply with a constant that is a multiple of the
// divisor. Eg: (i * 128) floordiv 64 = i * 2.
if (rhsConst) {
auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
if (lBin && lBin->getKind() == Kind::Mul) {
if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
// rhsConst is known to be positive if a constant.
if (lrhs->getValue() % rhsConst->getValue() == 0)
return AffineBinaryOpExpr::get(
Kind::Mul, lBin->getLHS(),
AffineConstantExpr::get(lrhs->getValue() / rhsConst->getValue(),
context),
context);
}
}
}
return nullptr;
// TODO(someone): implement more simplification along the lines described in
// simplifyMod TODO. For eg: 128*N floordiv 128 is N.
}
AffineExpr *AffineBinaryOpExpr::simplifyCeilDiv(AffineExpr *lhs,
AffineExpr *rhs,
MLIRContext *context) {
if (auto *l = dyn_cast<AffineConstantExpr>(lhs))
if (auto *r = dyn_cast<AffineConstantExpr>(rhs))
return AffineConstantExpr::get(
(int64_t)llvm::divideCeil((uint64_t)l->getValue(),
(uint64_t)r->getValue()),
context);
auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
if (lhsConst && rhsConst)
return AffineConstantExpr::get(
(int64_t)llvm::divideCeil((uint64_t)lhsConst->getValue(),
(uint64_t)rhsConst->getValue()),
context);
// Fold ceildiv of a multiply with a constant that is a multiple of the
// divisor. Eg: (i * 128) ceildiv 64 = i * 2.
if (rhsConst) {
auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
if (lBin && lBin->getKind() == Kind::Mul) {
if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
// rhsConst is known to be positive if a constant.
if (lrhs->getValue() % rhsConst->getValue() == 0)
return AffineBinaryOpExpr::get(
Kind::Mul, lBin->getLHS(),
AffineConstantExpr::get(lrhs->getValue() / rhsConst->getValue(),
context),
context);
}
}
}
return nullptr;
// TODO(someone): implement more simplification along the lines described in

View File

@ -454,9 +454,9 @@ void ModulePrinter::printAffineExprInternal(
if (rrhs->getValue() < -1) {
printAffineExprInternal(binOp->getLHS(), BindingStrength::Weak);
os << " - (";
os << " - ";
printAffineExprInternal(rhs->getLHS(), BindingStrength::Strong);
os << " * " << -rrhs->getValue() << ')';
os << " * " << -rrhs->getValue();
if (enclosingTightness == BindingStrength::Strong)
os << ')';
return;

View File

@ -7,16 +7,26 @@
#map1 = (i, j)[s0] -> (i, j)
// CHECK: #map{{[0-9]+}} = () -> (0)
// A map may have 0 inputs. However, an affine_apply always takes at least one input.
#map2 = () -> (0)
// All three maps are unique'd as one map and so there
// should be only one output.
// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0 + 1, d1)
#map3 = (i, j) -> (i+1, j)
// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0 + 1, d1 * 4 + 2)
#map3 = (i, j) -> (i+1, 4*j + 2)
// CHECK-EMPTY
#map3a = (i, j) -> (1+i, j)
#map3a = (i, j) -> (1+i, 4*j + 2)
// CHECK-EMPTY
#map3b = (i, j) -> (2+3-2*2+i, j)
#map3b = (i, j) -> (2 + 3 - 2*2 + i, 4*j + 2)
#map3c = (i, j) -> (i +1 + 0, 4*j + 2)
#map3d = (i, j) -> (i + 3 + 2 - 4, 4*j + 2)
#map3e = (i, j) -> (1*i+3*2-2*2-1, 4*j + 2)
#map3f = (i, j) -> (i + 1, 4*j*1 + 2)
#map3g = (i, j) -> (i + 1, 2*2*j + 2)
#map3h = (i, j) -> (i + 1, 2*j*2 + 2)
#map3i = (i, j) -> (i + 1, j*2*2 + 2)
#map3j = (i, j) -> (i + 1, j*1*4 + 2)
#map3k = (i, j) -> (i + 1, j*4*1 + 2)
// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0 + 2, d1)
#map4 = (i, j) -> (3+3-2*2+i, j)
@ -30,7 +40,7 @@
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 + d1 + s0, d1)
#map7 = (i, j)[s0] -> (i + j + s0, j)
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 + 5 + d1 + s0, d1)
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 + d1 + s0 + 5, d1)
#map8 = (i, j)[s0] -> (5 + i + j + s0, j)
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 + d1 + 5, d1)
@ -42,7 +52,7 @@
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 * 2, d1 * 3)
#map11 = (i, j)[s0] -> (2*i, 3*j)
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 + 12 + (d1 + s0 * 3) * 5, d1)
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 + (d1 + s0 * 3) * 5 + 12, d1)
#map12 = (i, j)[s0] -> (i + 2*6 + 5*(j+s0*3), j)
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 * 5 + d1, d1)
@ -51,8 +61,8 @@
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 + d1, d1)
#map14 = (i, j)[s0] -> ((i + j), (j))
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 + d1 + 5, d1 + 3)
#map15 = (i, j)[s0] -> ((i + j)+5, (j)+3)
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 + d1 + 7, d1 + 3)
#map15 = (i, j)[s0] -> ((i + j + 2) + 5, (j)+3)
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0, 0)
#map16 = (i, j)[s1] -> (i, 0)
@ -66,7 +76,7 @@
// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0, d0 + d1 * 3)
#map20 = (i, j) -> (i, i + 3*j)
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0, d0 * ((s0 * s0) * 9) + 2 + 1)
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0, d0 * ((s0 * s0) * 9) + 3)
#map18 = (i, j)[N] -> (i, 2 + N*N*9*i + 1)
// CHECK: #map{{[0-9]+}} = (d0, d1) -> (1, d0 + d1 * 3 + 5)
@ -105,9 +115,9 @@
// CHECK: #map{{[0-9]+}} = (d0, d1, d2)[s0, s1, s2] -> ((d0 * s1) * s2 + d1 * s1 + d2)
#map35 = (i, j, k)[s0, s1, s2] -> (i*s1*s2 + j*s1 + k)
// Constant folding.
// CHECK: #map{{[0-9]+}} = (d0, d1) -> (8, 4, 1, 3, 2, 4)
#map36 = (i, j) -> (5+3, 2*2, 8-7, 100 floordiv 32, 5 mod 3, 10 ceildiv 3)
// CHECK: #map{{[0-9]+}} = (d0, d1) -> (4, 11, 512, 15)
#map37 = (i, j) -> (5 mod 3 + 2, 5*3 - 4, 128 * (500 ceildiv 128), 40 floordiv 7 * 3)
@ -123,14 +133,21 @@
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0, s1] -> (d0, d1) size (s0, s1 + 10)
#map41 = (i, j)[N, M] -> (i, j) size (N, M+10)
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0, s1] -> (d0, d1) size (128, s0 * 2 + 5 + s1)
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0, s1] -> (d0, d1) size (128, s0 * 2 + s1 + 5)
#map42 = (i, j)[N, M] -> (i, j) size (64 + 64, 5 + 2*N + M)
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> ((d0 * 5) floordiv 4, (d1 ceildiv 7) mod s0)
#map43 = (i, j) [s0] -> ( i * 5 floordiv 4, j ceildiv 7 mod s0)
// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0 - d1 * 2)
#map44 = (i, j) -> (i - 2*j)
// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0 - d1 * 2, (d1 * 6) floordiv 4)
#map44 = (i, j) -> (i - 2*j, j * 6 floordiv 4)
// Simplifications
// CHECK: #map{{[0-9]+}} = (d0, d1, d2)[s0] -> (d0 + d1 + d2 + 1, d2 + d1, (d0 * s0) * 8)
#map45 = (i, j, k) [N] -> (1 + i + 3 + j - 3 + k, k + 5 + j - 5, 2*i*4*N)
// CHECK: #map{{[0-9]+}} = (d0, d1, d2) -> (0, d0 * 2, 0, d0, d0 * 4)
#map46 = (i, j, k) -> (i*0, i * 128 floordiv 64, j * 0 floordiv 64, i * 64 ceildiv 64, i * 512 ceildiv 128)
// CHECK: extfunc @f0(memref<2x4xi8, #map{{[0-9]+}}, 1>)
extfunc @f0(memref<2x4xi8, #map0, 1>)
@ -143,6 +160,28 @@ extfunc @f2(memref<2xi8, #map2, 1>)
// CHECK: extfunc @f3(memref<2x4xi8, #map{{[0-9]+}}, 1>)
extfunc @f3(memref<2x4xi8, #map3, 1>)
// CHECK: extfunc @f3a(memref<2x4xi8, #map{{[0-9]+}}, 1>)
extfunc @f3a(memref<2x4xi8, #map3a, 1>)
// CHECK: extfunc @f3b(memref<2x4xi8, #map{{[0-9]+}}, 1>)
extfunc @f3b(memref<2x4xi8, #map3b, 1>)
// CHECK: extfunc @f3c(memref<2x4xi8, #map{{[0-9]+}}, 1>)
extfunc @f3c(memref<2x4xi8, #map3c, 1>)
// CHECK: extfunc @f3d(memref<2x4xi8, #map{{[0-9]+}}, 1>)
extfunc @f3d(memref<2x4xi8, #map3d, 1>)
// CHECK: extfunc @f3e(memref<2x4xi8, #map{{[0-9]+}}, 1>)
extfunc @f3e(memref<2x4xi8, #map3e, 1>)
// CHECK: extfunc @f3f(memref<2x4xi8, #map{{[0-9]+}}, 1>)
extfunc @f3f(memref<2x4xi8, #map3f, 1>)
// CHECK: extfunc @f3g(memref<2x4xi8, #map{{[0-9]+}}, 1>)
extfunc @f3g(memref<2x4xi8, #map3g, 1>)
// CHECK: extfunc @f3h(memref<2x4xi8, #map{{[0-9]+}}, 1>)
extfunc @f3h(memref<2x4xi8, #map3h, 1>)
// CHECK: extfunc @f3i(memref<2x4xi8, #map{{[0-9]+}}, 1>)
extfunc @f3i(memref<2x4xi8, #map3i, 1>)
// CHECK: extfunc @f3j(memref<2x4xi8, #map{{[0-9]+}}, 1>)
extfunc @f3j(memref<2x4xi8, #map3j, 1>)
// CHECK: extfunc @f3k(memref<2x4xi8, #map{{[0-9]+}}, 1>)
extfunc @f3k(memref<2x4xi8, #map3k, 1>)
// CHECK: extfunc @f4(memref<2x4xi8, #map{{[0-9]+}}, 1>)
extfunc @f4(memref<2x4xi8, #map4, 1>)
@ -253,11 +292,13 @@ extfunc @f41(memref<2x4xi8, #map41, 1>)
extfunc @f42(memref<2x4xi8, #map42, 1>)
// CHECK: extfunc @f43(memref<2x4xi8, #map{{[0-9]+}}>)
extfunc @f43(memref<2x4xi8, #map42>)
extfunc @f43(memref<2x4xi8, #map43>)
// CHECK: extfunc @f44(memref<2x4xi8, #map{{[0-9]+}}>)
extfunc @f44(memref<2x4xi8, #map43>)
extfunc @f44(memref<2x4xi8, #map44>)
// CHECK: extfunc @f45(memref<2xi8, #map{{[0-9]+}}>)
extfunc @f45(memref<2xi8, #map44>)
// CHECK: extfunc @f45(memref<100x100x100xi8, #map{{[0-9]+}}>)
extfunc @f45(memref<100x100x100xi8, #map45>)
// CHECK: extfunc @f45(memref<100x100x100xi8, #map{{[0-9]+}}>)
extfunc @f45(memref<100x100x100xi8, #map46>)