[AArch64][LV][SLP] Vectorizers use call cost for vectorized frem (#82488)

getArithmeticInstrCost is used by both LoopVectorizer and SLPVectorizer
to compute the cost of frem, which becomes a call cost on AArch64 when
TLI has a vector library function.

Add tests that do SLP vectorization for code that contains 2x double and
4x float frem instructions.
This commit is contained in:
Paschalis Mpeis 2024-03-14 17:20:29 +00:00 committed by GitHub
parent 611c62b30d
commit f795d1a8b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 79 additions and 20 deletions

View File

@ -1247,13 +1247,16 @@ public:
/// cases or optimizations based on those values.
/// \p CxtI is the optional original context instruction, if one exists, to
/// provide even more information.
/// \p TLibInfo is used to search for platform specific vector library
/// functions for instructions that might be converted to calls (e.g. frem).
InstructionCost getArithmeticInstrCost(
unsigned Opcode, Type *Ty,
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput,
TTI::OperandValueInfo Opd1Info = {TTI::OK_AnyValue, TTI::OP_None},
TTI::OperandValueInfo Opd2Info = {TTI::OK_AnyValue, TTI::OP_None},
ArrayRef<const Value *> Args = ArrayRef<const Value *>(),
const Instruction *CxtI = nullptr) const;
const Instruction *CxtI = nullptr,
const TargetLibraryInfo *TLibInfo = nullptr) const;
/// Returns the cost estimation for alternating opcode pattern that can be
/// lowered to a single instruction on the target. In X86 this is for the

View File

@ -9,6 +9,7 @@
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/CFG.h"
#include "llvm/Analysis/LoopIterator.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfoImpl.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Dominators.h"
@ -874,7 +875,22 @@ TargetTransformInfo::getOperandInfo(const Value *V) {
InstructionCost TargetTransformInfo::getArithmeticInstrCost(
unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
OperandValueInfo Op1Info, OperandValueInfo Op2Info,
ArrayRef<const Value *> Args, const Instruction *CxtI) const {
ArrayRef<const Value *> Args, const Instruction *CxtI,
const TargetLibraryInfo *TLibInfo) const {
// Use call cost for frem intructions that have platform specific vector math
// functions, as those will be replaced with calls later by SelectionDAG or
// ReplaceWithVecLib pass.
if (TLibInfo && Opcode == Instruction::FRem) {
VectorType *VecTy = dyn_cast<VectorType>(Ty);
LibFunc Func;
if (VecTy &&
TLibInfo->getLibFunc(Instruction::FRem, Ty->getScalarType(), Func) &&
TLibInfo->isFunctionVectorizable(TLibInfo->getName(Func),
VecTy->getElementCount()))
return getCallInstrCost(nullptr, VecTy, {VecTy, VecTy}, CostKind);
}
InstructionCost Cost =
TTIImpl->getArithmeticInstrCost(Opcode, Ty, CostKind,
Op1Info, Op2Info,

View File

@ -6911,25 +6911,10 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF,
Op2Info.Kind = TargetTransformInfo::OK_UniformValue;
SmallVector<const Value *, 4> Operands(I->operand_values());
auto InstrCost = TTI.getArithmeticInstrCost(
return TTI.getArithmeticInstrCost(
I->getOpcode(), VectorTy, CostKind,
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
Op2Info, Operands, I);
// Some targets can replace frem with vector library calls.
InstructionCost VecCallCost = InstructionCost::getInvalid();
if (I->getOpcode() == Instruction::FRem) {
LibFunc Func;
if (TLI->getLibFunc(I->getOpcode(), I->getType(), Func) &&
TLI->isFunctionVectorizable(TLI->getName(Func), VF)) {
SmallVector<Type *, 4> OpTypes;
for (auto &Op : I->operands())
OpTypes.push_back(Op->getType());
VecCallCost =
TTI.getCallInstrCost(nullptr, VectorTy, OpTypes, CostKind);
}
}
return std::min(InstrCost, VecCallCost);
Op2Info, Operands, I, TLI);
}
case Instruction::FNeg: {
return TTI.getArithmeticInstrCost(

View File

@ -8902,7 +8902,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
TTI::OperandValueInfo Op1Info = getOperandInfo(E->getOperand(0));
TTI::OperandValueInfo Op2Info = getOperandInfo(E->getOperand(OpIdx));
return TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind, Op1Info,
Op2Info) +
Op2Info, std::nullopt, nullptr, TLI) +
CommonCost;
};
return GetCostDiff(GetScalarCost, GetVectorCost);

View File

@ -0,0 +1,55 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
; RUN: opt < %s -S -mtriple=aarch64 -vector-library=ArmPL -passes=slp-vectorizer | FileCheck %s
@a = common global ptr null, align 8
define void @frem_v2double() {
; CHECK-LABEL: define void @frem_v2double() {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = load <2 x double>, ptr @a, align 8
; CHECK-NEXT: [[TMP1:%.*]] = load <2 x double>, ptr @a, align 8
; CHECK-NEXT: [[TMP2:%.*]] = frem <2 x double> [[TMP0]], [[TMP1]]
; CHECK-NEXT: store <2 x double> [[TMP2]], ptr @a, align 8
; CHECK-NEXT: ret void
;
entry:
%a0 = load double, ptr getelementptr inbounds (double, ptr @a, i64 0), align 8
%a1 = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
%b0 = load double, ptr getelementptr inbounds (double, ptr @a, i64 0), align 8
%b1 = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
%r0 = frem double %a0, %b0
%r1 = frem double %a1, %b1
store double %r0, ptr getelementptr inbounds (double, ptr @a, i64 0), align 8
store double %r1, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
ret void
}
define void @frem_v4float() {
; CHECK-LABEL: define void @frem_v4float() {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = load <4 x float>, ptr @a, align 8
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x float>, ptr @a, align 8
; CHECK-NEXT: [[TMP2:%.*]] = frem <4 x float> [[TMP0]], [[TMP1]]
; CHECK-NEXT: store <4 x float> [[TMP2]], ptr @a, align 8
; CHECK-NEXT: ret void
;
entry:
%a0 = load float, ptr getelementptr inbounds (float, ptr @a, i64 0), align 8
%a1 = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
%a2 = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
%a3 = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
%b0 = load float, ptr getelementptr inbounds (float, ptr @a, i64 0), align 8
%b1 = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
%b2 = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
%b3 = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
%r0 = frem float %a0, %b0
%r1 = frem float %a1, %b1
%r2 = frem float %a2, %b2
%r3 = frem float %a3, %b3
store float %r0, ptr getelementptr inbounds (float, ptr @a, i64 0), align 8
store float %r1, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
store float %r2, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
store float %r3, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
ret void
}