mirror of
https://github.com/llvm/llvm-project.git
synced 2025-05-02 13:06:09 +00:00

This adds support for scalarizing these intrinsics as well the X86TargetTransformInfo support to avoid scalarizing them in the cases X86 can handle. I've omitted handling special cases for constant masks for this first pass. Though CodeGenPrepare can constant fold the branch conditions and remove some of the control flow anyway. Fixes PR40994 and is covers most of PR3666. Might want to implement constant masks to close that. Differential Revision: https://reviews.llvm.org/D59180 llvm-svn: 356687
766 lines
26 KiB
C++
766 lines
26 KiB
C++
//===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
|
|
// instrinsics
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This pass replaces masked memory intrinsics - when unsupported by the target
|
|
// - with a chain of basic blocks, that deal with the elements one-by-one if the
|
|
// appropriate mask bit is set.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "llvm/ADT/Twine.h"
|
|
#include "llvm/Analysis/TargetTransformInfo.h"
|
|
#include "llvm/CodeGen/TargetSubtargetInfo.h"
|
|
#include "llvm/IR/BasicBlock.h"
|
|
#include "llvm/IR/Constant.h"
|
|
#include "llvm/IR/Constants.h"
|
|
#include "llvm/IR/DerivedTypes.h"
|
|
#include "llvm/IR/Function.h"
|
|
#include "llvm/IR/IRBuilder.h"
|
|
#include "llvm/IR/InstrTypes.h"
|
|
#include "llvm/IR/Instruction.h"
|
|
#include "llvm/IR/Instructions.h"
|
|
#include "llvm/IR/IntrinsicInst.h"
|
|
#include "llvm/IR/Intrinsics.h"
|
|
#include "llvm/IR/Type.h"
|
|
#include "llvm/IR/Value.h"
|
|
#include "llvm/Pass.h"
|
|
#include "llvm/Support/Casting.h"
|
|
#include <algorithm>
|
|
#include <cassert>
|
|
|
|
using namespace llvm;
|
|
|
|
#define DEBUG_TYPE "scalarize-masked-mem-intrin"
|
|
|
|
namespace {
|
|
|
|
class ScalarizeMaskedMemIntrin : public FunctionPass {
|
|
const TargetTransformInfo *TTI = nullptr;
|
|
|
|
public:
|
|
static char ID; // Pass identification, replacement for typeid
|
|
|
|
explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
|
|
initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
|
|
}
|
|
|
|
bool runOnFunction(Function &F) override;
|
|
|
|
StringRef getPassName() const override {
|
|
return "Scalarize Masked Memory Intrinsics";
|
|
}
|
|
|
|
void getAnalysisUsage(AnalysisUsage &AU) const override {
|
|
AU.addRequired<TargetTransformInfoWrapperPass>();
|
|
}
|
|
|
|
private:
|
|
bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
|
|
bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
char ScalarizeMaskedMemIntrin::ID = 0;
|
|
|
|
INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
|
|
"Scalarize unsupported masked memory intrinsics", false, false)
|
|
|
|
FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
|
|
return new ScalarizeMaskedMemIntrin();
|
|
}
|
|
|
|
static bool isConstantIntVector(Value *Mask) {
|
|
Constant *C = dyn_cast<Constant>(Mask);
|
|
if (!C)
|
|
return false;
|
|
|
|
unsigned NumElts = Mask->getType()->getVectorNumElements();
|
|
for (unsigned i = 0; i != NumElts; ++i) {
|
|
Constant *CElt = C->getAggregateElement(i);
|
|
if (!CElt || !isa<ConstantInt>(CElt))
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
// Translate a masked load intrinsic like
|
|
// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
|
|
// <16 x i1> %mask, <16 x i32> %passthru)
|
|
// to a chain of basic blocks, with loading element one-by-one if
|
|
// the appropriate mask bit is set
|
|
//
|
|
// %1 = bitcast i8* %addr to i32*
|
|
// %2 = extractelement <16 x i1> %mask, i32 0
|
|
// br i1 %2, label %cond.load, label %else
|
|
//
|
|
// cond.load: ; preds = %0
|
|
// %3 = getelementptr i32* %1, i32 0
|
|
// %4 = load i32* %3
|
|
// %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
|
|
// br label %else
|
|
//
|
|
// else: ; preds = %0, %cond.load
|
|
// %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
|
|
// %6 = extractelement <16 x i1> %mask, i32 1
|
|
// br i1 %6, label %cond.load1, label %else2
|
|
//
|
|
// cond.load1: ; preds = %else
|
|
// %7 = getelementptr i32* %1, i32 1
|
|
// %8 = load i32* %7
|
|
// %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
|
|
// br label %else2
|
|
//
|
|
// else2: ; preds = %else, %cond.load1
|
|
// %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
|
|
// %10 = extractelement <16 x i1> %mask, i32 2
|
|
// br i1 %10, label %cond.load4, label %else5
|
|
//
|
|
static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
|
|
Value *Ptr = CI->getArgOperand(0);
|
|
Value *Alignment = CI->getArgOperand(1);
|
|
Value *Mask = CI->getArgOperand(2);
|
|
Value *Src0 = CI->getArgOperand(3);
|
|
|
|
unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
|
|
VectorType *VecType = cast<VectorType>(CI->getType());
|
|
|
|
Type *EltTy = VecType->getElementType();
|
|
|
|
IRBuilder<> Builder(CI->getContext());
|
|
Instruction *InsertPt = CI;
|
|
BasicBlock *IfBlock = CI->getParent();
|
|
|
|
Builder.SetInsertPoint(InsertPt);
|
|
Builder.SetCurrentDebugLocation(CI->getDebugLoc());
|
|
|
|
// Short-cut if the mask is all-true.
|
|
if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
|
|
Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
|
|
CI->replaceAllUsesWith(NewI);
|
|
CI->eraseFromParent();
|
|
return;
|
|
}
|
|
|
|
// Adjust alignment for the scalar instruction.
|
|
AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
|
|
// Bitcast %addr from i8* to EltTy*
|
|
Type *NewPtrType =
|
|
EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
|
|
Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
|
|
unsigned VectorWidth = VecType->getNumElements();
|
|
|
|
// The result vector
|
|
Value *VResult = Src0;
|
|
|
|
if (isConstantIntVector(Mask)) {
|
|
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
|
|
continue;
|
|
Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
|
|
LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
|
|
VResult = Builder.CreateInsertElement(VResult, Load, Idx);
|
|
}
|
|
CI->replaceAllUsesWith(VResult);
|
|
CI->eraseFromParent();
|
|
return;
|
|
}
|
|
|
|
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
// Fill the "else" block, created in the previous iteration
|
|
//
|
|
// %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
|
|
// %mask_1 = extractelement <16 x i1> %mask, i32 Idx
|
|
// br i1 %mask_1, label %cond.load, label %else
|
|
//
|
|
|
|
Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
|
|
|
|
// Create "cond" block
|
|
//
|
|
// %EltAddr = getelementptr i32* %1, i32 0
|
|
// %Elt = load i32* %EltAddr
|
|
// VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
|
|
//
|
|
BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
|
|
"cond.load");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
|
|
Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
|
|
LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
|
|
Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
|
|
|
|
// Create "else" block, fill it in the next iteration
|
|
BasicBlock *NewIfBlock =
|
|
CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
Instruction *OldBr = IfBlock->getTerminator();
|
|
BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
|
|
OldBr->eraseFromParent();
|
|
BasicBlock *PrevIfBlock = IfBlock;
|
|
IfBlock = NewIfBlock;
|
|
|
|
// Create the phi to join the new and previous value.
|
|
PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
|
|
Phi->addIncoming(NewVResult, CondBlock);
|
|
Phi->addIncoming(VResult, PrevIfBlock);
|
|
VResult = Phi;
|
|
}
|
|
|
|
CI->replaceAllUsesWith(VResult);
|
|
CI->eraseFromParent();
|
|
|
|
ModifiedDT = true;
|
|
}
|
|
|
|
// Translate a masked store intrinsic, like
|
|
// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
|
|
// <16 x i1> %mask)
|
|
// to a chain of basic blocks, that stores element one-by-one if
|
|
// the appropriate mask bit is set
|
|
//
|
|
// %1 = bitcast i8* %addr to i32*
|
|
// %2 = extractelement <16 x i1> %mask, i32 0
|
|
// br i1 %2, label %cond.store, label %else
|
|
//
|
|
// cond.store: ; preds = %0
|
|
// %3 = extractelement <16 x i32> %val, i32 0
|
|
// %4 = getelementptr i32* %1, i32 0
|
|
// store i32 %3, i32* %4
|
|
// br label %else
|
|
//
|
|
// else: ; preds = %0, %cond.store
|
|
// %5 = extractelement <16 x i1> %mask, i32 1
|
|
// br i1 %5, label %cond.store1, label %else2
|
|
//
|
|
// cond.store1: ; preds = %else
|
|
// %6 = extractelement <16 x i32> %val, i32 1
|
|
// %7 = getelementptr i32* %1, i32 1
|
|
// store i32 %6, i32* %7
|
|
// br label %else2
|
|
// . . .
|
|
static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
|
|
Value *Src = CI->getArgOperand(0);
|
|
Value *Ptr = CI->getArgOperand(1);
|
|
Value *Alignment = CI->getArgOperand(2);
|
|
Value *Mask = CI->getArgOperand(3);
|
|
|
|
unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
|
|
VectorType *VecType = cast<VectorType>(Src->getType());
|
|
|
|
Type *EltTy = VecType->getElementType();
|
|
|
|
IRBuilder<> Builder(CI->getContext());
|
|
Instruction *InsertPt = CI;
|
|
BasicBlock *IfBlock = CI->getParent();
|
|
Builder.SetInsertPoint(InsertPt);
|
|
Builder.SetCurrentDebugLocation(CI->getDebugLoc());
|
|
|
|
// Short-cut if the mask is all-true.
|
|
if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
|
|
Builder.CreateAlignedStore(Src, Ptr, AlignVal);
|
|
CI->eraseFromParent();
|
|
return;
|
|
}
|
|
|
|
// Adjust alignment for the scalar instruction.
|
|
AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
|
|
// Bitcast %addr from i8* to EltTy*
|
|
Type *NewPtrType =
|
|
EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
|
|
Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
|
|
unsigned VectorWidth = VecType->getNumElements();
|
|
|
|
if (isConstantIntVector(Mask)) {
|
|
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
|
|
continue;
|
|
Value *OneElt = Builder.CreateExtractElement(Src, Idx);
|
|
Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
|
|
Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
|
|
}
|
|
CI->eraseFromParent();
|
|
return;
|
|
}
|
|
|
|
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
// Fill the "else" block, created in the previous iteration
|
|
//
|
|
// %mask_1 = extractelement <16 x i1> %mask, i32 Idx
|
|
// br i1 %mask_1, label %cond.store, label %else
|
|
//
|
|
Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
|
|
|
|
// Create "cond" block
|
|
//
|
|
// %OneElt = extractelement <16 x i32> %Src, i32 Idx
|
|
// %EltAddr = getelementptr i32* %1, i32 0
|
|
// %store i32 %OneElt, i32* %EltAddr
|
|
//
|
|
BasicBlock *CondBlock =
|
|
IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
|
|
Value *OneElt = Builder.CreateExtractElement(Src, Idx);
|
|
Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
|
|
Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
|
|
|
|
// Create "else" block, fill it in the next iteration
|
|
BasicBlock *NewIfBlock =
|
|
CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
Instruction *OldBr = IfBlock->getTerminator();
|
|
BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
|
|
OldBr->eraseFromParent();
|
|
IfBlock = NewIfBlock;
|
|
}
|
|
CI->eraseFromParent();
|
|
|
|
ModifiedDT = true;
|
|
}
|
|
|
|
// Translate a masked gather intrinsic like
|
|
// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
|
|
// <16 x i1> %Mask, <16 x i32> %Src)
|
|
// to a chain of basic blocks, with loading element one-by-one if
|
|
// the appropriate mask bit is set
|
|
//
|
|
// %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
|
|
// %Mask0 = extractelement <16 x i1> %Mask, i32 0
|
|
// br i1 %Mask0, label %cond.load, label %else
|
|
//
|
|
// cond.load:
|
|
// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
|
|
// %Load0 = load i32, i32* %Ptr0, align 4
|
|
// %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
|
|
// br label %else
|
|
//
|
|
// else:
|
|
// %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
|
|
// %Mask1 = extractelement <16 x i1> %Mask, i32 1
|
|
// br i1 %Mask1, label %cond.load1, label %else2
|
|
//
|
|
// cond.load1:
|
|
// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
|
|
// %Load1 = load i32, i32* %Ptr1, align 4
|
|
// %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
|
|
// br label %else2
|
|
// . . .
|
|
// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
|
|
// ret <16 x i32> %Result
|
|
static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
|
|
Value *Ptrs = CI->getArgOperand(0);
|
|
Value *Alignment = CI->getArgOperand(1);
|
|
Value *Mask = CI->getArgOperand(2);
|
|
Value *Src0 = CI->getArgOperand(3);
|
|
|
|
VectorType *VecType = cast<VectorType>(CI->getType());
|
|
Type *EltTy = VecType->getElementType();
|
|
|
|
IRBuilder<> Builder(CI->getContext());
|
|
Instruction *InsertPt = CI;
|
|
BasicBlock *IfBlock = CI->getParent();
|
|
Builder.SetInsertPoint(InsertPt);
|
|
unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
|
|
|
|
Builder.SetCurrentDebugLocation(CI->getDebugLoc());
|
|
|
|
// The result vector
|
|
Value *VResult = Src0;
|
|
unsigned VectorWidth = VecType->getNumElements();
|
|
|
|
// Shorten the way if the mask is a vector of constants.
|
|
if (isConstantIntVector(Mask)) {
|
|
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
|
|
continue;
|
|
Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
|
|
LoadInst *Load =
|
|
Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
|
|
VResult =
|
|
Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
|
|
}
|
|
CI->replaceAllUsesWith(VResult);
|
|
CI->eraseFromParent();
|
|
return;
|
|
}
|
|
|
|
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
// Fill the "else" block, created in the previous iteration
|
|
//
|
|
// %Mask1 = extractelement <16 x i1> %Mask, i32 1
|
|
// br i1 %Mask1, label %cond.load, label %else
|
|
//
|
|
|
|
Value *Predicate =
|
|
Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
|
|
|
|
// Create "cond" block
|
|
//
|
|
// %EltAddr = getelementptr i32* %1, i32 0
|
|
// %Elt = load i32* %EltAddr
|
|
// VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
|
|
//
|
|
BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
|
|
Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
|
|
LoadInst *Load =
|
|
Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
|
|
Value *NewVResult =
|
|
Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
|
|
|
|
// Create "else" block, fill it in the next iteration
|
|
BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
Instruction *OldBr = IfBlock->getTerminator();
|
|
BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
|
|
OldBr->eraseFromParent();
|
|
BasicBlock *PrevIfBlock = IfBlock;
|
|
IfBlock = NewIfBlock;
|
|
|
|
PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
|
|
Phi->addIncoming(NewVResult, CondBlock);
|
|
Phi->addIncoming(VResult, PrevIfBlock);
|
|
VResult = Phi;
|
|
}
|
|
|
|
CI->replaceAllUsesWith(VResult);
|
|
CI->eraseFromParent();
|
|
|
|
ModifiedDT = true;
|
|
}
|
|
|
|
// Translate a masked scatter intrinsic, like
|
|
// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
|
|
// <16 x i1> %Mask)
|
|
// to a chain of basic blocks, that stores element one-by-one if
|
|
// the appropriate mask bit is set.
|
|
//
|
|
// %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
|
|
// %Mask0 = extractelement <16 x i1> %Mask, i32 0
|
|
// br i1 %Mask0, label %cond.store, label %else
|
|
//
|
|
// cond.store:
|
|
// %Elt0 = extractelement <16 x i32> %Src, i32 0
|
|
// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
|
|
// store i32 %Elt0, i32* %Ptr0, align 4
|
|
// br label %else
|
|
//
|
|
// else:
|
|
// %Mask1 = extractelement <16 x i1> %Mask, i32 1
|
|
// br i1 %Mask1, label %cond.store1, label %else2
|
|
//
|
|
// cond.store1:
|
|
// %Elt1 = extractelement <16 x i32> %Src, i32 1
|
|
// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
|
|
// store i32 %Elt1, i32* %Ptr1, align 4
|
|
// br label %else2
|
|
// . . .
|
|
static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
|
|
Value *Src = CI->getArgOperand(0);
|
|
Value *Ptrs = CI->getArgOperand(1);
|
|
Value *Alignment = CI->getArgOperand(2);
|
|
Value *Mask = CI->getArgOperand(3);
|
|
|
|
assert(isa<VectorType>(Src->getType()) &&
|
|
"Unexpected data type in masked scatter intrinsic");
|
|
assert(isa<VectorType>(Ptrs->getType()) &&
|
|
isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
|
|
"Vector of pointers is expected in masked scatter intrinsic");
|
|
|
|
IRBuilder<> Builder(CI->getContext());
|
|
Instruction *InsertPt = CI;
|
|
BasicBlock *IfBlock = CI->getParent();
|
|
Builder.SetInsertPoint(InsertPt);
|
|
Builder.SetCurrentDebugLocation(CI->getDebugLoc());
|
|
|
|
unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
|
|
unsigned VectorWidth = Src->getType()->getVectorNumElements();
|
|
|
|
// Shorten the way if the mask is a vector of constants.
|
|
if (isConstantIntVector(Mask)) {
|
|
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
if (cast<ConstantVector>(Mask)->getAggregateElement(Idx)->isNullValue())
|
|
continue;
|
|
Value *OneElt =
|
|
Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
|
|
Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
|
|
Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
|
|
}
|
|
CI->eraseFromParent();
|
|
return;
|
|
}
|
|
|
|
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
// Fill the "else" block, created in the previous iteration
|
|
//
|
|
// %Mask1 = extractelement <16 x i1> %Mask, i32 Idx
|
|
// br i1 %Mask1, label %cond.store, label %else
|
|
//
|
|
Value *Predicate =
|
|
Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
|
|
|
|
// Create "cond" block
|
|
//
|
|
// %Elt1 = extractelement <16 x i32> %Src, i32 1
|
|
// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
|
|
// %store i32 %Elt1, i32* %Ptr1
|
|
//
|
|
BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
|
|
Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
|
|
Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
|
|
Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
|
|
|
|
// Create "else" block, fill it in the next iteration
|
|
BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
Instruction *OldBr = IfBlock->getTerminator();
|
|
BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
|
|
OldBr->eraseFromParent();
|
|
IfBlock = NewIfBlock;
|
|
}
|
|
CI->eraseFromParent();
|
|
|
|
ModifiedDT = true;
|
|
}
|
|
|
|
static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
|
|
Value *Ptr = CI->getArgOperand(0);
|
|
Value *Mask = CI->getArgOperand(1);
|
|
Value *PassThru = CI->getArgOperand(2);
|
|
|
|
VectorType *VecType = cast<VectorType>(CI->getType());
|
|
|
|
Type *EltTy = VecType->getElementType();
|
|
|
|
IRBuilder<> Builder(CI->getContext());
|
|
Instruction *InsertPt = CI;
|
|
BasicBlock *IfBlock = CI->getParent();
|
|
|
|
Builder.SetInsertPoint(InsertPt);
|
|
Builder.SetCurrentDebugLocation(CI->getDebugLoc());
|
|
|
|
unsigned VectorWidth = VecType->getNumElements();
|
|
|
|
// The result vector
|
|
Value *VResult = PassThru;
|
|
|
|
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
// Fill the "else" block, created in the previous iteration
|
|
//
|
|
// %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
|
|
// %mask_1 = extractelement <16 x i1> %mask, i32 Idx
|
|
// br i1 %mask_1, label %cond.load, label %else
|
|
//
|
|
|
|
Value *Predicate =
|
|
Builder.CreateExtractElement(Mask, Idx);
|
|
|
|
// Create "cond" block
|
|
//
|
|
// %EltAddr = getelementptr i32* %1, i32 0
|
|
// %Elt = load i32* %EltAddr
|
|
// VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
|
|
//
|
|
BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
|
|
"cond.load");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
|
|
LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, 1);
|
|
Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
|
|
|
|
// Move the pointer if there are more blocks to come.
|
|
Value *NewPtr;
|
|
if ((Idx + 1) != VectorWidth)
|
|
NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
|
|
|
|
// Create "else" block, fill it in the next iteration
|
|
BasicBlock *NewIfBlock =
|
|
CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
Instruction *OldBr = IfBlock->getTerminator();
|
|
BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
|
|
OldBr->eraseFromParent();
|
|
BasicBlock *PrevIfBlock = IfBlock;
|
|
IfBlock = NewIfBlock;
|
|
|
|
// Create the phi to join the new and previous value.
|
|
PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
|
|
ResultPhi->addIncoming(NewVResult, CondBlock);
|
|
ResultPhi->addIncoming(VResult, PrevIfBlock);
|
|
VResult = ResultPhi;
|
|
|
|
// Add a PHI for the pointer if this isn't the last iteration.
|
|
if ((Idx + 1) != VectorWidth) {
|
|
PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
|
|
PtrPhi->addIncoming(NewPtr, CondBlock);
|
|
PtrPhi->addIncoming(Ptr, PrevIfBlock);
|
|
Ptr = PtrPhi;
|
|
}
|
|
}
|
|
|
|
CI->replaceAllUsesWith(VResult);
|
|
CI->eraseFromParent();
|
|
|
|
ModifiedDT = true;
|
|
}
|
|
|
|
static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
|
|
Value *Src = CI->getArgOperand(0);
|
|
Value *Ptr = CI->getArgOperand(1);
|
|
Value *Mask = CI->getArgOperand(2);
|
|
|
|
VectorType *VecType = cast<VectorType>(Src->getType());
|
|
|
|
IRBuilder<> Builder(CI->getContext());
|
|
Instruction *InsertPt = CI;
|
|
BasicBlock *IfBlock = CI->getParent();
|
|
|
|
Builder.SetInsertPoint(InsertPt);
|
|
Builder.SetCurrentDebugLocation(CI->getDebugLoc());
|
|
|
|
Type *EltTy = VecType->getVectorElementType();
|
|
|
|
unsigned VectorWidth = VecType->getNumElements();
|
|
|
|
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
// Fill the "else" block, created in the previous iteration
|
|
//
|
|
// %mask_1 = extractelement <16 x i1> %mask, i32 Idx
|
|
// br i1 %mask_1, label %cond.store, label %else
|
|
//
|
|
Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
|
|
|
|
// Create "cond" block
|
|
//
|
|
// %OneElt = extractelement <16 x i32> %Src, i32 Idx
|
|
// %EltAddr = getelementptr i32* %1, i32 0
|
|
// %store i32 %OneElt, i32* %EltAddr
|
|
//
|
|
BasicBlock *CondBlock =
|
|
IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
|
|
Value *OneElt = Builder.CreateExtractElement(Src, Idx);
|
|
Builder.CreateAlignedStore(OneElt, Ptr, 1);
|
|
|
|
// Move the pointer if there are more blocks to come.
|
|
Value *NewPtr;
|
|
if ((Idx + 1) != VectorWidth)
|
|
NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
|
|
|
|
// Create "else" block, fill it in the next iteration
|
|
BasicBlock *NewIfBlock =
|
|
CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
Instruction *OldBr = IfBlock->getTerminator();
|
|
BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
|
|
OldBr->eraseFromParent();
|
|
BasicBlock *PrevIfBlock = IfBlock;
|
|
IfBlock = NewIfBlock;
|
|
|
|
// Add a PHI for the pointer if this isn't the last iteration.
|
|
if ((Idx + 1) != VectorWidth) {
|
|
PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
|
|
PtrPhi->addIncoming(NewPtr, CondBlock);
|
|
PtrPhi->addIncoming(Ptr, PrevIfBlock);
|
|
Ptr = PtrPhi;
|
|
}
|
|
}
|
|
CI->eraseFromParent();
|
|
|
|
ModifiedDT = true;
|
|
}
|
|
|
|
bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
|
|
bool EverMadeChange = false;
|
|
|
|
TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
|
|
|
|
bool MadeChange = true;
|
|
while (MadeChange) {
|
|
MadeChange = false;
|
|
for (Function::iterator I = F.begin(); I != F.end();) {
|
|
BasicBlock *BB = &*I++;
|
|
bool ModifiedDTOnIteration = false;
|
|
MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
|
|
|
|
// Restart BB iteration if the dominator tree of the Function was changed
|
|
if (ModifiedDTOnIteration)
|
|
break;
|
|
}
|
|
|
|
EverMadeChange |= MadeChange;
|
|
}
|
|
|
|
return EverMadeChange;
|
|
}
|
|
|
|
bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
|
|
bool MadeChange = false;
|
|
|
|
BasicBlock::iterator CurInstIterator = BB.begin();
|
|
while (CurInstIterator != BB.end()) {
|
|
if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
|
|
MadeChange |= optimizeCallInst(CI, ModifiedDT);
|
|
if (ModifiedDT)
|
|
return true;
|
|
}
|
|
|
|
return MadeChange;
|
|
}
|
|
|
|
bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
|
|
bool &ModifiedDT) {
|
|
IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
|
|
if (II) {
|
|
switch (II->getIntrinsicID()) {
|
|
default:
|
|
break;
|
|
case Intrinsic::masked_load:
|
|
// Scalarize unsupported vector masked load
|
|
if (TTI->isLegalMaskedLoad(CI->getType()))
|
|
return false;
|
|
scalarizeMaskedLoad(CI, ModifiedDT);
|
|
return true;
|
|
case Intrinsic::masked_store:
|
|
if (TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType()))
|
|
return false;
|
|
scalarizeMaskedStore(CI, ModifiedDT);
|
|
return true;
|
|
case Intrinsic::masked_gather:
|
|
if (TTI->isLegalMaskedGather(CI->getType()))
|
|
return false;
|
|
scalarizeMaskedGather(CI, ModifiedDT);
|
|
return true;
|
|
case Intrinsic::masked_scatter:
|
|
if (TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType()))
|
|
return false;
|
|
scalarizeMaskedScatter(CI, ModifiedDT);
|
|
return true;
|
|
case Intrinsic::masked_expandload:
|
|
if (TTI->isLegalMaskedExpandLoad(CI->getType()))
|
|
return false;
|
|
scalarizeMaskedExpandLoad(CI, ModifiedDT);
|
|
return true;
|
|
case Intrinsic::masked_compressstore:
|
|
if (TTI->isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
|
|
return false;
|
|
scalarizeMaskedCompressStore(CI, ModifiedDT);
|
|
return true;
|
|
}
|
|
}
|
|
|
|
return false;
|
|
}
|