aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp37
1 files changed, 32 insertions, 5 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index f342c35fa283..055ee6b50296 100644
--- a/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -1885,6 +1885,7 @@ private:
OMPRTL___kmpc_barrier_simple_generic);
ExternalizationRAII ThreadId(OMPInfoCache,
OMPRTL___kmpc_get_hardware_thread_id_in_block);
+ ExternalizationRAII WarpSize(OMPInfoCache, OMPRTL___kmpc_get_warp_size);
registerAAs(IsModulePass);
@@ -3727,12 +3728,37 @@ struct AAKernelInfoFunction : AAKernelInfo {
CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+ bool UsedAssumedInformationFromReachingKernels = false;
if (!IsKernelEntry) {
- updateReachingKernelEntries(A);
updateParallelLevels(A);
+ bool AllReachingKernelsKnown = true;
+ updateReachingKernelEntries(A, AllReachingKernelsKnown);
+ UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
+
if (!ParallelLevels.isValidState())
SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+ else if (!ReachingKernelEntries.isValidState())
+ SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+ else if (!SPMDCompatibilityTracker.empty()) {
+ // Check if all reaching kernels agree on the mode as we can otherwise
+ // not guard instructions. We might not be sure about the mode so we
+ // we cannot fix the internal spmd-zation state either.
+ int SPMD = 0, Generic = 0;
+ for (auto *Kernel : ReachingKernelEntries) {
+ auto &CBAA = A.getAAFor<AAKernelInfo>(
+ *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
+ if (CBAA.SPMDCompatibilityTracker.isValidState() &&
+ CBAA.SPMDCompatibilityTracker.isAssumed())
+ ++SPMD;
+ else
+ ++Generic;
+ if (!CBAA.SPMDCompatibilityTracker.isAtFixpoint())
+ UsedAssumedInformationFromReachingKernels = true;
+ }
+ if (SPMD != 0 && Generic != 0)
+ SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+ }
}
// Callback to check a call instruction.
@@ -3779,7 +3805,8 @@ struct AAKernelInfoFunction : AAKernelInfo {
// If we haven't used any assumed information for the SPMD state we can fix
// it.
if (!UsedAssumedInformationInCheckRWInst &&
- !UsedAssumedInformationInCheckCallInst && AllSPMDStatesWereFixed)
+ !UsedAssumedInformationInCheckCallInst &&
+ !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
SPMDCompatibilityTracker.indicateOptimisticFixpoint();
return StateBefore == getState() ? ChangeStatus::UNCHANGED
@@ -3788,7 +3815,8 @@ struct AAKernelInfoFunction : AAKernelInfo {
private:
/// Update info regarding reaching kernels.
- void updateReachingKernelEntries(Attributor &A) {
+ void updateReachingKernelEntries(Attributor &A,
+ bool &AllReachingKernelsKnown) {
auto PredCallSite = [&](AbstractCallSite ACS) {
Function *Caller = ACS.getInstruction()->getFunction();
@@ -3808,10 +3836,9 @@ private:
return true;
};
- bool AllCallSitesKnown;
if (!A.checkForAllCallSites(PredCallSite, *this,
true /* RequireAllCallSites */,
- AllCallSitesKnown))
+ AllReachingKernelsKnown))
ReachingKernelEntries.indicatePessimisticFixpoint();
}