[VPlan] Implement VPWidenSelectRecipe::computeCost.

Implement VPlan-based cost computation for VPWidenSelectRecipe.
This commit is contained in:
Florian Hahn 2024-10-22 03:10:04 +01:00
parent b5bcdb5cfa
commit 1d9b3222f3
No known key found for this signature in database
GPG Key ID: 9E54DEA47A8F4434
4 changed files with 90 additions and 6 deletions

View File

@ -1701,3 +1701,11 @@ void LoopVectorizationPlanner::printPlans(raw_ostream &O) {
Plan->print(O);
}
#endif
TargetTransformInfo::OperandValueInfo
VPCostContext::getOperandInfo(VPValue *V) const {
if (!V->isLiveIn())
return {};
return TTI::getOperandInfo(V->getLiveInIRValue());
}

View File

@ -38,6 +38,7 @@
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/IVDescriptors.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/DebugLoc.h"
#include "llvm/IR/FMF.h"
@ -738,6 +739,9 @@ struct VPCostContext {
/// Return true if the cost for \p UI shouldn't be computed, e.g. because it
/// has already been pre-computed.
bool skipCostComputation(Instruction *UI, bool IsVector) const;
/// Returns the OperandInfo for \p V, if it is a live-in.
TargetTransformInfo::OperandValueInfo getOperandInfo(VPValue *V) const;
};
/// VPRecipeBase is a base class modeling a sequence of one or more output IR
@ -1844,6 +1848,10 @@ struct VPWidenSelectRecipe : public VPSingleDefRecipe {
/// Produce a widened version of the select instruction.
void execute(VPTransformState &State) override;
/// Return the cost of this VPWidenSelectRecipe.
InstructionCost computeCost(ElementCount VF,
VPCostContext &Ctx) const override;
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
/// Print the recipe.
void print(raw_ostream &O, const Twine &Indent,

View File

@ -75,8 +75,8 @@ template <unsigned BitWidth = 0> struct specific_intval {
if (!CI)
return false;
assert((BitWidth == 0 || CI->getBitWidth() == BitWidth) &&
"Trying the match constant with unexpected bitwidth.");
if (BitWidth != 0 && CI->getBitWidth() != BitWidth)
return false;
return APInt::isSameValue(CI->getValue(), Val);
}
};
@ -87,6 +87,8 @@ inline specific_intval<0> m_SpecificInt(uint64_t V) {
inline specific_intval<1> m_False() { return specific_intval<1>(APInt(64, 0)); }
inline specific_intval<1> m_True() { return specific_intval<1>(APInt(64, 1)); }
/// Matching combinators
template <typename LTy, typename RTy> struct match_combine_or {
LTy L;
@ -122,7 +124,8 @@ struct MatchRecipeAndOpcode<Opcode, RecipeTy> {
auto *DefR = dyn_cast<RecipeTy>(R);
// Check for recipes that do not have opcodes.
if constexpr (std::is_same<RecipeTy, VPScalarIVStepsRecipe>::value ||
std::is_same<RecipeTy, VPCanonicalIVPHIRecipe>::value)
std::is_same<RecipeTy, VPCanonicalIVPHIRecipe>::value ||
std::is_same<RecipeTy, VPWidenSelectRecipe>::value)
return DefR;
else
return DefR && DefR->getOpcode() == Opcode;
@ -322,10 +325,34 @@ m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
return m_BinaryOr<Op0_t, Op1_t, /*Commutative*/ true>(Op0, Op1);
}
template <typename Op0_t, typename Op1_t, typename Op2_t, unsigned Opcode>
using AllTernaryRecipe_match =
Recipe_match<std::tuple<Op0_t, Op1_t, Op2_t>, Opcode, false,
VPReplicateRecipe, VPInstruction, VPWidenSelectRecipe>;
template <typename Op0_t, typename Op1_t, typename Op2_t>
inline AllTernaryRecipe_match<Op0_t, Op1_t, Op2_t, Instruction::Select>
m_Select(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) {
return AllTernaryRecipe_match<Op0_t, Op1_t, Op2_t, Instruction::Select>(
{Op0, Op1, Op2});
}
template <typename Op0_t, typename Op1_t>
inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::LogicalAnd>
inline match_combine_or<
BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::LogicalAnd>,
AllTernaryRecipe_match<Op0_t, Op1_t, specific_intval<1>,
Instruction::Select>>
m_LogicalAnd(const Op0_t &Op0, const Op1_t &Op1) {
return m_VPInstruction<VPInstruction::LogicalAnd, Op0_t, Op1_t>(Op0, Op1);
return m_CombineOr(
m_VPInstruction<VPInstruction::LogicalAnd, Op0_t, Op1_t>(Op0, Op1),
m_Select(Op0, Op1, m_False()));
}
template <typename Op0_t, typename Op1_t>
inline AllTernaryRecipe_match<Op0_t, specific_intval<1>, Op1_t,
Instruction::Select>
m_LogicalOr(const Op0_t &Op0, const Op1_t &Op1) {
return m_Select(Op0, m_True(), Op1);
}
using VPCanonicalIVPHI_match =
@ -344,7 +371,6 @@ inline VPScalarIVSteps_match<Op0_t, Op1_t> m_ScalarIVSteps(const Op0_t &Op0,
const Op1_t &Op1) {
return VPScalarIVSteps_match<Op0_t, Op1_t>(Op0, Op1);
}
} // namespace VPlanPatternMatch
} // namespace llvm

