diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp index 0b085b10b2b3..2c1276d577a5 100644 --- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp +++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp @@ -558,6 +558,20 @@ mlir::intrange::inferOr(ArrayRef argRanges) { ConstantIntRanges mlir::intrange::inferXor(ArrayRef argRanges) { + // TODO: The code below doesn't work for bitwidths > i1. + // For input ranges lhs=[2060639849, 2060639850], rhs=[2060639849, 2060639849] + // widenBitwiseBounds will produce: + // lhs: + // 2060639848 01111010110100101101111001101000 + // 2060639851 01111010110100101101111001101011 + // rhs: + // 2060639849 01111010110100101101111001101001 + // 2060639849 01111010110100101101111001101001 + // None of those combinations xor to 0, while intermediate values does. + unsigned width = argRanges[0].umin().getBitWidth(); + if (width > 1) + return ConstantIntRanges::maxRange(width); + auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); auto xori = [](const APInt &a, const APInt &b) -> std::optional { diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir index afb0b4929bce..4db846fa4656 100644 --- a/mlir/test/Dialect/Arith/int-range-interface.mlir +++ b/mlir/test/Dialect/Arith/int-range-interface.mlir @@ -454,9 +454,35 @@ func.func @ori(%arg0 : i128, %arg1 : i128) -> i1 { func.return %2 : i1 } +// CHECK-LABEL: func @xori_issue_82168 +// arith.cmpi was erroneously folded to %false, see Issue #82168. +// CHECK: %[[R:.*]] = arith.cmpi eq, %{{.*}}, %{{.*}} : i64 +// CHECK: return %[[R]] +func.func @xori_issue_82168() -> i1 { + %c0_i64 = arith.constant 0 : i64 + %c2060639849_i64 = arith.constant 2060639849 : i64 + %2 = test.with_bounds { umin = 2060639849 : i64, umax = 2060639850 : i64, smin = 2060639849 : i64, smax = 2060639850 : i64 } : i64 + %3 = arith.xori %2, %c2060639849_i64 : i64 + %4 = arith.cmpi eq, %3, %c0_i64 : i64 + func.return %4 : i1 +} + +// CHECK-LABEL: func @xori_i1 +// CHECK-DAG: %[[true:.*]] = arith.constant true +// CHECK-DAG: %[[false:.*]] = arith.constant false +// CHECK: return %[[true]], %[[false]] +func.func @xori_i1() -> (i1, i1) { + %true = arith.constant true + %1 = test.with_bounds { umin = 0 : i1, umax = 0 : i1, smin = 0 : i1, smax = 0 : i1 } : i1 + %2 = test.with_bounds { umin = 1 : i1, umax = 1 : i1, smin = 1 : i1, smax = 1 : i1 } : i1 + %3 = arith.xori %1, %true : i1 + %4 = arith.xori %2, %true : i1 + func.return %3, %4 : i1, i1 +} + // CHECK-LABEL: func @xori -// CHECK: %[[false:.*]] = arith.constant false -// CHECK: return %[[false]] +// TODO: xor folding is temporarily disabled +// CHECK-NOT: arith.constant false func.func @xori(%arg0 : i64, %arg1 : i64) -> i1 { %c0 = arith.constant 0 : i64 %c7 = arith.constant 7 : i64