diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp index 2b2d937d55d8..6af229cae10a 100644 --- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp +++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp @@ -544,15 +544,30 @@ mlir::intrange::inferXor(ArrayRef argRanges) { ConstantIntRanges mlir::intrange::inferShl(ArrayRef argRanges) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + const APInt &lhsSMin = lhs.smin(), &lhsSMax = lhs.smax(), + &lhsUMax = lhs.umax(), &rhsUMin = rhs.umin(), + &rhsUMax = rhs.umax(); + ConstArithFn shl = [](const APInt &l, const APInt &r) -> std::optional { return r.uge(r.getBitWidth()) ? std::optional() : l.shl(r); }; + + // The minMax inference does not work when there is danger of overflow. In the + // signed case, this leads to the obvious problem that the sign bit might + // change. In the unsigned case, it also leads to problems because the largest + // LHS shifted by the largest RHS does not necessarily result in the largest + // result anymore. + assert(rhsUMax.isNonNegative() && "Unexpected negative shift count"); + if (rhsUMax.uge(lhsSMin.getNumSignBits()) || + rhsUMax.uge(lhsSMax.getNumSignBits())) + return ConstantIntRanges::maxRange(lhsUMax.getBitWidth()); + ConstantIntRanges urange = - minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, + minMaxBy(shl, {lhs.umin(), lhsUMax}, {rhsUMin, rhsUMax}, /*isSigned=*/false); ConstantIntRanges srange = - minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, + minMaxBy(shl, {lhsSMin, lhsSMax}, {rhsUMin, rhsUMax}, /*isSigned=*/true); return urange.intersection(srange); } diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir index be0a7e8ccd70..4c3c0854ed02 100644 --- a/mlir/test/Dialect/Arith/int-range-opts.mlir +++ b/mlir/test/Dialect/Arith/int-range-opts.mlir @@ -71,3 +71,32 @@ func.func @test() -> i1 { %1 = arith.cmpi sle, %0, %cst1 : index return %1: i1 } + +// ----- + +// CHECK-LABEL: func @test +// CHECK: test.reflect_bounds {smax = 24 : index, smin = 0 : index, umax = 24 : index, umin = 0 : index} +func.func @test() -> index { + %cst1 = arith.constant 1 : i8 + %0 = test.with_bounds { umin = 0 : index, umax = 12 : index, smin = 0 : index, smax = 12 : index } + %i8val = arith.index_cast %0 : index to i8 + %shifted = arith.shli %i8val, %cst1 : i8 + %si = arith.index_cast %shifted : i8 to index + %1 = test.reflect_bounds %si + return %1: index +} + +// ----- + +// CHECK-LABEL: func @test +// CHECK: test.reflect_bounds {smax = 127 : index, smin = -128 : index, umax = -1 : index, umin = 0 : index} +func.func @test() -> index { + %cst1 = arith.constant 1 : i8 + %0 = test.with_bounds { umin = 0 : index, umax = 127 : index, smin = 0 : index, smax = 127 : index } + %i8val = arith.index_cast %0 : index to i8 + %shifted = arith.shli %i8val, %cst1 : i8 + %si = arith.index_cast %shifted : i8 to index + %1 = test.reflect_bounds %si + return %1: index +} +