llvm-project/llvm/lib/IR/ReplaceConstant.cpp
Nikita Popov 5ef768d22b
[AMDGPULowerBufferFatPointers] Expand const exprs using fat pointers (#95558)
Expand all constant expressions that use fat pointers upfront, so that
the rewriting logic only has to deal with instructions and not the
constant expression variants as well.

My primary motivation is to remove the creation of illegal constant
expressions (mul and shl) from this pass, but this also cuts down quite
a bit on the amount of duplicate logic.
2024-06-17 09:28:09 +02:00

123 lines
4.2 KiB
C++

//===- ReplaceConstant.cpp - Replace LLVM constant expression--------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a utility function for replacing LLVM constant
// expressions by instructions.
//
//===----------------------------------------------------------------------===//
#include "llvm/IR/ReplaceConstant.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Instructions.h"
namespace llvm {
static bool isExpandableUser(User *U) {
return isa<ConstantExpr>(U) || isa<ConstantAggregate>(U);
}
static SmallVector<Instruction *, 4> expandUser(BasicBlock::iterator InsertPt,
Constant *C) {
SmallVector<Instruction *, 4> NewInsts;
if (auto *CE = dyn_cast<ConstantExpr>(C)) {
Instruction *ConstInst = CE->getAsInstruction();
ConstInst->insertBefore(*InsertPt->getParent(), InsertPt);
NewInsts.push_back(ConstInst);
} else if (isa<ConstantStruct>(C) || isa<ConstantArray>(C)) {
Value *V = PoisonValue::get(C->getType());
for (auto [Idx, Op] : enumerate(C->operands())) {
V = InsertValueInst::Create(V, Op, Idx, "", InsertPt);
NewInsts.push_back(cast<Instruction>(V));
}
} else if (isa<ConstantVector>(C)) {
Type *IdxTy = Type::getInt32Ty(C->getContext());
Value *V = PoisonValue::get(C->getType());
for (auto [Idx, Op] : enumerate(C->operands())) {
V = InsertElementInst::Create(V, Op, ConstantInt::get(IdxTy, Idx), "",
InsertPt);
NewInsts.push_back(cast<Instruction>(V));
}
} else {
llvm_unreachable("Not an expandable user");
}
return NewInsts;
}
bool convertUsersOfConstantsToInstructions(ArrayRef<Constant *> Consts,
Function *RestrictToFunc,
bool RemoveDeadConstants,
bool IncludeSelf) {
// Find all expandable direct users of Consts.
SmallVector<Constant *> Stack;
for (Constant *C : Consts) {
if (IncludeSelf) {
assert(isExpandableUser(C) && "One of the constants is not expandable");
Stack.push_back(C);
} else {
for (User *U : C->users())
if (isExpandableUser(U))
Stack.push_back(cast<Constant>(U));
}
}
// Include transitive users.
SetVector<Constant *> ExpandableUsers;
while (!Stack.empty()) {
Constant *C = Stack.pop_back_val();
if (!ExpandableUsers.insert(C))
continue;
for (auto *Nested : C->users())
if (isExpandableUser(Nested))
Stack.push_back(cast<Constant>(Nested));
}
// Find all instructions that use any of the expandable users
SetVector<Instruction *> InstructionWorklist;
for (Constant *C : ExpandableUsers)
for (User *U : C->users())
if (auto *I = dyn_cast<Instruction>(U))
if (!RestrictToFunc || I->getFunction() == RestrictToFunc)
InstructionWorklist.insert(I);
// Replace those expandable operands with instructions
bool Changed = false;
while (!InstructionWorklist.empty()) {
Instruction *I = InstructionWorklist.pop_back_val();
DebugLoc Loc = I->getDebugLoc();
for (Use &U : I->operands()) {
BasicBlock::iterator BI = I->getIterator();
if (auto *Phi = dyn_cast<PHINode>(I)) {
BasicBlock *BB = Phi->getIncomingBlock(U);
BI = BB->getFirstInsertionPt();
assert(BI != BB->end() && "Unexpected empty basic block");
}
if (auto *C = dyn_cast<Constant>(U.get())) {
if (ExpandableUsers.contains(C)) {
Changed = true;
auto NewInsts = expandUser(BI, C);
for (auto *NI : NewInsts)
NI->setDebugLoc(Loc);
InstructionWorklist.insert(NewInsts.begin(), NewInsts.end());
U.set(NewInsts.back());
}
}
}
}
if (RemoveDeadConstants)
for (Constant *C : Consts)
C->removeDeadConstantUsers();
return Changed;
}
} // namespace llvm