//===- llvm/unittest/Support/DivisionByConstantTest.cpp -------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "llvm/ADT/APInt.h" #include "llvm/Support/DivisionByConstantInfo.h" #include "gtest/gtest.h" using namespace llvm; namespace { template static void EnumerateAPInts(unsigned Bits, Fn TestFn) { APInt N(Bits, 0); do { TestFn(N); } while (++N != 0); } APInt MULHS(APInt X, APInt Y) { unsigned Bits = X.getBitWidth(); unsigned WideBits = 2 * Bits; return (X.sext(WideBits) * Y.sext(WideBits)).lshr(Bits).trunc(Bits); } APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor, SignedDivisionByConstantInfo Magics) { unsigned Bits = Numerator.getBitWidth(); APInt Factor(Bits, 0); APInt ShiftMask(Bits, -1); if (Divisor.isOne() || Divisor.isAllOnes()) { // If d is +1/-1, we just multiply the numerator by +1/-1. Factor = Divisor.getSExtValue(); Magics.Magic = 0; Magics.ShiftAmount = 0; ShiftMask = 0; } else if (Divisor.isStrictlyPositive() && Magics.Magic.isNegative()) { // If d > 0 and m < 0, add the numerator. Factor = 1; } else if (Divisor.isNegative() && Magics.Magic.isStrictlyPositive()) { // If d < 0 and m > 0, subtract the numerator. Factor = -1; } // Multiply the numerator by the magic value. APInt Q = MULHS(Numerator, Magics.Magic); // (Optionally) Add/subtract the numerator using Factor. Factor = Numerator * Factor; Q = Q + Factor; // Shift right algebraic by shift value. Q = Q.ashr(Magics.ShiftAmount); // Extract the sign bit, mask it and add it to the quotient. unsigned SignShift = Bits - 1; APInt T = Q.lshr(SignShift); T = T & ShiftMask; return Q + T; } TEST(SignedDivisionByConstantTest, Test) { for (unsigned Bits = 1; Bits <= 32; ++Bits) { if (Bits < 3) continue; // Not supported by `SignedDivisionByConstantInfo::get()`. if (Bits > 12) continue; // Unreasonably slow. EnumerateAPInts(Bits, [Bits](const APInt &Divisor) { if (Divisor.isZero()) return; // Division by zero is undefined behavior. SignedDivisionByConstantInfo Magics; if (!(Divisor.isOne() || Divisor.isAllOnes())) Magics = SignedDivisionByConstantInfo::get(Divisor); EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) { if (Numerator.isMinSignedValue() && Divisor.isAllOnes()) return; // Overflow is undefined behavior. APInt NativeResult = Numerator.sdiv(Divisor); APInt MagicResult = SignedDivideUsingMagic(Numerator, Divisor, Magics); ASSERT_EQ(MagicResult, NativeResult) << " ... given the operation: srem i" << Bits << " " << Numerator << ", " << Divisor; }); }); } } APInt MULHU(APInt X, APInt Y) { unsigned Bits = X.getBitWidth(); unsigned WideBits = 2 * Bits; return (X.zext(WideBits) * Y.zext(WideBits)).lshr(Bits).trunc(Bits); } APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor, bool LZOptimization, bool AllowEvenDivisorOptimization, bool ForceNPQ, UnsignedDivisionByConstantInfo Magics) { assert(!Divisor.isOne() && "Division by 1 is not supported using Magic."); unsigned Bits = Numerator.getBitWidth(); if (LZOptimization) { unsigned LeadingZeros = Numerator.countl_zero(); // Clip to the number of leading zeros in the divisor. LeadingZeros = std::min(LeadingZeros, Divisor.countl_zero()); if (LeadingZeros > 0) { Magics = UnsignedDivisionByConstantInfo::get( Divisor, LeadingZeros, AllowEvenDivisorOptimization); assert(!Magics.IsAdd && "Should use cheap fixup now"); } } assert(Magics.PreShift < Divisor.getBitWidth() && "We shouldn't generate an undefined shift!"); assert(Magics.PostShift < Divisor.getBitWidth() && "We shouldn't generate an undefined shift!"); assert((!Magics.IsAdd || Magics.PreShift == 0) && "Unexpected pre-shift"); unsigned PreShift = Magics.PreShift; unsigned PostShift = Magics.PostShift; bool UseNPQ = Magics.IsAdd; APInt NPQFactor = UseNPQ ? APInt::getSignedMinValue(Bits) : APInt::getZero(Bits); APInt Q = Numerator.lshr(PreShift); // Multiply the numerator by the magic value. Q = MULHU(Q, Magics.Magic); if (UseNPQ || ForceNPQ) { APInt NPQ = Numerator - Q; // For vectors we might have a mix of non-NPQ/NPQ paths, so use // MULHU to act as a SRL-by-1 for NPQ, else multiply by zero. APInt NPQ_Scalar = NPQ.lshr(1); (void)NPQ_Scalar; NPQ = MULHU(NPQ, NPQFactor); assert(!UseNPQ || NPQ == NPQ_Scalar); Q = NPQ + Q; } Q = Q.lshr(PostShift); return Q; } TEST(UnsignedDivisionByConstantTest, Test) { for (unsigned Bits = 1; Bits <= 32; ++Bits) { if (Bits < 2) continue; // Not supported by `UnsignedDivisionByConstantInfo::get()`. if (Bits > 10) continue; // Unreasonably slow. EnumerateAPInts(Bits, [Bits](const APInt &Divisor) { if (Divisor.isZero()) return; // Division by zero is undefined behavior. if (Divisor.isOne()) return; // Division by one is the numerator. const UnsignedDivisionByConstantInfo Magics = UnsignedDivisionByConstantInfo::get(Divisor); EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) { APInt NativeResult = Numerator.udiv(Divisor); for (bool LZOptimization : {true, false}) { for (bool AllowEvenDivisorOptimization : {true, false}) { for (bool ForceNPQ : {false, true}) { APInt MagicResult = UnsignedDivideUsingMagic( Numerator, Divisor, LZOptimization, AllowEvenDivisorOptimization, ForceNPQ, Magics); ASSERT_EQ(MagicResult, NativeResult) << " ... given the operation: urem i" << Bits << " " << Numerator << ", " << Divisor << " (allow LZ optimization = " << LZOptimization << ", allow even divisior optimization = " << AllowEvenDivisorOptimization << ", force NPQ = " << ForceNPQ << ")"; } } } }); }); } } } // end anonymous namespace