diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Analysis/LegacyDivergenceAnalysis.cpp')
| -rw-r--r-- | contrib/llvm-project/llvm/lib/Analysis/LegacyDivergenceAnalysis.cpp | 43 |
1 files changed, 31 insertions, 12 deletions
diff --git a/contrib/llvm-project/llvm/lib/Analysis/LegacyDivergenceAnalysis.cpp b/contrib/llvm-project/llvm/lib/Analysis/LegacyDivergenceAnalysis.cpp index 52212e1c42aa..0f274429f11f 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/LegacyDivergenceAnalysis.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/LegacyDivergenceAnalysis.cpp @@ -64,10 +64,10 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Analysis/LegacyDivergenceAnalysis.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/DivergenceAnalysis.h" -#include "llvm/Analysis/LegacyDivergenceAnalysis.h" #include "llvm/Analysis/Passes.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -75,6 +75,8 @@ #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Value.h" +#include "llvm/InitializePasses.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include <vector> @@ -93,8 +95,9 @@ namespace { class DivergencePropagator { public: DivergencePropagator(Function &F, TargetTransformInfo &TTI, DominatorTree &DT, - PostDominatorTree &PDT, DenseSet<const Value *> &DV) - : F(F), TTI(TTI), DT(DT), PDT(PDT), DV(DV) {} + PostDominatorTree &PDT, DenseSet<const Value *> &DV, + DenseSet<const Use *> &DU) + : F(F), TTI(TTI), DT(DT), PDT(PDT), DV(DV), DU(DU) {} void populateWithSourcesOfDivergence(); void propagate(); @@ -118,11 +121,14 @@ private: PostDominatorTree &PDT; std::vector<Value *> Worklist; // Stack for DFS. DenseSet<const Value *> &DV; // Stores all divergent values. + DenseSet<const Use *> &DU; // Stores divergent uses of possibly uniform + // values. }; void DivergencePropagator::populateWithSourcesOfDivergence() { Worklist.clear(); DV.clear(); + DU.clear(); for (auto &I : instructions(F)) { if (TTI.isSourceOfDivergence(&I)) { Worklist.push_back(&I); @@ -197,8 +203,10 @@ void DivergencePropagator::exploreSyncDependency(Instruction *TI) { // dominators of TI until it is outside the influence region. BasicBlock *InfluencedBB = ThisBB; while (InfluenceRegion.count(InfluencedBB)) { - for (auto &I : *InfluencedBB) - findUsersOutsideInfluenceRegion(I, InfluenceRegion); + for (auto &I : *InfluencedBB) { + if (!DV.count(&I)) + findUsersOutsideInfluenceRegion(I, InfluenceRegion); + } DomTreeNode *IDomNode = DT.getNode(InfluencedBB)->getIDom(); if (IDomNode == nullptr) break; @@ -208,9 +216,10 @@ void DivergencePropagator::exploreSyncDependency(Instruction *TI) { void DivergencePropagator::findUsersOutsideInfluenceRegion( Instruction &I, const DenseSet<BasicBlock *> &InfluenceRegion) { - for (User *U : I.users()) { - Instruction *UserInst = cast<Instruction>(U); + for (Use &Use : I.uses()) { + Instruction *UserInst = cast<Instruction>(Use.getUser()); if (!InfluenceRegion.count(UserInst->getParent())) { + DU.insert(&Use); if (DV.insert(UserInst).second) Worklist.push_back(UserInst); } @@ -250,9 +259,8 @@ void DivergencePropagator::computeInfluenceRegion( void DivergencePropagator::exploreDataDependency(Value *V) { // Follow def-use chains of V. for (User *U : V->users()) { - Instruction *UserInst = cast<Instruction>(U); - if (!TTI.isAlwaysUniform(U) && DV.insert(UserInst).second) - Worklist.push_back(UserInst); + if (!TTI.isAlwaysUniform(U) && DV.insert(U).second) + Worklist.push_back(U); } } @@ -275,6 +283,9 @@ void DivergencePropagator::propagate() { // Register this pass. char LegacyDivergenceAnalysis::ID = 0; +LegacyDivergenceAnalysis::LegacyDivergenceAnalysis() : FunctionPass(ID) { + initializeLegacyDivergenceAnalysisPass(*PassRegistry::getPassRegistry()); +} INITIALIZE_PASS_BEGIN(LegacyDivergenceAnalysis, "divergence", "Legacy Divergence Analysis", false, true) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) @@ -320,6 +331,7 @@ bool LegacyDivergenceAnalysis::runOnFunction(Function &F) { return false; DivergentValues.clear(); + DivergentUses.clear(); gpuDA = nullptr; auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); @@ -328,11 +340,11 @@ bool LegacyDivergenceAnalysis::runOnFunction(Function &F) { if (shouldUseGPUDivergenceAnalysis(F)) { // run the new GPU divergence analysis auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - gpuDA = llvm::make_unique<GPUDivergenceAnalysis>(F, DT, PDT, LI, TTI); + gpuDA = std::make_unique<GPUDivergenceAnalysis>(F, DT, PDT, LI, TTI); } else { // run LLVM's existing DivergenceAnalysis - DivergencePropagator DP(F, TTI, DT, PDT, DivergentValues); + DivergencePropagator DP(F, TTI, DT, PDT, DivergentValues, DivergentUses); DP.populateWithSourcesOfDivergence(); DP.propagate(); } @@ -351,6 +363,13 @@ bool LegacyDivergenceAnalysis::isDivergent(const Value *V) const { return DivergentValues.count(V); } +bool LegacyDivergenceAnalysis::isDivergentUse(const Use *U) const { + if (gpuDA) { + return gpuDA->isDivergentUse(*U); + } + return DivergentValues.count(U->get()) || DivergentUses.count(U); +} + void LegacyDivergenceAnalysis::print(raw_ostream &OS, const Module *) const { if ((!gpuDA || !gpuDA->hasDivergence()) && DivergentValues.empty()) return; |