View File

@ -13,6 +13,7 @@
#include "VPlan.h"
#include "VPlanAnalysis.h"
#include "VPlanPatternMatch.h"
#include "VPlanUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
@ -23,6 +24,7 @@
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/VectorBuilder.h"
@ -1200,6 +1202,46 @@ void VPWidenSelectRecipe::execute(VPTransformState &State) {
State.addMetadata(Sel, dyn_cast_or_null<Instruction>(getUnderlyingValue()));
}
InstructionCost VPWidenSelectRecipe::computeCost(ElementCount VF,
VPCostContext &Ctx) const {
SelectInst *SI = cast<SelectInst>(getUnderlyingValue());
bool ScalarCond = getOperand(0)->isDefinedOutsideLoopRegions();
Type *ScalarTy = Ctx.Types.inferScalarType(this);
Type *VectorTy = ToVectorTy(Ctx.Types.inferScalarType(this), VF);
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
VPValue *Op0, *Op1;
using namespace llvm::VPlanPatternMatch;
if (!ScalarCond && ScalarTy->getScalarSizeInBits() == 1 &&
(match(this, m_LogicalAnd(m_VPValue(Op0), m_VPValue(Op1))) ||
match(this, m_LogicalOr(m_VPValue(Op0), m_VPValue(Op1))))) {
// select x, y, false --> x & y
// select x, true, y --> x | y
const auto [Op1VK, Op1VP] = Ctx.getOperandInfo(Op0);
const auto [Op2VK, Op2VP] = Ctx.getOperandInfo(Op1);
SmallVector<const Value *, 2> Operands;
if (all_of(operands(),
[](VPValue *Op) { return Op->getUnderlyingValue(); }))
Operands.append(SI->op_begin(), SI->op_end());
bool IsLogicalOr = match(this, m_LogicalOr(m_VPValue(Op0), m_VPValue(Op1)));
return Ctx.TTI.getArithmeticInstrCost(
IsLogicalOr ? Instruction::Or : Instruction::And, VectorTy, CostKind,
{Op1VK, Op1VP}, {Op2VK, Op2VP}, Operands, SI);
}
Type *CondTy = Ctx.Types.inferScalarType(getOperand(0));
if (!ScalarCond)
CondTy = VectorType::get(CondTy, VF);
CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
if (auto *Cmp = dyn_cast<CmpInst>(SI->getCondition()))
Pred = Cmp->getPredicate();
return Ctx.TTI.getCmpSelInstrCost(Instruction::Select, VectorTy, CondTy, Pred,
CostKind, {TTI::OK_AnyValue, TTI::OP_None},
{TTI::OK_AnyValue, TTI::OP_None}, SI);
}
VPRecipeWithIRFlags::FastMathFlagsTy::FastMathFlagsTy(
const FastMathFlags &FMF) {
AllowReassoc = FMF.allowReassoc();