diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp')
| -rw-r--r-- | contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp | 37 |
1 files changed, 32 insertions, 5 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp index f342c35fa283..055ee6b50296 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -1885,6 +1885,7 @@ private: OMPRTL___kmpc_barrier_simple_generic); ExternalizationRAII ThreadId(OMPInfoCache, OMPRTL___kmpc_get_hardware_thread_id_in_block); + ExternalizationRAII WarpSize(OMPInfoCache, OMPRTL___kmpc_get_warp_size); registerAAs(IsModulePass); @@ -3727,12 +3728,37 @@ struct AAKernelInfoFunction : AAKernelInfo { CheckRWInst, *this, UsedAssumedInformationInCheckRWInst)) SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + bool UsedAssumedInformationFromReachingKernels = false; if (!IsKernelEntry) { - updateReachingKernelEntries(A); updateParallelLevels(A); + bool AllReachingKernelsKnown = true; + updateReachingKernelEntries(A, AllReachingKernelsKnown); + UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown; + if (!ParallelLevels.isValidState()) SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + else if (!ReachingKernelEntries.isValidState()) + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + else if (!SPMDCompatibilityTracker.empty()) { + // Check if all reaching kernels agree on the mode as we can otherwise + // not guard instructions. We might not be sure about the mode so we + // we cannot fix the internal spmd-zation state either. + int SPMD = 0, Generic = 0; + for (auto *Kernel : ReachingKernelEntries) { + auto &CBAA = A.getAAFor<AAKernelInfo>( + *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL); + if (CBAA.SPMDCompatibilityTracker.isValidState() && + CBAA.SPMDCompatibilityTracker.isAssumed()) + ++SPMD; + else + ++Generic; + if (!CBAA.SPMDCompatibilityTracker.isAtFixpoint()) + UsedAssumedInformationFromReachingKernels = true; + } + if (SPMD != 0 && Generic != 0) + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + } } // Callback to check a call instruction. @@ -3779,7 +3805,8 @@ struct AAKernelInfoFunction : AAKernelInfo { // If we haven't used any assumed information for the SPMD state we can fix // it. if (!UsedAssumedInformationInCheckRWInst && - !UsedAssumedInformationInCheckCallInst && AllSPMDStatesWereFixed) + !UsedAssumedInformationInCheckCallInst && + !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed) SPMDCompatibilityTracker.indicateOptimisticFixpoint(); return StateBefore == getState() ? ChangeStatus::UNCHANGED @@ -3788,7 +3815,8 @@ struct AAKernelInfoFunction : AAKernelInfo { private: /// Update info regarding reaching kernels. - void updateReachingKernelEntries(Attributor &A) { + void updateReachingKernelEntries(Attributor &A, + bool &AllReachingKernelsKnown) { auto PredCallSite = [&](AbstractCallSite ACS) { Function *Caller = ACS.getInstruction()->getFunction(); @@ -3808,10 +3836,9 @@ private: return true; }; - bool AllCallSitesKnown; if (!A.checkForAllCallSites(PredCallSite, *this, true /* RequireAllCallSites */, - AllCallSitesKnown)) + AllReachingKernelsKnown)) ReachingKernelEntries.indicatePessimisticFixpoint(); } |
