aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/IPO/OpenMPOpt.cpp')
-rw-r--r--llvm/lib/Transforms/IPO/OpenMPOpt.cpp79
1 files changed, 66 insertions, 13 deletions
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index bee154dab10f..eb499a1aa912 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -188,9 +188,9 @@ struct AAICVTracker;
struct OMPInformationCache : public InformationCache {
OMPInformationCache(Module &M, AnalysisGetter &AG,
BumpPtrAllocator &Allocator, SetVector<Function *> *CGSCC,
- KernelSet &Kernels)
+ KernelSet &Kernels, bool OpenMPPostLink)
: InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
- Kernels(Kernels) {
+ Kernels(Kernels), OpenMPPostLink(OpenMPPostLink) {
OMPBuilder.initialize();
initializeRuntimeFunctions(M);
@@ -448,6 +448,24 @@ struct OMPInformationCache : public InformationCache {
CI->setCallingConv(Fn->getCallingConv());
}
+ // Helper function to determine if it's legal to create a call to the runtime
+ // functions.
+ bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
+ // We can always emit calls if we haven't yet linked in the runtime.
+ if (!OpenMPPostLink)
+ return true;
+
+ // Once the runtime has been already been linked in we cannot emit calls to
+ // any undefined functions.
+ for (RuntimeFunction Fn : Fns) {
+ RuntimeFunctionInfo &RFI = RFIs[Fn];
+
+ if (RFI.Declaration && RFI.Declaration->isDeclaration())
+ return false;
+ }
+ return true;
+ }
+
/// Helper to initialize all runtime function information for those defined
/// in OpenMPKinds.def.
void initializeRuntimeFunctions(Module &M) {
@@ -523,6 +541,9 @@ struct OMPInformationCache : public InformationCache {
/// Collection of known OpenMP runtime functions..
DenseSet<const Function *> RTLFunctions;
+
+ /// Indicates if we have already linked in the OpenMP device library.
+ bool OpenMPPostLink = false;
};
template <typename Ty, bool InsertInvalidates = true>
@@ -1412,7 +1433,10 @@ private:
Changed |= WasSplit;
return WasSplit;
};
- RFI.foreachUse(SCC, SplitMemTransfers);
+ if (OMPInfoCache.runtimeFnsAvailable(
+ {OMPRTL___tgt_target_data_begin_mapper_issue,
+ OMPRTL___tgt_target_data_begin_mapper_wait}))
+ RFI.foreachUse(SCC, SplitMemTransfers);
return Changed;
}
@@ -2656,7 +2680,9 @@ struct AAExecutionDomainFunction : public AAExecutionDomain {
bool isExecutedInAlignedRegion(Attributor &A,
const Instruction &I) const override {
- if (!isValidState() || isa<CallBase>(I))
+ assert(I.getFunction() == getAnchorScope() &&
+ "Instruction is out of scope!");
+ if (!isValidState())
return false;
const Instruction *CurI;
@@ -2667,14 +2693,18 @@ struct AAExecutionDomainFunction : public AAExecutionDomain {
auto *CB = dyn_cast<CallBase>(CurI);
if (!CB)
continue;
+ if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB))) {
+ break;
+ }
const auto &It = CEDMap.find(CB);
if (It == CEDMap.end())
continue;
- if (!It->getSecond().IsReachedFromAlignedBarrierOnly)
+ if (!It->getSecond().IsReachingAlignedBarrierOnly)
return false;
+ break;
} while ((CurI = CurI->getNextNonDebugInstruction()));
- if (!CurI && !BEDMap.lookup(I.getParent()).IsReachedFromAlignedBarrierOnly)
+ if (!CurI && !BEDMap.lookup(I.getParent()).IsReachingAlignedBarrierOnly)
return false;
// Check backward until a call or the block beginning is reached.
@@ -2683,12 +2713,16 @@ struct AAExecutionDomainFunction : public AAExecutionDomain {
auto *CB = dyn_cast<CallBase>(CurI);
if (!CB)
continue;
+ if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB))) {
+ break;
+ }
const auto &It = CEDMap.find(CB);
if (It == CEDMap.end())
continue;
if (!AA::isNoSyncInst(A, *CB, *this)) {
- if (It->getSecond().IsReachedFromAlignedBarrierOnly)
+ if (It->getSecond().IsReachedFromAlignedBarrierOnly) {
break;
+ }
return false;
}
@@ -2984,7 +3018,8 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
if (EDAA.getState().isValidState()) {
const auto &CalleeED = EDAA.getFunctionExecutionDomain();
ED.IsReachedFromAlignedBarrierOnly =
- CalleeED.IsReachedFromAlignedBarrierOnly;
+ CallED.IsReachedFromAlignedBarrierOnly =
+ CalleeED.IsReachedFromAlignedBarrierOnly;
AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
ED.EncounteredNonLocalSideEffect |=
@@ -2999,8 +3034,9 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
continue;
}
}
- ED.IsReachedFromAlignedBarrierOnly =
- IsNoSync && ED.IsReachedFromAlignedBarrierOnly;
+ if (!IsNoSync)
+ ED.IsReachedFromAlignedBarrierOnly =
+ CallED.IsReachedFromAlignedBarrierOnly = false;
AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory();
if (!IsNoSync)
@@ -3914,6 +3950,12 @@ struct AAKernelInfoFunction : AAKernelInfo {
bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
+ // We cannot change to SPMD mode if the runtime functions aren't availible.
+ if (!OMPInfoCache.runtimeFnsAvailable(
+ {OMPRTL___kmpc_get_hardware_thread_id_in_block,
+ OMPRTL___kmpc_barrier_simple_spmd}))
+ return false;
+
if (!SPMDCompatibilityTracker.isAssumed()) {
for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
if (!NonCompatibleI)
@@ -4021,6 +4063,13 @@ struct AAKernelInfoFunction : AAKernelInfo {
if (!ReachedKnownParallelRegions.isValidState())
return ChangeStatus::UNCHANGED;
+ auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
+ if (!OMPInfoCache.runtimeFnsAvailable(
+ {OMPRTL___kmpc_get_hardware_num_threads_in_block,
+ OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
+ OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
+ return ChangeStatus::UNCHANGED;
+
const int InitModeArgNo = 1;
const int InitUseStateMachineArgNo = 2;
@@ -4167,7 +4216,6 @@ struct AAKernelInfoFunction : AAKernelInfo {
BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);
Module &M = *Kernel->getParent();
- auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
FunctionCallee BlockHwSizeFn =
OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
@@ -5343,7 +5391,10 @@ PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
BumpPtrAllocator Allocator;
CallGraphUpdater CGUpdater;
- OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, Kernels);
+ bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
+ LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;
+ OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, Kernels,
+ PostLink);
unsigned MaxFixpointIterations =
(isOpenMPDevice(M)) ? SetFixpointIterations : 32;
@@ -5417,9 +5468,11 @@ PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
CallGraphUpdater CGUpdater;
CGUpdater.initialize(CG, C, AM, UR);
+ bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
+ LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;
SetVector<Function *> Functions(SCC.begin(), SCC.end());
OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
- /*CGSCC*/ &Functions, Kernels);
+ /*CGSCC*/ &Functions, Kernels, PostLink);
unsigned MaxFixpointIterations =
(isOpenMPDevice(M)) ? SetFixpointIterations : 32;