diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2023-09-02 21:17:18 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2023-12-08 17:34:50 +0000 |
commit | 06c3fb2749bda94cb5201f81ffdb8fa6c3161b2e (patch) | |
tree | 62f873df87c7c675557a179e0c4c83fe9f3087bc /contrib/llvm-project/llvm/lib/Analysis/UniformityAnalysis.cpp | |
parent | cf037972ea8863e2bab7461d77345367d2c1e054 (diff) | |
parent | 7fa27ce4a07f19b07799a767fc29416f3b625afb (diff) |
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Analysis/UniformityAnalysis.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/Analysis/UniformityAnalysis.cpp | 74 |
1 files changed, 52 insertions, 22 deletions
diff --git a/contrib/llvm-project/llvm/lib/Analysis/UniformityAnalysis.cpp b/contrib/llvm-project/llvm/lib/Analysis/UniformityAnalysis.cpp index 8ed5af8a8d1c..bf0b194dcd70 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/UniformityAnalysis.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/UniformityAnalysis.cpp @@ -1,4 +1,4 @@ -//===- ConvergenceUtils.cpp -----------------------------------------------===// +//===- UniformityAnalysis.cpp ---------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -26,18 +26,16 @@ bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs( template <> bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent( - const Instruction &Instr, bool AllDefsDivergent) { - return markDivergent(&Instr); + const Instruction &Instr) { + return markDivergent(cast<Value>(&Instr)); } template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() { for (auto &I : instructions(F)) { - if (TTI->isSourceOfDivergence(&I)) { - assert(!I.isTerminator()); + if (TTI->isSourceOfDivergence(&I)) markDivergent(I); - } else if (TTI->isAlwaysUniform(&I)) { + else if (TTI->isAlwaysUniform(&I)) addUniformOverride(I); - } } for (auto &Arg : F.args()) { if (TTI->isSourceOfDivergence(&Arg)) { @@ -50,13 +48,8 @@ template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers( const Value *V) { for (const auto *User : V->users()) { - const auto *UserInstr = dyn_cast<const Instruction>(User); - if (!UserInstr) - continue; - if (isAlwaysUniform(*UserInstr)) - continue; - if (markDivergent(*UserInstr)) { - Worklist.push_back(UserInstr); + if (const auto *UserInstr = dyn_cast<const Instruction>(User)) { + markDivergent(*UserInstr); } } } @@ -73,8 +66,7 @@ void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers( template <> bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle( const Instruction &I, const Cycle &DefCycle) const { - if (isAlwaysUniform(I)) - return false; + assert(!isAlwaysUniform(I)); for (const Use &U : I.operands()) { if (auto *I = dyn_cast<Instruction>(&U)) { if (DefCycle.contains(I->getParent())) @@ -84,6 +76,33 @@ bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle( return false; } +template <> +void llvm::GenericUniformityAnalysisImpl< + SSAContext>::propagateTemporalDivergence(const Instruction &I, + const Cycle &DefCycle) { + if (isDivergent(I)) + return; + for (auto *User : I.users()) { + auto *UserInstr = cast<Instruction>(User); + if (DefCycle.contains(UserInstr->getParent())) + continue; + markDivergent(*UserInstr); + } +} + +template <> +bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse( + const Use &U) const { + const auto *V = U.get(); + if (isDivergent(V)) + return true; + if (const auto *DefInstr = dyn_cast<Instruction>(V)) { + const auto *UseInstr = cast<Instruction>(U.getUser()); + return isTemporalDivergent(*UseInstr->getParent(), *DefInstr); + } + return false; +} + // This ensures explicit instantiation of // GenericUniformityAnalysisImpl::ImplDeleter::operator() template class llvm::GenericUniformityInfo<SSAContext>; @@ -99,7 +118,12 @@ llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F, auto &DT = FAM.getResult<DominatorTreeAnalysis>(F); auto &TTI = FAM.getResult<TargetIRAnalysis>(F); auto &CI = FAM.getResult<CycleAnalysis>(F); - return UniformityInfo{F, DT, CI, &TTI}; + UniformityInfo UI{F, DT, CI, &TTI}; + // Skip computation if we can assume everything is uniform. + if (TTI.hasBranchDivergence(&F)) + UI.compute(); + + return UI; } AnalysisKey UniformityInfoAnalysis::Key; @@ -125,17 +149,18 @@ UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) { initializeUniformityInfoWrapperPassPass(*PassRegistry::getPassRegistry()); } -INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniforminfo", - "Uniform Info Analysis", true, true) +INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity", + "Uniformity Analysis", true, true) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniforminfo", - "Uniform Info Analysis", true, true) +INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity", + "Uniformity Analysis", true, true) void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<CycleInfoWrapperPass>(); + AU.addRequiredTransitive<CycleInfoWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); } @@ -148,6 +173,11 @@ bool UniformityInfoWrapperPass::runOnFunction(Function &F) { m_function = &F; m_uniformityInfo = UniformityInfo{F, domTree, cycleInfo, &targetTransformInfo}; + + // Skip computation if we can assume everything is uniform. + if (targetTransformInfo.hasBranchDivergence(m_function)) + m_uniformityInfo.compute(); + return false; } |