mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-26 22:56:05 +00:00

Record all uses outside cycle with divergent exit during propagateTemporalDivergence in Uniformity analysis. With this list of candidates for temporal divergence lowering, excluding known lane masks from control flow intrinsics, find sources from inside the cycle that are not i1 and uniform. Temporal divergence lowering (non i1): create copy(v_mov) to vgpr, with implicit exec (to stop other passes from moving this copy outside of the cycle) and use this vgpr outside of the cycle instead of original uniform source.
246 lines
7.7 KiB
C++
246 lines
7.7 KiB
C++
//===- MachineUniformityAnalysis.cpp --------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "llvm/CodeGen/MachineUniformityAnalysis.h"
|
|
#include "llvm/ADT/GenericUniformityImpl.h"
|
|
#include "llvm/CodeGen/MachineCycleAnalysis.h"
|
|
#include "llvm/CodeGen/MachineDominators.h"
|
|
#include "llvm/CodeGen/MachineRegisterInfo.h"
|
|
#include "llvm/CodeGen/MachineSSAContext.h"
|
|
#include "llvm/CodeGen/TargetInstrInfo.h"
|
|
#include "llvm/InitializePasses.h"
|
|
|
|
using namespace llvm;
|
|
|
|
template <>
|
|
bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs(
|
|
const MachineInstr &I) const {
|
|
for (auto &op : I.all_defs()) {
|
|
if (isDivergent(op.getReg()))
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
template <>
|
|
bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent(
|
|
const MachineInstr &Instr) {
|
|
bool insertedDivergent = false;
|
|
const auto &MRI = F.getRegInfo();
|
|
const auto &RBI = *F.getSubtarget().getRegBankInfo();
|
|
const auto &TRI = *MRI.getTargetRegisterInfo();
|
|
for (auto &op : Instr.all_defs()) {
|
|
if (!op.getReg().isVirtual())
|
|
continue;
|
|
assert(!op.getSubReg());
|
|
if (TRI.isUniformReg(MRI, RBI, op.getReg()))
|
|
continue;
|
|
insertedDivergent |= markDivergent(op.getReg());
|
|
}
|
|
return insertedDivergent;
|
|
}
|
|
|
|
template <>
|
|
void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
|
|
const auto &InstrInfo = *F.getSubtarget().getInstrInfo();
|
|
|
|
for (const MachineBasicBlock &block : F) {
|
|
for (const MachineInstr &instr : block) {
|
|
auto uniformity = InstrInfo.getInstructionUniformity(instr);
|
|
if (uniformity == InstructionUniformity::AlwaysUniform) {
|
|
addUniformOverride(instr);
|
|
continue;
|
|
}
|
|
|
|
if (uniformity == InstructionUniformity::NeverUniform) {
|
|
markDivergent(instr);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
|
|
Register Reg) {
|
|
assert(isDivergent(Reg));
|
|
const auto &RegInfo = F.getRegInfo();
|
|
for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
|
|
markDivergent(UserInstr);
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
|
|
const MachineInstr &Instr) {
|
|
assert(!isAlwaysUniform(Instr));
|
|
if (Instr.isTerminator())
|
|
return;
|
|
for (const MachineOperand &op : Instr.all_defs()) {
|
|
auto Reg = op.getReg();
|
|
if (isDivergent(Reg))
|
|
pushUsers(Reg);
|
|
}
|
|
}
|
|
|
|
template <>
|
|
bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle(
|
|
const MachineInstr &I, const MachineCycle &DefCycle) const {
|
|
assert(!isAlwaysUniform(I));
|
|
for (auto &Op : I.operands()) {
|
|
if (!Op.isReg() || !Op.readsReg())
|
|
continue;
|
|
auto Reg = Op.getReg();
|
|
|
|
// FIXME: Physical registers need to be properly checked instead of always
|
|
// returning true
|
|
if (Reg.isPhysical())
|
|
return true;
|
|
|
|
auto *Def = F.getRegInfo().getVRegDef(Reg);
|
|
if (DefCycle.contains(Def->getParent()))
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
template <>
|
|
void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::
|
|
propagateTemporalDivergence(const MachineInstr &I,
|
|
const MachineCycle &DefCycle) {
|
|
const auto &RegInfo = F.getRegInfo();
|
|
for (auto &Op : I.all_defs()) {
|
|
if (!Op.getReg().isVirtual())
|
|
continue;
|
|
auto Reg = Op.getReg();
|
|
for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
|
|
if (DefCycle.contains(UserInstr.getParent()))
|
|
continue;
|
|
markDivergent(UserInstr);
|
|
|
|
recordTemporalDivergence(Reg, &UserInstr, &DefCycle);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <>
|
|
bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse(
|
|
const MachineOperand &U) const {
|
|
if (!U.isReg())
|
|
return false;
|
|
|
|
auto Reg = U.getReg();
|
|
if (isDivergent(Reg))
|
|
return true;
|
|
|
|
const auto &RegInfo = F.getRegInfo();
|
|
auto *Def = RegInfo.getOneDef(Reg);
|
|
if (!Def)
|
|
return true;
|
|
|
|
auto *DefInstr = Def->getParent();
|
|
auto *UseInstr = U.getParent();
|
|
return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
|
|
}
|
|
|
|
// This ensures explicit instantiation of
|
|
// GenericUniformityAnalysisImpl::ImplDeleter::operator()
|
|
template class llvm::GenericUniformityInfo<MachineSSAContext>;
|
|
template struct llvm::GenericUniformityAnalysisImplDeleter<
|
|
llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>;
|
|
|
|
MachineUniformityInfo llvm::computeMachineUniformityInfo(
|
|
MachineFunction &F, const MachineCycleInfo &cycleInfo,
|
|
const MachineDominatorTree &domTree, bool HasBranchDivergence) {
|
|
assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!");
|
|
MachineUniformityInfo UI(domTree, cycleInfo);
|
|
if (HasBranchDivergence)
|
|
UI.compute();
|
|
return UI;
|
|
}
|
|
|
|
namespace {
|
|
|
|
class MachineUniformityInfoPrinterPass : public MachineFunctionPass {
|
|
public:
|
|
static char ID;
|
|
|
|
MachineUniformityInfoPrinterPass();
|
|
|
|
bool runOnMachineFunction(MachineFunction &F) override;
|
|
void getAnalysisUsage(AnalysisUsage &AU) const override;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
char MachineUniformityAnalysisPass::ID = 0;
|
|
|
|
MachineUniformityAnalysisPass::MachineUniformityAnalysisPass()
|
|
: MachineFunctionPass(ID) {
|
|
initializeMachineUniformityAnalysisPassPass(*PassRegistry::getPassRegistry());
|
|
}
|
|
|
|
INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity",
|
|
"Machine Uniformity Info Analysis", false, true)
|
|
INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)
|
|
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
|
|
INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity",
|
|
"Machine Uniformity Info Analysis", false, true)
|
|
|
|
void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const {
|
|
AU.setPreservesAll();
|
|
AU.addRequiredTransitive<MachineCycleInfoWrapperPass>();
|
|
AU.addRequired<MachineDominatorTreeWrapperPass>();
|
|
MachineFunctionPass::getAnalysisUsage(AU);
|
|
}
|
|
|
|
bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) {
|
|
auto &DomTree = getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
|
|
auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
|
|
// FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a
|
|
// default NoTTI
|
|
UI = computeMachineUniformityInfo(MF, CI, DomTree, true);
|
|
return false;
|
|
}
|
|
|
|
void MachineUniformityAnalysisPass::print(raw_ostream &OS,
|
|
const Module *) const {
|
|
OS << "MachineUniformityInfo for function: " << UI.getFunction().getName()
|
|
<< "\n";
|
|
UI.print(OS);
|
|
}
|
|
|
|
char MachineUniformityInfoPrinterPass::ID = 0;
|
|
|
|
MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass()
|
|
: MachineFunctionPass(ID) {
|
|
initializeMachineUniformityInfoPrinterPassPass(
|
|
*PassRegistry::getPassRegistry());
|
|
}
|
|
|
|
INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass,
|
|
"print-machine-uniformity",
|
|
"Print Machine Uniformity Info Analysis", true, true)
|
|
INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass)
|
|
INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass,
|
|
"print-machine-uniformity",
|
|
"Print Machine Uniformity Info Analysis", true, true)
|
|
|
|
void MachineUniformityInfoPrinterPass::getAnalysisUsage(
|
|
AnalysisUsage &AU) const {
|
|
AU.setPreservesAll();
|
|
AU.addRequired<MachineUniformityAnalysisPass>();
|
|
MachineFunctionPass::getAnalysisUsage(AU);
|
|
}
|
|
|
|
bool MachineUniformityInfoPrinterPass::runOnMachineFunction(
|
|
MachineFunction &F) {
|
|
auto &UI = getAnalysis<MachineUniformityAnalysisPass>();
|
|
UI.print(errs());
|
|
return false;
|
|
}
|