[LV] Split RecurrenceDescriptor into RecurKind + FastMathFlags in LoopUtils. NFC (#132014)

Split off from #131300, this splits up RecurrenceDescriptor arguments so
that arbitrary recurrence kinds may be used down the line.
This commit is contained in:
Luke Lau 2025-03-19 23:56:57 +09:00 committed by GitHub
parent 3bba268013
commit f536f71580
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 14 additions and 21 deletions

View File

@ -411,8 +411,8 @@ Value *createSimpleReduction(IRBuilderBase &B, Value *Src,
RecurKind RdxKind);
/// Overloaded function to generate vector-predication intrinsics for
/// reduction.
Value *createSimpleReduction(VectorBuilder &VB, Value *Src,
const RecurrenceDescriptor &Desc);
Value *createSimpleReduction(VectorBuilder &VB, Value *Src, RecurKind RdxKind,
FastMathFlags FMFs);
/// Create a reduction of the given vector \p Src for a reduction of the
/// kind RecurKind::IAnyOf or RecurKind::FAnyOf. The reduction operation is
@ -428,14 +428,12 @@ Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src,
const RecurrenceDescriptor &Desc);
/// Create an ordered reduction intrinsic using the given recurrence
/// descriptor \p Desc.
Value *createOrderedReduction(IRBuilderBase &B,
const RecurrenceDescriptor &Desc, Value *Src,
/// kind \p RdxKind.
Value *createOrderedReduction(IRBuilderBase &B, RecurKind RdxKind, Value *Src,
Value *Start);
/// Overloaded function to generate vector-predication intrinsics for ordered
/// reduction.
Value *createOrderedReduction(VectorBuilder &VB,
const RecurrenceDescriptor &Desc, Value *Src,
Value *createOrderedReduction(VectorBuilder &VB, RecurKind RdxKind, Value *Src,
Value *Start);
/// Get the intersection (logical and) of all of the potential IR flags

View File

@ -1333,24 +1333,21 @@ Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
}
Value *llvm::createSimpleReduction(VectorBuilder &VBuilder, Value *Src,
const RecurrenceDescriptor &Desc) {
RecurKind Kind = Desc.getRecurrenceKind();
RecurKind Kind, FastMathFlags FMFs) {
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
!RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) &&
"AnyOf or FindLastIV reductions are not supported.");
Intrinsic::ID Id = getReductionIntrinsicID(Kind);
auto *SrcTy = cast<VectorType>(Src->getType());
Type *SrcEltTy = SrcTy->getElementType();
Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags());
Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, FMFs);
Value *Ops[] = {Iden, Src};
return VBuilder.createSimpleReduction(Id, SrcTy, Ops);
}
Value *llvm::createOrderedReduction(IRBuilderBase &B,
const RecurrenceDescriptor &Desc,
Value *llvm::createOrderedReduction(IRBuilderBase &B, RecurKind Kind,
Value *Src, Value *Start) {
assert((Desc.getRecurrenceKind() == RecurKind::FAdd ||
Desc.getRecurrenceKind() == RecurKind::FMulAdd) &&
assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
"Unexpected reduction kind");
assert(Src->getType()->isVectorTy() && "Expected a vector type");
assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
@ -1358,11 +1355,9 @@ Value *llvm::createOrderedReduction(IRBuilderBase &B,
return B.CreateFAddReduce(Start, Src);
}
Value *llvm::createOrderedReduction(VectorBuilder &VBuilder,
const RecurrenceDescriptor &Desc,
Value *llvm::createOrderedReduction(VectorBuilder &VBuilder, RecurKind Kind,
Value *Src, Value *Start) {
assert((Desc.getRecurrenceKind() == RecurKind::FAdd ||
Desc.getRecurrenceKind() == RecurKind::FMulAdd) &&
assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
"Unexpected reduction kind");
assert(Src->getType()->isVectorTy() && "Expected a vector type");
assert(!Start->getType()->isVectorTy() && "Expected a scalar type");

View File

@ -2311,7 +2311,7 @@ void VPReductionRecipe::execute(VPTransformState &State) {
if (IsOrdered) {
if (State.VF.isVector())
NewRed =
createOrderedReduction(State.Builder, RdxDesc, NewVecOp, PrevInChain);
createOrderedReduction(State.Builder, Kind, NewVecOp, PrevInChain);
else
NewRed = State.Builder.CreateBinOp(
(Instruction::BinaryOps)RdxDesc.getOpcode(), PrevInChain, NewVecOp);
@ -2356,9 +2356,9 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
Value *NewRed;
if (isOrdered()) {
NewRed = createOrderedReduction(VBuilder, RdxDesc, VecOp, Prev);
NewRed = createOrderedReduction(VBuilder, Kind, VecOp, Prev);
} else {
NewRed = createSimpleReduction(VBuilder, VecOp, RdxDesc);
NewRed = createSimpleReduction(VBuilder, VecOp, Kind, getFastMathFlags());
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
NewRed = createMinMaxOp(Builder, Kind, NewRed, Prev);
else