aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/Analysis/UniformityAnalysis.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2023-09-02 21:17:18 +0000
committerDimitry Andric <dim@FreeBSD.org>2023-12-08 17:34:50 +0000
commit06c3fb2749bda94cb5201f81ffdb8fa6c3161b2e (patch)
tree62f873df87c7c675557a179e0c4c83fe9f3087bc /contrib/llvm-project/llvm/lib/Analysis/UniformityAnalysis.cpp
parentcf037972ea8863e2bab7461d77345367d2c1e054 (diff)
parent7fa27ce4a07f19b07799a767fc29416f3b625afb (diff)
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Analysis/UniformityAnalysis.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/Analysis/UniformityAnalysis.cpp74
1 files changed, 52 insertions, 22 deletions
diff --git a/contrib/llvm-project/llvm/lib/Analysis/UniformityAnalysis.cpp b/contrib/llvm-project/llvm/lib/Analysis/UniformityAnalysis.cpp
index 8ed5af8a8d1c..bf0b194dcd70 100644
--- a/contrib/llvm-project/llvm/lib/Analysis/UniformityAnalysis.cpp
+++ b/contrib/llvm-project/llvm/lib/Analysis/UniformityAnalysis.cpp
@@ -1,4 +1,4 @@
-//===- ConvergenceUtils.cpp -----------------------------------------------===//
+//===- UniformityAnalysis.cpp ---------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -26,18 +26,16 @@ bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs(
template <>
bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
- const Instruction &Instr, bool AllDefsDivergent) {
- return markDivergent(&Instr);
+ const Instruction &Instr) {
+ return markDivergent(cast<Value>(&Instr));
}
template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
for (auto &I : instructions(F)) {
- if (TTI->isSourceOfDivergence(&I)) {
- assert(!I.isTerminator());
+ if (TTI->isSourceOfDivergence(&I))
markDivergent(I);
- } else if (TTI->isAlwaysUniform(&I)) {
+ else if (TTI->isAlwaysUniform(&I))
addUniformOverride(I);
- }
}
for (auto &Arg : F.args()) {
if (TTI->isSourceOfDivergence(&Arg)) {
@@ -50,13 +48,8 @@ template <>
void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
const Value *V) {
for (const auto *User : V->users()) {
- const auto *UserInstr = dyn_cast<const Instruction>(User);
- if (!UserInstr)
- continue;
- if (isAlwaysUniform(*UserInstr))
- continue;
- if (markDivergent(*UserInstr)) {
- Worklist.push_back(UserInstr);
+ if (const auto *UserInstr = dyn_cast<const Instruction>(User)) {
+ markDivergent(*UserInstr);
}
}
}
@@ -73,8 +66,7 @@ void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
template <>
bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle(
const Instruction &I, const Cycle &DefCycle) const {
- if (isAlwaysUniform(I))
- return false;
+ assert(!isAlwaysUniform(I));
for (const Use &U : I.operands()) {
if (auto *I = dyn_cast<Instruction>(&U)) {
if (DefCycle.contains(I->getParent()))
@@ -84,6 +76,33 @@ bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle(
return false;
}
+template <>
+void llvm::GenericUniformityAnalysisImpl<
+ SSAContext>::propagateTemporalDivergence(const Instruction &I,
+ const Cycle &DefCycle) {
+ if (isDivergent(I))
+ return;
+ for (auto *User : I.users()) {
+ auto *UserInstr = cast<Instruction>(User);
+ if (DefCycle.contains(UserInstr->getParent()))
+ continue;
+ markDivergent(*UserInstr);
+ }
+}
+
+template <>
+bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse(
+ const Use &U) const {
+ const auto *V = U.get();
+ if (isDivergent(V))
+ return true;
+ if (const auto *DefInstr = dyn_cast<Instruction>(V)) {
+ const auto *UseInstr = cast<Instruction>(U.getUser());
+ return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
+ }
+ return false;
+}
+
// This ensures explicit instantiation of
// GenericUniformityAnalysisImpl::ImplDeleter::operator()
template class llvm::GenericUniformityInfo<SSAContext>;
@@ -99,7 +118,12 @@ llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
auto &CI = FAM.getResult<CycleAnalysis>(F);
- return UniformityInfo{F, DT, CI, &TTI};
+ UniformityInfo UI{F, DT, CI, &TTI};
+ // Skip computation if we can assume everything is uniform.
+ if (TTI.hasBranchDivergence(&F))
+ UI.compute();
+
+ return UI;
}
AnalysisKey UniformityInfoAnalysis::Key;
@@ -125,17 +149,18 @@ UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {
initializeUniformityInfoWrapperPassPass(*PassRegistry::getPassRegistry());
}
-INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniforminfo",
- "Uniform Info Analysis", true, true)
+INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity",
+ "Uniformity Analysis", true, true)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
-INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniforminfo",
- "Uniform Info Analysis", true, true)
+INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
+ "Uniformity Analysis", true, true)
void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesAll();
AU.addRequired<DominatorTreeWrapperPass>();
- AU.addRequired<CycleInfoWrapperPass>();
+ AU.addRequiredTransitive<CycleInfoWrapperPass>();
AU.addRequired<TargetTransformInfoWrapperPass>();
}
@@ -148,6 +173,11 @@ bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
m_function = &F;
m_uniformityInfo =
UniformityInfo{F, domTree, cycleInfo, &targetTransformInfo};
+
+ // Skip computation if we can assume everything is uniform.
+ if (targetTransformInfo.hasBranchDivergence(m_function))
+ m_uniformityInfo.compute();
+
return false;
}