diff options
Diffstat (limited to 'llvm/lib/Transforms/IPO/OpenMPOpt.cpp')
| -rw-r--r-- | llvm/lib/Transforms/IPO/OpenMPOpt.cpp | 719 |
1 files changed, 595 insertions, 124 deletions
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp index b80349352719..f342c35fa283 100644 --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -22,6 +22,7 @@ #include "llvm/ADT/EnumeratedArray.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraphSCCPass.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" @@ -33,6 +34,8 @@ #include "llvm/IR/GlobalValue.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IntrinsicsAMDGPU.h" +#include "llvm/IR/IntrinsicsNVPTX.h" #include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/IPO.h" @@ -41,6 +44,8 @@ #include "llvm/Transforms/Utils/CallGraphUpdater.h" #include "llvm/Transforms/Utils/CodeExtractor.h" +#include <algorithm> + using namespace llvm; using namespace omp; @@ -72,6 +77,46 @@ static cl::opt<bool> HideMemoryTransferLatency( " transfers"), cl::Hidden, cl::init(false)); +static cl::opt<bool> DisableOpenMPOptDeglobalization( + "openmp-opt-disable-deglobalization", cl::ZeroOrMore, + cl::desc("Disable OpenMP optimizations involving deglobalization."), + cl::Hidden, cl::init(false)); + +static cl::opt<bool> DisableOpenMPOptSPMDization( + "openmp-opt-disable-spmdization", cl::ZeroOrMore, + cl::desc("Disable OpenMP optimizations involving SPMD-ization."), + cl::Hidden, cl::init(false)); + +static cl::opt<bool> DisableOpenMPOptFolding( + "openmp-opt-disable-folding", cl::ZeroOrMore, + cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden, + cl::init(false)); + +static cl::opt<bool> DisableOpenMPOptStateMachineRewrite( + "openmp-opt-disable-state-machine-rewrite", cl::ZeroOrMore, + cl::desc("Disable OpenMP optimizations that replace the state machine."), + cl::Hidden, cl::init(false)); + +static cl::opt<bool> PrintModuleAfterOptimizations( + "openmp-opt-print-module", cl::ZeroOrMore, + cl::desc("Print the current module after OpenMP optimizations."), + cl::Hidden, cl::init(false)); + +static cl::opt<bool> AlwaysInlineDeviceFunctions( + "openmp-opt-inline-device", cl::ZeroOrMore, + cl::desc("Inline all applicible functions on the device."), cl::Hidden, + cl::init(false)); + +static cl::opt<bool> + EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::ZeroOrMore, + cl::desc("Enables more verbose remarks."), cl::Hidden, + cl::init(false)); + +static cl::opt<unsigned> + SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden, + cl::desc("Maximal number of attributor iterations."), + cl::init(256)); + STATISTIC(NumOpenMPRuntimeCallsDeduplicated, "Number of OpenMP runtime calls deduplicated"); STATISTIC(NumOpenMPParallelRegionsDeleted, @@ -328,7 +373,7 @@ struct OMPInformationCache : public InformationCache { if (F->arg_size() != RTFArgTypes.size()) return false; - auto RTFTyIt = RTFArgTypes.begin(); + auto *RTFTyIt = RTFArgTypes.begin(); for (Argument &Arg : F->args()) { if (Arg.getType() != *RTFTyIt) return false; @@ -503,7 +548,7 @@ struct KernelInfoState : AbstractState { /// State to track if we are in SPMD-mode, assumed or know, and why we decided /// we cannot be. If it is assumed, then RequiresFullRuntime should also be /// false. - BooleanStateWithPtrSetVector<Instruction> SPMDCompatibilityTracker; + BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker; /// The __kmpc_target_init call in this kernel, if any. If we find more than /// one we abort as the kernel is malformed. @@ -542,7 +587,9 @@ struct KernelInfoState : AbstractState { /// See AbstractState::indicatePessimisticFixpoint(...) ChangeStatus indicatePessimisticFixpoint() override { IsAtFixpoint = true; + ReachingKernelEntries.indicatePessimisticFixpoint(); SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + ReachedKnownParallelRegions.indicatePessimisticFixpoint(); ReachedUnknownParallelRegions.indicatePessimisticFixpoint(); return ChangeStatus::CHANGED; } @@ -550,6 +597,10 @@ struct KernelInfoState : AbstractState { /// See AbstractState::indicateOptimisticFixpoint(...) ChangeStatus indicateOptimisticFixpoint() override { IsAtFixpoint = true; + ReachingKernelEntries.indicateOptimisticFixpoint(); + SPMDCompatibilityTracker.indicateOptimisticFixpoint(); + ReachedKnownParallelRegions.indicateOptimisticFixpoint(); + ReachedUnknownParallelRegions.indicateOptimisticFixpoint(); return ChangeStatus::UNCHANGED; } @@ -569,6 +620,12 @@ struct KernelInfoState : AbstractState { return true; } + /// Returns true if this kernel contains any OpenMP parallel regions. + bool mayContainParallelRegion() { + return !ReachedKnownParallelRegions.empty() || + !ReachedUnknownParallelRegions.empty(); + } + /// Return empty set as the best state of potential values. static KernelInfoState getBestState() { return KernelInfoState(true); } @@ -584,12 +641,14 @@ struct KernelInfoState : AbstractState { // Do not merge two different _init and _deinit call sites. if (KIS.KernelInitCB) { if (KernelInitCB && KernelInitCB != KIS.KernelInitCB) - indicatePessimisticFixpoint(); + llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt " + "assumptions."); KernelInitCB = KIS.KernelInitCB; } if (KIS.KernelDeinitCB) { if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB) - indicatePessimisticFixpoint(); + llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt " + "assumptions."); KernelDeinitCB = KIS.KernelDeinitCB; } SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker; @@ -1032,8 +1091,8 @@ private: Args.clear(); Args.push_back(OutlinedFn->getArg(0)); Args.push_back(OutlinedFn->getArg(1)); - for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands(); - U < E; ++U) + for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E; + ++U) Args.push_back(CI->getArgOperand(U)); CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI); @@ -1041,9 +1100,9 @@ private: NewCI->setDebugLoc(CI->getDebugLoc()); // Forward parameter attributes from the callback to the callee. - for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands(); - U < E; ++U) - for (const Attribute &A : CI->getAttributes().getParamAttributes(U)) + for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E; + ++U) + for (const Attribute &A : CI->getAttributes().getParamAttrs(U)) NewCI->addParamAttr( U - (CallbackFirstArgOperand - CallbackCalleeOperand), A); @@ -1563,13 +1622,13 @@ private: // TODO: Use dominance to find a good position instead. auto CanBeMoved = [this](CallBase &CB) { - unsigned NumArgs = CB.getNumArgOperands(); + unsigned NumArgs = CB.arg_size(); if (NumArgs == 0) return true; if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr) return false; - for (unsigned u = 1; u < NumArgs; ++u) - if (isa<Instruction>(CB.getArgOperand(u))) + for (unsigned U = 1; U < NumArgs; ++U) + if (isa<Instruction>(CB.getArgOperand(U))) return false; return true; }; @@ -1612,7 +1671,7 @@ private: // valid at the new location. For now we just pick a global one, either // existing and used by one of the calls, or created from scratch. if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) { - if (CI->getNumArgOperands() > 0 && + if (!CI->arg_empty() && CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) { Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F, /* GlobalOnly */ true); @@ -1695,8 +1754,8 @@ private: // Transitively search for more arguments by looking at the users of the // ones we know already. During the search the GTIdArgs vector is extended // so we cannot cache the size nor can we use a range based for. - for (unsigned u = 0; u < GTIdArgs.size(); ++u) - AddUserArgs(*GTIdArgs[u]); + for (unsigned U = 0; U < GTIdArgs.size(); ++U) + AddUserArgs(*GTIdArgs[U]); } /// Kernel (=GPU) optimizations and utility functions @@ -1822,6 +1881,10 @@ private: OMPRTL___kmpc_kernel_end_parallel); ExternalizationRAII BarrierSPMD(OMPInfoCache, OMPRTL___kmpc_barrier_simple_spmd); + ExternalizationRAII BarrierGeneric(OMPInfoCache, + OMPRTL___kmpc_barrier_simple_generic); + ExternalizationRAII ThreadId(OMPInfoCache, + OMPRTL___kmpc_get_hardware_thread_id_in_block); registerAAs(IsModulePass); @@ -1918,6 +1981,10 @@ bool OpenMPOpt::rewriteDeviceCodeStateMachine() { if (!KernelParallelRFI) return Changed; + // If we have disabled state machine changes, exit + if (DisableOpenMPOptStateMachineRewrite) + return Changed; + for (Function *F : SCC) { // Check if the function is a use in a __kmpc_parallel_51 call at @@ -1996,7 +2063,8 @@ bool OpenMPOpt::rewriteDeviceCodeStateMachine() { UndefValue::get(Int8Ty), F->getName() + ".ID"); for (Use *U : ToBeReplacedStateMachineUses) - U->set(ConstantExpr::getBitCast(ID, U->get()->getType())); + U->set(ConstantExpr::getPointerBitCastOrAddrSpaceCast( + ID, U->get()->getType())); ++NumOpenMPParallelRegionsReplacedInGPUStateMachine; @@ -2508,9 +2576,8 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; - // Check if the edge into the successor block compares the __kmpc_target_init - // result with -1. If we are in non-SPMD-mode that signals only the main - // thread will execute the edge. + // Check if the edge into the successor block contains a condition that only + // lets the main thread execute it. auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) { if (!Edge || !Edge->isConditional()) return false; @@ -2525,16 +2592,27 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { if (!C) return false; - // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!) + // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!) if (C->isAllOnesValue()) { auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0)); CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr; if (!CB) return false; - const int InitIsSPMDArgNo = 1; - auto *IsSPMDModeCI = - dyn_cast<ConstantInt>(CB->getOperand(InitIsSPMDArgNo)); - return IsSPMDModeCI && IsSPMDModeCI->isZero(); + const int InitModeArgNo = 1; + auto *ModeCI = dyn_cast<ConstantInt>(CB->getOperand(InitModeArgNo)); + return ModeCI && (ModeCI->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC); + } + + if (C->isZero()) { + // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x() + if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0))) + if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x) + return true; + + // Match: 0 == llvm.amdgcn.workitem.id.x() + if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0))) + if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x) + return true; } return false; @@ -2543,15 +2621,14 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { // Merge all the predecessor states into the current basic block. A basic // block is executed by a single thread if all of its predecessors are. auto MergePredecessorStates = [&](BasicBlock *BB) { - if (pred_begin(BB) == pred_end(BB)) + if (pred_empty(BB)) return SingleThreadedBBs.contains(BB); bool IsInitialThread = true; - for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB); - PredBB != PredEndBB; ++PredBB) { - if (!IsInitialThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()), + for (BasicBlock *PredBB : predecessors(BB)) { + if (!IsInitialThreadOnly(dyn_cast<BranchInst>(PredBB->getTerminator()), BB)) - IsInitialThread &= SingleThreadedBBs.contains(*PredBB); + IsInitialThread &= SingleThreadedBBs.contains(PredBB); } return IsInitialThread; @@ -2683,9 +2760,8 @@ struct AAHeapToSharedFunction : public AAHeapToShared { ConstantInt *AllocSize = dyn_cast<ConstantInt>(CB->getArgOperand(0)); - LLVM_DEBUG(dbgs() << TAG << "Replace globalization call in " - << CB->getCaller()->getName() << " with " - << AllocSize->getZExtValue() + LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB + << " with " << AllocSize->getZExtValue() << " bytes of shared memory\n"); // Create a new shared memory buffer of the same size as the allocation @@ -2734,7 +2810,7 @@ struct AAHeapToSharedFunction : public AAHeapToShared { const auto &ED = A.getAAFor<AAExecutionDomain>( *this, IRPosition::function(*F), DepClassTy::REQUIRED); if (CallBase *CB = dyn_cast<CallBase>(U)) - if (!dyn_cast<ConstantInt>(CB->getArgOperand(0)) || + if (!isa<ConstantInt>(CB->getArgOperand(0)) || !ED.isExecutedByInitialThreadOnly(*CB)) MallocCalls.erase(CB); } @@ -2769,9 +2845,17 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> { std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]" : "") + std::string(" #PRs: ") + - std::to_string(ReachedKnownParallelRegions.size()) + + (ReachedKnownParallelRegions.isValidState() + ? std::to_string(ReachedKnownParallelRegions.size()) + : "<invalid>") + ", #Unknown PRs: " + - std::to_string(ReachedUnknownParallelRegions.size()); + (ReachedUnknownParallelRegions.isValidState() + ? std::to_string(ReachedUnknownParallelRegions.size()) + : "<invalid>") + + ", #Reaching Kernels: " + + (ReachingKernelEntries.isValidState() + ? std::to_string(ReachingKernelEntries.size()) + : "<invalid>"); } /// Create an abstract attribute biew for the position \p IRP. @@ -2797,6 +2881,12 @@ struct AAKernelInfoFunction : AAKernelInfo { AAKernelInfoFunction(const IRPosition &IRP, Attributor &A) : AAKernelInfo(IRP, A) {} + SmallPtrSet<Instruction *, 4> GuardedInstructions; + + SmallPtrSetImpl<Instruction *> &getGuardedInstructions() { + return GuardedInstructions; + } + /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { // This is a high-level transform that might change the constant arguments @@ -2843,8 +2933,11 @@ struct AAKernelInfoFunction : AAKernelInfo { }, Fn); - assert((KernelInitCB && KernelDeinitCB) && - "Kernel without __kmpc_target_init or __kmpc_target_deinit!"); + // Ignore kernels without initializers such as global constructors. + if (!KernelInitCB || !KernelDeinitCB) { + indicateOptimisticFixpoint(); + return; + } // For kernels we might need to initialize/finalize the IsSPMD state and // we need to register a simplification callback so that the Attributor @@ -2859,7 +2952,10 @@ struct AAKernelInfoFunction : AAKernelInfo { // state. As long as we are not in an invalid state, we will create a // custom state machine so the value should be a `i1 false`. If we are // in an invalid state, we won't change the value that is in the IR. - if (!isValidState()) + if (!ReachedKnownParallelRegions.isValidState()) + return nullptr; + // If we have disabled state machine rewrites, don't make a custom one. + if (DisableOpenMPOptStateMachineRewrite) return nullptr; if (AA) A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); @@ -2869,7 +2965,7 @@ struct AAKernelInfoFunction : AAKernelInfo { return FalseVal; }; - Attributor::SimplifictionCallbackTy IsSPMDModeSimplifyCB = + Attributor::SimplifictionCallbackTy ModeSimplifyCB = [&](const IRPosition &IRP, const AbstractAttribute *AA, bool &UsedAssumedInformation) -> Optional<Value *> { // IRP represents the "SPMDCompatibilityTracker" argument of an @@ -2885,8 +2981,10 @@ struct AAKernelInfoFunction : AAKernelInfo { } else { UsedAssumedInformation = false; } - auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(), - SPMDCompatibilityTracker.isAssumed()); + auto *Val = ConstantInt::getSigned( + IntegerType::getInt8Ty(IRP.getAnchorValue().getContext()), + SPMDCompatibilityTracker.isAssumed() ? OMP_TGT_EXEC_MODE_SPMD + : OMP_TGT_EXEC_MODE_GENERIC); return Val; }; @@ -2911,8 +3009,8 @@ struct AAKernelInfoFunction : AAKernelInfo { return Val; }; - constexpr const int InitIsSPMDArgNo = 1; - constexpr const int DeinitIsSPMDArgNo = 1; + constexpr const int InitModeArgNo = 1; + constexpr const int DeinitModeArgNo = 1; constexpr const int InitUseStateMachineArgNo = 2; constexpr const int InitRequiresFullRuntimeArgNo = 3; constexpr const int DeinitRequiresFullRuntimeArgNo = 2; @@ -2920,11 +3018,11 @@ struct AAKernelInfoFunction : AAKernelInfo { IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo), StateMachineSimplifyCB); A.registerSimplificationCallback( - IRPosition::callsite_argument(*KernelInitCB, InitIsSPMDArgNo), - IsSPMDModeSimplifyCB); + IRPosition::callsite_argument(*KernelInitCB, InitModeArgNo), + ModeSimplifyCB); A.registerSimplificationCallback( - IRPosition::callsite_argument(*KernelDeinitCB, DeinitIsSPMDArgNo), - IsSPMDModeSimplifyCB); + IRPosition::callsite_argument(*KernelDeinitCB, DeinitModeArgNo), + ModeSimplifyCB); A.registerSimplificationCallback( IRPosition::callsite_argument(*KernelInitCB, InitRequiresFullRuntimeArgNo), @@ -2935,10 +3033,25 @@ struct AAKernelInfoFunction : AAKernelInfo { IsGenericModeSimplifyCB); // Check if we know we are in SPMD-mode already. - ConstantInt *IsSPMDArg = - dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo)); - if (IsSPMDArg && !IsSPMDArg->isZero()) + ConstantInt *ModeArg = + dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo)); + if (ModeArg && (ModeArg->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)) SPMDCompatibilityTracker.indicateOptimisticFixpoint(); + // This is a generic region but SPMDization is disabled so stop tracking. + else if (DisableOpenMPOptSPMDization) + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + } + + /// Sanitize the string \p S such that it is a suitable global symbol name. + static std::string sanitizeForGlobalName(std::string S) { + std::replace_if( + S.begin(), S.end(), + [](const char C) { + return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') || + (C >= '0' && C <= '9') || C == '_'); + }, + '.'); + return S; } /// Modify the IR based on the KernelInfoState as the fixpoint iteration is @@ -2949,19 +3062,16 @@ struct AAKernelInfoFunction : AAKernelInfo { if (!KernelInitCB || !KernelDeinitCB) return ChangeStatus::UNCHANGED; - // Known SPMD-mode kernels need no manifest changes. - if (SPMDCompatibilityTracker.isKnown()) - return ChangeStatus::UNCHANGED; - // If we can we change the execution mode to SPMD-mode otherwise we build a // custom state machine. - if (!changeToSPMDMode(A)) - buildCustomStateMachine(A); + ChangeStatus Changed = ChangeStatus::UNCHANGED; + if (!changeToSPMDMode(A, Changed)) + return buildCustomStateMachine(A); - return ChangeStatus::CHANGED; + return Changed; } - bool changeToSPMDMode(Attributor &A) { + bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) { auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); if (!SPMDCompatibilityTracker.isAssumed()) { @@ -2993,38 +3103,259 @@ struct AAKernelInfoFunction : AAKernelInfo { return false; } - // Adjust the global exec mode flag that tells the runtime what mode this - // kernel is executed in. + // Check if the kernel is already in SPMD mode, if so, return success. Function *Kernel = getAnchorScope(); GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable( (Kernel->getName() + "_exec_mode").str()); assert(ExecMode && "Kernel without exec mode?"); - assert(ExecMode->getInitializer() && - ExecMode->getInitializer()->isOneValue() && - "Initially non-SPMD kernel has SPMD exec mode!"); + assert(ExecMode->getInitializer() && "ExecMode doesn't have initializer!"); // Set the global exec mode flag to indicate SPMD-Generic mode. - constexpr int SPMDGeneric = 2; - if (!ExecMode->getInitializer()->isZeroValue()) - ExecMode->setInitializer( - ConstantInt::get(ExecMode->getInitializer()->getType(), SPMDGeneric)); + assert(isa<ConstantInt>(ExecMode->getInitializer()) && + "ExecMode is not an integer!"); + const int8_t ExecModeVal = + cast<ConstantInt>(ExecMode->getInitializer())->getSExtValue(); + if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC) + return true; + + // We will now unconditionally modify the IR, indicate a change. + Changed = ChangeStatus::CHANGED; + + auto CreateGuardedRegion = [&](Instruction *RegionStartI, + Instruction *RegionEndI) { + LoopInfo *LI = nullptr; + DominatorTree *DT = nullptr; + MemorySSAUpdater *MSU = nullptr; + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + + BasicBlock *ParentBB = RegionStartI->getParent(); + Function *Fn = ParentBB->getParent(); + Module &M = *Fn->getParent(); + + // Create all the blocks and logic. + // ParentBB: + // goto RegionCheckTidBB + // RegionCheckTidBB: + // Tid = __kmpc_hardware_thread_id() + // if (Tid != 0) + // goto RegionBarrierBB + // RegionStartBB: + // <execute instructions guarded> + // goto RegionEndBB + // RegionEndBB: + // <store escaping values to shared mem> + // goto RegionBarrierBB + // RegionBarrierBB: + // __kmpc_simple_barrier_spmd() + // // second barrier is omitted if lacking escaping values. + // <load escaping values from shared mem> + // __kmpc_simple_barrier_spmd() + // goto RegionExitBB + // RegionExitBB: + // <execute rest of instructions> + + BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(), + DT, LI, MSU, "region.guarded.end"); + BasicBlock *RegionBarrierBB = + SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI, + MSU, "region.barrier"); + BasicBlock *RegionExitBB = + SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(), + DT, LI, MSU, "region.exit"); + BasicBlock *RegionStartBB = + SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded"); + + assert(ParentBB->getUniqueSuccessor() == RegionStartBB && + "Expected a different CFG"); + + BasicBlock *RegionCheckTidBB = SplitBlock( + ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid"); + + // Register basic blocks with the Attributor. + A.registerManifestAddedBasicBlock(*RegionEndBB); + A.registerManifestAddedBasicBlock(*RegionBarrierBB); + A.registerManifestAddedBasicBlock(*RegionExitBB); + A.registerManifestAddedBasicBlock(*RegionStartBB); + A.registerManifestAddedBasicBlock(*RegionCheckTidBB); + + bool HasBroadcastValues = false; + // Find escaping outputs from the guarded region to outside users and + // broadcast their values to them. + for (Instruction &I : *RegionStartBB) { + SmallPtrSet<Instruction *, 4> OutsideUsers; + for (User *Usr : I.users()) { + Instruction &UsrI = *cast<Instruction>(Usr); + if (UsrI.getParent() != RegionStartBB) + OutsideUsers.insert(&UsrI); + } + + if (OutsideUsers.empty()) + continue; + + HasBroadcastValues = true; + + // Emit a global variable in shared memory to store the broadcasted + // value. + auto *SharedMem = new GlobalVariable( + M, I.getType(), /* IsConstant */ false, + GlobalValue::InternalLinkage, UndefValue::get(I.getType()), + sanitizeForGlobalName( + (I.getName() + ".guarded.output.alloc").str()), + nullptr, GlobalValue::NotThreadLocal, + static_cast<unsigned>(AddressSpace::Shared)); + + // Emit a store instruction to update the value. + new StoreInst(&I, SharedMem, RegionEndBB->getTerminator()); + + LoadInst *LoadI = new LoadInst(I.getType(), SharedMem, + I.getName() + ".guarded.output.load", + RegionBarrierBB->getTerminator()); + + // Emit a load instruction and replace uses of the output value. + for (Instruction *UsrI : OutsideUsers) + UsrI->replaceUsesOfWith(&I, LoadI); + } + + auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); + + // Go to tid check BB in ParentBB. + const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc(); + ParentBB->getTerminator()->eraseFromParent(); + OpenMPIRBuilder::LocationDescription Loc( + InsertPointTy(ParentBB, ParentBB->end()), DL); + OMPInfoCache.OMPBuilder.updateToLocation(Loc); + auto *SrcLocStr = OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc); + Value *Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr); + BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL); + + // Add check for Tid in RegionCheckTidBB + RegionCheckTidBB->getTerminator()->eraseFromParent(); + OpenMPIRBuilder::LocationDescription LocRegionCheckTid( + InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL); + OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid); + FunctionCallee HardwareTidFn = + OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( + M, OMPRTL___kmpc_get_hardware_thread_id_in_block); + Value *Tid = + OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {}); + Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid); + OMPInfoCache.OMPBuilder.Builder + .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB) + ->setDebugLoc(DL); + + // First barrier for synchronization, ensures main thread has updated + // values. + FunctionCallee BarrierFn = + OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( + M, OMPRTL___kmpc_barrier_simple_spmd); + OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy( + RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt())); + OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid}) + ->setDebugLoc(DL); + + // Second barrier ensures workers have read broadcast values. + if (HasBroadcastValues) + CallInst::Create(BarrierFn, {Ident, Tid}, "", + RegionBarrierBB->getTerminator()) + ->setDebugLoc(DL); + }; + + auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared]; + SmallPtrSet<BasicBlock *, 8> Visited; + for (Instruction *GuardedI : SPMDCompatibilityTracker) { + BasicBlock *BB = GuardedI->getParent(); + if (!Visited.insert(BB).second) + continue; + + SmallVector<std::pair<Instruction *, Instruction *>> Reorders; + Instruction *LastEffect = nullptr; + BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend(); + while (++IP != IPEnd) { + if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory()) + continue; + Instruction *I = &*IP; + if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI)) + continue; + if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) { + LastEffect = nullptr; + continue; + } + if (LastEffect) + Reorders.push_back({I, LastEffect}); + LastEffect = &*IP; + } + for (auto &Reorder : Reorders) + Reorder.first->moveBefore(Reorder.second); + } + + SmallVector<std::pair<Instruction *, Instruction *>, 4> GuardedRegions; + + for (Instruction *GuardedI : SPMDCompatibilityTracker) { + BasicBlock *BB = GuardedI->getParent(); + auto *CalleeAA = A.lookupAAFor<AAKernelInfo>( + IRPosition::function(*GuardedI->getFunction()), nullptr, + DepClassTy::NONE); + assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo"); + auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA); + // Continue if instruction is already guarded. + if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI)) + continue; + + Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr; + for (Instruction &I : *BB) { + // If instruction I needs to be guarded update the guarded region + // bounds. + if (SPMDCompatibilityTracker.contains(&I)) { + CalleeAAFunction.getGuardedInstructions().insert(&I); + if (GuardedRegionStart) + GuardedRegionEnd = &I; + else + GuardedRegionStart = GuardedRegionEnd = &I; + + continue; + } + + // Instruction I does not need guarding, store + // any region found and reset bounds. + if (GuardedRegionStart) { + GuardedRegions.push_back( + std::make_pair(GuardedRegionStart, GuardedRegionEnd)); + GuardedRegionStart = nullptr; + GuardedRegionEnd = nullptr; + } + } + } + + for (auto &GR : GuardedRegions) + CreateGuardedRegion(GR.first, GR.second); + + // Adjust the global exec mode flag that tells the runtime what mode this + // kernel is executed in. + assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC && + "Initially non-SPMD kernel has SPMD exec mode!"); + ExecMode->setInitializer( + ConstantInt::get(ExecMode->getInitializer()->getType(), + ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD)); // Next rewrite the init and deinit calls to indicate we use SPMD-mode now. - const int InitIsSPMDArgNo = 1; - const int DeinitIsSPMDArgNo = 1; + const int InitModeArgNo = 1; + const int DeinitModeArgNo = 1; const int InitUseStateMachineArgNo = 2; const int InitRequiresFullRuntimeArgNo = 3; const int DeinitRequiresFullRuntimeArgNo = 2; auto &Ctx = getAnchorValue().getContext(); - A.changeUseAfterManifest(KernelInitCB->getArgOperandUse(InitIsSPMDArgNo), - *ConstantInt::getBool(Ctx, 1)); + A.changeUseAfterManifest( + KernelInitCB->getArgOperandUse(InitModeArgNo), + *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx), + OMP_TGT_EXEC_MODE_SPMD)); A.changeUseAfterManifest( KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *ConstantInt::getBool(Ctx, 0)); A.changeUseAfterManifest( - KernelDeinitCB->getArgOperandUse(DeinitIsSPMDArgNo), - *ConstantInt::getBool(Ctx, 1)); + KernelDeinitCB->getArgOperandUse(DeinitModeArgNo), + *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx), + OMP_TGT_EXEC_MODE_SPMD)); A.changeUseAfterManifest( KernelInitCB->getArgOperandUse(InitRequiresFullRuntimeArgNo), *ConstantInt::getBool(Ctx, 0)); @@ -3042,10 +3373,15 @@ struct AAKernelInfoFunction : AAKernelInfo { }; ChangeStatus buildCustomStateMachine(Attributor &A) { - assert(ReachedKnownParallelRegions.isValidState() && - "Custom state machine with invalid parallel region states?"); + // If we have disabled state machine rewrites, don't make a custom one + if (DisableOpenMPOptStateMachineRewrite) + return ChangeStatus::UNCHANGED; + + // Don't rewrite the state machine if we are not in a valid state. + if (!ReachedKnownParallelRegions.isValidState()) + return ChangeStatus::UNCHANGED; - const int InitIsSPMDArgNo = 1; + const int InitModeArgNo = 1; const int InitUseStateMachineArgNo = 2; // Check if the current configuration is non-SPMD and generic state machine. @@ -3054,14 +3390,14 @@ struct AAKernelInfoFunction : AAKernelInfo { // we give up. ConstantInt *UseStateMachine = dyn_cast<ConstantInt>( KernelInitCB->getArgOperand(InitUseStateMachineArgNo)); - ConstantInt *IsSPMD = - dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo)); + ConstantInt *Mode = + dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo)); // If we are stuck with generic mode, try to create a custom device (=GPU) // state machine which is specialized for the parallel regions that are // reachable by the kernel. - if (!UseStateMachine || UseStateMachine->isZero() || !IsSPMD || - !IsSPMD->isZero()) + if (!UseStateMachine || UseStateMachine->isZero() || !Mode || + (Mode->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)) return ChangeStatus::UNCHANGED; // If not SPMD mode, indicate we use a custom state machine now. @@ -3074,8 +3410,7 @@ struct AAKernelInfoFunction : AAKernelInfo { // happen if there simply are no parallel regions. In the resulting kernel // all worker threads will simply exit right away, leaving the main thread // to do the work alone. - if (ReachedKnownParallelRegions.empty() && - ReachedUnknownParallelRegions.empty()) { + if (!mayContainParallelRegion()) { ++NumOpenMPTargetRegionKernelsWithoutStateMachine; auto Remark = [&](OptimizationRemark OR) { @@ -3121,9 +3456,14 @@ struct AAKernelInfoFunction : AAKernelInfo { // Create all the blocks: // // InitCB = __kmpc_target_init(...) - // bool IsWorker = InitCB >= 0; + // BlockHwSize = + // __kmpc_get_hardware_num_threads_in_block(); + // WarpSize = __kmpc_get_warp_size(); + // BlockSize = BlockHwSize - WarpSize; + // if (InitCB >= BlockSize) return; + // IsWorkerCheckBB: bool IsWorker = InitCB >= 0; // if (IsWorker) { - // SMBeginBB: __kmpc_barrier_simple_spmd(...); + // SMBeginBB: __kmpc_barrier_simple_generic(...); // void *WorkFn; // bool Active = __kmpc_kernel_parallel(&WorkFn); // if (!WorkFn) return; @@ -3137,7 +3477,7 @@ struct AAKernelInfoFunction : AAKernelInfo { // ((WorkFnTy*)WorkFn)(...); // SMEndParallelBB: __kmpc_kernel_end_parallel(...); // } - // SMDoneBB: __kmpc_barrier_simple_spmd(...); + // SMDoneBB: __kmpc_barrier_simple_generic(...); // goto SMBeginBB; // } // UserCodeEntryBB: // user code @@ -3149,6 +3489,8 @@ struct AAKernelInfoFunction : AAKernelInfo { BasicBlock *InitBB = KernelInitCB->getParent(); BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock( KernelInitCB->getNextNode(), "thread.user_code.check"); + BasicBlock *IsWorkerCheckBB = + BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB); BasicBlock *StateMachineBeginBB = BasicBlock::Create( Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB); BasicBlock *StateMachineFinishedBB = BasicBlock::Create( @@ -3165,6 +3507,7 @@ struct AAKernelInfoFunction : AAKernelInfo { Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB); A.registerManifestAddedBasicBlock(*InitBB); A.registerManifestAddedBasicBlock(*UserCodeEntryBB); + A.registerManifestAddedBasicBlock(*IsWorkerCheckBB); A.registerManifestAddedBasicBlock(*StateMachineBeginBB); A.registerManifestAddedBasicBlock(*StateMachineFinishedBB); A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB); @@ -3174,22 +3517,47 @@ struct AAKernelInfoFunction : AAKernelInfo { const DebugLoc &DLoc = KernelInitCB->getDebugLoc(); ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc); - InitBB->getTerminator()->eraseFromParent(); + + Module &M = *Kernel->getParent(); + auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); + FunctionCallee BlockHwSizeFn = + OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( + M, OMPRTL___kmpc_get_hardware_num_threads_in_block); + FunctionCallee WarpSizeFn = + OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( + M, OMPRTL___kmpc_get_warp_size); + Instruction *BlockHwSize = + CallInst::Create(BlockHwSizeFn, "block.hw_size", InitBB); + BlockHwSize->setDebugLoc(DLoc); + Instruction *WarpSize = CallInst::Create(WarpSizeFn, "warp.size", InitBB); + WarpSize->setDebugLoc(DLoc); + Instruction *BlockSize = + BinaryOperator::CreateSub(BlockHwSize, WarpSize, "block.size", InitBB); + BlockSize->setDebugLoc(DLoc); + Instruction *IsMainOrWorker = + ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, + BlockSize, "thread.is_main_or_worker", InitBB); + IsMainOrWorker->setDebugLoc(DLoc); + BranchInst::Create(IsWorkerCheckBB, StateMachineFinishedBB, IsMainOrWorker, + InitBB); + Instruction *IsWorker = ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB, ConstantInt::get(KernelInitCB->getType(), -1), - "thread.is_worker", InitBB); + "thread.is_worker", IsWorkerCheckBB); IsWorker->setDebugLoc(DLoc); - BranchInst::Create(StateMachineBeginBB, UserCodeEntryBB, IsWorker, InitBB); + BranchInst::Create(StateMachineBeginBB, UserCodeEntryBB, IsWorker, + IsWorkerCheckBB); // Create local storage for the work function pointer. + const DataLayout &DL = M.getDataLayout(); Type *VoidPtrTy = Type::getInt8PtrTy(Ctx); - AllocaInst *WorkFnAI = new AllocaInst(VoidPtrTy, 0, "worker.work_fn.addr", - &Kernel->getEntryBlock().front()); + Instruction *WorkFnAI = + new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr, + "worker.work_fn.addr", &Kernel->getEntryBlock().front()); WorkFnAI->setDebugLoc(DLoc); - auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); OMPInfoCache.OMPBuilder.updateToLocation( OpenMPIRBuilder::LocationDescription( IRBuilder<>::InsertPoint(StateMachineBeginBB, @@ -3199,13 +3567,23 @@ struct AAKernelInfoFunction : AAKernelInfo { Value *Ident = KernelInitCB->getArgOperand(0); Value *GTid = KernelInitCB; - Module &M = *Kernel->getParent(); FunctionCallee BarrierFn = OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( - M, OMPRTL___kmpc_barrier_simple_spmd); + M, OMPRTL___kmpc_barrier_simple_generic); CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB) ->setDebugLoc(DLoc); + if (WorkFnAI->getType()->getPointerAddressSpace() != + (unsigned int)AddressSpace::Generic) { + WorkFnAI = new AddrSpaceCastInst( + WorkFnAI, + PointerType::getWithSamePointeeType( + cast<PointerType>(WorkFnAI->getType()), + (unsigned int)AddressSpace::Generic), + WorkFnAI->getName() + ".generic", StateMachineBeginBB); + WorkFnAI->setDebugLoc(DLoc); + } + FunctionCallee KernelParallelFn = OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( M, OMPRTL___kmpc_kernel_parallel); @@ -3243,8 +3621,8 @@ struct AAKernelInfoFunction : AAKernelInfo { // Now that we have most of the CFG skeleton it is time for the if-cascade // that checks the function pointer we got from the runtime against the // parallel regions we expect, if there are any. - for (int i = 0, e = ReachedKnownParallelRegions.size(); i < e; ++i) { - auto *ParallelRegion = ReachedKnownParallelRegions[i]; + for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) { + auto *ParallelRegion = ReachedKnownParallelRegions[I]; BasicBlock *PRExecuteBB = BasicBlock::Create( Ctx, "worker_state_machine.parallel_region.execute", Kernel, StateMachineEndParallelBB); @@ -3260,7 +3638,7 @@ struct AAKernelInfoFunction : AAKernelInfo { // Check if we need to compare the pointer at all or if we can just // call the parallel region function. Value *IsPR; - if (i + 1 < e || !ReachedUnknownParallelRegions.empty()) { + if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) { Instruction *CmpI = ICmpInst::Create( ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion, "worker.check_parallel_region", StateMachineIfCascadeCurrentBB); @@ -3324,8 +3702,21 @@ struct AAKernelInfoFunction : AAKernelInfo { if (llvm::all_of(Objects, [](const Value *Obj) { return isa<AllocaInst>(Obj); })) return true; + // Check for AAHeapToStack moved objects which must not be guarded. + auto &HS = A.getAAFor<AAHeapToStack>( + *this, IRPosition::function(*I.getFunction()), + DepClassTy::OPTIONAL); + if (llvm::all_of(Objects, [&HS](const Value *Obj) { + auto *CB = dyn_cast<CallBase>(Obj); + if (!CB) + return false; + return HS.isAssumedHeapToStack(*CB); + })) { + return true; + } } - // For now we give up on everything but stores. + + // Insert instruction that needs guarding. SPMDCompatibilityTracker.insert(&I); return true; }; @@ -3339,9 +3730,13 @@ struct AAKernelInfoFunction : AAKernelInfo { if (!IsKernelEntry) { updateReachingKernelEntries(A); updateParallelLevels(A); + + if (!ParallelLevels.isValidState()) + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); } // Callback to check a call instruction. + bool AllParallelRegionStatesWereFixed = true; bool AllSPMDStatesWereFixed = true; auto CheckCallInst = [&](Instruction &I) { auto &CB = cast<CallBase>(I); @@ -3349,13 +3744,37 @@ struct AAKernelInfoFunction : AAKernelInfo { *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL); getState() ^= CBAA.getState(); AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint(); + AllParallelRegionStatesWereFixed &= + CBAA.ReachedKnownParallelRegions.isAtFixpoint(); + AllParallelRegionStatesWereFixed &= + CBAA.ReachedUnknownParallelRegions.isAtFixpoint(); return true; }; bool UsedAssumedInformationInCheckCallInst = false; if (!A.checkForAllCallLikeInstructions( - CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) + CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) { + LLVM_DEBUG(dbgs() << TAG + << "Failed to visit all call-like instructions!\n";); return indicatePessimisticFixpoint(); + } + + // If we haven't used any assumed information for the reached parallel + // region states we can fix it. + if (!UsedAssumedInformationInCheckCallInst && + AllParallelRegionStatesWereFixed) { + ReachedKnownParallelRegions.indicateOptimisticFixpoint(); + ReachedUnknownParallelRegions.indicateOptimisticFixpoint(); + } + + // If we are sure there are no parallel regions in the kernel we do not + // want SPMD mode. + if (IsKernelEntry && ReachedUnknownParallelRegions.isAtFixpoint() && + ReachedKnownParallelRegions.isAtFixpoint() && + ReachedUnknownParallelRegions.isValidState() && + ReachedKnownParallelRegions.isValidState() && + !mayContainParallelRegion()) + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); // If we haven't used any assumed information for the SPMD state we can fix // it. @@ -3454,14 +3873,14 @@ struct AAKernelInfoCallSite : AAKernelInfo { CallBase &CB = cast<CallBase>(getAssociatedValue()); Function *Callee = getAssociatedFunction(); - // Helper to lookup an assumption string. - auto HasAssumption = [](Function *Fn, StringRef AssumptionStr) { - return Fn && hasAssumption(*Fn, AssumptionStr); - }; + auto &AssumptionAA = A.getAAFor<AAAssumptionInfo>( + *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL); // Check for SPMD-mode assumptions. - if (HasAssumption(Callee, "ompx_spmd_amenable")) + if (AssumptionAA.hasAssumption("ompx_spmd_amenable")) { SPMDCompatibilityTracker.indicateOptimisticFixpoint(); + indicateOptimisticFixpoint(); + } // First weed out calls we do not care about, that is readonly/readnone // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a @@ -3483,14 +3902,16 @@ struct AAKernelInfoCallSite : AAKernelInfo { // Unknown callees might contain parallel regions, except if they have // an appropriate assumption attached. - if (!(HasAssumption(Callee, "omp_no_openmp") || - HasAssumption(Callee, "omp_no_parallelism"))) + if (!(AssumptionAA.hasAssumption("omp_no_openmp") || + AssumptionAA.hasAssumption("omp_no_parallelism"))) ReachedUnknownParallelRegions.insert(&CB); // If SPMDCompatibilityTracker is not fixed, we need to give up on the // idea we can run something unknown in SPMD-mode. - if (!SPMDCompatibilityTracker.isAtFixpoint()) + if (!SPMDCompatibilityTracker.isAtFixpoint()) { + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); SPMDCompatibilityTracker.insert(&CB); + } // We have updated the state for this unknown call properly, there won't // be any change so we indicate a fixpoint. @@ -3506,6 +3927,7 @@ struct AAKernelInfoCallSite : AAKernelInfo { switch (RF) { // All the functions we know are compatible with SPMD mode. case OMPRTL___kmpc_is_spmd_exec_mode: + case OMPRTL___kmpc_distribute_static_fini: case OMPRTL___kmpc_for_static_fini: case OMPRTL___kmpc_global_thread_num: case OMPRTL___kmpc_get_hardware_num_threads_in_block: @@ -3516,6 +3938,10 @@ struct AAKernelInfoCallSite : AAKernelInfo { case OMPRTL___kmpc_end_master: case OMPRTL___kmpc_barrier: break; + case OMPRTL___kmpc_distribute_static_init_4: + case OMPRTL___kmpc_distribute_static_init_4u: + case OMPRTL___kmpc_distribute_static_init_8: + case OMPRTL___kmpc_distribute_static_init_8u: case OMPRTL___kmpc_for_static_init_4: case OMPRTL___kmpc_for_static_init_4u: case OMPRTL___kmpc_for_static_init_8: @@ -3533,6 +3959,7 @@ struct AAKernelInfoCallSite : AAKernelInfo { case OMPScheduleType::DistributeChunked: break; default: + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); SPMDCompatibilityTracker.insert(&CB); break; }; @@ -3565,7 +3992,7 @@ struct AAKernelInfoCallSite : AAKernelInfo { return; default: // Unknown OpenMP runtime calls cannot be executed in SPMD-mode, - // generally. + // generally. However, they do not hide parallel regions. SPMDCompatibilityTracker.insert(&CB); break; } @@ -3685,6 +4112,9 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall { } void initialize(Attributor &A) override { + if (DisableOpenMPOptFolding) + indicatePessimisticFixpoint(); + Function *Callee = getAssociatedFunction(); auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); @@ -3741,11 +4171,24 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall { ChangeStatus Changed = ChangeStatus::UNCHANGED; if (SimplifiedValue.hasValue() && SimplifiedValue.getValue()) { - Instruction &CB = *getCtxI(); - A.changeValueAfterManifest(CB, **SimplifiedValue); - A.deleteAfterManifest(CB); + Instruction &I = *getCtxI(); + A.changeValueAfterManifest(I, **SimplifiedValue); + A.deleteAfterManifest(I); + + CallBase *CB = dyn_cast<CallBase>(&I); + auto Remark = [&](OptimizationRemark OR) { + if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue)) + return OR << "Replacing OpenMP runtime call " + << CB->getCalledFunction()->getName() << " with " + << ore::NV("FoldedValue", C->getZExtValue()) << "."; + return OR << "Replacing OpenMP runtime call " + << CB->getCalledFunction()->getName() << "."; + }; + + if (CB && EnableVerboseRemarks) + A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark); - LLVM_DEBUG(dbgs() << TAG << "Folding runtime call: " << CB << " with " + LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with " << **SimplifiedValue << "\n"); Changed = ChangeStatus::CHANGED; @@ -3979,7 +4422,6 @@ void OpenMPOpt::registerAAs(bool IsModulePass) { DepClassTy::NONE, /* ForceUpdate */ false, /* UpdateAfterInit */ false); - registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id); registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode); registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level); @@ -4012,7 +4454,8 @@ void OpenMPOpt::registerAAs(bool IsModulePass) { A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F)); return false; }; - GlobalizationRFI.foreachUse(SCC, CreateAA); + if (!DisableOpenMPOptDeglobalization) + GlobalizationRFI.foreachUse(SCC, CreateAA); // Create an ExecutionDomain AA for every function and a HeapToStack AA for // every function if there is a device kernel. @@ -4024,7 +4467,8 @@ void OpenMPOpt::registerAAs(bool IsModulePass) { continue; A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(*F)); - A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F)); + if (!DisableOpenMPOptDeglobalization) + A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F)); for (auto &I : instructions(*F)) { if (auto *LI = dyn_cast<LoadInst>(&I)) { @@ -4176,28 +4620,32 @@ PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) { ORE.emit([&]() { OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F); return ORA << "Could not internalize function. " - << "Some optimizations may not be possible."; + << "Some optimizations may not be possible. [OMP140]"; }); }; // Create internal copies of each function if this is a kernel Module. This // allows iterprocedural passes to see every call edge. - DenseSet<const Function *> InternalizedFuncs; - if (isOpenMPDevice(M)) + DenseMap<Function *, Function *> InternalizedMap; + if (isOpenMPDevice(M)) { + SmallPtrSet<Function *, 16> InternalizeFns; for (Function &F : M) if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) && !DisableInternalization) { - if (Attributor::internalizeFunction(F, /* Force */ true)) { - InternalizedFuncs.insert(&F); + if (Attributor::isInternalizable(F)) { + InternalizeFns.insert(&F); } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) { EmitRemark(F); } } + Attributor::internalizeFunctions(InternalizeFns, InternalizedMap); + } + // Look at every function in the Module unless it was internalized. SmallVector<Function *, 16> SCC; for (Function &F : M) - if (!F.isDeclaration() && !InternalizedFuncs.contains(&F)) + if (!F.isDeclaration() && !InternalizedMap.lookup(&F)) SCC.push_back(&F); if (SCC.empty()) @@ -4215,12 +4663,24 @@ PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) { SetVector<Function *> Functions(SCC.begin(), SCC.end()); OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, Kernels); - unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32; + unsigned MaxFixpointIterations = + (isOpenMPDevice(M)) ? SetFixpointIterations : 32; Attributor A(Functions, InfoCache, CGUpdater, nullptr, true, false, MaxFixpointIterations, OREGetter, DEBUG_TYPE); OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); bool Changed = OMPOpt.run(true); + + // Optionally inline device functions for potentially better performance. + if (AlwaysInlineDeviceFunctions && isOpenMPDevice(M)) + for (Function &F : M) + if (!F.isDeclaration() && !Kernels.contains(&F) && + !F.hasFnAttribute(Attribute::NoInline)) + F.addFnAttr(Attribute::AlwaysInline); + + if (PrintModuleAfterOptimizations) + LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M); + if (Changed) return PreservedAnalyses::none(); @@ -4267,12 +4727,17 @@ PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C, OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator, /*CGSCC*/ Functions, Kernels); - unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32; + unsigned MaxFixpointIterations = + (isOpenMPDevice(M)) ? SetFixpointIterations : 32; Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true, MaxFixpointIterations, OREGetter, DEBUG_TYPE); OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); bool Changed = OMPOpt.run(false); + + if (PrintModuleAfterOptimizations) + LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M); + if (Changed) return PreservedAnalyses::none(); @@ -4333,12 +4798,18 @@ struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass { Allocator, /*CGSCC*/ Functions, Kernels); - unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32; + unsigned MaxFixpointIterations = + (isOpenMPDevice(M)) ? SetFixpointIterations : 32; Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true, MaxFixpointIterations, OREGetter, DEBUG_TYPE); OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); - return OMPOpt.run(false); + bool Result = OMPOpt.run(false); + + if (PrintModuleAfterOptimizations) + LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M); + + return Result; } bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); } |
