aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis/UniformityAnalysis.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Analysis/UniformityAnalysis.cpp')
-rw-r--r--llvm/lib/Analysis/UniformityAnalysis.cpp161
1 files changed, 161 insertions, 0 deletions
diff --git a/llvm/lib/Analysis/UniformityAnalysis.cpp b/llvm/lib/Analysis/UniformityAnalysis.cpp
new file mode 100644
index 000000000000..8ed5af8a8d1c
--- /dev/null
+++ b/llvm/lib/Analysis/UniformityAnalysis.cpp
@@ -0,0 +1,161 @@
+//===- ConvergenceUtils.cpp -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/UniformityAnalysis.h"
+#include "llvm/ADT/GenericUniformityImpl.h"
+#include "llvm/Analysis/CycleAnalysis.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/InitializePasses.h"
+
+using namespace llvm;
+
+template <>
+bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs(
+ const Instruction &I) const {
+ return isDivergent((const Value *)&I);
+}
+
+template <>
+bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
+ const Instruction &Instr, bool AllDefsDivergent) {
+ return markDivergent(&Instr);
+}
+
+template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
+ for (auto &I : instructions(F)) {
+ if (TTI->isSourceOfDivergence(&I)) {
+ assert(!I.isTerminator());
+ markDivergent(I);
+ } else if (TTI->isAlwaysUniform(&I)) {
+ addUniformOverride(I);
+ }
+ }
+ for (auto &Arg : F.args()) {
+ if (TTI->isSourceOfDivergence(&Arg)) {
+ markDivergent(&Arg);
+ }
+ }
+}
+
+template <>
+void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
+ const Value *V) {
+ for (const auto *User : V->users()) {
+ const auto *UserInstr = dyn_cast<const Instruction>(User);
+ if (!UserInstr)
+ continue;
+ if (isAlwaysUniform(*UserInstr))
+ continue;
+ if (markDivergent(*UserInstr)) {
+ Worklist.push_back(UserInstr);
+ }
+ }
+}
+
+template <>
+void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
+ const Instruction &Instr) {
+ assert(!isAlwaysUniform(Instr));
+ if (Instr.isTerminator())
+ return;
+ pushUsers(cast<Value>(&Instr));
+}
+
+template <>
+bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle(
+ const Instruction &I, const Cycle &DefCycle) const {
+ if (isAlwaysUniform(I))
+ return false;
+ for (const Use &U : I.operands()) {
+ if (auto *I = dyn_cast<Instruction>(&U)) {
+ if (DefCycle.contains(I->getParent()))
+ return true;
+ }
+ }
+ return false;
+}
+
+// This ensures explicit instantiation of
+// GenericUniformityAnalysisImpl::ImplDeleter::operator()
+template class llvm::GenericUniformityInfo<SSAContext>;
+template struct llvm::GenericUniformityAnalysisImplDeleter<
+ llvm::GenericUniformityAnalysisImpl<SSAContext>>;
+
+//===----------------------------------------------------------------------===//
+// UniformityInfoAnalysis and related pass implementations
+//===----------------------------------------------------------------------===//
+
+llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
+ FunctionAnalysisManager &FAM) {
+ auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
+ auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
+ auto &CI = FAM.getResult<CycleAnalysis>(F);
+ return UniformityInfo{F, DT, CI, &TTI};
+}
+
+AnalysisKey UniformityInfoAnalysis::Key;
+
+UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS)
+ : OS(OS) {}
+
+PreservedAnalyses UniformityInfoPrinterPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ OS << "UniformityInfo for function '" << F.getName() << "':\n";
+ AM.getResult<UniformityInfoAnalysis>(F).print(OS);
+
+ return PreservedAnalyses::all();
+}
+
+//===----------------------------------------------------------------------===//
+// UniformityInfoWrapperPass Implementation
+//===----------------------------------------------------------------------===//
+
+char UniformityInfoWrapperPass::ID = 0;
+
+UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {
+ initializeUniformityInfoWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniforminfo",
+ "Uniform Info Analysis", true, true)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
+INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniforminfo",
+ "Uniform Info Analysis", true, true)
+
+void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addRequired<CycleInfoWrapperPass>();
+ AU.addRequired<TargetTransformInfoWrapperPass>();
+}
+
+bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
+ auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
+ auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ auto &targetTransformInfo =
+ getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+
+ m_function = &F;
+ m_uniformityInfo =
+ UniformityInfo{F, domTree, cycleInfo, &targetTransformInfo};
+ return false;
+}
+
+void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const {
+ OS << "UniformityInfo for function '" << m_function->getName() << "':\n";
+}
+
+void UniformityInfoWrapperPass::releaseMemory() {
+ m_uniformityInfo = UniformityInfo{};
+ m_function = nullptr;
+}