diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h index 4ec20403c5ab..fea4d0da1d0d 100644 --- a/llvm/include/llvm/IR/ConstantRange.h +++ b/llvm/include/llvm/IR/ConstantRange.h @@ -332,6 +332,14 @@ public: ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type = Smallest) const; + /// Intersect the two ranges and return the result if it can be represented + /// exactly, otherwise return None. + Optional exactIntersectWith(const ConstantRange &CR) const; + + /// Union the two ranges and return the result if it can be represented + /// exactly, otherwise return None. + Optional exactUnionWith(const ConstantRange &CR) const; + /// Return a new range representing the possible values resulting /// from an application of the specified cast operator to this range. \p /// BitWidth is the target bitwidth of the cast. For casts which don't diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp index bbb07bfe3172..a0f2179bddb4 100644 --- a/llvm/lib/IR/ConstantRange.cpp +++ b/llvm/lib/IR/ConstantRange.cpp @@ -681,6 +681,24 @@ ConstantRange ConstantRange::unionWith(const ConstantRange &CR, return ConstantRange(std::move(L), std::move(U)); } +Optional +ConstantRange::exactIntersectWith(const ConstantRange &CR) const { + // TODO: This can be implemented more efficiently. + ConstantRange Result = intersectWith(CR); + if (Result == inverse().unionWith(CR.inverse()).inverse()) + return Result; + return None; +} + +Optional +ConstantRange::exactUnionWith(const ConstantRange &CR) const { + // TODO: This can be implemented more efficiently. + ConstantRange Result = unionWith(CR); + if (Result == inverse().intersectWith(CR.inverse()).inverse()) + return Result; + return None; +} + ConstantRange ConstantRange::castOp(Instruction::CastOps CastOp, uint32_t ResultBitWidth) const { switch (CastOp) { diff --git a/llvm/lib/Transforms/Scalar/GuardWidening.cpp b/llvm/lib/Transforms/Scalar/GuardWidening.cpp index b1f393765cb9..82b81003ef21 100644 --- a/llvm/lib/Transforms/Scalar/GuardWidening.cpp +++ b/llvm/lib/Transforms/Scalar/GuardWidening.cpp @@ -518,27 +518,20 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, ConstantRange CR1 = ConstantRange::makeExactICmpRegion(Pred1, RHS1->getValue()); - // SubsetIntersect is a subset of the actual mathematical intersection of - // CR0 and CR1, while SupersetIntersect is a superset of the actual - // mathematical intersection. If these two ConstantRanges are equal, then - // we know we were able to represent the actual mathematical intersection - // of CR0 and CR1, and can use the same to generate an icmp instruction. - // // Given what we're doing here and the semantics of guards, it would - // actually be correct to just use SubsetIntersect, but that may be too + // be correct to use a subset intersection, but that may be too // aggressive in cases we care about. - auto SubsetIntersect = CR0.inverse().unionWith(CR1.inverse()).inverse(); - auto SupersetIntersect = CR0.intersectWith(CR1); - - APInt NewRHSAP; - CmpInst::Predicate Pred; - if (SubsetIntersect == SupersetIntersect && - SubsetIntersect.getEquivalentICmp(Pred, NewRHSAP)) { - if (InsertPt) { - ConstantInt *NewRHS = ConstantInt::get(Cond0->getContext(), NewRHSAP); - Result = new ICmpInst(InsertPt, Pred, LHS, NewRHS, "wide.chk"); + if (Optional Intersect = CR0.exactIntersectWith(CR1)) { + APInt NewRHSAP; + CmpInst::Predicate Pred; + if (Intersect->getEquivalentICmp(Pred, NewRHSAP)) { + if (InsertPt) { + ConstantInt *NewRHS = + ConstantInt::get(Cond0->getContext(), NewRHSAP); + Result = new ICmpInst(InsertPt, Pred, LHS, NewRHS, "wide.chk"); + } + return true; } - return true; } } } diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp index 2de1ae73d9df..17cc29b0b268 100644 --- a/llvm/unittests/IR/ConstantRangeTest.cpp +++ b/llvm/unittests/IR/ConstantRangeTest.cpp @@ -558,8 +558,8 @@ TEST_F(ConstantRangeTest, IntersectWith) { EXPECT_EQ(LHS.intersectWith(RHS), ConstantRange(APInt(32, 15), APInt(32, 0))); } -template -void testBinarySetOperationExhaustive(Fn1 OpFn, Fn2 InResultFn) { +template +void testBinarySetOperationExhaustive(Fn1 OpFn, Fn2 ExactOpFn, Fn3 InResultFn) { unsigned Bits = 4; EnumerateTwoConstantRanges(Bits, [=](const ConstantRange &CR1, const ConstantRange &CR2) { @@ -577,6 +577,13 @@ void testBinarySetOperationExhaustive(Fn1 OpFn, Fn2 InResultFn) { ConstantRange SignedCR = OpFn(CR1, CR2, ConstantRange::Signed); TestRange(SignedCR, Elems, PreferSmallestNonFullSigned, {CR1, CR2}); + + Optional ExactCR = ExactOpFn(CR1, CR2); + if (SmallestCR.isSizeLargerThan(Elems.count())) { + EXPECT_TRUE(!ExactCR.hasValue()); + } else { + EXPECT_EQ(SmallestCR, *ExactCR); + } }); } @@ -586,6 +593,9 @@ TEST_F(ConstantRangeTest, IntersectWithExhaustive) { ConstantRange::PreferredRangeType Type) { return CR1.intersectWith(CR2, Type); }, + [](const ConstantRange &CR1, const ConstantRange &CR2) { + return CR1.exactIntersectWith(CR2); + }, [](const ConstantRange &CR1, const ConstantRange &CR2, const APInt &N) { return CR1.contains(N) && CR2.contains(N); }); @@ -597,6 +607,9 @@ TEST_F(ConstantRangeTest, UnionWithExhaustive) { ConstantRange::PreferredRangeType Type) { return CR1.unionWith(CR2, Type); }, + [](const ConstantRange &CR1, const ConstantRange &CR2) { + return CR1.exactUnionWith(CR2); + }, [](const ConstantRange &CR1, const ConstantRange &CR2, const APInt &N) { return CR1.contains(N) || CR2.contains(N); });