[VPlan] Implement VPWidenCastRecipe::computeCost(). (NFCI) (#111339)

This patch implement `VPWidenCastRecipe::computeCost()` and skip cast
recipies in the in-loop reduction.
This commit is contained in:
Elvis Wang 2024-10-22 12:23:49 +08:00 committed by GitHub
parent a4819bd46d
commit b3edc764f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 73 additions and 2 deletions

View File

@ -7307,12 +7307,30 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
const auto &ChainOps = RdxDesc.getReductionOpChain(RedPhi, OrigLoop);
SetVector<Instruction *> ChainOpsAndOperands(ChainOps.begin(),
ChainOps.end());
auto IsZExtOrSExt = [](const unsigned Opcode) -> bool {
return Opcode == Instruction::ZExt || Opcode == Instruction::SExt;
};
// Also include the operands of instructions in the chain, as the cost-model
// may mark extends as free.
//
// For ARM, some of the instruction can folded into the reducion
// instruction. So we need to mark all folded instructions free.
// For example: We can fold reduce(mul(ext(A), ext(B))) into one
// instruction.
for (auto *ChainOp : ChainOps) {
for (Value *Op : ChainOp->operands()) {
if (auto *I = dyn_cast<Instruction>(Op))
if (auto *I = dyn_cast<Instruction>(Op)) {
ChainOpsAndOperands.insert(I);
if (I->getOpcode() == Instruction::Mul) {
auto *Ext0 = dyn_cast<Instruction>(I->getOperand(0));
auto *Ext1 = dyn_cast<Instruction>(I->getOperand(1));
if (Ext0 && IsZExtOrSExt(Ext0->getOpcode()) && Ext1 &&
Ext0->getOpcode() == Ext1->getOpcode()) {
ChainOpsAndOperands.insert(Ext0);
ChainOpsAndOperands.insert(Ext1);
}
}
}
}
}

View File

@ -1603,6 +1603,10 @@ public:
/// Produce widened copies of the cast.
void execute(VPTransformState &State) override;
/// Return the cost of this VPWidenCastRecipe.
InstructionCost computeCost(ElementCount VF,
VPCostContext &Ctx) const override;
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
/// Print the recipe.
void print(raw_ostream &O, const Twine &Indent,

View File

@ -1522,6 +1522,55 @@ void VPWidenCastRecipe::execute(VPTransformState &State) {
State.addMetadata(Cast, cast_or_null<Instruction>(getUnderlyingValue()));
}
InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
VPCostContext &Ctx) const {
// Computes the CastContextHint from a recipes that may access memory.
auto ComputeCCH = [&](const VPRecipeBase *R) -> TTI::CastContextHint {
if (VF.isScalar())
return TTI::CastContextHint::Normal;
if (isa<VPInterleaveRecipe>(R))
return TTI::CastContextHint::Interleave;
if (const auto *ReplicateRecipe = dyn_cast<VPReplicateRecipe>(R))
return ReplicateRecipe->isPredicated() ? TTI::CastContextHint::Masked
: TTI::CastContextHint::Normal;
const auto *WidenMemoryRecipe = dyn_cast<VPWidenMemoryRecipe>(R);
if (WidenMemoryRecipe == nullptr)
return TTI::CastContextHint::None;
if (!WidenMemoryRecipe->isConsecutive())
return TTI::CastContextHint::GatherScatter;
if (WidenMemoryRecipe->isReverse())
return TTI::CastContextHint::Reversed;
if (WidenMemoryRecipe->isMasked())
return TTI::CastContextHint::Masked;
return TTI::CastContextHint::Normal;
};
VPValue *Operand = getOperand(0);
TTI::CastContextHint CCH = TTI::CastContextHint::None;
// For Trunc/FPTrunc, get the context from the only user.
if ((Opcode == Instruction::Trunc || Opcode == Instruction::FPTrunc) &&
!hasMoreThanOneUniqueUser() && getNumUsers() > 0) {
if (auto *StoreRecipe = dyn_cast<VPRecipeBase>(*user_begin()))
CCH = ComputeCCH(StoreRecipe);
}
// For Z/Sext, get the context from the operand.
else if (Opcode == Instruction::ZExt || Opcode == Instruction::SExt ||
Opcode == Instruction::FPExt) {
if (Operand->isLiveIn())
CCH = TTI::CastContextHint::Normal;
else if (Operand->getDefiningRecipe())
CCH = ComputeCCH(Operand->getDefiningRecipe());
}
auto *SrcTy =
cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(Operand), VF));
auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
// Arm TTI will use the underlying instruction to determine the cost.
return Ctx.TTI.getCastInstrCost(
Opcode, DestTy, SrcTy, CCH, TTI::TCK_RecipThroughput,
dyn_cast_if_present<Instruction>(getUnderlyingValue()));
}
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
void VPWidenCastRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {

View File

@ -135,7 +135,7 @@ public:
}
/// Returns true if the value has more than one unique user.
bool hasMoreThanOneUniqueUser() {
bool hasMoreThanOneUniqueUser() const {
if (getNumUsers() == 0)
return false;