diff options
Diffstat (limited to 'llvm/lib/Transforms/IPO/OpenMPOpt.cpp')
| -rw-r--r-- | llvm/lib/Transforms/IPO/OpenMPOpt.cpp | 924 |
1 files changed, 625 insertions, 299 deletions
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp index 588f3901e3cb..b2665161c090 100644 --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -33,6 +33,7 @@ #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" +#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h" #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" #include "llvm/IR/Assumptions.h" #include "llvm/IR/BasicBlock.h" @@ -42,6 +43,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" @@ -156,6 +158,8 @@ STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified, "Number of OpenMP runtime function uses identified"); STATISTIC(NumOpenMPTargetRegionKernels, "Number of OpenMP target region entry points (=kernels) identified"); +STATISTIC(NumNonOpenMPTargetRegionKernels, + "Number of non-OpenMP target region kernels identified"); STATISTIC(NumOpenMPTargetRegionKernelsSPMD, "Number of OpenMP target region entry points (=kernels) executed in " "SPMD-mode instead of generic-mode"); @@ -181,6 +185,92 @@ STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated"); static constexpr auto TAG = "[" DEBUG_TYPE "]"; #endif +namespace KernelInfo { + +// struct ConfigurationEnvironmentTy { +// uint8_t UseGenericStateMachine; +// uint8_t MayUseNestedParallelism; +// llvm::omp::OMPTgtExecModeFlags ExecMode; +// int32_t MinThreads; +// int32_t MaxThreads; +// int32_t MinTeams; +// int32_t MaxTeams; +// }; + +// struct DynamicEnvironmentTy { +// uint16_t DebugIndentionLevel; +// }; + +// struct KernelEnvironmentTy { +// ConfigurationEnvironmentTy Configuration; +// IdentTy *Ident; +// DynamicEnvironmentTy *DynamicEnv; +// }; + +#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX) \ + constexpr const unsigned MEMBER##Idx = IDX; + +KERNEL_ENVIRONMENT_IDX(Configuration, 0) +KERNEL_ENVIRONMENT_IDX(Ident, 1) + +#undef KERNEL_ENVIRONMENT_IDX + +#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX) \ + constexpr const unsigned MEMBER##Idx = IDX; + +KERNEL_ENVIRONMENT_CONFIGURATION_IDX(UseGenericStateMachine, 0) +KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MayUseNestedParallelism, 1) +KERNEL_ENVIRONMENT_CONFIGURATION_IDX(ExecMode, 2) +KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MinThreads, 3) +KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MaxThreads, 4) +KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MinTeams, 5) +KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MaxTeams, 6) + +#undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX + +#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE) \ + RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \ + return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx)); \ + } + +KERNEL_ENVIRONMENT_GETTER(Ident, Constant) +KERNEL_ENVIRONMENT_GETTER(Configuration, ConstantStruct) + +#undef KERNEL_ENVIRONMENT_GETTER + +#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER) \ + ConstantInt *get##MEMBER##FromKernelEnvironment( \ + ConstantStruct *KernelEnvC) { \ + ConstantStruct *ConfigC = \ + getConfigurationFromKernelEnvironment(KernelEnvC); \ + return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx)); \ + } + +KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(UseGenericStateMachine) +KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MayUseNestedParallelism) +KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(ExecMode) +KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MinThreads) +KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MaxThreads) +KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MinTeams) +KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MaxTeams) + +#undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER + +GlobalVariable * +getKernelEnvironementGVFromKernelInitCB(CallBase *KernelInitCB) { + constexpr const int InitKernelEnvironmentArgNo = 0; + return cast<GlobalVariable>( + KernelInitCB->getArgOperand(InitKernelEnvironmentArgNo) + ->stripPointerCasts()); +} + +ConstantStruct *getKernelEnvironementFromKernelInitCB(CallBase *KernelInitCB) { + GlobalVariable *KernelEnvGV = + getKernelEnvironementGVFromKernelInitCB(KernelInitCB); + return cast<ConstantStruct>(KernelEnvGV->getInitializer()); +} +} // namespace KernelInfo + namespace { struct AAHeapToShared; @@ -196,6 +286,7 @@ struct OMPInformationCache : public InformationCache { : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M), OpenMPPostLink(OpenMPPostLink) { + OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(OMPBuilder.M); OMPBuilder.initialize(); initializeRuntimeFunctions(M); initializeInternalControlVars(); @@ -531,7 +622,7 @@ struct OMPInformationCache : public InformationCache { for (Function &F : M) { for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"}) if (F.hasFnAttribute(Attribute::NoInline) && - F.getName().startswith(Prefix) && + F.getName().starts_with(Prefix) && !F.hasFnAttribute(Attribute::OptimizeNone)) F.removeFnAttr(Attribute::NoInline); } @@ -595,7 +686,7 @@ struct KernelInfoState : AbstractState { /// The parallel regions (identified by the outlined parallel functions) that /// can be reached from the associated function. - BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false> + BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false> ReachedKnownParallelRegions; /// State to track what parallel region we might reach. @@ -610,6 +701,10 @@ struct KernelInfoState : AbstractState { /// one we abort as the kernel is malformed. CallBase *KernelInitCB = nullptr; + /// The constant kernel environement as taken from and passed to + /// __kmpc_target_init. + ConstantStruct *KernelEnvC = nullptr; + /// The __kmpc_target_deinit call in this kernel, if any. If we find more than /// one we abort as the kernel is malformed. CallBase *KernelDeinitCB = nullptr; @@ -651,6 +746,7 @@ struct KernelInfoState : AbstractState { SPMDCompatibilityTracker.indicatePessimisticFixpoint(); ReachedKnownParallelRegions.indicatePessimisticFixpoint(); ReachedUnknownParallelRegions.indicatePessimisticFixpoint(); + NestedParallelism = true; return ChangeStatus::CHANGED; } @@ -680,6 +776,8 @@ struct KernelInfoState : AbstractState { return false; if (ParallelLevels != RHS.ParallelLevels) return false; + if (NestedParallelism != RHS.NestedParallelism) + return false; return true; } @@ -714,6 +812,12 @@ struct KernelInfoState : AbstractState { "assumptions."); KernelDeinitCB = KIS.KernelDeinitCB; } + if (KIS.KernelEnvC) { + if (KernelEnvC && KernelEnvC != KIS.KernelEnvC) + llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt " + "assumptions."); + KernelEnvC = KIS.KernelEnvC; + } SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker; ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions; ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions; @@ -875,6 +979,9 @@ struct OpenMPOpt { } } + if (OMPInfoCache.OpenMPPostLink) + Changed |= removeRuntimeSymbols(); + return Changed; } @@ -903,7 +1010,7 @@ struct OpenMPOpt { /// Print OpenMP GPU kernels for testing. void printKernels() const { for (Function *F : SCC) { - if (!omp::isKernel(*F)) + if (!omp::isOpenMPKernel(*F)) continue; auto Remark = [&](OptimizationRemarkAnalysis ORA) { @@ -1404,6 +1511,37 @@ private: return Changed; } + /// Tries to remove known runtime symbols that are optional from the module. + bool removeRuntimeSymbols() { + // The RPC client symbol is defined in `libc` and indicates that something + // required an RPC server. If its users were all optimized out then we can + // safely remove it. + // TODO: This should be somewhere more common in the future. + if (GlobalVariable *GV = M.getNamedGlobal("__llvm_libc_rpc_client")) { + if (!GV->getType()->isPointerTy()) + return false; + + Constant *C = GV->getInitializer(); + if (!C) + return false; + + // Check to see if the only user of the RPC client is the external handle. + GlobalVariable *Client = dyn_cast<GlobalVariable>(C->stripPointerCasts()); + if (!Client || Client->getNumUses() > 1 || + Client->user_back() != GV->getInitializer()) + return false; + + Client->replaceAllUsesWith(PoisonValue::get(Client->getType())); + Client->eraseFromParent(); + + GV->replaceAllUsesWith(PoisonValue::get(GV->getType())); + GV->eraseFromParent(); + + return true; + } + return false; + } + /// Tries to hide the latency of runtime calls that involve host to /// device memory transfers by splitting them into their "issue" and "wait" /// versions. The "issue" is moved upwards as much as possible. The "wait" is @@ -1858,7 +1996,7 @@ private: Function *F = I->getParent()->getParent(); auto &ORE = OREGetter(F); - if (RemarkName.startswith("OMP")) + if (RemarkName.starts_with("OMP")) ORE.emit([&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)) << " [" << RemarkName << "]"; @@ -1874,7 +2012,7 @@ private: RemarkCallBack &&RemarkCB) const { auto &ORE = OREGetter(F); - if (RemarkName.startswith("OMP")) + if (RemarkName.starts_with("OMP")) ORE.emit([&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)) << " [" << RemarkName << "]"; @@ -1944,7 +2082,7 @@ Kernel OpenMPOpt::getUniqueKernelFor(Function &F) { // TODO: We should use an AA to create an (optimistic and callback // call-aware) call graph. For now we stick to simple patterns that // are less powerful, basically the worst fixpoint. - if (isKernel(F)) { + if (isOpenMPKernel(F)) { CachedKernel = Kernel(&F); return *CachedKernel; } @@ -2535,6 +2673,17 @@ struct AAICVTrackerCallSiteReturned : AAICVTracker { } }; +/// Determines if \p BB exits the function unconditionally itself or reaches a +/// block that does through only unique successors. +static bool hasFunctionEndAsUniqueSuccessor(const BasicBlock *BB) { + if (succ_empty(BB)) + return true; + const BasicBlock *const Successor = BB->getUniqueSuccessor(); + if (!Successor) + return false; + return hasFunctionEndAsUniqueSuccessor(Successor); +} + struct AAExecutionDomainFunction : public AAExecutionDomain { AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A) : AAExecutionDomain(IRP, A) {} @@ -2587,18 +2736,22 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { if (!ED.IsReachedFromAlignedBarrierOnly || ED.EncounteredNonLocalSideEffect) return; + if (!ED.EncounteredAssumes.empty() && !A.isModulePass()) + return; - // We can remove this barrier, if it is one, or all aligned barriers - // reaching the kernel end. In the latter case we can transitively work - // our way back until we find a barrier that guards a side-effect if we - // are dealing with the kernel end here. + // We can remove this barrier, if it is one, or aligned barriers reaching + // the kernel end (if CB is nullptr). Aligned barriers reaching the kernel + // end should only be removed if the kernel end is their unique successor; + // otherwise, they may have side-effects that aren't accounted for in the + // kernel end in their other successors. If those barriers have other + // barriers reaching them, those can be transitively removed as well as + // long as the kernel end is also their unique successor. if (CB) { DeletedBarriers.insert(CB); A.deleteAfterManifest(*CB); ++NumBarriersEliminated; Changed = ChangeStatus::CHANGED; } else if (!ED.AlignedBarriers.empty()) { - NumBarriersEliminated += ED.AlignedBarriers.size(); Changed = ChangeStatus::CHANGED; SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(), ED.AlignedBarriers.end()); @@ -2609,7 +2762,10 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { continue; if (LastCB->getFunction() != getAnchorScope()) continue; + if (!hasFunctionEndAsUniqueSuccessor(LastCB->getParent())) + continue; if (!DeletedBarriers.count(LastCB)) { + ++NumBarriersEliminated; A.deleteAfterManifest(*LastCB); continue; } @@ -2633,7 +2789,7 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { HandleAlignedBarrier(CB); // Handle the "kernel end barrier" for kernels too. - if (omp::isKernel(*getAnchorScope())) + if (omp::isOpenMPKernel(*getAnchorScope())) HandleAlignedBarrier(nullptr); return Changed; @@ -2779,9 +2935,11 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr; if (!CB) return false; - const int InitModeArgNo = 1; - auto *ModeCI = dyn_cast<ConstantInt>(CB->getOperand(InitModeArgNo)); - return ModeCI && (ModeCI->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC); + ConstantStruct *KernelEnvC = + KernelInfo::getKernelEnvironementFromKernelInitCB(CB); + ConstantInt *ExecModeC = + KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC); + return ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC; } if (C->isZero()) { @@ -2884,11 +3042,11 @@ bool AAExecutionDomainFunction::handleCallees(Attributor &A, } else { // We could not find all predecessors, so this is either a kernel or a // function with external linkage (or with some other weird uses). - if (omp::isKernel(*getAnchorScope())) { + if (omp::isOpenMPKernel(*getAnchorScope())) { EntryBBED.IsExecutedByInitialThreadOnly = false; EntryBBED.IsReachedFromAlignedBarrierOnly = true; EntryBBED.EncounteredNonLocalSideEffect = false; - ExitED.IsReachingAlignedBarrierOnly = true; + ExitED.IsReachingAlignedBarrierOnly = false; } else { EntryBBED.IsExecutedByInitialThreadOnly = false; EntryBBED.IsReachedFromAlignedBarrierOnly = false; @@ -2938,7 +3096,7 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { Function *F = getAnchorScope(); BasicBlock &EntryBB = F->getEntryBlock(); - bool IsKernel = omp::isKernel(*F); + bool IsKernel = omp::isOpenMPKernel(*F); SmallVector<Instruction *> SyncInstWorklist; for (auto &RIt : *RPOT) { @@ -3063,7 +3221,7 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { if (EDAA && EDAA->getState().isValidState()) { const auto &CalleeED = EDAA->getFunctionExecutionDomain(); ED.IsReachedFromAlignedBarrierOnly = - CalleeED.IsReachedFromAlignedBarrierOnly; + CalleeED.IsReachedFromAlignedBarrierOnly; AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly; if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly) ED.EncounteredNonLocalSideEffect |= @@ -3442,6 +3600,10 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> { using Base = StateWrapper<KernelInfoState, AbstractAttribute>; AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {} + /// The callee value is tracked beyond a simple stripPointerCasts, so we allow + /// unknown callees. + static bool requiresCalleeForCallBase() { return false; } + /// Statistics are tracked as part of manifest for now. void trackStatistics() const override {} @@ -3468,7 +3630,8 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> { ", #ParLevels: " + (ParallelLevels.isValidState() ? std::to_string(ParallelLevels.size()) - : "<invalid>"); + : "<invalid>") + + ", NestedPar: " + (NestedParallelism ? "yes" : "no"); } /// Create an abstract attribute biew for the position \p IRP. @@ -3500,6 +3663,33 @@ struct AAKernelInfoFunction : AAKernelInfo { return GuardedInstructions; } + void setConfigurationOfKernelEnvironment(ConstantStruct *ConfigC) { + Constant *NewKernelEnvC = ConstantFoldInsertValueInstruction( + KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx}); + assert(NewKernelEnvC && "Failed to create new kernel environment"); + KernelEnvC = cast<ConstantStruct>(NewKernelEnvC); + } + +#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER) \ + void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) { \ + ConstantStruct *ConfigC = \ + KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC); \ + Constant *NewConfigC = ConstantFoldInsertValueInstruction( \ + ConfigC, NewVal, {KernelInfo::MEMBER##Idx}); \ + assert(NewConfigC && "Failed to create new configuration environment"); \ + setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC)); \ + } + + KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(UseGenericStateMachine) + KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MayUseNestedParallelism) + KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(ExecMode) + KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MinThreads) + KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MaxThreads) + KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MinTeams) + KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MaxTeams) + +#undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER + /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { // This is a high-level transform that might change the constant arguments @@ -3548,61 +3738,73 @@ struct AAKernelInfoFunction : AAKernelInfo { ReachingKernelEntries.insert(Fn); IsKernelEntry = true; - // For kernels we might need to initialize/finalize the IsSPMD state and - // we need to register a simplification callback so that the Attributor - // knows the constant arguments to __kmpc_target_init and - // __kmpc_target_deinit might actually change. - - Attributor::SimplifictionCallbackTy StateMachineSimplifyCB = - [&](const IRPosition &IRP, const AbstractAttribute *AA, - bool &UsedAssumedInformation) -> std::optional<Value *> { - return nullptr; - }; + KernelEnvC = + KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB); + GlobalVariable *KernelEnvGV = + KernelInfo::getKernelEnvironementGVFromKernelInitCB(KernelInitCB); - Attributor::SimplifictionCallbackTy ModeSimplifyCB = - [&](const IRPosition &IRP, const AbstractAttribute *AA, - bool &UsedAssumedInformation) -> std::optional<Value *> { - // IRP represents the "SPMDCompatibilityTracker" argument of an - // __kmpc_target_init or - // __kmpc_target_deinit call. We will answer this one with the internal - // state. - if (!SPMDCompatibilityTracker.isValidState()) - return nullptr; - if (!SPMDCompatibilityTracker.isAtFixpoint()) { - if (AA) - A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); + Attributor::GlobalVariableSimplifictionCallbackTy + KernelConfigurationSimplifyCB = + [&](const GlobalVariable &GV, const AbstractAttribute *AA, + bool &UsedAssumedInformation) -> std::optional<Constant *> { + if (!isAtFixpoint()) { + if (!AA) + return nullptr; UsedAssumedInformation = true; - } else { - UsedAssumedInformation = false; + A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); } - auto *Val = ConstantInt::getSigned( - IntegerType::getInt8Ty(IRP.getAnchorValue().getContext()), - SPMDCompatibilityTracker.isAssumed() ? OMP_TGT_EXEC_MODE_SPMD - : OMP_TGT_EXEC_MODE_GENERIC); - return Val; + return KernelEnvC; }; - constexpr const int InitModeArgNo = 1; - constexpr const int DeinitModeArgNo = 1; - constexpr const int InitUseStateMachineArgNo = 2; - A.registerSimplificationCallback( - IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo), - StateMachineSimplifyCB); - A.registerSimplificationCallback( - IRPosition::callsite_argument(*KernelInitCB, InitModeArgNo), - ModeSimplifyCB); - A.registerSimplificationCallback( - IRPosition::callsite_argument(*KernelDeinitCB, DeinitModeArgNo), - ModeSimplifyCB); + A.registerGlobalVariableSimplificationCallback( + *KernelEnvGV, KernelConfigurationSimplifyCB); // Check if we know we are in SPMD-mode already. - ConstantInt *ModeArg = - dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo)); - if (ModeArg && (ModeArg->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)) + ConstantInt *ExecModeC = + KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC); + ConstantInt *AssumedExecModeC = ConstantInt::get( + ExecModeC->getType(), + ExecModeC->getSExtValue() | OMP_TGT_EXEC_MODE_GENERIC_SPMD); + if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD) SPMDCompatibilityTracker.indicateOptimisticFixpoint(); - // This is a generic region but SPMDization is disabled so stop tracking. else if (DisableOpenMPOptSPMDization) + // This is a generic region but SPMDization is disabled so stop + // tracking. SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + else + setExecModeOfKernelEnvironment(AssumedExecModeC); + + const Triple T(Fn->getParent()->getTargetTriple()); + auto *Int32Ty = Type::getInt32Ty(Fn->getContext()); + auto [MinThreads, MaxThreads] = + OpenMPIRBuilder::readThreadBoundsForKernel(T, *Fn); + if (MinThreads) + setMinThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinThreads)); + if (MaxThreads) + setMaxThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxThreads)); + auto [MinTeams, MaxTeams] = + OpenMPIRBuilder::readTeamBoundsForKernel(T, *Fn); + if (MinTeams) + setMinTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinTeams)); + if (MaxTeams) + setMaxTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxTeams)); + + ConstantInt *MayUseNestedParallelismC = + KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC); + ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get( + MayUseNestedParallelismC->getType(), NestedParallelism); + setMayUseNestedParallelismOfKernelEnvironment( + AssumedMayUseNestedParallelismC); + + if (!DisableOpenMPOptStateMachineRewrite) { + ConstantInt *UseGenericStateMachineC = + KernelInfo::getUseGenericStateMachineFromKernelEnvironment( + KernelEnvC); + ConstantInt *AssumedUseGenericStateMachineC = + ConstantInt::get(UseGenericStateMachineC->getType(), false); + setUseGenericStateMachineOfKernelEnvironment( + AssumedUseGenericStateMachineC); + } // Register virtual uses of functions we might need to preserve. auto RegisterVirtualUse = [&](RuntimeFunction RFKind, @@ -3703,22 +3905,32 @@ struct AAKernelInfoFunction : AAKernelInfo { if (!KernelInitCB || !KernelDeinitCB) return ChangeStatus::UNCHANGED; - /// Insert nested Parallelism global variable - Function *Kernel = getAnchorScope(); - Module &M = *Kernel->getParent(); - Type *Int8Ty = Type::getInt8Ty(M.getContext()); - auto *GV = new GlobalVariable( - M, Int8Ty, /* isConstant */ true, GlobalValue::WeakAnyLinkage, - ConstantInt::get(Int8Ty, NestedParallelism ? 1 : 0), - Kernel->getName() + "_nested_parallelism"); - GV->setVisibility(GlobalValue::HiddenVisibility); - - // If we can we change the execution mode to SPMD-mode otherwise we build a - // custom state machine. ChangeStatus Changed = ChangeStatus::UNCHANGED; + + bool HasBuiltStateMachine = true; if (!changeToSPMDMode(A, Changed)) { if (!KernelInitCB->getCalledFunction()->isDeclaration()) - return buildCustomStateMachine(A); + HasBuiltStateMachine = buildCustomStateMachine(A, Changed); + else + HasBuiltStateMachine = false; + } + + // We need to reset KernelEnvC if specific rewriting is not done. + ConstantStruct *ExistingKernelEnvC = + KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB); + ConstantInt *OldUseGenericStateMachineVal = + KernelInfo::getUseGenericStateMachineFromKernelEnvironment( + ExistingKernelEnvC); + if (!HasBuiltStateMachine) + setUseGenericStateMachineOfKernelEnvironment( + OldUseGenericStateMachineVal); + + // At last, update the KernelEnvc + GlobalVariable *KernelEnvGV = + KernelInfo::getKernelEnvironementGVFromKernelInitCB(KernelInitCB); + if (KernelEnvGV->getInitializer() != KernelEnvC) { + KernelEnvGV->setInitializer(KernelEnvC); + Changed = ChangeStatus::CHANGED; } return Changed; @@ -3788,14 +4000,14 @@ struct AAKernelInfoFunction : AAKernelInfo { // Find escaping outputs from the guarded region to outside users and // broadcast their values to them. for (Instruction &I : *RegionStartBB) { - SmallPtrSet<Instruction *, 4> OutsideUsers; - for (User *Usr : I.users()) { - Instruction &UsrI = *cast<Instruction>(Usr); + SmallVector<Use *, 4> OutsideUses; + for (Use &U : I.uses()) { + Instruction &UsrI = *cast<Instruction>(U.getUser()); if (UsrI.getParent() != RegionStartBB) - OutsideUsers.insert(&UsrI); + OutsideUses.push_back(&U); } - if (OutsideUsers.empty()) + if (OutsideUses.empty()) continue; HasBroadcastValues = true; @@ -3818,8 +4030,8 @@ struct AAKernelInfoFunction : AAKernelInfo { RegionBarrierBB->getTerminator()); // Emit a load instruction and replace uses of the output value. - for (Instruction *UsrI : OutsideUsers) - UsrI->replaceUsesOfWith(&I, LoadI); + for (Use *U : OutsideUses) + A.changeUseAfterManifest(*U, *LoadI); } auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); @@ -4043,19 +4255,14 @@ struct AAKernelInfoFunction : AAKernelInfo { auto *CB = cast<CallBase>(Kernel->user_back()); Kernel = CB->getCaller(); } - assert(omp::isKernel(*Kernel) && "Expected kernel function!"); + assert(omp::isOpenMPKernel(*Kernel) && "Expected kernel function!"); // Check if the kernel is already in SPMD mode, if so, return success. - GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable( - (Kernel->getName() + "_exec_mode").str()); - assert(ExecMode && "Kernel without exec mode?"); - assert(ExecMode->getInitializer() && "ExecMode doesn't have initializer!"); - - // Set the global exec mode flag to indicate SPMD-Generic mode. - assert(isa<ConstantInt>(ExecMode->getInitializer()) && - "ExecMode is not an integer!"); - const int8_t ExecModeVal = - cast<ConstantInt>(ExecMode->getInitializer())->getSExtValue(); + ConstantStruct *ExistingKernelEnvC = + KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB); + auto *ExecModeC = + KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC); + const int8_t ExecModeVal = ExecModeC->getSExtValue(); if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC) return true; @@ -4073,27 +4280,8 @@ struct AAKernelInfoFunction : AAKernelInfo { // kernel is executed in. assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC && "Initially non-SPMD kernel has SPMD exec mode!"); - ExecMode->setInitializer( - ConstantInt::get(ExecMode->getInitializer()->getType(), - ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD)); - - // Next rewrite the init and deinit calls to indicate we use SPMD-mode now. - const int InitModeArgNo = 1; - const int DeinitModeArgNo = 1; - const int InitUseStateMachineArgNo = 2; - - auto &Ctx = getAnchorValue().getContext(); - A.changeUseAfterManifest( - KernelInitCB->getArgOperandUse(InitModeArgNo), - *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx), - OMP_TGT_EXEC_MODE_SPMD)); - A.changeUseAfterManifest( - KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), - *ConstantInt::getBool(Ctx, false)); - A.changeUseAfterManifest( - KernelDeinitCB->getArgOperandUse(DeinitModeArgNo), - *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx), - OMP_TGT_EXEC_MODE_SPMD)); + setExecModeOfKernelEnvironment(ConstantInt::get( + ExecModeC->getType(), ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD)); ++NumOpenMPTargetRegionKernelsSPMD; @@ -4104,46 +4292,47 @@ struct AAKernelInfoFunction : AAKernelInfo { return true; }; - ChangeStatus buildCustomStateMachine(Attributor &A) { + bool buildCustomStateMachine(Attributor &A, ChangeStatus &Changed) { // If we have disabled state machine rewrites, don't make a custom one if (DisableOpenMPOptStateMachineRewrite) - return ChangeStatus::UNCHANGED; + return false; // Don't rewrite the state machine if we are not in a valid state. if (!ReachedKnownParallelRegions.isValidState()) - return ChangeStatus::UNCHANGED; + return false; auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); if (!OMPInfoCache.runtimeFnsAvailable( {OMPRTL___kmpc_get_hardware_num_threads_in_block, OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic, OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel})) - return ChangeStatus::UNCHANGED; + return false; - const int InitModeArgNo = 1; - const int InitUseStateMachineArgNo = 2; + ConstantStruct *ExistingKernelEnvC = + KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB); // Check if the current configuration is non-SPMD and generic state machine. // If we already have SPMD mode or a custom state machine we do not need to // go any further. If it is anything but a constant something is weird and // we give up. - ConstantInt *UseStateMachine = dyn_cast<ConstantInt>( - KernelInitCB->getArgOperand(InitUseStateMachineArgNo)); - ConstantInt *Mode = - dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo)); + ConstantInt *UseStateMachineC = + KernelInfo::getUseGenericStateMachineFromKernelEnvironment( + ExistingKernelEnvC); + ConstantInt *ModeC = + KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC); // If we are stuck with generic mode, try to create a custom device (=GPU) // state machine which is specialized for the parallel regions that are // reachable by the kernel. - if (!UseStateMachine || UseStateMachine->isZero() || !Mode || - (Mode->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)) - return ChangeStatus::UNCHANGED; + if (UseStateMachineC->isZero() || + (ModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)) + return false; + + Changed = ChangeStatus::CHANGED; // If not SPMD mode, indicate we use a custom state machine now. - auto &Ctx = getAnchorValue().getContext(); - auto *FalseVal = ConstantInt::getBool(Ctx, false); - A.changeUseAfterManifest( - KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal); + setUseGenericStateMachineOfKernelEnvironment( + ConstantInt::get(UseStateMachineC->getType(), false)); // If we don't actually need a state machine we are done here. This can // happen if there simply are no parallel regions. In the resulting kernel @@ -4157,7 +4346,7 @@ struct AAKernelInfoFunction : AAKernelInfo { }; A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark); - return ChangeStatus::CHANGED; + return true; } // Keep track in the statistics of our new shiny custom state machine. @@ -4222,6 +4411,7 @@ struct AAKernelInfoFunction : AAKernelInfo { // UserCodeEntryBB: // user code // __kmpc_target_deinit(...) // + auto &Ctx = getAnchorValue().getContext(); Function *Kernel = getAssociatedFunction(); assert(Kernel && "Expected an associated function!"); @@ -4292,7 +4482,7 @@ struct AAKernelInfoFunction : AAKernelInfo { // Create local storage for the work function pointer. const DataLayout &DL = M.getDataLayout(); - Type *VoidPtrTy = Type::getInt8PtrTy(Ctx); + Type *VoidPtrTy = PointerType::getUnqual(Ctx); Instruction *WorkFnAI = new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr, "worker.work_fn.addr", &Kernel->getEntryBlock().front()); @@ -4304,7 +4494,7 @@ struct AAKernelInfoFunction : AAKernelInfo { StateMachineBeginBB->end()), DLoc)); - Value *Ident = KernelInitCB->getArgOperand(0); + Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC); Value *GTid = KernelInitCB; FunctionCallee BarrierFn = @@ -4337,9 +4527,6 @@ struct AAKernelInfoFunction : AAKernelInfo { FunctionType *ParallelRegionFnTy = FunctionType::get( Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)}, false); - Value *WorkFnCast = BitCastInst::CreatePointerBitCastOrAddrSpaceCast( - WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast", - StateMachineBeginBB); Instruction *IsDone = ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, @@ -4358,11 +4545,15 @@ struct AAKernelInfoFunction : AAKernelInfo { Value *ZeroArg = Constant::getNullValue(ParallelRegionFnTy->getParamType(0)); + const unsigned int WrapperFunctionArgNo = 6; + // Now that we have most of the CFG skeleton it is time for the if-cascade // that checks the function pointer we got from the runtime against the // parallel regions we expect, if there are any. for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) { - auto *ParallelRegion = ReachedKnownParallelRegions[I]; + auto *CB = ReachedKnownParallelRegions[I]; + auto *ParallelRegion = dyn_cast<Function>( + CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts()); BasicBlock *PRExecuteBB = BasicBlock::Create( Ctx, "worker_state_machine.parallel_region.execute", Kernel, StateMachineEndParallelBB); @@ -4374,13 +4565,15 @@ struct AAKernelInfoFunction : AAKernelInfo { BasicBlock *PRNextBB = BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check", Kernel, StateMachineEndParallelBB); + A.registerManifestAddedBasicBlock(*PRExecuteBB); + A.registerManifestAddedBasicBlock(*PRNextBB); // Check if we need to compare the pointer at all or if we can just // call the parallel region function. Value *IsPR; if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) { Instruction *CmpI = ICmpInst::Create( - ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion, + ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, ParallelRegion, "worker.check_parallel_region", StateMachineIfCascadeCurrentBB); CmpI->setDebugLoc(DLoc); IsPR = CmpI; @@ -4400,7 +4593,7 @@ struct AAKernelInfoFunction : AAKernelInfo { if (!ReachedUnknownParallelRegions.empty()) { StateMachineIfCascadeCurrentBB->setName( "worker_state_machine.parallel_region.fallback.execute"); - CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "", + CallInst::Create(ParallelRegionFnTy, WorkFn, {ZeroArg, GTid}, "", StateMachineIfCascadeCurrentBB) ->setDebugLoc(DLoc); } @@ -4423,7 +4616,7 @@ struct AAKernelInfoFunction : AAKernelInfo { BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB) ->setDebugLoc(DLoc); - return ChangeStatus::CHANGED; + return true; } /// Fixpoint iteration update function. Will be called every time a dependence @@ -4431,6 +4624,46 @@ struct AAKernelInfoFunction : AAKernelInfo { ChangeStatus updateImpl(Attributor &A) override { KernelInfoState StateBefore = getState(); + // When we leave this function this RAII will make sure the member + // KernelEnvC is updated properly depending on the state. That member is + // used for simplification of values and needs to be up to date at all + // times. + struct UpdateKernelEnvCRAII { + AAKernelInfoFunction &AA; + + UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {} + + ~UpdateKernelEnvCRAII() { + if (!AA.KernelEnvC) + return; + + ConstantStruct *ExistingKernelEnvC = + KernelInfo::getKernelEnvironementFromKernelInitCB(AA.KernelInitCB); + + if (!AA.isValidState()) { + AA.KernelEnvC = ExistingKernelEnvC; + return; + } + + if (!AA.ReachedKnownParallelRegions.isValidState()) + AA.setUseGenericStateMachineOfKernelEnvironment( + KernelInfo::getUseGenericStateMachineFromKernelEnvironment( + ExistingKernelEnvC)); + + if (!AA.SPMDCompatibilityTracker.isValidState()) + AA.setExecModeOfKernelEnvironment( + KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC)); + + ConstantInt *MayUseNestedParallelismC = + KernelInfo::getMayUseNestedParallelismFromKernelEnvironment( + AA.KernelEnvC); + ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get( + MayUseNestedParallelismC->getType(), AA.NestedParallelism); + AA.setMayUseNestedParallelismOfKernelEnvironment( + NewMayUseNestedParallelismC); + } + } RAII(*this); + // Callback to check a read/write instruction. auto CheckRWInst = [&](Instruction &I) { // We handle calls later. @@ -4634,15 +4867,13 @@ struct AAKernelInfoCallSite : AAKernelInfo { AAKernelInfo::initialize(A); CallBase &CB = cast<CallBase>(getAssociatedValue()); - Function *Callee = getAssociatedFunction(); - auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>( *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL); // Check for SPMD-mode assumptions. if (AssumptionAA && AssumptionAA->hasAssumption("ompx_spmd_amenable")) { - SPMDCompatibilityTracker.indicateOptimisticFixpoint(); indicateOptimisticFixpoint(); + return; } // First weed out calls we do not care about, that is readonly/readnone @@ -4657,124 +4888,156 @@ struct AAKernelInfoCallSite : AAKernelInfo { // we will handle them explicitly in the switch below. If it is not, we // will use an AAKernelInfo object on the callee to gather information and // merge that into the current state. The latter happens in the updateImpl. - auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); - const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee); - if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) { - // Unknown caller or declarations are not analyzable, we give up. - if (!Callee || !A.isFunctionIPOAmendable(*Callee)) { + auto CheckCallee = [&](Function *Callee, unsigned NumCallees) { + auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); + const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee); + if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) { + // Unknown caller or declarations are not analyzable, we give up. + if (!Callee || !A.isFunctionIPOAmendable(*Callee)) { - // Unknown callees might contain parallel regions, except if they have - // an appropriate assumption attached. - if (!AssumptionAA || - !(AssumptionAA->hasAssumption("omp_no_openmp") || - AssumptionAA->hasAssumption("omp_no_parallelism"))) - ReachedUnknownParallelRegions.insert(&CB); + // Unknown callees might contain parallel regions, except if they have + // an appropriate assumption attached. + if (!AssumptionAA || + !(AssumptionAA->hasAssumption("omp_no_openmp") || + AssumptionAA->hasAssumption("omp_no_parallelism"))) + ReachedUnknownParallelRegions.insert(&CB); - // If SPMDCompatibilityTracker is not fixed, we need to give up on the - // idea we can run something unknown in SPMD-mode. - if (!SPMDCompatibilityTracker.isAtFixpoint()) { - SPMDCompatibilityTracker.indicatePessimisticFixpoint(); - SPMDCompatibilityTracker.insert(&CB); - } + // If SPMDCompatibilityTracker is not fixed, we need to give up on the + // idea we can run something unknown in SPMD-mode. + if (!SPMDCompatibilityTracker.isAtFixpoint()) { + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + SPMDCompatibilityTracker.insert(&CB); + } - // We have updated the state for this unknown call properly, there won't - // be any change so we indicate a fixpoint. - indicateOptimisticFixpoint(); + // We have updated the state for this unknown call properly, there + // won't be any change so we indicate a fixpoint. + indicateOptimisticFixpoint(); + } + // If the callee is known and can be used in IPO, we will update the + // state based on the callee state in updateImpl. + return; + } + if (NumCallees > 1) { + indicatePessimisticFixpoint(); + return; } - // If the callee is known and can be used in IPO, we will update the state - // based on the callee state in updateImpl. - return; - } - const unsigned int WrapperFunctionArgNo = 6; - RuntimeFunction RF = It->getSecond(); - switch (RF) { - // All the functions we know are compatible with SPMD mode. - case OMPRTL___kmpc_is_spmd_exec_mode: - case OMPRTL___kmpc_distribute_static_fini: - case OMPRTL___kmpc_for_static_fini: - case OMPRTL___kmpc_global_thread_num: - case OMPRTL___kmpc_get_hardware_num_threads_in_block: - case OMPRTL___kmpc_get_hardware_num_blocks: - case OMPRTL___kmpc_single: - case OMPRTL___kmpc_end_single: - case OMPRTL___kmpc_master: - case OMPRTL___kmpc_end_master: - case OMPRTL___kmpc_barrier: - case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2: - case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2: - case OMPRTL___kmpc_nvptx_end_reduce_nowait: - break; - case OMPRTL___kmpc_distribute_static_init_4: - case OMPRTL___kmpc_distribute_static_init_4u: - case OMPRTL___kmpc_distribute_static_init_8: - case OMPRTL___kmpc_distribute_static_init_8u: - case OMPRTL___kmpc_for_static_init_4: - case OMPRTL___kmpc_for_static_init_4u: - case OMPRTL___kmpc_for_static_init_8: - case OMPRTL___kmpc_for_static_init_8u: { - // Check the schedule and allow static schedule in SPMD mode. - unsigned ScheduleArgOpNo = 2; - auto *ScheduleTypeCI = - dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo)); - unsigned ScheduleTypeVal = - ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0; - switch (OMPScheduleType(ScheduleTypeVal)) { - case OMPScheduleType::UnorderedStatic: - case OMPScheduleType::UnorderedStaticChunked: - case OMPScheduleType::OrderedDistribute: - case OMPScheduleType::OrderedDistributeChunked: + RuntimeFunction RF = It->getSecond(); + switch (RF) { + // All the functions we know are compatible with SPMD mode. + case OMPRTL___kmpc_is_spmd_exec_mode: + case OMPRTL___kmpc_distribute_static_fini: + case OMPRTL___kmpc_for_static_fini: + case OMPRTL___kmpc_global_thread_num: + case OMPRTL___kmpc_get_hardware_num_threads_in_block: + case OMPRTL___kmpc_get_hardware_num_blocks: + case OMPRTL___kmpc_single: + case OMPRTL___kmpc_end_single: + case OMPRTL___kmpc_master: + case OMPRTL___kmpc_end_master: + case OMPRTL___kmpc_barrier: + case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2: + case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2: + case OMPRTL___kmpc_error: + case OMPRTL___kmpc_flush: + case OMPRTL___kmpc_get_hardware_thread_id_in_block: + case OMPRTL___kmpc_get_warp_size: + case OMPRTL_omp_get_thread_num: + case OMPRTL_omp_get_num_threads: + case OMPRTL_omp_get_max_threads: + case OMPRTL_omp_in_parallel: + case OMPRTL_omp_get_dynamic: + case OMPRTL_omp_get_cancellation: + case OMPRTL_omp_get_nested: + case OMPRTL_omp_get_schedule: + case OMPRTL_omp_get_thread_limit: + case OMPRTL_omp_get_supported_active_levels: + case OMPRTL_omp_get_max_active_levels: + case OMPRTL_omp_get_level: + case OMPRTL_omp_get_ancestor_thread_num: + case OMPRTL_omp_get_team_size: + case OMPRTL_omp_get_active_level: + case OMPRTL_omp_in_final: + case OMPRTL_omp_get_proc_bind: + case OMPRTL_omp_get_num_places: + case OMPRTL_omp_get_num_procs: + case OMPRTL_omp_get_place_proc_ids: + case OMPRTL_omp_get_place_num: + case OMPRTL_omp_get_partition_num_places: + case OMPRTL_omp_get_partition_place_nums: + case OMPRTL_omp_get_wtime: break; - default: + case OMPRTL___kmpc_distribute_static_init_4: + case OMPRTL___kmpc_distribute_static_init_4u: + case OMPRTL___kmpc_distribute_static_init_8: + case OMPRTL___kmpc_distribute_static_init_8u: + case OMPRTL___kmpc_for_static_init_4: + case OMPRTL___kmpc_for_static_init_4u: + case OMPRTL___kmpc_for_static_init_8: + case OMPRTL___kmpc_for_static_init_8u: { + // Check the schedule and allow static schedule in SPMD mode. + unsigned ScheduleArgOpNo = 2; + auto *ScheduleTypeCI = + dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo)); + unsigned ScheduleTypeVal = + ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0; + switch (OMPScheduleType(ScheduleTypeVal)) { + case OMPScheduleType::UnorderedStatic: + case OMPScheduleType::UnorderedStaticChunked: + case OMPScheduleType::OrderedDistribute: + case OMPScheduleType::OrderedDistributeChunked: + break; + default: + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + SPMDCompatibilityTracker.insert(&CB); + break; + }; + } break; + case OMPRTL___kmpc_target_init: + KernelInitCB = &CB; + break; + case OMPRTL___kmpc_target_deinit: + KernelDeinitCB = &CB; + break; + case OMPRTL___kmpc_parallel_51: + if (!handleParallel51(A, CB)) + indicatePessimisticFixpoint(); + return; + case OMPRTL___kmpc_omp_task: + // We do not look into tasks right now, just give up. SPMDCompatibilityTracker.indicatePessimisticFixpoint(); SPMDCompatibilityTracker.insert(&CB); + ReachedUnknownParallelRegions.insert(&CB); break; - }; - } break; - case OMPRTL___kmpc_target_init: - KernelInitCB = &CB; - break; - case OMPRTL___kmpc_target_deinit: - KernelDeinitCB = &CB; - break; - case OMPRTL___kmpc_parallel_51: - if (auto *ParallelRegion = dyn_cast<Function>( - CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) { - ReachedKnownParallelRegions.insert(ParallelRegion); - /// Check nested parallelism - auto *FnAA = A.getAAFor<AAKernelInfo>( - *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL); - NestedParallelism |= !FnAA || !FnAA->getState().isValidState() || - !FnAA->ReachedKnownParallelRegions.empty() || - !FnAA->ReachedUnknownParallelRegions.empty(); + case OMPRTL___kmpc_alloc_shared: + case OMPRTL___kmpc_free_shared: + // Return without setting a fixpoint, to be resolved in updateImpl. + return; + default: + // Unknown OpenMP runtime calls cannot be executed in SPMD-mode, + // generally. However, they do not hide parallel regions. + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + SPMDCompatibilityTracker.insert(&CB); break; } - // The condition above should usually get the parallel region function - // pointer and record it. In the off chance it doesn't we assume the - // worst. - ReachedUnknownParallelRegions.insert(&CB); - break; - case OMPRTL___kmpc_omp_task: - // We do not look into tasks right now, just give up. - SPMDCompatibilityTracker.indicatePessimisticFixpoint(); - SPMDCompatibilityTracker.insert(&CB); - ReachedUnknownParallelRegions.insert(&CB); - break; - case OMPRTL___kmpc_alloc_shared: - case OMPRTL___kmpc_free_shared: - // Return without setting a fixpoint, to be resolved in updateImpl. + // All other OpenMP runtime calls will not reach parallel regions so they + // can be safely ignored for now. Since it is a known OpenMP runtime call + // we have now modeled all effects and there is no need for any update. + indicateOptimisticFixpoint(); + }; + + const auto *AACE = + A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL); + if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) { + CheckCallee(getAssociatedFunction(), 1); return; - default: - // Unknown OpenMP runtime calls cannot be executed in SPMD-mode, - // generally. However, they do not hide parallel regions. - SPMDCompatibilityTracker.indicatePessimisticFixpoint(); - SPMDCompatibilityTracker.insert(&CB); - break; } - // All other OpenMP runtime calls will not reach parallel regions so they - // can be safely ignored for now. Since it is a known OpenMP runtime call we - // have now modeled all effects and there is no need for any update. - indicateOptimisticFixpoint(); + const auto &OptimisticEdges = AACE->getOptimisticEdges(); + for (auto *Callee : OptimisticEdges) { + CheckCallee(Callee, OptimisticEdges.size()); + if (isAtFixpoint()) + break; + } } ChangeStatus updateImpl(Attributor &A) override { @@ -4782,62 +5045,115 @@ struct AAKernelInfoCallSite : AAKernelInfo { // call site specific liveness information and then it makes // sense to specialize attributes for call sites arguments instead of // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); - const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F); + KernelInfoState StateBefore = getState(); - // If F is not a runtime function, propagate the AAKernelInfo of the callee. - if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) { - const IRPosition &FnPos = IRPosition::function(*F); - auto *FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED); - if (!FnAA) + auto CheckCallee = [&](Function *F, int NumCallees) { + const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F); + + // If F is not a runtime function, propagate the AAKernelInfo of the + // callee. + if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) { + const IRPosition &FnPos = IRPosition::function(*F); + auto *FnAA = + A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED); + if (!FnAA) + return indicatePessimisticFixpoint(); + if (getState() == FnAA->getState()) + return ChangeStatus::UNCHANGED; + getState() = FnAA->getState(); + return ChangeStatus::CHANGED; + } + if (NumCallees > 1) return indicatePessimisticFixpoint(); - if (getState() == FnAA->getState()) - return ChangeStatus::UNCHANGED; - getState() = FnAA->getState(); - return ChangeStatus::CHANGED; - } - // F is a runtime function that allocates or frees memory, check - // AAHeapToStack and AAHeapToShared. - KernelInfoState StateBefore = getState(); - assert((It->getSecond() == OMPRTL___kmpc_alloc_shared || - It->getSecond() == OMPRTL___kmpc_free_shared) && - "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call"); + CallBase &CB = cast<CallBase>(getAssociatedValue()); + if (It->getSecond() == OMPRTL___kmpc_parallel_51) { + if (!handleParallel51(A, CB)) + return indicatePessimisticFixpoint(); + return StateBefore == getState() ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + } - CallBase &CB = cast<CallBase>(getAssociatedValue()); + // F is a runtime function that allocates or frees memory, check + // AAHeapToStack and AAHeapToShared. + assert( + (It->getSecond() == OMPRTL___kmpc_alloc_shared || + It->getSecond() == OMPRTL___kmpc_free_shared) && + "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call"); - auto *HeapToStackAA = A.getAAFor<AAHeapToStack>( - *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL); - auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>( - *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL); + auto *HeapToStackAA = A.getAAFor<AAHeapToStack>( + *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL); + auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>( + *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL); - RuntimeFunction RF = It->getSecond(); + RuntimeFunction RF = It->getSecond(); - switch (RF) { - // If neither HeapToStack nor HeapToShared assume the call is removed, - // assume SPMD incompatibility. - case OMPRTL___kmpc_alloc_shared: - if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) && - (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB))) - SPMDCompatibilityTracker.insert(&CB); - break; - case OMPRTL___kmpc_free_shared: - if ((!HeapToStackAA || - !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) && - (!HeapToSharedAA || - !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB))) + switch (RF) { + // If neither HeapToStack nor HeapToShared assume the call is removed, + // assume SPMD incompatibility. + case OMPRTL___kmpc_alloc_shared: + if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) && + (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB))) + SPMDCompatibilityTracker.insert(&CB); + break; + case OMPRTL___kmpc_free_shared: + if ((!HeapToStackAA || + !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) && + (!HeapToSharedAA || + !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB))) + SPMDCompatibilityTracker.insert(&CB); + break; + default: + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); SPMDCompatibilityTracker.insert(&CB); - break; - default: - SPMDCompatibilityTracker.indicatePessimisticFixpoint(); - SPMDCompatibilityTracker.insert(&CB); + } + return ChangeStatus::CHANGED; + }; + + const auto *AACE = + A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL); + if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) { + if (Function *F = getAssociatedFunction()) + CheckCallee(F, /*NumCallees=*/1); + } else { + const auto &OptimisticEdges = AACE->getOptimisticEdges(); + for (auto *Callee : OptimisticEdges) { + CheckCallee(Callee, OptimisticEdges.size()); + if (isAtFixpoint()) + break; + } } return StateBefore == getState() ? ChangeStatus::UNCHANGED : ChangeStatus::CHANGED; } + + /// Deal with a __kmpc_parallel_51 call (\p CB). Returns true if the call was + /// handled, if a problem occurred, false is returned. + bool handleParallel51(Attributor &A, CallBase &CB) { + const unsigned int NonWrapperFunctionArgNo = 5; + const unsigned int WrapperFunctionArgNo = 6; + auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed() + ? NonWrapperFunctionArgNo + : WrapperFunctionArgNo; + + auto *ParallelRegion = dyn_cast<Function>( + CB.getArgOperand(ParallelRegionOpArgNo)->stripPointerCasts()); + if (!ParallelRegion) + return false; + + ReachedKnownParallelRegions.insert(&CB); + /// Check nested parallelism + auto *FnAA = A.getAAFor<AAKernelInfo>( + *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL); + NestedParallelism |= !FnAA || !FnAA->getState().isValidState() || + !FnAA->ReachedKnownParallelRegions.empty() || + !FnAA->ReachedKnownParallelRegions.isValidState() || + !FnAA->ReachedUnknownParallelRegions.isValidState() || + !FnAA->ReachedUnknownParallelRegions.empty(); + return true; + } }; struct AAFoldRuntimeCall @@ -5251,6 +5567,11 @@ void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) { UsedAssumedInformation, AA::Interprocedural); continue; } + if (auto *CI = dyn_cast<CallBase>(&I)) { + if (CI->isIndirectCall()) + A.getOrCreateAAFor<AAIndirectCallInfo>( + IRPosition::callsite_function(*CI)); + } if (auto *SI = dyn_cast<StoreInst>(&I)) { A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI)); continue; @@ -5569,7 +5890,9 @@ PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C, return PreservedAnalyses::all(); } -bool llvm::omp::isKernel(Function &Fn) { return Fn.hasFnAttribute("kernel"); } +bool llvm::omp::isOpenMPKernel(Function &Fn) { + return Fn.hasFnAttribute("kernel"); +} KernelSet llvm::omp::getDeviceKernels(Module &M) { // TODO: Create a more cross-platform way of determining device kernels. @@ -5591,10 +5914,13 @@ KernelSet llvm::omp::getDeviceKernels(Module &M) { if (!KernelFn) continue; - assert(isKernel(*KernelFn) && "Inconsistent kernel function annotation"); - ++NumOpenMPTargetRegionKernels; - - Kernels.insert(KernelFn); + // We are only interested in OpenMP target regions. Others, such as kernels + // generated by CUDA but linked together, are not interesting to this pass. + if (isOpenMPKernel(*KernelFn)) { + ++NumOpenMPTargetRegionKernels; + Kernels.insert(KernelFn); + } else + ++NumNonOpenMPTargetRegionKernels; } return Kernels; |
