[mlir][intrange] Fix arith.shl inference in case of overflow (#91737)

When an overflow happens during shift left, i.e. the last sign bit or
the most significant data bit gets shifted out, the current approach of
inferring the range of results does not work anymore.

This patch checks for possible overflow and returns the max range in
that case.

Fix https://github.com/llvm/llvm-project/issues/82158
This commit is contained in:
Felix Schneider 2024-05-13 19:27:38 +02:00 committed by GitHub
parent cf40c93b5b
commit 0f7906645d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 46 additions and 2 deletions

View File

@ -544,15 +544,30 @@ mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
ConstantIntRanges
mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
const APInt &lhsSMin = lhs.smin(), &lhsSMax = lhs.smax(),
&lhsUMax = lhs.umax(), &rhsUMin = rhs.umin(),
&rhsUMax = rhs.umax();
ConstArithFn shl = [](const APInt &l,
const APInt &r) -> std::optional<APInt> {
return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
};
// The minMax inference does not work when there is danger of overflow. In the
// signed case, this leads to the obvious problem that the sign bit might
// change. In the unsigned case, it also leads to problems because the largest
// LHS shifted by the largest RHS does not necessarily result in the largest
// result anymore.
assert(rhsUMax.isNonNegative() && "Unexpected negative shift count");
if (rhsUMax.uge(lhsSMin.getNumSignBits()) ||
rhsUMax.uge(lhsSMax.getNumSignBits()))
return ConstantIntRanges::maxRange(lhsUMax.getBitWidth());
ConstantIntRanges urange =
minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
minMaxBy(shl, {lhs.umin(), lhsUMax}, {rhsUMin, rhsUMax},
/*isSigned=*/false);
ConstantIntRanges srange =
minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
minMaxBy(shl, {lhsSMin, lhsSMax}, {rhsUMin, rhsUMax},
/*isSigned=*/true);
return urange.intersection(srange);
}

View File

@ -71,3 +71,32 @@ func.func @test() -> i1 {
%1 = arith.cmpi sle, %0, %cst1 : index
return %1: i1
}
// -----
// CHECK-LABEL: func @test
// CHECK: test.reflect_bounds {smax = 24 : index, smin = 0 : index, umax = 24 : index, umin = 0 : index}
func.func @test() -> index {
%cst1 = arith.constant 1 : i8
%0 = test.with_bounds { umin = 0 : index, umax = 12 : index, smin = 0 : index, smax = 12 : index }
%i8val = arith.index_cast %0 : index to i8
%shifted = arith.shli %i8val, %cst1 : i8
%si = arith.index_cast %shifted : i8 to index
%1 = test.reflect_bounds %si
return %1: index
}
// -----
// CHECK-LABEL: func @test
// CHECK: test.reflect_bounds {smax = 127 : index, smin = -128 : index, umax = -1 : index, umin = 0 : index}
func.func @test() -> index {
%cst1 = arith.constant 1 : i8
%0 = test.with_bounds { umin = 0 : index, umax = 127 : index, smin = 0 : index, smax = 127 : index }
%i8val = arith.index_cast %0 : index to i8
%shifted = arith.shli %i8val, %cst1 : i8
%si = arith.index_cast %shifted : i8 to index
%1 = test.reflect_bounds %si
return %1: index
}