diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2023-02-11 12:38:04 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2023-02-11 12:38:11 +0000 |
commit | e3b557809604d036af6e00c60f012c2025b59a5e (patch) | |
tree | 8a11ba2269a3b669601e2fd41145b174008f4da8 /llvm/lib/CodeGen/MachineUniformityAnalysis.cpp | |
parent | 08e8dd7b9db7bb4a9de26d44c1cbfd24e869c014 (diff) |
Diffstat (limited to 'llvm/lib/CodeGen/MachineUniformityAnalysis.cpp')
-rw-r--r-- | llvm/lib/CodeGen/MachineUniformityAnalysis.cpp | 223 |
1 files changed, 223 insertions, 0 deletions
diff --git a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp new file mode 100644 index 000000000000..2fe5e40a58c2 --- /dev/null +++ b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp @@ -0,0 +1,223 @@ +//===- MachineUniformityAnalysis.cpp --------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/CodeGen/MachineUniformityAnalysis.h" +#include "llvm/ADT/GenericUniformityImpl.h" +#include "llvm/CodeGen/MachineCycleAnalysis.h" +#include "llvm/CodeGen/MachineDominators.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/MachineSSAContext.h" +#include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/InitializePasses.h" + +using namespace llvm; + +template <> +bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs( + const MachineInstr &I) const { + for (auto &op : I.operands()) { + if (!op.isReg() || !op.isDef()) + continue; + if (isDivergent(op.getReg())) + return true; + } + return false; +} + +template <> +bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent( + const MachineInstr &Instr, bool AllDefsDivergent) { + bool insertedDivergent = false; + const auto &MRI = F.getRegInfo(); + const auto &TRI = *MRI.getTargetRegisterInfo(); + for (auto &op : Instr.operands()) { + if (!op.isReg() || !op.isDef()) + continue; + if (!op.getReg().isVirtual()) + continue; + assert(!op.getSubReg()); + if (!AllDefsDivergent) { + auto *RC = MRI.getRegClassOrNull(op.getReg()); + if (RC && !TRI.isDivergentRegClass(RC)) + continue; + } + insertedDivergent |= markDivergent(op.getReg()); + } + return insertedDivergent; +} + +template <> +void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() { + const auto &InstrInfo = *F.getSubtarget().getInstrInfo(); + + for (const MachineBasicBlock &block : F) { + for (const MachineInstr &instr : block) { + auto uniformity = InstrInfo.getInstructionUniformity(instr); + if (uniformity == InstructionUniformity::AlwaysUniform) { + addUniformOverride(instr); + continue; + } + + if (uniformity == InstructionUniformity::NeverUniform) { + markDefsDivergent(instr, /* AllDefsDivergent = */ false); + } + } + } +} + +template <> +void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers( + Register Reg) { + const auto &RegInfo = F.getRegInfo(); + for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { + if (isAlwaysUniform(UserInstr)) + continue; + if (markDivergent(UserInstr)) + Worklist.push_back(&UserInstr); + } +} + +template <> +void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers( + const MachineInstr &Instr) { + assert(!isAlwaysUniform(Instr)); + if (Instr.isTerminator()) + return; + for (const MachineOperand &op : Instr.operands()) { + if (op.isReg() && op.isDef() && op.getReg().isVirtual()) + pushUsers(op.getReg()); + } +} + +template <> +bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle( + const MachineInstr &I, const MachineCycle &DefCycle) const { + assert(!isAlwaysUniform(I)); + for (auto &Op : I.operands()) { + if (!Op.isReg() || !Op.readsReg()) + continue; + auto Reg = Op.getReg(); + assert(Reg.isVirtual()); + auto *Def = F.getRegInfo().getVRegDef(Reg); + if (DefCycle.contains(Def->getParent())) + return true; + } + return false; +} + +// This ensures explicit instantiation of +// GenericUniformityAnalysisImpl::ImplDeleter::operator() +template class llvm::GenericUniformityInfo<MachineSSAContext>; +template struct llvm::GenericUniformityAnalysisImplDeleter< + llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>; + +MachineUniformityInfo +llvm::computeMachineUniformityInfo(MachineFunction &F, + const MachineCycleInfo &cycleInfo, + const MachineDomTree &domTree) { + assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!"); + return MachineUniformityInfo(F, domTree, cycleInfo); +} + +namespace { + +/// Legacy analysis pass which computes a \ref MachineUniformityInfo. +class MachineUniformityAnalysisPass : public MachineFunctionPass { + MachineUniformityInfo UI; + +public: + static char ID; + + MachineUniformityAnalysisPass(); + + MachineUniformityInfo &getUniformityInfo() { return UI; } + const MachineUniformityInfo &getUniformityInfo() const { return UI; } + + bool runOnMachineFunction(MachineFunction &F) override; + void getAnalysisUsage(AnalysisUsage &AU) const override; + void print(raw_ostream &OS, const Module *M = nullptr) const override; + + // TODO: verify analysis +}; + +class MachineUniformityInfoPrinterPass : public MachineFunctionPass { +public: + static char ID; + + MachineUniformityInfoPrinterPass(); + + bool runOnMachineFunction(MachineFunction &F) override; + void getAnalysisUsage(AnalysisUsage &AU) const override; +}; + +} // namespace + +char MachineUniformityAnalysisPass::ID = 0; + +MachineUniformityAnalysisPass::MachineUniformityAnalysisPass() + : MachineFunctionPass(ID) { + initializeMachineUniformityAnalysisPassPass(*PassRegistry::getPassRegistry()); +} + +INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity", + "Machine Uniformity Info Analysis", true, true) +INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree) +INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity", + "Machine Uniformity Info Analysis", true, true) + +void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<MachineCycleInfoWrapperPass>(); + AU.addRequired<MachineDominatorTree>(); + MachineFunctionPass::getAnalysisUsage(AU); +} + +bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) { + auto &DomTree = getAnalysis<MachineDominatorTree>().getBase(); + auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo(); + UI = computeMachineUniformityInfo(MF, CI, DomTree); + return false; +} + +void MachineUniformityAnalysisPass::print(raw_ostream &OS, + const Module *) const { + OS << "MachineUniformityInfo for function: " << UI.getFunction().getName() + << "\n"; + UI.print(OS); +} + +char MachineUniformityInfoPrinterPass::ID = 0; + +MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass() + : MachineFunctionPass(ID) { + initializeMachineUniformityInfoPrinterPassPass( + *PassRegistry::getPassRegistry()); +} + +INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass, + "print-machine-uniformity", + "Print Machine Uniformity Info Analysis", true, true) +INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass) +INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass, + "print-machine-uniformity", + "Print Machine Uniformity Info Analysis", true, true) + +void MachineUniformityInfoPrinterPass::getAnalysisUsage( + AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<MachineUniformityAnalysisPass>(); + MachineFunctionPass::getAnalysisUsage(AU); +} + +bool MachineUniformityInfoPrinterPass::runOnMachineFunction( + MachineFunction &F) { + auto &UI = getAnalysis<MachineUniformityAnalysisPass>(); + UI.print(errs()); + return false; +} |