diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/IPO')
29 files changed, 6624 insertions, 1461 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/AlwaysInliner.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/AlwaysInliner.cpp index 53f9512f86f3..532599b42e0d 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/AlwaysInliner.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/AlwaysInliner.cpp @@ -13,8 +13,10 @@ #include "llvm/Transforms/IPO/AlwaysInliner.h" #include "llvm/ADT/SetVector.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InlineCost.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/CallingConv.h" #include "llvm/IR/DataLayout.h" @@ -39,12 +41,19 @@ PreservedAnalyses AlwaysInlinerPass::run(Module &M, auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & { return FAM.getResult<AssumptionAnalysis>(F); }; - InlineFunctionInfo IFI(/*cg=*/nullptr, GetAssumptionCache); + auto &PSI = MAM.getResult<ProfileSummaryAnalysis>(M); SmallSetVector<CallBase *, 16> Calls; bool Changed = false; SmallVector<Function *, 16> InlinedFunctions; - for (Function &F : M) + for (Function &F : M) { + // When callee coroutine function is inlined into caller coroutine function + // before coro-split pass, + // coro-early pass can not handle this quiet well. + // So we won't inline the coroutine function if it have not been unsplited + if (F.isPresplitCoroutine()) + continue; + if (!F.isDeclaration() && F.hasFnAttribute(Attribute::AlwaysInline) && isInlineViable(F).isSuccess()) { Calls.clear(); @@ -54,18 +63,41 @@ PreservedAnalyses AlwaysInlinerPass::run(Module &M, if (CB->getCalledFunction() == &F) Calls.insert(CB); - for (CallBase *CB : Calls) - // FIXME: We really shouldn't be able to fail to inline at this point! - // We should do something to log or check the inline failures here. - Changed |= - InlineFunction(*CB, IFI, /*CalleeAAR=*/nullptr, InsertLifetime) - .isSuccess(); + for (CallBase *CB : Calls) { + Function *Caller = CB->getCaller(); + OptimizationRemarkEmitter ORE(Caller); + auto OIC = shouldInline( + *CB, + [&](CallBase &CB) { + return InlineCost::getAlways("always inline attribute"); + }, + ORE); + assert(OIC); + emitInlinedInto(ORE, CB->getDebugLoc(), CB->getParent(), F, *Caller, + *OIC, false, DEBUG_TYPE); + + InlineFunctionInfo IFI( + /*cg=*/nullptr, GetAssumptionCache, &PSI, + &FAM.getResult<BlockFrequencyAnalysis>(*(CB->getCaller())), + &FAM.getResult<BlockFrequencyAnalysis>(F)); + + InlineResult Res = InlineFunction( + *CB, IFI, &FAM.getResult<AAManager>(F), InsertLifetime); + assert(Res.isSuccess() && "unexpected failure to inline"); + (void)Res; + + // Merge the attributes based on the inlining. + AttributeFuncs::mergeAttributesForInlining(*Caller, F); + + Changed = true; + } // Remember to try and delete this function afterward. This both avoids // re-walking the rest of the module and avoids dealing with any iterator // invalidation issues while deleting functions. InlinedFunctions.push_back(&F); } + } // Remove any live functions. erase_if(InlinedFunctions, [&](Function *F) { @@ -158,6 +190,13 @@ InlineCost AlwaysInlinerLegacyPass::getInlineCost(CallBase &CB) { if (!Callee) return InlineCost::getNever("indirect call"); + // When callee coroutine function is inlined into caller coroutine function + // before coro-split pass, + // coro-early pass can not handle this quiet well. + // So we won't inline the coroutine function if it have not been unsplited + if (Callee->isPresplitCoroutine()) + return InlineCost::getNever("unsplited coroutine call"); + // FIXME: We shouldn't even get here for declarations. if (Callee->isDeclaration()) return InlineCost::getNever("no definition"); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/Annotation2Metadata.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/Annotation2Metadata.cpp new file mode 100644 index 000000000000..5ca4e24df8fc --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/Annotation2Metadata.cpp @@ -0,0 +1,106 @@ +//===-- Annotation2Metadata.cpp - Add !annotation metadata. ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Add !annotation metadata for entries in @llvm.global.anotations, generated +// using __attribute__((annotate("_name"))) on functions in Clang. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO/Annotation2Metadata.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Module.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/IPO.h" + +using namespace llvm; + +#define DEBUG_TYPE "annotation2metadata" + +static bool convertAnnotation2Metadata(Module &M) { + // Only add !annotation metadata if the corresponding remarks pass is also + // enabled. + if (!OptimizationRemarkEmitter::allowExtraAnalysis(M.getContext(), + "annotation-remarks")) + return false; + + auto *Annotations = M.getGlobalVariable("llvm.global.annotations"); + auto *C = dyn_cast_or_null<Constant>(Annotations); + if (!C || C->getNumOperands() != 1) + return false; + + C = cast<Constant>(C->getOperand(0)); + + // Iterate over all entries in C and attach !annotation metadata to suitable + // entries. + for (auto &Op : C->operands()) { + // Look at the operands to check if we can use the entry to generate + // !annotation metadata. + auto *OpC = dyn_cast<ConstantStruct>(&Op); + if (!OpC || OpC->getNumOperands() != 4) + continue; + auto *StrGEP = dyn_cast<ConstantExpr>(OpC->getOperand(1)); + if (!StrGEP || StrGEP->getNumOperands() < 2) + continue; + auto *StrC = dyn_cast<GlobalValue>(StrGEP->getOperand(0)); + if (!StrC) + continue; + auto *StrData = dyn_cast<ConstantDataSequential>(StrC->getOperand(0)); + if (!StrData) + continue; + // Look through bitcast. + auto *Bitcast = dyn_cast<ConstantExpr>(OpC->getOperand(0)); + if (!Bitcast || Bitcast->getOpcode() != Instruction::BitCast) + continue; + auto *Fn = dyn_cast<Function>(Bitcast->getOperand(0)); + if (!Fn) + continue; + + // Add annotation to all instructions in the function. + for (auto &I : instructions(Fn)) + I.addAnnotationMetadata(StrData->getAsCString()); + } + return true; +} + +namespace { +struct Annotation2MetadataLegacy : public ModulePass { + static char ID; + + Annotation2MetadataLegacy() : ModulePass(ID) { + initializeAnnotation2MetadataLegacyPass(*PassRegistry::getPassRegistry()); + } + + bool runOnModule(Module &M) override { return convertAnnotation2Metadata(M); } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + } +}; + +} // end anonymous namespace + +char Annotation2MetadataLegacy::ID = 0; + +INITIALIZE_PASS_BEGIN(Annotation2MetadataLegacy, DEBUG_TYPE, + "Annotation2Metadata", false, false) +INITIALIZE_PASS_END(Annotation2MetadataLegacy, DEBUG_TYPE, + "Annotation2Metadata", false, false) + +ModulePass *llvm::createAnnotation2MetadataLegacyPass() { + return new Annotation2MetadataLegacy(); +} + +PreservedAnalyses Annotation2MetadataPass::run(Module &M, + ModuleAnalysisManager &AM) { + convertAnnotation2Metadata(M); + return PreservedAnalyses::all(); +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp index ad0d7eb51507..7998a1ae5c6e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -33,11 +33,11 @@ #include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Twine.h" -#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/CGSCCPassManager.h" @@ -142,7 +142,7 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, // Simple byval argument? Just add all the struct element types. Type *AgTy = cast<PointerType>(I->getType())->getElementType(); StructType *STy = cast<StructType>(AgTy); - Params.insert(Params.end(), STy->element_begin(), STy->element_end()); + llvm::append_range(Params, STy->elements()); ArgAttrVec.insert(ArgAttrVec.end(), STy->getNumElements(), AttributeSet()); ++NumByValArgsPromoted; @@ -153,10 +153,6 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, } else if (I->use_empty()) { // Dead argument (which are always marked as promotable) ++NumArgumentsDead; - - // There may be remaining metadata uses of the argument for things like - // llvm.dbg.value. Replace them with undef. - I->replaceAllUsesWith(UndefValue::get(I->getType())); } else { // Okay, this is being promoted. This means that the only uses are loads // or GEPs which are only used by loads @@ -164,13 +160,19 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, // In this table, we will track which indices are loaded from the argument // (where direct loads are tracked as no indices). ScalarizeTable &ArgIndices = ScalarizedElements[&*I]; - for (User *U : I->users()) { + for (User *U : make_early_inc_range(I->users())) { Instruction *UI = cast<Instruction>(U); Type *SrcTy; if (LoadInst *L = dyn_cast<LoadInst>(UI)) SrcTy = L->getType(); else SrcTy = cast<GetElementPtrInst>(UI)->getSourceElementType(); + // Skip dead GEPs and remove them. + if (isa<GetElementPtrInst>(UI) && UI->use_empty()) { + UI->eraseFromParent(); + continue; + } + IndicesVector Indices; Indices.reserve(UI->getNumOperands() - 1); // Since loads will only have a single operand, and GEPs only a single @@ -218,9 +220,11 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, Function *NF = Function::Create(NFTy, F->getLinkage(), F->getAddressSpace(), F->getName()); NF->copyAttributesFrom(F); + NF->copyMetadata(F, 0); - // Patch the pointer to LLVM function in debug info descriptor. - NF->setSubprogram(F->getSubprogram()); + // The new function will have the !dbg metadata copied from the original + // function. The original function may not be deleted, and dbg metadata need + // to be unique so we need to drop it. F->setSubprogram(nullptr); LLVM_DEBUG(dbgs() << "ARG PROMOTION: Promoting to:" << *NF << "\n" @@ -414,6 +418,11 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, continue; } + // There potentially are metadata uses for things like llvm.dbg.value. + // Replace them with undef, after handling the other regular uses. + auto RauwUndefMetadata = make_scope_exit( + [&]() { I->replaceAllUsesWith(UndefValue::get(I->getType())); }); + if (I->use_empty()) continue; @@ -433,6 +442,8 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, << "' in function '" << F->getName() << "'\n"); } else { GetElementPtrInst *GEP = cast<GetElementPtrInst>(I->user_back()); + assert(!GEP->use_empty() && + "GEPs without uses should be cleaned up already"); IndicesVector Operands; Operands.reserve(GEP->getNumIndices()); for (User::op_iterator II = GEP->idx_begin(), IE = GEP->idx_end(); @@ -465,7 +476,6 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, GEP->eraseFromParent(); } } - // Increment I2 past all of the arguments added for this promoted pointer. std::advance(I2, ArgIndices.size()); } @@ -672,11 +682,7 @@ static bool isSafeToPromoteArgument(Argument *Arg, Type *ByValTy, AAResults &AAR if (GEP->use_empty()) { // Dead GEP's cause trouble later. Just remove them if we run into // them. - GEP->eraseFromParent(); - // TODO: This runs the above loop over and over again for dead GEPs - // Couldn't we just do increment the UI iterator earlier and erase the - // use? - return isSafeToPromoteArgument(Arg, ByValTy, AAR, MaxElements); + continue; } if (!UpdateBaseTy(GEP->getSourceElementType())) @@ -816,14 +822,12 @@ static bool canPaddingBeAccessed(Argument *arg) { // Scan through the uses recursively to make sure the pointer is always used // sanely. - SmallVector<Value *, 16> WorkList; - WorkList.insert(WorkList.end(), arg->user_begin(), arg->user_end()); + SmallVector<Value *, 16> WorkList(arg->users()); while (!WorkList.empty()) { - Value *V = WorkList.back(); - WorkList.pop_back(); + Value *V = WorkList.pop_back_val(); if (isa<GetElementPtrInst>(V) || isa<PHINode>(V)) { if (PtrValues.insert(V).second) - WorkList.insert(WorkList.end(), V->user_begin(), V->user_end()); + llvm::append_range(WorkList, V->users()); } else if (StoreInst *Store = dyn_cast<StoreInst>(V)) { Stores.push_back(Store); } else if (!isa<LoadInst>(V)) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/Attributor.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/Attributor.cpp index f96dac5f3515..03ad45135001 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/Attributor.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/Attributor.cpp @@ -15,31 +15,47 @@ #include "llvm/Transforms/IPO/Attributor.h" +#include "llvm/ADT/GraphTraits.h" +#include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/TinyPtrVector.h" +#include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/LazyValueInfo.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/MustExecute.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/GlobalValue.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/NoFolder.h" #include "llvm/IR/Verifier.h" #include "llvm/InitializePasses.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugCounter.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/GraphWriter.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" #include <cassert> +#include <string> using namespace llvm; #define DEBUG_TYPE "attributor" +DEBUG_COUNTER(ManifestDBGCounter, "attributor-manifest", + "Determine what attributes are manifested in the IR"); + STATISTIC(NumFnDeleted, "Number of function deleted"); STATISTIC(NumFnWithExactDefinition, "Number of functions with exact definitions"); STATISTIC(NumFnWithoutExactDefinition, "Number of functions without exact definitions"); -STATISTIC(NumFnShallowWrapperCreated, "Number of shallow wrappers created"); +STATISTIC(NumFnShallowWrappersCreated, "Number of shallow wrappers created"); STATISTIC(NumAttributesTimedOut, "Number of abstract attributes timed out before fixpoint"); STATISTIC(NumAttributesValidFixpoint, @@ -61,6 +77,14 @@ static cl::opt<unsigned> MaxFixpointIterations("attributor-max-iterations", cl::Hidden, cl::desc("Maximal number of fixpoint iterations."), cl::init(32)); + +static cl::opt<unsigned, true> MaxInitializationChainLengthX( + "attributor-max-initialization-chain-length", cl::Hidden, + cl::desc( + "Maximal number of chained initializations (to avoid stack overflows)"), + cl::location(MaxInitializationChainLength), cl::init(1024)); +unsigned llvm::MaxInitializationChainLength; + static cl::opt<bool> VerifyMaxFixpointIterations( "attributor-max-iterations-verify", cl::Hidden, cl::desc("Verify that max-iterations is a tight bound for a fixpoint"), @@ -79,20 +103,52 @@ static cl::opt<bool> "wrappers for non-exact definitions."), cl::init(false)); +static cl::opt<bool> + AllowDeepWrapper("attributor-allow-deep-wrappers", cl::Hidden, + cl::desc("Allow the Attributor to use IP information " + "derived from non-exact functions via cloning"), + cl::init(false)); + +// These options can only used for debug builds. +#ifndef NDEBUG static cl::list<std::string> SeedAllowList("attributor-seed-allow-list", cl::Hidden, - cl::desc("Comma seperated list of attrbute names that are " + cl::desc("Comma seperated list of attribute names that are " "allowed to be seeded."), cl::ZeroOrMore, cl::CommaSeparated); +static cl::list<std::string> FunctionSeedAllowList( + "attributor-function-seed-allow-list", cl::Hidden, + cl::desc("Comma seperated list of function names that are " + "allowed to be seeded."), + cl::ZeroOrMore, cl::CommaSeparated); +#endif + +static cl::opt<bool> + DumpDepGraph("attributor-dump-dep-graph", cl::Hidden, + cl::desc("Dump the dependency graph to dot files."), + cl::init(false)); + +static cl::opt<std::string> DepGraphDotFileNamePrefix( + "attributor-depgraph-dot-filename-prefix", cl::Hidden, + cl::desc("The prefix used for the CallGraph dot file names.")); + +static cl::opt<bool> ViewDepGraph("attributor-view-dep-graph", cl::Hidden, + cl::desc("View the dependency graph."), + cl::init(false)); + +static cl::opt<bool> PrintDependencies("attributor-print-dep", cl::Hidden, + cl::desc("Print attribute dependencies"), + cl::init(false)); + /// Logic operators for the change status enum class. /// ///{ -ChangeStatus llvm::operator|(ChangeStatus l, ChangeStatus r) { - return l == ChangeStatus::CHANGED ? l : r; +ChangeStatus llvm::operator|(ChangeStatus L, ChangeStatus R) { + return L == ChangeStatus::CHANGED ? L : R; } -ChangeStatus llvm::operator&(ChangeStatus l, ChangeStatus r) { - return l == ChangeStatus::UNCHANGED ? l : r; +ChangeStatus llvm::operator&(ChangeStatus L, ChangeStatus R) { + return L == ChangeStatus::UNCHANGED ? L : R; } ///} @@ -145,7 +201,7 @@ Argument *IRPosition::getAssociatedArgument() const { // Not an Argument and no argument number means this is not a call site // argument, thus we cannot find a callback argument to return. - int ArgNo = getArgNo(); + int ArgNo = getCallSiteArgNo(); if (ArgNo < 0) return nullptr; @@ -273,6 +329,13 @@ const IRPosition SubsumingPositionIterator::SubsumingPositionIterator(const IRPosition &IRP) { IRPositions.emplace_back(IRP); + // Helper to determine if operand bundles on a call site are benin or + // potentially problematic. We handle only llvm.assume for now. + auto CanIgnoreOperandBundles = [](const CallBase &CB) { + return (isa<IntrinsicInst>(CB) && + cast<IntrinsicInst>(CB).getIntrinsicID() == Intrinsic ::assume); + }; + const auto *CB = dyn_cast<CallBase>(&IRP.getAnchorValue()); switch (IRP.getPositionKind()) { case IRPosition::IRP_INVALID: @@ -287,7 +350,7 @@ SubsumingPositionIterator::SubsumingPositionIterator(const IRPosition &IRP) { assert(CB && "Expected call site!"); // TODO: We need to look at the operand bundles similar to the redirection // in CallBase. - if (!CB->hasOperandBundles()) + if (!CB->hasOperandBundles() || CanIgnoreOperandBundles(*CB)) if (const Function *Callee = CB->getCalledFunction()) IRPositions.emplace_back(IRPosition::function(*Callee)); return; @@ -295,7 +358,7 @@ SubsumingPositionIterator::SubsumingPositionIterator(const IRPosition &IRP) { assert(CB && "Expected call site!"); // TODO: We need to look at the operand bundles similar to the redirection // in CallBase. - if (!CB->hasOperandBundles()) { + if (!CB->hasOperandBundles() || CanIgnoreOperandBundles(*CB)) { if (const Function *Callee = CB->getCalledFunction()) { IRPositions.emplace_back(IRPosition::returned(*Callee)); IRPositions.emplace_back(IRPosition::function(*Callee)); @@ -312,16 +375,16 @@ SubsumingPositionIterator::SubsumingPositionIterator(const IRPosition &IRP) { IRPositions.emplace_back(IRPosition::callsite_function(*CB)); return; case IRPosition::IRP_CALL_SITE_ARGUMENT: { - int ArgNo = IRP.getArgNo(); - assert(CB && ArgNo >= 0 && "Expected call site!"); + assert(CB && "Expected call site!"); // TODO: We need to look at the operand bundles similar to the redirection // in CallBase. - if (!CB->hasOperandBundles()) { + if (!CB->hasOperandBundles() || CanIgnoreOperandBundles(*CB)) { const Function *Callee = CB->getCalledFunction(); - if (Callee && Callee->arg_size() > unsigned(ArgNo)) - IRPositions.emplace_back(IRPosition::argument(*Callee->getArg(ArgNo))); - if (Callee) + if (Callee) { + if (Argument *Arg = IRP.getAssociatedArgument()) + IRPositions.emplace_back(IRPosition::argument(*Arg)); IRPositions.emplace_back(IRPosition::function(*Callee)); + } } IRPositions.emplace_back(IRPosition::value(IRP.getAssociatedValue())); return; @@ -459,7 +522,7 @@ void IRPosition::verify() { "Expected call base argument operand for a 'call site argument' " "position"); assert(cast<CallBase>(U->getUser())->getArgOperandNo(U) == - unsigned(getArgNo()) && + unsigned(getCallSiteArgNo()) && "Argument number mismatch!"); assert(U->get() == &getAssociatedValue() && "Associated value mismatch!"); return; @@ -498,8 +561,10 @@ Attributor::getAssumedConstant(const Value &V, const AbstractAttribute &AA, Attributor::~Attributor() { // The abstract attributes are allocated via the BumpPtrAllocator Allocator, // thus we cannot delete them. We can, and want to, destruct them though. - for (AbstractAttribute *AA : AllAbstractAttributes) + for (auto &DepAA : DG.SyntheticRoot.Deps) { + AbstractAttribute *AA = cast<AbstractAttribute>(DepAA.getPointer()); AA->~AbstractAttribute(); + } } bool Attributor::isAssumedDead(const AbstractAttribute &AA, @@ -864,13 +929,15 @@ bool Attributor::checkForAllInstructions(function_ref<bool(Instruction &)> Pred, // TODO: use the function scope once we have call site AAReturnedValues. const IRPosition &QueryIRP = IRPosition::function(*AssociatedFunction); - const auto &LivenessAA = - getAAFor<AAIsDead>(QueryingAA, QueryIRP, /* TrackDependence */ false); + const auto *LivenessAA = + CheckBBLivenessOnly ? nullptr + : &(getAAFor<AAIsDead>(QueryingAA, QueryIRP, + /* TrackDependence */ false)); auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(*AssociatedFunction); if (!checkForAllInstructionsImpl(this, OpcodeInstMap, Pred, &QueryingAA, - &LivenessAA, Opcodes, CheckBBLivenessOnly)) + LivenessAA, Opcodes, CheckBBLivenessOnly)) return false; return true; @@ -903,8 +970,9 @@ bool Attributor::checkForAllReadWriteInstructions( } void Attributor::runTillFixpoint() { + TimeTraceScope TimeScope("Attributor::runTillFixpoint"); LLVM_DEBUG(dbgs() << "[Attributor] Identified and initialized " - << AllAbstractAttributes.size() + << DG.SyntheticRoot.Deps.size() << " abstract attributes.\n"); // Now that all abstract attributes are collected and initialized we start @@ -914,11 +982,11 @@ void Attributor::runTillFixpoint() { SmallVector<AbstractAttribute *, 32> ChangedAAs; SetVector<AbstractAttribute *> Worklist, InvalidAAs; - Worklist.insert(AllAbstractAttributes.begin(), AllAbstractAttributes.end()); + Worklist.insert(DG.SyntheticRoot.begin(), DG.SyntheticRoot.end()); do { // Remember the size to determine new attributes. - size_t NumAAs = AllAbstractAttributes.size(); + size_t NumAAs = DG.SyntheticRoot.Deps.size(); LLVM_DEBUG(dbgs() << "\n\n[Attributor] #Iteration: " << IterationCounter << ", Worklist size: " << Worklist.size() << "\n"); @@ -935,7 +1003,7 @@ void Attributor::runTillFixpoint() { while (!InvalidAA->Deps.empty()) { const auto &Dep = InvalidAA->Deps.back(); InvalidAA->Deps.pop_back(); - AbstractAttribute *DepAA = Dep.getPointer(); + AbstractAttribute *DepAA = cast<AbstractAttribute>(Dep.getPointer()); if (Dep.getInt() == unsigned(DepClassTy::OPTIONAL)) { Worklist.insert(DepAA); continue; @@ -953,7 +1021,8 @@ void Attributor::runTillFixpoint() { // changed to the work list. for (AbstractAttribute *ChangedAA : ChangedAAs) while (!ChangedAA->Deps.empty()) { - Worklist.insert(ChangedAA->Deps.back().getPointer()); + Worklist.insert( + cast<AbstractAttribute>(ChangedAA->Deps.back().getPointer())); ChangedAA->Deps.pop_back(); } @@ -981,8 +1050,8 @@ void Attributor::runTillFixpoint() { // Add attributes to the changed set if they have been created in the last // iteration. - ChangedAAs.append(AllAbstractAttributes.begin() + NumAAs, - AllAbstractAttributes.end()); + ChangedAAs.append(DG.SyntheticRoot.begin() + NumAAs, + DG.SyntheticRoot.end()); // Reset the work list and repopulate with the changed abstract attributes. // Note that dependent ones are added above. @@ -1015,7 +1084,8 @@ void Attributor::runTillFixpoint() { } while (!ChangedAA->Deps.empty()) { - ChangedAAs.push_back(ChangedAA->Deps.back().getPointer()); + ChangedAAs.push_back( + cast<AbstractAttribute>(ChangedAA->Deps.back().getPointer())); ChangedAA->Deps.pop_back(); } } @@ -1037,12 +1107,14 @@ void Attributor::runTillFixpoint() { } ChangeStatus Attributor::manifestAttributes() { - size_t NumFinalAAs = AllAbstractAttributes.size(); + TimeTraceScope TimeScope("Attributor::manifestAttributes"); + size_t NumFinalAAs = DG.SyntheticRoot.Deps.size(); unsigned NumManifested = 0; unsigned NumAtFixpoint = 0; ChangeStatus ManifestChange = ChangeStatus::UNCHANGED; - for (AbstractAttribute *AA : AllAbstractAttributes) { + for (auto &DepAA : DG.SyntheticRoot.Deps) { + AbstractAttribute *AA = cast<AbstractAttribute>(DepAA.getPointer()); AbstractState &State = AA->getState(); // If there is not already a fixpoint reached, we can now take the @@ -1059,6 +1131,10 @@ ChangeStatus Attributor::manifestAttributes() { // Skip dead code. if (isAssumedDead(*AA, nullptr, /* CheckBBLivenessOnly */ true)) continue; + // Check if the manifest debug counter that allows skipping manifestation of + // AAs + if (!DebugCounter::shouldExecute(ManifestDBGCounter)) + continue; // Manifest the state and record if we changed the IR. ChangeStatus LocalChange = AA->manifest(*this); if (LocalChange == ChangeStatus::CHANGED && AreStatisticsEnabled()) @@ -1082,11 +1158,14 @@ ChangeStatus Attributor::manifestAttributes() { NumAttributesValidFixpoint += NumAtFixpoint; (void)NumFinalAAs; - if (NumFinalAAs != AllAbstractAttributes.size()) { - for (unsigned u = NumFinalAAs; u < AllAbstractAttributes.size(); ++u) - errs() << "Unexpected abstract attribute: " << *AllAbstractAttributes[u] + if (NumFinalAAs != DG.SyntheticRoot.Deps.size()) { + for (unsigned u = NumFinalAAs; u < DG.SyntheticRoot.Deps.size(); ++u) + errs() << "Unexpected abstract attribute: " + << cast<AbstractAttribute>(DG.SyntheticRoot.Deps[u].getPointer()) << " :: " - << AllAbstractAttributes[u]->getIRPosition().getAssociatedValue() + << cast<AbstractAttribute>(DG.SyntheticRoot.Deps[u].getPointer()) + ->getIRPosition() + .getAssociatedValue() << "\n"; llvm_unreachable("Expected the final number of abstract attributes to " "remain unchanged!"); @@ -1094,7 +1173,50 @@ ChangeStatus Attributor::manifestAttributes() { return ManifestChange; } +void Attributor::identifyDeadInternalFunctions() { + // Identify dead internal functions and delete them. This happens outside + // the other fixpoint analysis as we might treat potentially dead functions + // as live to lower the number of iterations. If they happen to be dead, the + // below fixpoint loop will identify and eliminate them. + SmallVector<Function *, 8> InternalFns; + for (Function *F : Functions) + if (F->hasLocalLinkage()) + InternalFns.push_back(F); + + SmallPtrSet<Function *, 8> LiveInternalFns; + bool FoundLiveInternal = true; + while (FoundLiveInternal) { + FoundLiveInternal = false; + for (unsigned u = 0, e = InternalFns.size(); u < e; ++u) { + Function *F = InternalFns[u]; + if (!F) + continue; + + bool AllCallSitesKnown; + if (checkForAllCallSites( + [&](AbstractCallSite ACS) { + Function *Callee = ACS.getInstruction()->getFunction(); + return ToBeDeletedFunctions.count(Callee) || + (Functions.count(Callee) && Callee->hasLocalLinkage() && + !LiveInternalFns.count(Callee)); + }, + *F, true, nullptr, AllCallSitesKnown)) { + continue; + } + + LiveInternalFns.insert(F); + InternalFns[u] = nullptr; + FoundLiveInternal = true; + } + } + + for (unsigned u = 0, e = InternalFns.size(); u < e; ++u) + if (Function *F = InternalFns[u]) + ToBeDeletedFunctions.insert(F); +} + ChangeStatus Attributor::cleanupIR() { + TimeTraceScope TimeScope("Attributor::cleanupIR"); // Delete stuff at the end to avoid invalid references and a nice order. LLVM_DEBUG(dbgs() << "\n[Attributor] Delete at least " << ToBeDeletedFunctions.size() << " functions and " @@ -1205,50 +1327,45 @@ ChangeStatus Attributor::cleanupIR() { DetatchDeadBlocks(ToBeDeletedBBs, nullptr); } - // Identify dead internal functions and delete them. This happens outside - // the other fixpoint analysis as we might treat potentially dead functions - // as live to lower the number of iterations. If they happen to be dead, the - // below fixpoint loop will identify and eliminate them. - SmallVector<Function *, 8> InternalFns; - for (Function *F : Functions) - if (F->hasLocalLinkage()) - InternalFns.push_back(F); - - bool FoundDeadFn = true; - while (FoundDeadFn) { - FoundDeadFn = false; - for (unsigned u = 0, e = InternalFns.size(); u < e; ++u) { - Function *F = InternalFns[u]; - if (!F) - continue; - - bool AllCallSitesKnown; - if (!checkForAllCallSites( - [this](AbstractCallSite ACS) { - return ToBeDeletedFunctions.count( - ACS.getInstruction()->getFunction()); - }, - *F, true, nullptr, AllCallSitesKnown)) - continue; - - ToBeDeletedFunctions.insert(F); - InternalFns[u] = nullptr; - FoundDeadFn = true; - } - } + identifyDeadInternalFunctions(); // Rewrite the functions as requested during manifest. ChangeStatus ManifestChange = rewriteFunctionSignatures(CGModifiedFunctions); for (Function *Fn : CGModifiedFunctions) - CGUpdater.reanalyzeFunction(*Fn); + if (!ToBeDeletedFunctions.count(Fn)) + CGUpdater.reanalyzeFunction(*Fn); - for (Function *Fn : ToBeDeletedFunctions) + for (Function *Fn : ToBeDeletedFunctions) { + if (!Functions.count(Fn)) + continue; CGUpdater.removeFunction(*Fn); + } + + if (!ToBeChangedUses.empty()) + ManifestChange = ChangeStatus::CHANGED; + + if (!ToBeChangedToUnreachableInsts.empty()) + ManifestChange = ChangeStatus::CHANGED; + + if (!ToBeDeletedFunctions.empty()) + ManifestChange = ChangeStatus::CHANGED; + + if (!ToBeDeletedBlocks.empty()) + ManifestChange = ChangeStatus::CHANGED; + + if (!ToBeDeletedInsts.empty()) + ManifestChange = ChangeStatus::CHANGED; + + if (!InvokeWithDeadSuccessor.empty()) + ManifestChange = ChangeStatus::CHANGED; + + if (!DeadInsts.empty()) + ManifestChange = ChangeStatus::CHANGED; NumFnDeleted += ToBeDeletedFunctions.size(); - LLVM_DEBUG(dbgs() << "[Attributor] Deleted " << NumFnDeleted + LLVM_DEBUG(dbgs() << "[Attributor] Deleted " << ToBeDeletedFunctions.size() << " functions after manifest.\n"); #ifdef EXPENSIVE_CHECKS @@ -1263,14 +1380,37 @@ ChangeStatus Attributor::cleanupIR() { } ChangeStatus Attributor::run() { - SeedingPeriod = false; + TimeTraceScope TimeScope("Attributor::run"); + + Phase = AttributorPhase::UPDATE; runTillFixpoint(); + + // dump graphs on demand + if (DumpDepGraph) + DG.dumpGraph(); + + if (ViewDepGraph) + DG.viewGraph(); + + if (PrintDependencies) + DG.print(); + + Phase = AttributorPhase::MANIFEST; ChangeStatus ManifestChange = manifestAttributes(); + + Phase = AttributorPhase::CLEANUP; ChangeStatus CleanupChange = cleanupIR(); + return ManifestChange | CleanupChange; } ChangeStatus Attributor::updateAA(AbstractAttribute &AA) { + TimeTraceScope TimeScope( + AA.getName() + std::to_string(AA.getIRPosition().getPositionKind()) + + "::updateAA"); + assert(Phase == AttributorPhase::UPDATE && + "We can update AA only in the update stage!"); + // Use a new dependence vector for this update. DependenceVector DV; DependenceStack.push_back(&DV); @@ -1298,23 +1438,7 @@ ChangeStatus Attributor::updateAA(AbstractAttribute &AA) { return CS; } -/// Create a shallow wrapper for \p F such that \p F has internal linkage -/// afterwards. It also sets the original \p F 's name to anonymous -/// -/// A wrapper is a function with the same type (and attributes) as \p F -/// that will only call \p F and return the result, if any. -/// -/// Assuming the declaration of looks like: -/// rty F(aty0 arg0, ..., atyN argN); -/// -/// The wrapper will then look as follows: -/// rty wrapper(aty0 arg0, ..., atyN argN) { -/// return F(arg0, ..., argN); -/// } -/// -static void createShallowWrapper(Function &F) { - assert(AllowShallowWrappers && - "Cannot create a wrapper if it is not allowed!"); +void Attributor::createShallowWrapper(Function &F) { assert(!F.isDeclaration() && "Cannot create a wrapper around a declaration!"); Module &M = *F.getParent(); @@ -1347,7 +1471,7 @@ static void createShallowWrapper(Function &F) { BasicBlock *EntryBB = BasicBlock::Create(Ctx, "entry", Wrapper); SmallVector<Value *, 8> Args; - auto FArgIt = F.arg_begin(); + Argument *FArgIt = F.arg_begin(); for (Argument &Arg : Wrapper->args()) { Args.push_back(&Arg); Arg.setName((FArgIt++)->getName()); @@ -1358,7 +1482,57 @@ static void createShallowWrapper(Function &F) { CI->addAttribute(AttributeList::FunctionIndex, Attribute::NoInline); ReturnInst::Create(Ctx, CI->getType()->isVoidTy() ? nullptr : CI, EntryBB); - NumFnShallowWrapperCreated++; + NumFnShallowWrappersCreated++; +} + +/// Make another copy of the function \p F such that the copied version has +/// internal linkage afterwards and can be analysed. Then we replace all uses +/// of the original function to the copied one +/// +/// Only non-exactly defined functions that have `linkonce_odr` or `weak_odr` +/// linkage can be internalized because these linkages guarantee that other +/// definitions with the same name have the same semantics as this one +/// +static Function *internalizeFunction(Function &F) { + assert(AllowDeepWrapper && "Cannot create a copy if not allowed."); + assert(!F.isDeclaration() && !F.hasExactDefinition() && + !GlobalValue::isInterposableLinkage(F.getLinkage()) && + "Trying to internalize function which cannot be internalized."); + + Module &M = *F.getParent(); + FunctionType *FnTy = F.getFunctionType(); + + // create a copy of the current function + Function *Copied = Function::Create(FnTy, F.getLinkage(), F.getAddressSpace(), + F.getName() + ".internalized"); + ValueToValueMapTy VMap; + auto *NewFArgIt = Copied->arg_begin(); + for (auto &Arg : F.args()) { + auto ArgName = Arg.getName(); + NewFArgIt->setName(ArgName); + VMap[&Arg] = &(*NewFArgIt++); + } + SmallVector<ReturnInst *, 8> Returns; + + // Copy the body of the original function to the new one + CloneFunctionInto(Copied, &F, VMap, /* ModuleLevelChanges */ false, Returns); + + // Set the linakage and visibility late as CloneFunctionInto has some implicit + // requirements. + Copied->setVisibility(GlobalValue::DefaultVisibility); + Copied->setLinkage(GlobalValue::PrivateLinkage); + + // Copy metadata + SmallVector<std::pair<unsigned, MDNode *>, 1> MDs; + F.getAllMetadata(MDs); + for (auto MDIt : MDs) + Copied->addMetadata(MDIt.first, *MDIt.second); + + M.getFunctionList().insert(F.getIterator(), Copied); + F.replaceAllUsesWith(Copied); + Copied->setDSOLocal(true); + + return Copied; } bool Attributor::isValidFunctionSignatureRewrite( @@ -1461,9 +1635,17 @@ bool Attributor::registerFunctionSignatureRewrite( } bool Attributor::shouldSeedAttribute(AbstractAttribute &AA) { - if (SeedAllowList.size() == 0) - return true; - return std::count(SeedAllowList.begin(), SeedAllowList.end(), AA.getName()); + bool Result = true; +#ifndef NDEBUG + if (SeedAllowList.size() != 0) + Result = + std::count(SeedAllowList.begin(), SeedAllowList.end(), AA.getName()); + Function *Fn = AA.getAnchorScope(); + if (FunctionSeedAllowList.size() != 0 && Fn) + Result &= std::count(FunctionSeedAllowList.begin(), + FunctionSeedAllowList.end(), Fn->getName()); +#endif + return Result; } ChangeStatus Attributor::rewriteFunctionSignatures( @@ -1474,7 +1656,7 @@ ChangeStatus Attributor::rewriteFunctionSignatures( Function *OldFn = It.getFirst(); // Deleted functions do not require rewrites. - if (ToBeDeletedFunctions.count(OldFn)) + if (!Functions.count(OldFn) || ToBeDeletedFunctions.count(OldFn)) continue; const SmallVectorImpl<std::unique_ptr<ArgumentReplacementInfo>> &ARIs = @@ -1617,8 +1799,8 @@ ChangeStatus Attributor::rewriteFunctionSignatures( assert(Success && "Assumed call site replacement to succeed!"); // Rewire the arguments. - auto OldFnArgIt = OldFn->arg_begin(); - auto NewFnArgIt = NewFn->arg_begin(); + Argument *OldFnArgIt = OldFn->arg_begin(); + Argument *NewFnArgIt = NewFn->arg_begin(); for (unsigned OldArgNum = 0; OldArgNum < ARIs.size(); ++OldArgNum, ++OldFnArgIt) { if (const std::unique_ptr<ArgumentReplacementInfo> &ARI = @@ -1727,6 +1909,10 @@ void InformationCache::initializeInformationCache(const Function &CF, InlineableFunctions.insert(&F); } +AAResults *InformationCache::getAAResultsForFunction(const Function &F) { + return AG.getAnalysis<AAManager>(F); +} + InformationCache::FunctionInfo::~FunctionInfo() { // The instruction vectors are allocated using a BumpPtrAllocator, we need to // manually destroy them. @@ -1827,6 +2013,9 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { // Every function might be simplified. getOrCreateAAFor<AAValueSimplify>(RetPos); + // Every returned value might be marked noundef. + getOrCreateAAFor<AANoUndef>(RetPos); + if (ReturnType->isPointerTy()) { // Every function with pointer return type might be marked align. @@ -1853,6 +2042,9 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { // Every argument might be dead. getOrCreateAAFor<AAIsDead>(ArgPos); + // Every argument might be marked noundef. + getOrCreateAAFor<AANoUndef>(ArgPos); + if (Arg.getType()->isPointerTy()) { // Every argument with pointer type might be marked nonnull. getOrCreateAAFor<AANonNull>(ArgPos); @@ -1920,6 +2112,9 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { // Call site argument might be simplified. getOrCreateAAFor<AAValueSimplify>(CBArgPos); + // Every call site argument might be marked "noundef". + getOrCreateAAFor<AANoUndef>(CBArgPos); + if (!CB.getArgOperand(I)->getType()->isPointerTy()) continue; @@ -2005,7 +2200,8 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, IRPosition::Kind AP) { raw_ostream &llvm::operator<<(raw_ostream &OS, const IRPosition &Pos) { const Value &AV = Pos.getAssociatedValue(); return OS << "{" << Pos.getPositionKind() << ":" << AV.getName() << " [" - << Pos.getAnchorValue().getName() << "@" << Pos.getArgNo() << "]}"; + << Pos.getAnchorValue().getName() << "@" << Pos.getCallSiteArgNo() + << "]}"; } raw_ostream &llvm::operator<<(raw_ostream &OS, const IntegerRangeState &S) { @@ -2027,9 +2223,48 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, const AbstractAttribute &AA) { return OS; } +raw_ostream &llvm::operator<<(raw_ostream &OS, + const PotentialConstantIntValuesState &S) { + OS << "set-state(< {"; + if (!S.isValidState()) + OS << "full-set"; + else { + for (auto &it : S.getAssumedSet()) + OS << it << ", "; + if (S.undefIsContained()) + OS << "undef "; + } + OS << "} >)"; + + return OS; +} + void AbstractAttribute::print(raw_ostream &OS) const { - OS << "[P: " << getIRPosition() << "][" << getAsStr() << "][S: " << getState() - << "]"; + OS << "["; + OS << getName(); + OS << "] for CtxI "; + + if (auto *I = getCtxI()) { + OS << "'"; + I->print(OS); + OS << "'"; + } else + OS << "<<null inst>>"; + + OS << " at position " << getIRPosition() << " with state " << getAsStr() + << '\n'; +} + +void AbstractAttribute::printWithDeps(raw_ostream &OS) const { + print(OS); + + for (const auto &DepAA : Deps) { + auto *AA = DepAA.getPointer(); + OS << " updates "; + AA->print(OS); + } + + OS << '\n'; } ///} @@ -2055,7 +2290,31 @@ static bool runAttributorOnFunctions(InformationCache &InfoCache, if (AllowShallowWrappers) for (Function *F : Functions) if (!A.isFunctionIPOAmendable(*F)) - createShallowWrapper(*F); + Attributor::createShallowWrapper(*F); + + // Internalize non-exact functions + // TODO: for now we eagerly internalize functions without calculating the + // cost, we need a cost interface to determine whether internalizing + // a function is "benefitial" + if (AllowDeepWrapper) { + unsigned FunSize = Functions.size(); + for (unsigned u = 0; u < FunSize; u++) { + Function *F = Functions[u]; + if (!F->isDeclaration() && !F->isDefinitionExact() && F->getNumUses() && + !GlobalValue::isInterposableLinkage(F->getLinkage())) { + Function *NewF = internalizeFunction(*F); + Functions.insert(NewF); + + // Update call graph + CGUpdater.replaceFunctionWith(*F, *NewF); + for (const Use &U : NewF->uses()) + if (CallBase *CB = dyn_cast<CallBase>(U.getUser())) { + auto *CallerF = CB->getCaller(); + CGUpdater.reanalyzeFunction(*CallerF); + } + } + } + } for (Function *F : Functions) { if (F->hasExactDefinition()) @@ -2064,8 +2323,8 @@ static bool runAttributorOnFunctions(InformationCache &InfoCache, NumFnWithoutExactDefinition++; // We look at internal functions only on-demand but if any use is not a - // direct call or outside the current set of analyzed functions, we have to - // do it eagerly. + // direct call or outside the current set of analyzed functions, we have + // to do it eagerly. if (F->hasLocalLinkage()) { if (llvm::all_of(F->uses(), [&Functions](const Use &U) { const auto *CB = dyn_cast<CallBase>(U.getUser()); @@ -2081,11 +2340,41 @@ static bool runAttributorOnFunctions(InformationCache &InfoCache, } ChangeStatus Changed = A.run(); + LLVM_DEBUG(dbgs() << "[Attributor] Done with " << Functions.size() << " functions, result: " << Changed << ".\n"); return Changed == ChangeStatus::CHANGED; } +void AADepGraph::viewGraph() { llvm::ViewGraph(this, "Dependency Graph"); } + +void AADepGraph::dumpGraph() { + static std::atomic<int> CallTimes; + std::string Prefix; + + if (!DepGraphDotFileNamePrefix.empty()) + Prefix = DepGraphDotFileNamePrefix; + else + Prefix = "dep_graph"; + std::string Filename = + Prefix + "_" + std::to_string(CallTimes.load()) + ".dot"; + + outs() << "Dependency graph dump to " << Filename << ".\n"; + + std::error_code EC; + + raw_fd_ostream File(Filename, EC, sys::fs::OF_Text); + if (!EC) + llvm::WriteGraph(File, this); + + CallTimes++; +} + +void AADepGraph::print() { + for (auto DepAA : SyntheticRoot.Deps) + cast<AbstractAttribute>(DepAA.getPointer())->printWithDeps(outs()); +} + PreservedAnalyses AttributorPass::run(Module &M, ModuleAnalysisManager &AM) { FunctionAnalysisManager &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); @@ -2127,11 +2416,58 @@ PreservedAnalyses AttributorCGSCCPass::run(LazyCallGraph::SCC &C, InformationCache InfoCache(M, AG, Allocator, /* CGSCC */ &Functions); if (runAttributorOnFunctions(InfoCache, Functions, AG, CGUpdater)) { // FIXME: Think about passes we will preserve and add them here. - return PreservedAnalyses::none(); + PreservedAnalyses PA; + PA.preserve<FunctionAnalysisManagerCGSCCProxy>(); + return PA; } return PreservedAnalyses::all(); } +namespace llvm { + +template <> struct GraphTraits<AADepGraphNode *> { + using NodeRef = AADepGraphNode *; + using DepTy = PointerIntPair<AADepGraphNode *, 1>; + using EdgeRef = PointerIntPair<AADepGraphNode *, 1>; + + static NodeRef getEntryNode(AADepGraphNode *DGN) { return DGN; } + static NodeRef DepGetVal(DepTy &DT) { return DT.getPointer(); } + + using ChildIteratorType = + mapped_iterator<TinyPtrVector<DepTy>::iterator, decltype(&DepGetVal)>; + using ChildEdgeIteratorType = TinyPtrVector<DepTy>::iterator; + + static ChildIteratorType child_begin(NodeRef N) { return N->child_begin(); } + + static ChildIteratorType child_end(NodeRef N) { return N->child_end(); } +}; + +template <> +struct GraphTraits<AADepGraph *> : public GraphTraits<AADepGraphNode *> { + static NodeRef getEntryNode(AADepGraph *DG) { return DG->GetEntryNode(); } + + using nodes_iterator = + mapped_iterator<TinyPtrVector<DepTy>::iterator, decltype(&DepGetVal)>; + + static nodes_iterator nodes_begin(AADepGraph *DG) { return DG->begin(); } + + static nodes_iterator nodes_end(AADepGraph *DG) { return DG->end(); } +}; + +template <> struct DOTGraphTraits<AADepGraph *> : public DefaultDOTGraphTraits { + DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {} + + static std::string getNodeLabel(const AADepGraphNode *Node, + const AADepGraph *DG) { + std::string AAString; + raw_string_ostream O(AAString); + Node->print(O); + return AAString; + } +}; + +} // end namespace llvm + namespace { struct AttributorLegacyPass : public ModulePass { @@ -2163,7 +2499,6 @@ struct AttributorLegacyPass : public ModulePass { }; struct AttributorCGSCCLegacyPass : public CallGraphSCCPass { - CallGraphUpdater CGUpdater; static char ID; AttributorCGSCCLegacyPass() : CallGraphSCCPass(ID) { @@ -2185,6 +2520,7 @@ struct AttributorCGSCCLegacyPass : public CallGraphSCCPass { AnalysisGetter AG; CallGraph &CG = const_cast<CallGraph &>(SCC.getCallGraph()); + CallGraphUpdater CGUpdater; CGUpdater.initialize(CG, SCC); Module &M = *Functions.back()->getParent(); BumpPtrAllocator Allocator; @@ -2192,8 +2528,6 @@ struct AttributorCGSCCLegacyPass : public CallGraphSCCPass { return runAttributorOnFunctions(InfoCache, Functions, AG, CGUpdater); } - bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); } - void getAnalysisUsage(AnalysisUsage &AU) const override { // FIXME: Think about passes we will preserve and add them here. AU.addRequired<TargetLibraryInfoWrapperPass>(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/AttributorAttributes.cpp index 7e9fd61eeb41..d6127a8df628 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -13,15 +13,20 @@ #include "llvm/Transforms/IPO/Attributor.h" +#include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumeBundleQueries.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/LazyValueInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/NoFolder.h" #include "llvm/Support/CommandLine.h" @@ -42,6 +47,16 @@ static cl::opt<bool> ManifestInternal( static cl::opt<int> MaxHeapToStackSize("max-heap-to-stack-size", cl::init(128), cl::Hidden); +template <> +unsigned llvm::PotentialConstantIntValuesState::MaxPotentialValues = 0; + +static cl::opt<unsigned, true> MaxPotentialValues( + "attributor-max-potential-values", cl::Hidden, + cl::desc("Maximum number of potential values to be " + "tracked for each position."), + cl::location(llvm::PotentialConstantIntValuesState::MaxPotentialValues), + cl::init(7)); + STATISTIC(NumAAs, "Number of abstract attributes created"); // Some helper macros to deal with statistics tracking. @@ -117,6 +132,8 @@ PIPE_OPERATOR(AAMemoryLocation) PIPE_OPERATOR(AAValueConstantRange) PIPE_OPERATOR(AAPrivatizablePtr) PIPE_OPERATOR(AAUndefinedBehavior) +PIPE_OPERATOR(AAPotentialValues) +PIPE_OPERATOR(AANoUndef) #undef PIPE_OPERATOR } // namespace llvm @@ -435,7 +452,7 @@ static void clampReturnedValueStates(Attributor &A, const AAType &QueryingAA, const AAType &AA = A.getAAFor<AAType>(QueryingAA, RVPos); LLVM_DEBUG(dbgs() << "[Attributor] RV: " << RV << " AA: " << AA.getAsStr() << " @ " << RVPos << "\n"); - const StateType &AAS = static_cast<const StateType &>(AA.getState()); + const StateType &AAS = AA.getState(); if (T.hasValue()) *T &= AAS; else @@ -485,7 +502,7 @@ static void clampCallSiteArgumentStates(Attributor &A, const AAType &QueryingAA, Optional<StateType> T; // The argument number which is also the call site argument number. - unsigned ArgNo = QueryingAA.getIRPosition().getArgNo(); + unsigned ArgNo = QueryingAA.getIRPosition().getCallSiteArgNo(); auto CallSiteCheck = [&](AbstractCallSite ACS) { const IRPosition &ACSArgPos = IRPosition::callsite_argument(ACS, ArgNo); @@ -497,7 +514,7 @@ static void clampCallSiteArgumentStates(Attributor &A, const AAType &QueryingAA, const AAType &AA = A.getAAFor<AAType>(QueryingAA, ACSArgPos); LLVM_DEBUG(dbgs() << "[Attributor] ACS: " << *ACS.getInstruction() << " AA: " << AA.getAsStr() << " @" << ACSArgPos << "\n"); - const StateType &AAS = static_cast<const StateType &>(AA.getState()); + const StateType &AAS = AA.getState(); if (T.hasValue()) *T &= AAS; else @@ -554,8 +571,7 @@ struct AACallSiteReturnedFromReturned : public BaseType { IRPosition FnPos = IRPosition::returned(*AssociatedFunction); const AAType &AA = A.getAAFor<AAType>(*this, FnPos); - return clampStateAndIndicateChange( - S, static_cast<const StateType &>(AA.getState())); + return clampStateAndIndicateChange(S, AA.getState()); } }; @@ -722,7 +738,7 @@ struct AANoUnwindCallSite final : AANoUnwindImpl { void initialize(Attributor &A) override { AANoUnwindImpl::initialize(A); Function *F = getAssociatedFunction(); - if (!F) + if (!F || F->isDeclaration()) indicatePessimisticFixpoint(); } @@ -735,9 +751,7 @@ struct AANoUnwindCallSite final : AANoUnwindImpl { Function *F = getAssociatedFunction(); const IRPosition &FnPos = IRPosition::function(*F); auto &FnAA = A.getAAFor<AANoUnwind>(*this, FnPos); - return clampStateAndIndicateChange( - getState(), - static_cast<const AANoUnwind::StateType &>(FnAA.getState())); + return clampStateAndIndicateChange(getState(), FnAA.getState()); } /// See AbstractAttribute::trackStatistics() @@ -783,7 +797,7 @@ public: ReturnedValues.clear(); Function *F = getAssociatedFunction(); - if (!F) { + if (!F || F->isDeclaration()) { indicatePessimisticFixpoint(); return; } @@ -1052,9 +1066,10 @@ ChangeStatus AAReturnedValuesImpl::updateImpl(Attributor &A) { // map, NewRVsMap. decltype(ReturnedValues) NewRVsMap; - auto HandleReturnValue = [&](Value *RV, SmallSetVector<ReturnInst *, 4> &RIs) { - LLVM_DEBUG(dbgs() << "[AAReturnedValues] Returned value: " << *RV - << " by #" << RIs.size() << " RIs\n"); + auto HandleReturnValue = [&](Value *RV, + SmallSetVector<ReturnInst *, 4> &RIs) { + LLVM_DEBUG(dbgs() << "[AAReturnedValues] Returned value: " << *RV << " by #" + << RIs.size() << " RIs\n"); CallBase *CB = dyn_cast<CallBase>(RV); if (!CB || UnresolvedCalls.count(CB)) return; @@ -1128,11 +1143,13 @@ ChangeStatus AAReturnedValuesImpl::updateImpl(Attributor &A) { RVState RVS({NewRVsMap, Unused, RetValAAIt.second}); VisitReturnedValue(*CB->getArgOperand(Arg->getArgNo()), RVS, CB); continue; - } else if (isa<CallBase>(RetVal)) { + } + if (isa<CallBase>(RetVal)) { // Call sites are resolved by the callee attribute over time, no need to // do anything for us. continue; - } else if (isa<Constant>(RetVal)) { + } + if (isa<Constant>(RetVal)) { // Constants are valid everywhere, we can simply take them. NewRVsMap[RetVal].insert(RIs.begin(), RIs.end()); continue; @@ -1373,7 +1390,7 @@ struct AANoSyncCallSite final : AANoSyncImpl { void initialize(Attributor &A) override { AANoSyncImpl::initialize(A); Function *F = getAssociatedFunction(); - if (!F) + if (!F || F->isDeclaration()) indicatePessimisticFixpoint(); } @@ -1386,8 +1403,7 @@ struct AANoSyncCallSite final : AANoSyncImpl { Function *F = getAssociatedFunction(); const IRPosition &FnPos = IRPosition::function(*F); auto &FnAA = A.getAAFor<AANoSync>(*this, FnPos); - return clampStateAndIndicateChange( - getState(), static_cast<const AANoSync::StateType &>(FnAA.getState())); + return clampStateAndIndicateChange(getState(), FnAA.getState()); } /// See AbstractAttribute::trackStatistics() @@ -1439,7 +1455,7 @@ struct AANoFreeCallSite final : AANoFreeImpl { void initialize(Attributor &A) override { AANoFreeImpl::initialize(A); Function *F = getAssociatedFunction(); - if (!F) + if (!F || F->isDeclaration()) indicatePessimisticFixpoint(); } @@ -1452,8 +1468,7 @@ struct AANoFreeCallSite final : AANoFreeImpl { Function *F = getAssociatedFunction(); const IRPosition &FnPos = IRPosition::function(*F); auto &FnAA = A.getAAFor<AANoFree>(*this, FnPos); - return clampStateAndIndicateChange( - getState(), static_cast<const AANoFree::StateType &>(FnAA.getState())); + return clampStateAndIndicateChange(getState(), FnAA.getState()); } /// See AbstractAttribute::trackStatistics() @@ -1535,8 +1550,7 @@ struct AANoFreeCallSiteArgument final : AANoFreeFloating { return indicatePessimisticFixpoint(); const IRPosition &ArgPos = IRPosition::argument(*Arg); auto &ArgAA = A.getAAFor<AANoFree>(*this, ArgPos); - return clampStateAndIndicateChange( - getState(), static_cast<const AANoFree::StateType &>(ArgAA.getState())); + return clampStateAndIndicateChange(getState(), ArgAA.getState()); } /// See AbstractAttribute::trackStatistics() @@ -1672,21 +1686,33 @@ struct AANonNullImpl : AANonNull { Value &V = getAssociatedValue(); if (!NullIsDefined && hasAttr({Attribute::NonNull, Attribute::Dereferenceable}, - /* IgnoreSubsumingPositions */ false, &A)) + /* IgnoreSubsumingPositions */ false, &A)) { indicateOptimisticFixpoint(); - else if (isa<ConstantPointerNull>(V)) + return; + } + + if (isa<ConstantPointerNull>(V)) { indicatePessimisticFixpoint(); - else - AANonNull::initialize(A); + return; + } + + AANonNull::initialize(A); bool CanBeNull = true; - if (V.getPointerDereferenceableBytes(A.getDataLayout(), CanBeNull)) - if (!CanBeNull) + if (V.getPointerDereferenceableBytes(A.getDataLayout(), CanBeNull)) { + if (!CanBeNull) { indicateOptimisticFixpoint(); + return; + } + } - if (!getState().isAtFixpoint()) - if (Instruction *CtxI = getCtxI()) - followUsesInMBEC(*this, A, getState(), *CtxI); + if (isa<GlobalValue>(&getAssociatedValue())) { + indicatePessimisticFixpoint(); + return; + } + + if (Instruction *CtxI = getCtxI()) + followUsesInMBEC(*this, A, getState(), *CtxI); } /// See followUsesInMBEC @@ -1717,13 +1743,6 @@ struct AANonNullFloating : public AANonNullImpl { /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { - if (!NullIsDefined) { - const auto &DerefAA = - A.getAAFor<AADereferenceable>(*this, getIRPosition()); - if (DerefAA.getAssumedDereferenceableBytes()) - return ChangeStatus::UNCHANGED; - } - const DataLayout &DL = A.getDataLayout(); DominatorTree *DT = nullptr; @@ -1742,8 +1761,7 @@ struct AANonNullFloating : public AANonNullImpl { T.indicatePessimisticFixpoint(); } else { // Use abstract attribute information. - const AANonNull::StateType &NS = - static_cast<const AANonNull::StateType &>(AA.getState()); + const AANonNull::StateType &NS = AA.getState(); T ^= NS; } return T.isValidState(); @@ -1763,9 +1781,14 @@ struct AANonNullFloating : public AANonNullImpl { /// NonNull attribute for function return value. struct AANonNullReturned final - : AAReturnedFromReturnedValues<AANonNull, AANonNullImpl> { + : AAReturnedFromReturnedValues<AANonNull, AANonNull> { AANonNullReturned(const IRPosition &IRP, Attributor &A) - : AAReturnedFromReturnedValues<AANonNull, AANonNullImpl>(IRP, A) {} + : AAReturnedFromReturnedValues<AANonNull, AANonNull>(IRP, A) {} + + /// See AbstractAttribute::getAsStr(). + const std::string getAsStr() const override { + return getAssumed() ? "nonnull" : "may-null"; + } /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_FNRET_ATTR(nonnull) } @@ -1879,7 +1902,7 @@ struct AANoRecurseCallSite final : AANoRecurseImpl { void initialize(Attributor &A) override { AANoRecurseImpl::initialize(A); Function *F = getAssociatedFunction(); - if (!F) + if (!F || F->isDeclaration()) indicatePessimisticFixpoint(); } @@ -1892,9 +1915,7 @@ struct AANoRecurseCallSite final : AANoRecurseImpl { Function *F = getAssociatedFunction(); const IRPosition &FnPos = IRPosition::function(*F); auto &FnAA = A.getAAFor<AANoRecurse>(*this, FnPos); - return clampStateAndIndicateChange( - getState(), - static_cast<const AANoRecurse::StateType &>(FnAA.getState())); + return clampStateAndIndicateChange(getState(), FnAA.getState()); } /// See AbstractAttribute::trackStatistics() @@ -1979,6 +2000,98 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior { return true; }; + auto InspectCallSiteForUB = [&](Instruction &I) { + // Check whether a callsite always cause UB or not + + // Skip instructions that are already saved. + if (AssumedNoUBInsts.count(&I) || KnownUBInsts.count(&I)) + return true; + + // Check nonnull and noundef argument attribute violation for each + // callsite. + CallBase &CB = cast<CallBase>(I); + Function *Callee = CB.getCalledFunction(); + if (!Callee) + return true; + for (unsigned idx = 0; idx < CB.getNumArgOperands(); idx++) { + // If current argument is known to be simplified to null pointer and the + // corresponding argument position is known to have nonnull attribute, + // the argument is poison. Furthermore, if the argument is poison and + // the position is known to have noundef attriubte, this callsite is + // considered UB. + if (idx >= Callee->arg_size()) + break; + Value *ArgVal = CB.getArgOperand(idx); + if (!ArgVal) + continue; + // Here, we handle three cases. + // (1) Not having a value means it is dead. (we can replace the value + // with undef) + // (2) Simplified to undef. The argument violate noundef attriubte. + // (3) Simplified to null pointer where known to be nonnull. + // The argument is a poison value and violate noundef attribute. + IRPosition CalleeArgumentIRP = IRPosition::callsite_argument(CB, idx); + auto &NoUndefAA = A.getAAFor<AANoUndef>(*this, CalleeArgumentIRP, + /* TrackDependence */ false); + if (!NoUndefAA.isKnownNoUndef()) + continue; + auto &ValueSimplifyAA = A.getAAFor<AAValueSimplify>( + *this, IRPosition::value(*ArgVal), /* TrackDependence */ false); + if (!ValueSimplifyAA.isKnown()) + continue; + Optional<Value *> SimplifiedVal = + ValueSimplifyAA.getAssumedSimplifiedValue(A); + if (!SimplifiedVal.hasValue() || + isa<UndefValue>(*SimplifiedVal.getValue())) { + KnownUBInsts.insert(&I); + continue; + } + if (!ArgVal->getType()->isPointerTy() || + !isa<ConstantPointerNull>(*SimplifiedVal.getValue())) + continue; + auto &NonNullAA = A.getAAFor<AANonNull>(*this, CalleeArgumentIRP, + /* TrackDependence */ false); + if (NonNullAA.isKnownNonNull()) + KnownUBInsts.insert(&I); + } + return true; + }; + + auto InspectReturnInstForUB = + [&](Value &V, const SmallSetVector<ReturnInst *, 4> RetInsts) { + // Check if a return instruction always cause UB or not + // Note: It is guaranteed that the returned position of the anchor + // scope has noundef attribute when this is called. + // We also ensure the return position is not "assumed dead" + // because the returned value was then potentially simplified to + // `undef` in AAReturnedValues without removing the `noundef` + // attribute yet. + + // When the returned position has noundef attriubte, UB occur in the + // following cases. + // (1) Returned value is known to be undef. + // (2) The value is known to be a null pointer and the returned + // position has nonnull attribute (because the returned value is + // poison). + bool FoundUB = false; + if (isa<UndefValue>(V)) { + FoundUB = true; + } else { + if (isa<ConstantPointerNull>(V)) { + auto &NonNullAA = A.getAAFor<AANonNull>( + *this, IRPosition::returned(*getAnchorScope()), + /* TrackDependence */ false); + if (NonNullAA.isKnownNonNull()) + FoundUB = true; + } + } + + if (FoundUB) + for (ReturnInst *RI : RetInsts) + KnownUBInsts.insert(RI); + return true; + }; + A.checkForAllInstructions(InspectMemAccessInstForUB, *this, {Instruction::Load, Instruction::Store, Instruction::AtomicCmpXchg, @@ -1986,6 +2099,22 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior { /* CheckBBLivenessOnly */ true); A.checkForAllInstructions(InspectBrInstForUB, *this, {Instruction::Br}, /* CheckBBLivenessOnly */ true); + A.checkForAllCallLikeInstructions(InspectCallSiteForUB, *this); + + // If the returned position of the anchor scope has noundef attriubte, check + // all returned instructions. + if (!getAnchorScope()->getReturnType()->isVoidTy()) { + const IRPosition &ReturnIRP = IRPosition::returned(*getAnchorScope()); + if (!A.isAssumedDead(ReturnIRP, this, nullptr)) { + auto &RetPosNoUndefAA = + A.getAAFor<AANoUndef>(*this, ReturnIRP, + /* TrackDependence */ false); + if (RetPosNoUndefAA.isKnownNoUndef()) + A.checkForAllReturnedValuesAndReturnInsts(InspectReturnInstForUB, + *this); + } + } + if (NoUBPrevSize != AssumedNoUBInsts.size() || UBPrevSize != KnownUBInsts.size()) return ChangeStatus::CHANGED; @@ -2153,7 +2282,7 @@ struct AAWillReturnImpl : public AAWillReturn { AAWillReturn::initialize(A); Function *F = getAnchorScope(); - if (!F || !A.isFunctionIPOAmendable(*F) || mayContainUnboundedCycle(*F, A)) + if (!F || F->isDeclaration() || mayContainUnboundedCycle(*F, A)) indicatePessimisticFixpoint(); } @@ -2197,9 +2326,9 @@ struct AAWillReturnCallSite final : AAWillReturnImpl { /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { - AAWillReturnImpl::initialize(A); + AAWillReturn::initialize(A); Function *F = getAssociatedFunction(); - if (!F) + if (!F || !A.isFunctionIPOAmendable(*F)) indicatePessimisticFixpoint(); } @@ -2212,9 +2341,7 @@ struct AAWillReturnCallSite final : AAWillReturnImpl { Function *F = getAssociatedFunction(); const IRPosition &FnPos = IRPosition::function(*F); auto &FnAA = A.getAAFor<AAWillReturn>(*this, FnPos); - return clampStateAndIndicateChange( - getState(), - static_cast<const AAWillReturn::StateType &>(FnAA.getState())); + return clampStateAndIndicateChange(getState(), FnAA.getState()); } /// See AbstractAttribute::trackStatistics() @@ -2374,7 +2501,7 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl { void initialize(Attributor &A) override { // See callsite argument attribute and callee argument attribute. const auto &CB = cast<CallBase>(getAnchorValue()); - if (CB.paramHasAttr(getArgNo(), Attribute::NoAlias)) + if (CB.paramHasAttr(getCallSiteArgNo(), Attribute::NoAlias)) indicateOptimisticFixpoint(); Value &Val = getAssociatedValue(); if (isa<ConstantPointerNull>(Val) && @@ -2389,7 +2516,7 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl { const AAMemoryBehavior &MemBehaviorAA, const CallBase &CB, unsigned OtherArgNo) { // We do not need to worry about aliasing with the underlying IRP. - if (this->getArgNo() == (int)OtherArgNo) + if (this->getCalleeArgNo() == (int)OtherArgNo) return false; // If it is not a pointer or pointer vector we do not alias. @@ -2451,6 +2578,7 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl { A.recordDependence(NoAliasAA, *this, DepClassTy::OPTIONAL); const IRPosition &VIRP = IRPosition::value(getAssociatedValue()); + const Function *ScopeFn = VIRP.getAnchorScope(); auto &NoCaptureAA = A.getAAFor<AANoCapture>(*this, VIRP, /* TrackDependence */ false); // Check whether the value is captured in the scope using AANoCapture. @@ -2459,16 +2587,18 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl { auto UsePred = [&](const Use &U, bool &Follow) -> bool { Instruction *UserI = cast<Instruction>(U.getUser()); - // If user if curr instr and only use. - if (UserI == getCtxI() && UserI->hasOneUse()) + // If UserI is the curr instruction and there is a single potential use of + // the value in UserI we allow the use. + // TODO: We should inspect the operands and allow those that cannot alias + // with the value. + if (UserI == getCtxI() && UserI->getNumOperands() == 1) return true; - const Function *ScopeFn = VIRP.getAnchorScope(); if (ScopeFn) { const auto &ReachabilityAA = A.getAAFor<AAReachability>(*this, IRPosition::function(*ScopeFn)); - if (!ReachabilityAA.isAssumedReachable(UserI, getCtxI())) + if (!ReachabilityAA.isAssumedReachable(A, *UserI, *getCtxI())) return true; if (auto *CB = dyn_cast<CallBase>(UserI)) { @@ -2554,6 +2684,14 @@ struct AANoAliasReturned final : AANoAliasImpl { AANoAliasReturned(const IRPosition &IRP, Attributor &A) : AANoAliasImpl(IRP, A) {} + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + AANoAliasImpl::initialize(A); + Function *F = getAssociatedFunction(); + if (!F || F->isDeclaration()) + indicatePessimisticFixpoint(); + } + /// See AbstractAttribute::updateImpl(...). virtual ChangeStatus updateImpl(Attributor &A) override { @@ -2595,7 +2733,7 @@ struct AANoAliasCallSiteReturned final : AANoAliasImpl { void initialize(Attributor &A) override { AANoAliasImpl::initialize(A); Function *F = getAssociatedFunction(); - if (!F) + if (!F || F->isDeclaration()) indicatePessimisticFixpoint(); } @@ -2608,8 +2746,7 @@ struct AANoAliasCallSiteReturned final : AANoAliasImpl { Function *F = getAssociatedFunction(); const IRPosition &FnPos = IRPosition::returned(*F); auto &FnAA = A.getAAFor<AANoAlias>(*this, FnPos); - return clampStateAndIndicateChange( - getState(), static_cast<const AANoAlias::StateType &>(FnAA.getState())); + return clampStateAndIndicateChange(getState(), FnAA.getState()); } /// See AbstractAttribute::trackStatistics() @@ -2799,14 +2936,13 @@ struct AAIsDeadCallSiteArgument : public AAIsDeadValueImpl { return indicatePessimisticFixpoint(); const IRPosition &ArgPos = IRPosition::argument(*Arg); auto &ArgAA = A.getAAFor<AAIsDead>(*this, ArgPos); - return clampStateAndIndicateChange( - getState(), static_cast<const AAIsDead::StateType &>(ArgAA.getState())); + return clampStateAndIndicateChange(getState(), ArgAA.getState()); } /// See AbstractAttribute::manifest(...). ChangeStatus manifest(Attributor &A) override { CallBase &CB = cast<CallBase>(getAnchorValue()); - Use &U = CB.getArgOperandUse(getArgNo()); + Use &U = CB.getArgOperandUse(getCallSiteArgNo()); assert(!isa<UndefValue>(U.get()) && "Expected undef values to be filtered out!"); UndefValue &UV = *UndefValue::get(U->getType()); @@ -2921,8 +3057,14 @@ struct AAIsDeadFunction : public AAIsDead { void initialize(Attributor &A) override { const Function *F = getAnchorScope(); if (F && !F->isDeclaration()) { - ToBeExploredFrom.insert(&F->getEntryBlock().front()); - assumeLive(A, F->getEntryBlock()); + // We only want to compute liveness once. If the function is not part of + // the SCC, skip it. + if (A.isRunOn(*const_cast<Function *>(F))) { + ToBeExploredFrom.insert(&F->getEntryBlock().front()); + assumeLive(A, F->getEntryBlock()); + } else { + indicatePessimisticFixpoint(); + } } } @@ -2985,6 +3127,10 @@ struct AAIsDeadFunction : public AAIsDead { /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override; + bool isEdgeDead(const BasicBlock *From, const BasicBlock *To) const override { + return !AssumedLiveEdges.count(std::make_pair(From, To)); + } + /// See AbstractAttribute::trackStatistics() void trackStatistics() const override {} @@ -3062,6 +3208,9 @@ struct AAIsDeadFunction : public AAIsDead { /// Collection of instructions that are known to not transfer control. SmallSetVector<const Instruction *, 8> KnownDeadEnds; + /// Collection of all assumed live edges + DenseSet<std::pair<const BasicBlock *, const BasicBlock *>> AssumedLiveEdges; + /// Collection of all assumed live BasicBlocks. DenseSet<const BasicBlock *> AssumedLiveBlocks; }; @@ -3177,18 +3326,23 @@ ChangeStatus AAIsDeadFunction::updateImpl(Attributor &A) { const Instruction *I = Worklist.pop_back_val(); LLVM_DEBUG(dbgs() << "[AAIsDead] Exploration inst: " << *I << "\n"); + // Fast forward for uninteresting instructions. We could look for UB here + // though. + while (!I->isTerminator() && !isa<CallBase>(I)) { + Change = ChangeStatus::CHANGED; + I = I->getNextNode(); + } + AliveSuccessors.clear(); bool UsedAssumedInformation = false; switch (I->getOpcode()) { // TODO: look for (assumed) UB to backwards propagate "deadness". default: - if (I->isTerminator()) { - for (const BasicBlock *SuccBB : successors(I->getParent())) - AliveSuccessors.push_back(&SuccBB->front()); - } else { - AliveSuccessors.push_back(I->getNextNode()); - } + assert(I->isTerminator() && + "Expected non-terminators to be handled already!"); + for (const BasicBlock *SuccBB : successors(I->getParent())) + AliveSuccessors.push_back(&SuccBB->front()); break; case Instruction::Call: UsedAssumedInformation = identifyAliveSuccessors(A, cast<CallInst>(*I), @@ -3227,6 +3381,9 @@ ChangeStatus AAIsDeadFunction::updateImpl(Attributor &A) { "Non-terminator expected to have a single successor!"); Worklist.push_back(AliveSuccessor); } else { + // record the assumed live edge + AssumedLiveEdges.insert( + std::make_pair(I->getParent(), AliveSuccessor->getParent())); if (assumeLive(A, *AliveSuccessor->getParent())) Worklist.push_back(AliveSuccessor); } @@ -3342,7 +3499,6 @@ struct AADereferenceableImpl : AADereferenceable { State.addAccessedBytes(Offset, Size); } } - return; } /// See followUsesInMBEC @@ -3420,12 +3576,11 @@ struct AADereferenceableFloating : AADereferenceableImpl { DerefBytes = Base->getPointerDereferenceableBytes(DL, CanBeNull); T.GlobalState.indicatePessimisticFixpoint(); } else { - const DerefState &DS = static_cast<const DerefState &>(AA.getState()); + const DerefState &DS = AA.getState(); DerefBytes = DS.DerefBytesState.getAssumed(); T.GlobalState &= DS.GlobalState; } - // For now we do not try to "increase" dereferenceability due to negative // indices as we first have to come up with code to deal with loops and // for overflows of the dereferenceable bytes. @@ -3697,14 +3852,27 @@ struct AAAlignFloating : AAAlignImpl { AAAlign::StateType &T, bool Stripped) -> bool { const auto &AA = A.getAAFor<AAAlign>(*this, IRPosition::value(V)); if (!Stripped && this == &AA) { + int64_t Offset; + unsigned Alignment = 1; + if (const Value *Base = + GetPointerBaseWithConstantOffset(&V, Offset, DL)) { + Align PA = Base->getPointerAlignment(DL); + // BasePointerAddr + Offset = Alignment * Q for some integer Q. + // So we can say that the maximum power of two which is a divisor of + // gcd(Offset, Alignment) is an alignment. + + uint32_t gcd = greatestCommonDivisor(uint32_t(abs((int32_t)Offset)), + uint32_t(PA.value())); + Alignment = llvm::PowerOf2Floor(gcd); + } else { + Alignment = V.getPointerAlignment(DL).value(); + } // Use only IR information if we did not strip anything. - Align PA = V.getPointerAlignment(DL); - T.takeKnownMaximum(PA.value()); + T.takeKnownMaximum(Alignment); T.indicatePessimisticFixpoint(); } else { // Use abstract attribute information. - const AAAlign::StateType &DS = - static_cast<const AAAlign::StateType &>(AA.getState()); + const AAAlign::StateType &DS = AA.getState(); T ^= DS; } return T.isValidState(); @@ -3727,8 +3895,16 @@ struct AAAlignFloating : AAAlignImpl { /// Align attribute for function return value. struct AAAlignReturned final : AAReturnedFromReturnedValues<AAAlign, AAAlignImpl> { - AAAlignReturned(const IRPosition &IRP, Attributor &A) - : AAReturnedFromReturnedValues<AAAlign, AAAlignImpl>(IRP, A) {} + using Base = AAReturnedFromReturnedValues<AAAlign, AAAlignImpl>; + AAAlignReturned(const IRPosition &IRP, Attributor &A) : Base(IRP, A) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + Base::initialize(A); + Function *F = getAssociatedFunction(); + if (!F || F->isDeclaration()) + indicatePessimisticFixpoint(); + } /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_FNRET_ATTR(aligned) } @@ -3802,7 +3978,7 @@ struct AAAlignCallSiteReturned final void initialize(Attributor &A) override { Base::initialize(A); Function *F = getAssociatedFunction(); - if (!F) + if (!F || F->isDeclaration()) indicatePessimisticFixpoint(); } @@ -3818,7 +3994,7 @@ struct AANoReturnImpl : public AANoReturn { void initialize(Attributor &A) override { AANoReturn::initialize(A); Function *F = getAssociatedFunction(); - if (!F) + if (!F || F->isDeclaration()) indicatePessimisticFixpoint(); } @@ -3850,6 +4026,17 @@ struct AANoReturnCallSite final : AANoReturnImpl { AANoReturnCallSite(const IRPosition &IRP, Attributor &A) : AANoReturnImpl(IRP, A) {} + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + AANoReturnImpl::initialize(A); + if (Function *F = getAssociatedFunction()) { + const IRPosition &FnPos = IRPosition::function(*F); + auto &FnAA = A.getAAFor<AANoReturn>(*this, FnPos); + if (!FnAA.isAssumedNoReturn()) + indicatePessimisticFixpoint(); + } + } + /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { // TODO: Once we have call site specific value information we can provide @@ -3859,9 +4046,7 @@ struct AANoReturnCallSite final : AANoReturnImpl { Function *F = getAssociatedFunction(); const IRPosition &FnPos = IRPosition::function(*F); auto &FnAA = A.getAAFor<AANoReturn>(*this, FnPos); - return clampStateAndIndicateChange( - getState(), - static_cast<const AANoReturn::StateType &>(FnAA.getState())); + return clampStateAndIndicateChange(getState(), FnAA.getState()); } /// See AbstractAttribute::trackStatistics() @@ -3894,7 +4079,8 @@ struct AANoCaptureImpl : public AANoCapture { return; } - const Function *F = getArgNo() >= 0 ? getAssociatedFunction() : AnchorScope; + const Function *F = + isArgumentPosition() ? getAssociatedFunction() : AnchorScope; // Check what state the associated function can actually capture. if (F) @@ -3913,7 +4099,7 @@ struct AANoCaptureImpl : public AANoCapture { if (!isAssumedNoCaptureMaybeReturned()) return; - if (getArgNo() >= 0) { + if (isArgumentPosition()) { if (isAssumedNoCapture()) Attrs.emplace_back(Attribute::get(Ctx, Attribute::NoCapture)); else if (ManifestInternal) @@ -3949,7 +4135,7 @@ struct AANoCaptureImpl : public AANoCapture { State.addKnownBits(NOT_CAPTURED_IN_RET); // Check existing "returned" attributes. - int ArgNo = IRP.getArgNo(); + int ArgNo = IRP.getCalleeArgNo(); if (F.doesNotThrow() && ArgNo >= 0) { for (unsigned u = 0, e = F.arg_size(); u < e; ++u) if (F.hasParamAttribute(u, Attribute::Returned)) { @@ -4125,13 +4311,13 @@ private: ChangeStatus AANoCaptureImpl::updateImpl(Attributor &A) { const IRPosition &IRP = getIRPosition(); - const Value *V = - getArgNo() >= 0 ? IRP.getAssociatedArgument() : &IRP.getAssociatedValue(); + const Value *V = isArgumentPosition() ? IRP.getAssociatedArgument() + : &IRP.getAssociatedValue(); if (!V) return indicatePessimisticFixpoint(); const Function *F = - getArgNo() >= 0 ? IRP.getAssociatedFunction() : IRP.getAnchorScope(); + isArgumentPosition() ? IRP.getAssociatedFunction() : IRP.getAnchorScope(); assert(F && "Expected a function!"); const IRPosition &FnPos = IRPosition::function(*F); const auto &IsDeadAA = @@ -4248,9 +4434,7 @@ struct AANoCaptureCallSiteArgument final : AANoCaptureImpl { return indicatePessimisticFixpoint(); const IRPosition &ArgPos = IRPosition::argument(*Arg); auto &ArgAA = A.getAAFor<AANoCapture>(*this, ArgPos); - return clampStateAndIndicateChange( - getState(), - static_cast<const AANoCapture::StateType &>(ArgAA.getState())); + return clampStateAndIndicateChange(getState(), ArgAA.getState()); } /// See AbstractAttribute::trackStatistics() @@ -4366,24 +4550,35 @@ struct AAValueSimplifyImpl : AAValueSimplify { return true; } - bool askSimplifiedValueForAAValueConstantRange(Attributor &A) { + /// Returns a candidate is found or not + template <typename AAType> bool askSimplifiedValueFor(Attributor &A) { if (!getAssociatedValue().getType()->isIntegerTy()) return false; - const auto &ValueConstantRangeAA = - A.getAAFor<AAValueConstantRange>(*this, getIRPosition()); + const auto &AA = + A.getAAFor<AAType>(*this, getIRPosition(), /* TrackDependence */ false); - Optional<ConstantInt *> COpt = - ValueConstantRangeAA.getAssumedConstantInt(A); - if (COpt.hasValue()) { - if (auto *C = COpt.getValue()) - SimplifiedAssociatedValue = C; - else - return false; - } else { + Optional<ConstantInt *> COpt = AA.getAssumedConstantInt(A); + + if (!COpt.hasValue()) { SimplifiedAssociatedValue = llvm::None; + A.recordDependence(AA, *this, DepClassTy::OPTIONAL); + return true; } - return true; + if (auto *C = COpt.getValue()) { + SimplifiedAssociatedValue = C; + A.recordDependence(AA, *this, DepClassTy::OPTIONAL); + return true; + } + return false; + } + + bool askSimplifiedValueForOtherAAs(Attributor &A) { + if (askSimplifiedValueFor<AAValueConstantRange>(A)) + return true; + if (askSimplifiedValueFor<AAPotentialValues>(A)) + return true; + return false; } /// See AbstractAttribute::manifest(...). @@ -4468,7 +4663,7 @@ struct AAValueSimplifyArgument final : AAValueSimplifyImpl { auto PredForCallSite = [&](AbstractCallSite ACS) { const IRPosition &ACSArgPos = - IRPosition::callsite_argument(ACS, getArgNo()); + IRPosition::callsite_argument(ACS, getCallSiteArgNo()); // Check if a coresponding argument was found or if it is on not // associated (which can happen for callback calls). if (ACSArgPos.getPositionKind() == IRPosition::IRP_INVALID) @@ -4490,7 +4685,7 @@ struct AAValueSimplifyArgument final : AAValueSimplifyImpl { bool AllCallSitesKnown; if (!A.checkForAllCallSites(PredForCallSite, *this, true, AllCallSitesKnown)) - if (!askSimplifiedValueForAAValueConstantRange(A)) + if (!askSimplifiedValueForOtherAAs(A)) return indicatePessimisticFixpoint(); // If a candicate was found in this update, return CHANGED. @@ -4518,7 +4713,7 @@ struct AAValueSimplifyReturned : AAValueSimplifyImpl { }; if (!A.checkForAllReturnedValues(PredForReturned, *this)) - if (!askSimplifiedValueForAAValueConstantRange(A)) + if (!askSimplifiedValueForOtherAAs(A)) return indicatePessimisticFixpoint(); // If a candicate was found in this update, return CHANGED. @@ -4587,10 +4782,76 @@ struct AAValueSimplifyFloating : AAValueSimplifyImpl { indicatePessimisticFixpoint(); } + /// Check if \p ICmp is an equality comparison (==/!=) with at least one + /// nullptr. If so, try to simplify it using AANonNull on the other operand. + /// Return true if successful, in that case SimplifiedAssociatedValue will be + /// updated and \p Changed is set appropriately. + bool checkForNullPtrCompare(Attributor &A, ICmpInst *ICmp, + ChangeStatus &Changed) { + if (!ICmp) + return false; + if (!ICmp->isEquality()) + return false; + + // This is a comparison with == or !-. We check for nullptr now. + bool Op0IsNull = isa<ConstantPointerNull>(ICmp->getOperand(0)); + bool Op1IsNull = isa<ConstantPointerNull>(ICmp->getOperand(1)); + if (!Op0IsNull && !Op1IsNull) + return false; + + LLVMContext &Ctx = ICmp->getContext(); + // Check for `nullptr ==/!= nullptr` first: + if (Op0IsNull && Op1IsNull) { + Value *NewVal = ConstantInt::get( + Type::getInt1Ty(Ctx), ICmp->getPredicate() == CmpInst::ICMP_EQ); + assert(!SimplifiedAssociatedValue.hasValue() && + "Did not expect non-fixed value for constant comparison"); + SimplifiedAssociatedValue = NewVal; + indicateOptimisticFixpoint(); + Changed = ChangeStatus::CHANGED; + return true; + } + + // Left is the nullptr ==/!= non-nullptr case. We'll use AANonNull on the + // non-nullptr operand and if we assume it's non-null we can conclude the + // result of the comparison. + assert((Op0IsNull || Op1IsNull) && + "Expected nullptr versus non-nullptr comparison at this point"); + + // The index is the operand that we assume is not null. + unsigned PtrIdx = Op0IsNull; + auto &PtrNonNullAA = A.getAAFor<AANonNull>( + *this, IRPosition::value(*ICmp->getOperand(PtrIdx))); + if (!PtrNonNullAA.isAssumedNonNull()) + return false; + + // The new value depends on the predicate, true for != and false for ==. + Value *NewVal = ConstantInt::get(Type::getInt1Ty(Ctx), + ICmp->getPredicate() == CmpInst::ICMP_NE); + + assert((!SimplifiedAssociatedValue.hasValue() || + SimplifiedAssociatedValue == NewVal) && + "Did not expect to change value for zero-comparison"); + + bool HasValueBefore = SimplifiedAssociatedValue.hasValue(); + SimplifiedAssociatedValue = NewVal; + + if (PtrNonNullAA.isKnownNonNull()) + indicateOptimisticFixpoint(); + + Changed = HasValueBefore ? ChangeStatus::UNCHANGED : ChangeStatus ::CHANGED; + return true; + } + /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { bool HasValueBefore = SimplifiedAssociatedValue.hasValue(); + ChangeStatus Changed; + if (checkForNullPtrCompare(A, dyn_cast<ICmpInst>(&getAnchorValue()), + Changed)) + return Changed; + auto VisitValueCB = [&](Value &V, const Instruction *CtxI, bool &, bool Stripped) -> bool { auto &AA = A.getAAFor<AAValueSimplify>(*this, IRPosition::value(V)); @@ -4608,7 +4869,7 @@ struct AAValueSimplifyFloating : AAValueSimplifyImpl { if (!genericValueTraversal<AAValueSimplify, bool>( A, getIRPosition(), *this, Dummy, VisitValueCB, getCtxI(), /* UseValueSimplify */ false)) - if (!askSimplifiedValueForAAValueConstantRange(A)) + if (!askSimplifiedValueForOtherAAs(A)) return indicatePessimisticFixpoint(); // If a candicate was found in this update, return CHANGED. @@ -4683,7 +4944,8 @@ struct AAValueSimplifyCallSiteArgument : AAValueSimplifyFloating { ? dyn_cast<Constant>(SimplifiedAssociatedValue.getValue()) : UndefValue::get(V.getType()); if (C) { - Use &U = cast<CallBase>(&getAnchorValue())->getArgOperandUse(getArgNo()); + Use &U = cast<CallBase>(&getAnchorValue()) + ->getArgOperandUse(getCallSiteArgNo()); // We can replace the AssociatedValue with the constant. if (&V != C && V.getType() == C->getType()) { if (A.changeUseAfterManifest(U, *C)) @@ -5002,7 +5264,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { return getAssociatedValue().getType()->getPointerElementType(); Optional<Type *> Ty; - unsigned ArgNo = getIRPosition().getArgNo(); + unsigned ArgNo = getIRPosition().getCallSiteArgNo(); // Make sure the associated call site argument has the same type at all call // sites and it is an allocation we know is safe to privatize, for now that @@ -5265,8 +5527,9 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { new StoreInst(F.getArg(ArgNo + u), Ptr, &IP); } } else if (auto *PrivArrayType = dyn_cast<ArrayType>(PrivType)) { - Type *PointeePtrTy = PrivArrayType->getElementType()->getPointerTo(); - uint64_t PointeeTySize = DL.getTypeStoreSize(PointeePtrTy); + Type *PointeeTy = PrivArrayType->getElementType(); + Type *PointeePtrTy = PointeeTy->getPointerTo(); + uint64_t PointeeTySize = DL.getTypeStoreSize(PointeeTy); for (unsigned u = 0, e = PrivArrayType->getNumElements(); u < e; u++) { Value *Ptr = constructPointer(PointeePtrTy, &Base, u * PointeeTySize, IRB, DL); @@ -5312,7 +5575,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { for (unsigned u = 0, e = PrivArrayType->getNumElements(); u < e; u++) { Value *Ptr = constructPointer(PointeePtrTy, Base, u * PointeeTySize, IRB, DL); - LoadInst *L = new LoadInst(PointeePtrTy, Ptr, "", IP); + LoadInst *L = new LoadInst(PointeeTy, Ptr, "", IP); L->setAlignment(Alignment); ReplacementValues.push_back(L); } @@ -5356,10 +5619,14 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { Function &ReplacementFn, Function::arg_iterator ArgIt) { BasicBlock &EntryBB = ReplacementFn.getEntryBlock(); Instruction *IP = &*EntryBB.getFirstInsertionPt(); - auto *AI = new AllocaInst(PrivatizableType.getValue(), 0, - Arg->getName() + ".priv", IP); + Instruction *AI = new AllocaInst(PrivatizableType.getValue(), 0, + Arg->getName() + ".priv", IP); createInitialization(PrivatizableType.getValue(), *AI, ReplacementFn, ArgIt->getArgNo(), *IP); + + if (AI->getType() != Arg->getType()) + AI = + BitCastInst::CreateBitOrPointerCast(AI, Arg->getType(), "", IP); Arg->replaceAllUsesWith(AI); for (CallInst *CI : TailCalls) @@ -5418,8 +5685,7 @@ struct AAPrivatizablePtrFloating : public AAPrivatizablePtrImpl { /// See AAPrivatizablePtrImpl::identifyPrivatizableType(...) Optional<Type *> identifyPrivatizableType(Attributor &A) override { - Value *Obj = - GetUnderlyingObject(&getAssociatedValue(), A.getInfoCache().getDL()); + Value *Obj = getUnderlyingObject(&getAssociatedValue()); if (!Obj) { LLVM_DEBUG(dbgs() << "[AAPrivatizablePtr] No underlying object found!\n"); return nullptr; @@ -5539,7 +5805,7 @@ struct AAMemoryBehaviorImpl : public AAMemoryBehavior { void initialize(Attributor &A) override { intersectAssumedBits(BEST_STATE); getKnownStateFromValue(getIRPosition(), getState()); - IRAttribute::initialize(A); + AAMemoryBehavior::initialize(A); } /// Return the memory behavior information encoded in the IR for \p IRP. @@ -5634,9 +5900,7 @@ struct AAMemoryBehaviorFloating : AAMemoryBehaviorImpl { /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { AAMemoryBehaviorImpl::initialize(A); - // Initialize the use vector with all direct uses of the associated value. - for (const Use &U : getAssociatedValue().uses()) - Uses.insert(&U); + addUsesOf(A, getAssociatedValue()); } /// See AbstractAttribute::updateImpl(...). @@ -5662,8 +5926,14 @@ private: void analyzeUseIn(Attributor &A, const Use *U, const Instruction *UserI); protected: + /// Add the uses of \p V to the `Uses` set we look at during the update step. + void addUsesOf(Attributor &A, const Value &V); + /// Container for (transitive) uses of the associated argument. - SetVector<const Use *> Uses; + SmallVector<const Use *, 8> Uses; + + /// Set to remember the uses we already traversed. + SmallPtrSet<const Use *, 8> Visited; }; /// Memory behavior attribute for function argument. @@ -5688,9 +5958,7 @@ struct AAMemoryBehaviorArgument : AAMemoryBehaviorFloating { if (!Arg || !A.isFunctionIPOAmendable(*(Arg->getParent()))) { indicatePessimisticFixpoint(); } else { - // Initialize the use vector with all direct uses of the associated value. - for (const Use &U : Arg->uses()) - Uses.insert(&U); + addUsesOf(A, *Arg); } } @@ -5725,14 +5993,21 @@ struct AAMemoryBehaviorCallSiteArgument final : AAMemoryBehaviorArgument { /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { - if (Argument *Arg = getAssociatedArgument()) { - if (Arg->hasByValAttr()) { - addKnownBits(NO_WRITES); - removeKnownBits(NO_READS); - removeAssumedBits(NO_READS); - } + // If we don't have an associated attribute this is either a variadic call + // or an indirect call, either way, nothing to do here. + Argument *Arg = getAssociatedArgument(); + if (!Arg) { + indicatePessimisticFixpoint(); + return; + } + if (Arg->hasByValAttr()) { + addKnownBits(NO_WRITES); + removeKnownBits(NO_READS); + removeAssumedBits(NO_READS); } AAMemoryBehaviorArgument::initialize(A); + if (getAssociatedFunction()->isDeclaration()) + indicatePessimisticFixpoint(); } /// See AbstractAttribute::updateImpl(...). @@ -5744,9 +6019,7 @@ struct AAMemoryBehaviorCallSiteArgument final : AAMemoryBehaviorArgument { Argument *Arg = getAssociatedArgument(); const IRPosition &ArgPos = IRPosition::argument(*Arg); auto &ArgAA = A.getAAFor<AAMemoryBehavior>(*this, ArgPos); - return clampStateAndIndicateChange( - getState(), - static_cast<const AAMemoryBehavior::StateType &>(ArgAA.getState())); + return clampStateAndIndicateChange(getState(), ArgAA.getState()); } /// See AbstractAttribute::trackStatistics() @@ -5765,6 +6038,14 @@ struct AAMemoryBehaviorCallSiteReturned final : AAMemoryBehaviorFloating { AAMemoryBehaviorCallSiteReturned(const IRPosition &IRP, Attributor &A) : AAMemoryBehaviorFloating(IRP, A) {} + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + AAMemoryBehaviorImpl::initialize(A); + Function *F = getAssociatedFunction(); + if (!F || F->isDeclaration()) + indicatePessimisticFixpoint(); + } + /// See AbstractAttribute::manifest(...). ChangeStatus manifest(Attributor &A) override { // We do not annotate returned values. @@ -5814,10 +6095,8 @@ struct AAMemoryBehaviorCallSite final : AAMemoryBehaviorImpl { void initialize(Attributor &A) override { AAMemoryBehaviorImpl::initialize(A); Function *F = getAssociatedFunction(); - if (!F || !A.isFunctionIPOAmendable(*F)) { + if (!F || F->isDeclaration()) indicatePessimisticFixpoint(); - return; - } } /// See AbstractAttribute::updateImpl(...). @@ -5829,9 +6108,7 @@ struct AAMemoryBehaviorCallSite final : AAMemoryBehaviorImpl { Function *F = getAssociatedFunction(); const IRPosition &FnPos = IRPosition::function(*F); auto &FnAA = A.getAAFor<AAMemoryBehavior>(*this, FnPos); - return clampStateAndIndicateChange( - getState(), - static_cast<const AAMemoryBehavior::StateType &>(FnAA.getState())); + return clampStateAndIndicateChange(getState(), FnAA.getState()); } /// See AbstractAttribute::trackStatistics() @@ -5933,8 +6210,7 @@ ChangeStatus AAMemoryBehaviorFloating::updateImpl(Attributor &A) { // Check if the users of UserI should also be visited. if (followUsersOfUseIn(A, U, UserI)) - for (const Use &UserIUse : UserI->uses()) - Uses.insert(&UserIUse); + addUsesOf(A, *UserI); // If UserI might touch memory we analyze the use in detail. if (UserI->mayReadOrWriteMemory()) @@ -5945,6 +6221,28 @@ ChangeStatus AAMemoryBehaviorFloating::updateImpl(Attributor &A) { : ChangeStatus::UNCHANGED; } +void AAMemoryBehaviorFloating::addUsesOf(Attributor &A, const Value &V) { + SmallVector<const Use *, 8> WL; + for (const Use &U : V.uses()) + WL.push_back(&U); + + while (!WL.empty()) { + const Use *U = WL.pop_back_val(); + if (!Visited.insert(U).second) + continue; + + const Instruction *UserI = cast<Instruction>(U->getUser()); + if (UserI->mayReadOrWriteMemory()) { + Uses.push_back(U); + continue; + } + if (!followUsersOfUseIn(A, U, UserI)) + continue; + for (const Use &UU : UserI->uses()) + WL.push_back(&UU); + } +} + bool AAMemoryBehaviorFloating::followUsersOfUseIn(Attributor &A, const Use *U, const Instruction *UserI) { // The loaded value is unrelated to the pointer argument, no need to @@ -6096,7 +6394,7 @@ struct AAMemoryLocationImpl : public AAMemoryLocation { void initialize(Attributor &A) override { intersectAssumedBits(BEST_STATE); getKnownStateFromValue(A, getIRPosition(), getState()); - IRAttribute::initialize(A); + AAMemoryLocation::initialize(A); } /// Return the memory behavior information encoded in the IR for \p IRP. @@ -6259,6 +6557,13 @@ protected: using AccessSet = SmallSet<AccessInfo, 2, AccessInfo>; AccessSet *AccessKind2Accesses[llvm::CTLog2<VALID_STATE>()]; + /// Categorize the pointer arguments of CB that might access memory in + /// AccessedLoc and update the state and access map accordingly. + void + categorizeArgumentPointerLocations(Attributor &A, CallBase &CB, + AAMemoryLocation::StateType &AccessedLocs, + bool &Changed); + /// Return the kind(s) of location that may be accessed by \p V. AAMemoryLocation::MemoryLocationsKind categorizeAccessedLocations(Attributor &A, Instruction &I, bool &Changed); @@ -6324,6 +6629,7 @@ void AAMemoryLocationImpl::categorizePtrValue( auto VisitValueCB = [&](Value &V, const Instruction *, AAMemoryLocation::StateType &T, bool Stripped) -> bool { + // TODO: recognize the TBAA used for constant accesses. MemoryLocationsKind MLK = NO_LOCATIONS; assert(!isa<GEPOperator>(V) && "GEPs should have been stripped."); if (isa<UndefValue>(V)) @@ -6334,6 +6640,13 @@ void AAMemoryLocationImpl::categorizePtrValue( else MLK = NO_ARGUMENT_MEM; } else if (auto *GV = dyn_cast<GlobalValue>(&V)) { + // Reading constant memory is not treated as a read "effect" by the + // function attr pass so we won't neither. Constants defined by TBAA are + // similar. (We know we do not write it because it is constant.) + if (auto *GVar = dyn_cast<GlobalVariable>(GV)) + if (GVar->isConstant()) + return true; + if (GV->hasLocalLinkage()) MLK = NO_GLOBAL_INTERNAL_MEM; else @@ -6380,6 +6693,30 @@ void AAMemoryLocationImpl::categorizePtrValue( } } +void AAMemoryLocationImpl::categorizeArgumentPointerLocations( + Attributor &A, CallBase &CB, AAMemoryLocation::StateType &AccessedLocs, + bool &Changed) { + for (unsigned ArgNo = 0, E = CB.getNumArgOperands(); ArgNo < E; ++ArgNo) { + + // Skip non-pointer arguments. + const Value *ArgOp = CB.getArgOperand(ArgNo); + if (!ArgOp->getType()->isPtrOrPtrVectorTy()) + continue; + + // Skip readnone arguments. + const IRPosition &ArgOpIRP = IRPosition::callsite_argument(CB, ArgNo); + const auto &ArgOpMemLocationAA = A.getAAFor<AAMemoryBehavior>( + *this, ArgOpIRP, /* TrackDependence */ true, DepClassTy::OPTIONAL); + + if (ArgOpMemLocationAA.isAssumedReadNone()) + continue; + + // Categorize potentially accessed pointer arguments as if there was an + // access instruction with them as pointer. + categorizePtrValue(A, CB, *ArgOp, AccessedLocs, Changed); + } +} + AAMemoryLocation::MemoryLocationsKind AAMemoryLocationImpl::categorizeAccessedLocations(Attributor &A, Instruction &I, bool &Changed) { @@ -6441,28 +6778,8 @@ AAMemoryLocationImpl::categorizeAccessedLocations(Attributor &A, Instruction &I, // Now handle argument memory if it might be accessed. bool HasArgAccesses = ((~CBAssumedNotAccessedLocs) & NO_ARGUMENT_MEM); - if (HasArgAccesses) { - for (unsigned ArgNo = 0, E = CB->getNumArgOperands(); ArgNo < E; - ++ArgNo) { - - // Skip non-pointer arguments. - const Value *ArgOp = CB->getArgOperand(ArgNo); - if (!ArgOp->getType()->isPtrOrPtrVectorTy()) - continue; - - // Skip readnone arguments. - const IRPosition &ArgOpIRP = IRPosition::callsite_argument(*CB, ArgNo); - const auto &ArgOpMemLocationAA = A.getAAFor<AAMemoryBehavior>( - *this, ArgOpIRP, /* TrackDependence */ true, DepClassTy::OPTIONAL); - - if (ArgOpMemLocationAA.isAssumedReadNone()) - continue; - - // Categorize potentially accessed pointer arguments as if there was an - // access instruction with them as pointer. - categorizePtrValue(A, I, *ArgOp, AccessedLocs, Changed); - } - } + if (HasArgAccesses) + categorizeArgumentPointerLocations(A, *CB, AccessedLocs, Changed); LLVM_DEBUG( dbgs() << "[AAMemoryLocation] Accessed state after argument handling: " @@ -6514,7 +6831,9 @@ struct AAMemoryLocationFunction final : public AAMemoryLocationImpl { LLVM_DEBUG(dbgs() << "[AAMemoryLocation] Accessed locations for " << I << ": " << getMemoryLocationsAsStr(MLK) << "\n"); removeAssumedBits(inverseLocation(MLK, false, false)); - return true; + // Stop once only the valid bit set in the *not assumed location*, thus + // once we don't actually exclude any memory locations in the state. + return getAssumedNotAccessedLocation() != VALID_STATE; }; if (!A.checkForAllReadWriteInstructions(CheckRWInst, *this)) @@ -6546,10 +6865,8 @@ struct AAMemoryLocationCallSite final : AAMemoryLocationImpl { void initialize(Attributor &A) override { AAMemoryLocationImpl::initialize(A); Function *F = getAssociatedFunction(); - if (!F || !A.isFunctionIPOAmendable(*F)) { + if (!F || F->isDeclaration()) indicatePessimisticFixpoint(); - return; - } } /// See AbstractAttribute::updateImpl(...). @@ -6655,7 +6972,6 @@ struct AAValueConstantRangeImpl : AAValueConstantRange { if (!LVI || !CtxI) return getWorstState(getBitWidth()); return LVI->getConstantRange(&getAssociatedValue(), - const_cast<BasicBlock *>(CtxI->getParent()), const_cast<Instruction *>(CtxI)); } @@ -6759,10 +7075,13 @@ struct AAValueConstantRangeImpl : AAValueConstantRange { auto &V = getAssociatedValue(); if (!AssumedConstantRange.isEmptySet() && !AssumedConstantRange.isSingleElement()) { - if (Instruction *I = dyn_cast<Instruction>(&V)) + if (Instruction *I = dyn_cast<Instruction>(&V)) { + assert(I == getCtxI() && "Should not annotate an instruction which is " + "not the context instruction"); if (isa<CallInst>(I) || isa<LoadInst>(I)) if (setRangeMetadataIfisBetterRange(I, AssumedConstantRange)) Changed = ChangeStatus::CHANGED; + } } return Changed; @@ -6831,6 +7150,9 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl { return; } + if (isa<CallBase>(&V)) + return; + if (isa<BinaryOperator>(&V) || isa<CmpInst>(&V) || isa<CastInst>(&V)) return; // If it is a load instruction with range metadata, use it. @@ -7068,11 +7390,641 @@ struct AAValueConstantRangeCallSiteArgument : AAValueConstantRangeFloating { AAValueConstantRangeCallSiteArgument(const IRPosition &IRP, Attributor &A) : AAValueConstantRangeFloating(IRP, A) {} + /// See AbstractAttribute::manifest() + ChangeStatus manifest(Attributor &A) override { + return ChangeStatus::UNCHANGED; + } + /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CSARG_ATTR(value_range) } }; + +/// ------------------ Potential Values Attribute ------------------------- + +struct AAPotentialValuesImpl : AAPotentialValues { + using StateType = PotentialConstantIntValuesState; + + AAPotentialValuesImpl(const IRPosition &IRP, Attributor &A) + : AAPotentialValues(IRP, A) {} + + /// See AbstractAttribute::getAsStr(). + const std::string getAsStr() const override { + std::string Str; + llvm::raw_string_ostream OS(Str); + OS << getState(); + return OS.str(); + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + return indicatePessimisticFixpoint(); + } +}; + +struct AAPotentialValuesArgument final + : AAArgumentFromCallSiteArguments<AAPotentialValues, AAPotentialValuesImpl, + PotentialConstantIntValuesState> { + using Base = + AAArgumentFromCallSiteArguments<AAPotentialValues, AAPotentialValuesImpl, + PotentialConstantIntValuesState>; + AAPotentialValuesArgument(const IRPosition &IRP, Attributor &A) + : Base(IRP, A) {} + + /// See AbstractAttribute::initialize(..). + void initialize(Attributor &A) override { + if (!getAnchorScope() || getAnchorScope()->isDeclaration()) { + indicatePessimisticFixpoint(); + } else { + Base::initialize(A); + } + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_ARG_ATTR(potential_values) + } +}; + +struct AAPotentialValuesReturned + : AAReturnedFromReturnedValues<AAPotentialValues, AAPotentialValuesImpl> { + using Base = + AAReturnedFromReturnedValues<AAPotentialValues, AAPotentialValuesImpl>; + AAPotentialValuesReturned(const IRPosition &IRP, Attributor &A) + : Base(IRP, A) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_FNRET_ATTR(potential_values) + } +}; + +struct AAPotentialValuesFloating : AAPotentialValuesImpl { + AAPotentialValuesFloating(const IRPosition &IRP, Attributor &A) + : AAPotentialValuesImpl(IRP, A) {} + + /// See AbstractAttribute::initialize(..). + void initialize(Attributor &A) override { + Value &V = getAssociatedValue(); + + if (auto *C = dyn_cast<ConstantInt>(&V)) { + unionAssumed(C->getValue()); + indicateOptimisticFixpoint(); + return; + } + + if (isa<UndefValue>(&V)) { + unionAssumedWithUndef(); + indicateOptimisticFixpoint(); + return; + } + + if (isa<BinaryOperator>(&V) || isa<ICmpInst>(&V) || isa<CastInst>(&V)) + return; + + if (isa<SelectInst>(V) || isa<PHINode>(V)) + return; + + indicatePessimisticFixpoint(); + + LLVM_DEBUG(dbgs() << "[AAPotentialValues] We give up: " + << getAssociatedValue() << "\n"); + } + + static bool calculateICmpInst(const ICmpInst *ICI, const APInt &LHS, + const APInt &RHS) { + ICmpInst::Predicate Pred = ICI->getPredicate(); + switch (Pred) { + case ICmpInst::ICMP_UGT: + return LHS.ugt(RHS); + case ICmpInst::ICMP_SGT: + return LHS.sgt(RHS); + case ICmpInst::ICMP_EQ: + return LHS.eq(RHS); + case ICmpInst::ICMP_UGE: + return LHS.uge(RHS); + case ICmpInst::ICMP_SGE: + return LHS.sge(RHS); + case ICmpInst::ICMP_ULT: + return LHS.ult(RHS); + case ICmpInst::ICMP_SLT: + return LHS.slt(RHS); + case ICmpInst::ICMP_NE: + return LHS.ne(RHS); + case ICmpInst::ICMP_ULE: + return LHS.ule(RHS); + case ICmpInst::ICMP_SLE: + return LHS.sle(RHS); + default: + llvm_unreachable("Invalid ICmp predicate!"); + } + } + + static APInt calculateCastInst(const CastInst *CI, const APInt &Src, + uint32_t ResultBitWidth) { + Instruction::CastOps CastOp = CI->getOpcode(); + switch (CastOp) { + default: + llvm_unreachable("unsupported or not integer cast"); + case Instruction::Trunc: + return Src.trunc(ResultBitWidth); + case Instruction::SExt: + return Src.sext(ResultBitWidth); + case Instruction::ZExt: + return Src.zext(ResultBitWidth); + case Instruction::BitCast: + return Src; + } + } + + static APInt calculateBinaryOperator(const BinaryOperator *BinOp, + const APInt &LHS, const APInt &RHS, + bool &SkipOperation, bool &Unsupported) { + Instruction::BinaryOps BinOpcode = BinOp->getOpcode(); + // Unsupported is set to true when the binary operator is not supported. + // SkipOperation is set to true when UB occur with the given operand pair + // (LHS, RHS). + // TODO: we should look at nsw and nuw keywords to handle operations + // that create poison or undef value. + switch (BinOpcode) { + default: + Unsupported = true; + return LHS; + case Instruction::Add: + return LHS + RHS; + case Instruction::Sub: + return LHS - RHS; + case Instruction::Mul: + return LHS * RHS; + case Instruction::UDiv: + if (RHS.isNullValue()) { + SkipOperation = true; + return LHS; + } + return LHS.udiv(RHS); + case Instruction::SDiv: + if (RHS.isNullValue()) { + SkipOperation = true; + return LHS; + } + return LHS.sdiv(RHS); + case Instruction::URem: + if (RHS.isNullValue()) { + SkipOperation = true; + return LHS; + } + return LHS.urem(RHS); + case Instruction::SRem: + if (RHS.isNullValue()) { + SkipOperation = true; + return LHS; + } + return LHS.srem(RHS); + case Instruction::Shl: + return LHS.shl(RHS); + case Instruction::LShr: + return LHS.lshr(RHS); + case Instruction::AShr: + return LHS.ashr(RHS); + case Instruction::And: + return LHS & RHS; + case Instruction::Or: + return LHS | RHS; + case Instruction::Xor: + return LHS ^ RHS; + } + } + + bool calculateBinaryOperatorAndTakeUnion(const BinaryOperator *BinOp, + const APInt &LHS, const APInt &RHS) { + bool SkipOperation = false; + bool Unsupported = false; + APInt Result = + calculateBinaryOperator(BinOp, LHS, RHS, SkipOperation, Unsupported); + if (Unsupported) + return false; + // If SkipOperation is true, we can ignore this operand pair (L, R). + if (!SkipOperation) + unionAssumed(Result); + return isValidState(); + } + + ChangeStatus updateWithICmpInst(Attributor &A, ICmpInst *ICI) { + auto AssumedBefore = getAssumed(); + Value *LHS = ICI->getOperand(0); + Value *RHS = ICI->getOperand(1); + if (!LHS->getType()->isIntegerTy() || !RHS->getType()->isIntegerTy()) + return indicatePessimisticFixpoint(); + + auto &LHSAA = A.getAAFor<AAPotentialValues>(*this, IRPosition::value(*LHS)); + if (!LHSAA.isValidState()) + return indicatePessimisticFixpoint(); + + auto &RHSAA = A.getAAFor<AAPotentialValues>(*this, IRPosition::value(*RHS)); + if (!RHSAA.isValidState()) + return indicatePessimisticFixpoint(); + + const DenseSet<APInt> &LHSAAPVS = LHSAA.getAssumedSet(); + const DenseSet<APInt> &RHSAAPVS = RHSAA.getAssumedSet(); + + // TODO: make use of undef flag to limit potential values aggressively. + bool MaybeTrue = false, MaybeFalse = false; + const APInt Zero(RHS->getType()->getIntegerBitWidth(), 0); + if (LHSAA.undefIsContained() && RHSAA.undefIsContained()) { + // The result of any comparison between undefs can be soundly replaced + // with undef. + unionAssumedWithUndef(); + } else if (LHSAA.undefIsContained()) { + bool MaybeTrue = false, MaybeFalse = false; + for (const APInt &R : RHSAAPVS) { + bool CmpResult = calculateICmpInst(ICI, Zero, R); + MaybeTrue |= CmpResult; + MaybeFalse |= !CmpResult; + if (MaybeTrue & MaybeFalse) + return indicatePessimisticFixpoint(); + } + } else if (RHSAA.undefIsContained()) { + for (const APInt &L : LHSAAPVS) { + bool CmpResult = calculateICmpInst(ICI, L, Zero); + MaybeTrue |= CmpResult; + MaybeFalse |= !CmpResult; + if (MaybeTrue & MaybeFalse) + return indicatePessimisticFixpoint(); + } + } else { + for (const APInt &L : LHSAAPVS) { + for (const APInt &R : RHSAAPVS) { + bool CmpResult = calculateICmpInst(ICI, L, R); + MaybeTrue |= CmpResult; + MaybeFalse |= !CmpResult; + if (MaybeTrue & MaybeFalse) + return indicatePessimisticFixpoint(); + } + } + } + if (MaybeTrue) + unionAssumed(APInt(/* numBits */ 1, /* val */ 1)); + if (MaybeFalse) + unionAssumed(APInt(/* numBits */ 1, /* val */ 0)); + return AssumedBefore == getAssumed() ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + } + + ChangeStatus updateWithSelectInst(Attributor &A, SelectInst *SI) { + auto AssumedBefore = getAssumed(); + Value *LHS = SI->getTrueValue(); + Value *RHS = SI->getFalseValue(); + if (!LHS->getType()->isIntegerTy() || !RHS->getType()->isIntegerTy()) + return indicatePessimisticFixpoint(); + + // TODO: Use assumed simplified condition value + auto &LHSAA = A.getAAFor<AAPotentialValues>(*this, IRPosition::value(*LHS)); + if (!LHSAA.isValidState()) + return indicatePessimisticFixpoint(); + + auto &RHSAA = A.getAAFor<AAPotentialValues>(*this, IRPosition::value(*RHS)); + if (!RHSAA.isValidState()) + return indicatePessimisticFixpoint(); + + if (LHSAA.undefIsContained() && RHSAA.undefIsContained()) + // select i1 *, undef , undef => undef + unionAssumedWithUndef(); + else { + unionAssumed(LHSAA); + unionAssumed(RHSAA); + } + return AssumedBefore == getAssumed() ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + } + + ChangeStatus updateWithCastInst(Attributor &A, CastInst *CI) { + auto AssumedBefore = getAssumed(); + if (!CI->isIntegerCast()) + return indicatePessimisticFixpoint(); + assert(CI->getNumOperands() == 1 && "Expected cast to be unary!"); + uint32_t ResultBitWidth = CI->getDestTy()->getIntegerBitWidth(); + Value *Src = CI->getOperand(0); + auto &SrcAA = A.getAAFor<AAPotentialValues>(*this, IRPosition::value(*Src)); + if (!SrcAA.isValidState()) + return indicatePessimisticFixpoint(); + const DenseSet<APInt> &SrcAAPVS = SrcAA.getAssumedSet(); + if (SrcAA.undefIsContained()) + unionAssumedWithUndef(); + else { + for (const APInt &S : SrcAAPVS) { + APInt T = calculateCastInst(CI, S, ResultBitWidth); + unionAssumed(T); + } + } + return AssumedBefore == getAssumed() ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + } + + ChangeStatus updateWithBinaryOperator(Attributor &A, BinaryOperator *BinOp) { + auto AssumedBefore = getAssumed(); + Value *LHS = BinOp->getOperand(0); + Value *RHS = BinOp->getOperand(1); + if (!LHS->getType()->isIntegerTy() || !RHS->getType()->isIntegerTy()) + return indicatePessimisticFixpoint(); + + auto &LHSAA = A.getAAFor<AAPotentialValues>(*this, IRPosition::value(*LHS)); + if (!LHSAA.isValidState()) + return indicatePessimisticFixpoint(); + + auto &RHSAA = A.getAAFor<AAPotentialValues>(*this, IRPosition::value(*RHS)); + if (!RHSAA.isValidState()) + return indicatePessimisticFixpoint(); + + const DenseSet<APInt> &LHSAAPVS = LHSAA.getAssumedSet(); + const DenseSet<APInt> &RHSAAPVS = RHSAA.getAssumedSet(); + const APInt Zero = APInt(LHS->getType()->getIntegerBitWidth(), 0); + + // TODO: make use of undef flag to limit potential values aggressively. + if (LHSAA.undefIsContained() && RHSAA.undefIsContained()) { + if (!calculateBinaryOperatorAndTakeUnion(BinOp, Zero, Zero)) + return indicatePessimisticFixpoint(); + } else if (LHSAA.undefIsContained()) { + for (const APInt &R : RHSAAPVS) { + if (!calculateBinaryOperatorAndTakeUnion(BinOp, Zero, R)) + return indicatePessimisticFixpoint(); + } + } else if (RHSAA.undefIsContained()) { + for (const APInt &L : LHSAAPVS) { + if (!calculateBinaryOperatorAndTakeUnion(BinOp, L, Zero)) + return indicatePessimisticFixpoint(); + } + } else { + for (const APInt &L : LHSAAPVS) { + for (const APInt &R : RHSAAPVS) { + if (!calculateBinaryOperatorAndTakeUnion(BinOp, L, R)) + return indicatePessimisticFixpoint(); + } + } + } + return AssumedBefore == getAssumed() ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + } + + ChangeStatus updateWithPHINode(Attributor &A, PHINode *PHI) { + auto AssumedBefore = getAssumed(); + for (unsigned u = 0, e = PHI->getNumIncomingValues(); u < e; u++) { + Value *IncomingValue = PHI->getIncomingValue(u); + auto &PotentialValuesAA = A.getAAFor<AAPotentialValues>( + *this, IRPosition::value(*IncomingValue)); + if (!PotentialValuesAA.isValidState()) + return indicatePessimisticFixpoint(); + if (PotentialValuesAA.undefIsContained()) + unionAssumedWithUndef(); + else + unionAssumed(PotentialValuesAA.getAssumed()); + } + return AssumedBefore == getAssumed() ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + Value &V = getAssociatedValue(); + Instruction *I = dyn_cast<Instruction>(&V); + + if (auto *ICI = dyn_cast<ICmpInst>(I)) + return updateWithICmpInst(A, ICI); + + if (auto *SI = dyn_cast<SelectInst>(I)) + return updateWithSelectInst(A, SI); + + if (auto *CI = dyn_cast<CastInst>(I)) + return updateWithCastInst(A, CI); + + if (auto *BinOp = dyn_cast<BinaryOperator>(I)) + return updateWithBinaryOperator(A, BinOp); + + if (auto *PHI = dyn_cast<PHINode>(I)) + return updateWithPHINode(A, PHI); + + return indicatePessimisticFixpoint(); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_FLOATING_ATTR(potential_values) + } +}; + +struct AAPotentialValuesFunction : AAPotentialValuesImpl { + AAPotentialValuesFunction(const IRPosition &IRP, Attributor &A) + : AAPotentialValuesImpl(IRP, A) {} + + /// See AbstractAttribute::initialize(...). + ChangeStatus updateImpl(Attributor &A) override { + llvm_unreachable("AAPotentialValues(Function|CallSite)::updateImpl will " + "not be called"); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_FN_ATTR(potential_values) + } +}; + +struct AAPotentialValuesCallSite : AAPotentialValuesFunction { + AAPotentialValuesCallSite(const IRPosition &IRP, Attributor &A) + : AAPotentialValuesFunction(IRP, A) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_CS_ATTR(potential_values) + } +}; + +struct AAPotentialValuesCallSiteReturned + : AACallSiteReturnedFromReturned<AAPotentialValues, AAPotentialValuesImpl> { + AAPotentialValuesCallSiteReturned(const IRPosition &IRP, Attributor &A) + : AACallSiteReturnedFromReturned<AAPotentialValues, + AAPotentialValuesImpl>(IRP, A) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_CSRET_ATTR(potential_values) + } +}; + +struct AAPotentialValuesCallSiteArgument : AAPotentialValuesFloating { + AAPotentialValuesCallSiteArgument(const IRPosition &IRP, Attributor &A) + : AAPotentialValuesFloating(IRP, A) {} + + /// See AbstractAttribute::initialize(..). + void initialize(Attributor &A) override { + Value &V = getAssociatedValue(); + + if (auto *C = dyn_cast<ConstantInt>(&V)) { + unionAssumed(C->getValue()); + indicateOptimisticFixpoint(); + return; + } + + if (isa<UndefValue>(&V)) { + unionAssumedWithUndef(); + indicateOptimisticFixpoint(); + return; + } + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + Value &V = getAssociatedValue(); + auto AssumedBefore = getAssumed(); + auto &AA = A.getAAFor<AAPotentialValues>(*this, IRPosition::value(V)); + const auto &S = AA.getAssumed(); + unionAssumed(S); + return AssumedBefore == getAssumed() ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_CSARG_ATTR(potential_values) + } +}; + +/// ------------------------ NoUndef Attribute --------------------------------- +struct AANoUndefImpl : AANoUndef { + AANoUndefImpl(const IRPosition &IRP, Attributor &A) : AANoUndef(IRP, A) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + if (getIRPosition().hasAttr({Attribute::NoUndef})) { + indicateOptimisticFixpoint(); + return; + } + Value &V = getAssociatedValue(); + if (isa<UndefValue>(V)) + indicatePessimisticFixpoint(); + else if (isa<FreezeInst>(V)) + indicateOptimisticFixpoint(); + else if (getPositionKind() != IRPosition::IRP_RETURNED && + isGuaranteedNotToBeUndefOrPoison(&V)) + indicateOptimisticFixpoint(); + else + AANoUndef::initialize(A); + } + + /// See followUsesInMBEC + bool followUseInMBEC(Attributor &A, const Use *U, const Instruction *I, + AANoUndef::StateType &State) { + const Value *UseV = U->get(); + const DominatorTree *DT = nullptr; + AssumptionCache *AC = nullptr; + InformationCache &InfoCache = A.getInfoCache(); + if (Function *F = getAnchorScope()) { + DT = InfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(*F); + AC = InfoCache.getAnalysisResultForFunction<AssumptionAnalysis>(*F); + } + State.setKnown(isGuaranteedNotToBeUndefOrPoison(UseV, AC, I, DT)); + bool TrackUse = false; + // Track use for instructions which must produce undef or poison bits when + // at least one operand contains such bits. + if (isa<CastInst>(*I) || isa<GetElementPtrInst>(*I)) + TrackUse = true; + return TrackUse; + } + + /// See AbstractAttribute::getAsStr(). + const std::string getAsStr() const override { + return getAssumed() ? "noundef" : "may-undef-or-poison"; + } + + ChangeStatus manifest(Attributor &A) override { + // We don't manifest noundef attribute for dead positions because the + // associated values with dead positions would be replaced with undef + // values. + if (A.isAssumedDead(getIRPosition(), nullptr, nullptr)) + return ChangeStatus::UNCHANGED; + // A position whose simplified value does not have any value is + // considered to be dead. We don't manifest noundef in such positions for + // the same reason above. + auto &ValueSimplifyAA = A.getAAFor<AAValueSimplify>( + *this, getIRPosition(), /* TrackDependence */ false); + if (!ValueSimplifyAA.getAssumedSimplifiedValue(A).hasValue()) + return ChangeStatus::UNCHANGED; + return AANoUndef::manifest(A); + } +}; + +struct AANoUndefFloating : public AANoUndefImpl { + AANoUndefFloating(const IRPosition &IRP, Attributor &A) + : AANoUndefImpl(IRP, A) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + AANoUndefImpl::initialize(A); + if (!getState().isAtFixpoint()) + if (Instruction *CtxI = getCtxI()) + followUsesInMBEC(*this, A, getState(), *CtxI); + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + auto VisitValueCB = [&](Value &V, const Instruction *CtxI, + AANoUndef::StateType &T, bool Stripped) -> bool { + const auto &AA = A.getAAFor<AANoUndef>(*this, IRPosition::value(V)); + if (!Stripped && this == &AA) { + T.indicatePessimisticFixpoint(); + } else { + const AANoUndef::StateType &S = + static_cast<const AANoUndef::StateType &>(AA.getState()); + T ^= S; + } + return T.isValidState(); + }; + + StateType T; + if (!genericValueTraversal<AANoUndef, StateType>( + A, getIRPosition(), *this, T, VisitValueCB, getCtxI())) + return indicatePessimisticFixpoint(); + + return clampStateAndIndicateChange(getState(), T); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_FNRET_ATTR(noundef) } +}; + +struct AANoUndefReturned final + : AAReturnedFromReturnedValues<AANoUndef, AANoUndefImpl> { + AANoUndefReturned(const IRPosition &IRP, Attributor &A) + : AAReturnedFromReturnedValues<AANoUndef, AANoUndefImpl>(IRP, A) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_FNRET_ATTR(noundef) } +}; + +struct AANoUndefArgument final + : AAArgumentFromCallSiteArguments<AANoUndef, AANoUndefImpl> { + AANoUndefArgument(const IRPosition &IRP, Attributor &A) + : AAArgumentFromCallSiteArguments<AANoUndef, AANoUndefImpl>(IRP, A) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(noundef) } +}; + +struct AANoUndefCallSiteArgument final : AANoUndefFloating { + AANoUndefCallSiteArgument(const IRPosition &IRP, Attributor &A) + : AANoUndefFloating(IRP, A) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_CSARG_ATTR(noundef) } +}; + +struct AANoUndefCallSiteReturned final + : AACallSiteReturnedFromReturned<AANoUndef, AANoUndefImpl> { + AANoUndefCallSiteReturned(const IRPosition &IRP, Attributor &A) + : AACallSiteReturnedFromReturned<AANoUndef, AANoUndefImpl>(IRP, A) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_CSRET_ATTR(noundef) } +}; } // namespace const char AAReturnedValues::ID = 0; @@ -7096,6 +8048,8 @@ const char AAPrivatizablePtr::ID = 0; const char AAMemoryBehavior::ID = 0; const char AAMemoryLocation::ID = 0; const char AAValueConstantRange::ID = 0; +const char AAPotentialValues::ID = 0; +const char AANoUndef::ID = 0; // Macro magic to create the static generator function for attributes that // follow the naming scheme. @@ -7205,6 +8159,8 @@ CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AADereferenceable) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAAlign) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoCapture) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAValueConstantRange) +CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAPotentialValues) +CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoUndef) CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAValueSimplify) CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAIsDead) diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/BlockExtractor.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/BlockExtractor.cpp index 1d1300c6cd1d..c6e222a096eb 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/BlockExtractor.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/BlockExtractor.cpp @@ -11,10 +11,12 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/IPO/BlockExtractor.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" @@ -38,13 +40,10 @@ cl::opt<bool> BlockExtractorEraseFuncs("extract-blocks-erase-funcs", cl::desc("Erase the existing functions"), cl::Hidden); namespace { -class BlockExtractor : public ModulePass { - SmallVector<SmallVector<BasicBlock *, 16>, 4> GroupsOfBlocks; - bool EraseFunctions; - /// Map a function name to groups of blocks. - SmallVector<std::pair<std::string, SmallVector<std::string, 4>>, 4> - BlocksByName; - +class BlockExtractor { +public: + BlockExtractor(bool EraseFunctions) : EraseFunctions(EraseFunctions) {} + bool runOnModule(Module &M); void init(const SmallVectorImpl<SmallVector<BasicBlock *, 16>> &GroupsOfBlocksToExtract) { for (const SmallVectorImpl<BasicBlock *> &GroupOfBlocks : @@ -57,11 +56,26 @@ class BlockExtractor : public ModulePass { loadFile(); } +private: + SmallVector<SmallVector<BasicBlock *, 16>, 4> GroupsOfBlocks; + bool EraseFunctions; + /// Map a function name to groups of blocks. + SmallVector<std::pair<std::string, SmallVector<std::string, 4>>, 4> + BlocksByName; + + void loadFile(); + void splitLandingPadPreds(Function &F); +}; + +class BlockExtractorLegacyPass : public ModulePass { + BlockExtractor BE; + bool runOnModule(Module &M) override; + public: static char ID; - BlockExtractor(const SmallVectorImpl<BasicBlock *> &BlocksToExtract, - bool EraseFunctions) - : ModulePass(ID), EraseFunctions(EraseFunctions) { + BlockExtractorLegacyPass(const SmallVectorImpl<BasicBlock *> &BlocksToExtract, + bool EraseFunctions) + : ModulePass(ID), BE(EraseFunctions) { // We want one group per element of the input list. SmallVector<SmallVector<BasicBlock *, 16>, 4> MassagedGroupsOfBlocks; for (BasicBlock *BB : BlocksToExtract) { @@ -69,39 +83,38 @@ public: NewGroup.push_back(BB); MassagedGroupsOfBlocks.push_back(NewGroup); } - init(MassagedGroupsOfBlocks); + BE.init(MassagedGroupsOfBlocks); } - BlockExtractor(const SmallVectorImpl<SmallVector<BasicBlock *, 16>> - &GroupsOfBlocksToExtract, - bool EraseFunctions) - : ModulePass(ID), EraseFunctions(EraseFunctions) { - init(GroupsOfBlocksToExtract); + BlockExtractorLegacyPass(const SmallVectorImpl<SmallVector<BasicBlock *, 16>> + &GroupsOfBlocksToExtract, + bool EraseFunctions) + : ModulePass(ID), BE(EraseFunctions) { + BE.init(GroupsOfBlocksToExtract); } - BlockExtractor() : BlockExtractor(SmallVector<BasicBlock *, 0>(), false) {} - bool runOnModule(Module &M) override; - -private: - void loadFile(); - void splitLandingPadPreds(Function &F); + BlockExtractorLegacyPass() + : BlockExtractorLegacyPass(SmallVector<BasicBlock *, 0>(), false) {} }; + } // end anonymous namespace -char BlockExtractor::ID = 0; -INITIALIZE_PASS(BlockExtractor, "extract-blocks", +char BlockExtractorLegacyPass::ID = 0; +INITIALIZE_PASS(BlockExtractorLegacyPass, "extract-blocks", "Extract basic blocks from module", false, false) -ModulePass *llvm::createBlockExtractorPass() { return new BlockExtractor(); } +ModulePass *llvm::createBlockExtractorPass() { + return new BlockExtractorLegacyPass(); +} ModulePass *llvm::createBlockExtractorPass( const SmallVectorImpl<BasicBlock *> &BlocksToExtract, bool EraseFunctions) { - return new BlockExtractor(BlocksToExtract, EraseFunctions); + return new BlockExtractorLegacyPass(BlocksToExtract, EraseFunctions); } ModulePass *llvm::createBlockExtractorPass( const SmallVectorImpl<SmallVector<BasicBlock *, 16>> &GroupsOfBlocksToExtract, bool EraseFunctions) { - return new BlockExtractor(GroupsOfBlocksToExtract, EraseFunctions); + return new BlockExtractorLegacyPass(GroupsOfBlocksToExtract, EraseFunctions); } /// Gets all of the blocks specified in the input file. @@ -233,3 +246,15 @@ bool BlockExtractor::runOnModule(Module &M) { return Changed; } + +bool BlockExtractorLegacyPass::runOnModule(Module &M) { + return BE.runOnModule(M); +} + +PreservedAnalyses BlockExtractorPass::run(Module &M, + ModuleAnalysisManager &AM) { + BlockExtractor BE(false); + BE.init(SmallVector<SmallVector<BasicBlock *, 16>, 0>()); + return BE.runOnModule(M) ? PreservedAnalyses::none() + : PreservedAnalyses::all(); +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp index f2588938d964..0b763e423fe0 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -289,7 +289,7 @@ bool DeadArgumentEliminationPass::RemoveDeadArgumentsFromCallers(Function &Fn) { for (Argument &Arg : Fn.args()) { if (!Arg.hasSwiftErrorAttr() && Arg.use_empty() && - !Arg.hasPassPointeeByValueAttr()) { + !Arg.hasPassPointeeByValueCopyAttr()) { if (Arg.isUsedByMetadata()) { Arg.replaceAllUsesWith(UndefValue::get(Arg.getType())); Changed = true; diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp index 2cb184e8d4f4..1a8bb225a626 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp @@ -26,6 +26,13 @@ static cl::list<std::string> "example -force-attribute=foo:noinline. This " "option can be specified multiple times.")); +static cl::list<std::string> ForceRemoveAttributes( + "force-remove-attribute", cl::Hidden, + cl::desc("Remove an attribute from a function. This should be a " + "pair of 'function-name:attribute-name', for " + "example -force-remove-attribute=foo:noinline. This " + "option can be specified multiple times.")); + static Attribute::AttrKind parseAttrKind(StringRef Kind) { return StringSwitch<Attribute::AttrKind>(Kind) .Case("alwaysinline", Attribute::AlwaysInline) @@ -70,31 +77,49 @@ static Attribute::AttrKind parseAttrKind(StringRef Kind) { } /// If F has any forced attributes given on the command line, add them. -static void addForcedAttributes(Function &F) { - for (auto &S : ForceAttributes) { +/// If F has any forced remove attributes given on the command line, remove +/// them. When both force and force-remove are given to a function, the latter +/// takes precedence. +static void forceAttributes(Function &F) { + auto ParseFunctionAndAttr = [&](StringRef S) { + auto Kind = Attribute::None; auto KV = StringRef(S).split(':'); if (KV.first != F.getName()) - continue; - - auto Kind = parseAttrKind(KV.second); + return Kind; + Kind = parseAttrKind(KV.second); if (Kind == Attribute::None) { LLVM_DEBUG(dbgs() << "ForcedAttribute: " << KV.second << " unknown or not handled!\n"); - continue; } - if (F.hasFnAttribute(Kind)) + return Kind; + }; + + for (auto &S : ForceAttributes) { + auto Kind = ParseFunctionAndAttr(S); + if (Kind == Attribute::None || F.hasFnAttribute(Kind)) continue; F.addFnAttr(Kind); } + + for (auto &S : ForceRemoveAttributes) { + auto Kind = ParseFunctionAndAttr(S); + if (Kind == Attribute::None || !F.hasFnAttribute(Kind)) + continue; + F.removeFnAttr(Kind); + } +} + +static bool hasForceAttributes() { + return !ForceAttributes.empty() || !ForceRemoveAttributes.empty(); } PreservedAnalyses ForceFunctionAttrsPass::run(Module &M, ModuleAnalysisManager &) { - if (ForceAttributes.empty()) + if (!hasForceAttributes()) return PreservedAnalyses::all(); for (Function &F : M.functions()) - addForcedAttributes(F); + forceAttributes(F); // Just conservatively invalidate analyses, this isn't likely to be important. return PreservedAnalyses::none(); @@ -109,11 +134,11 @@ struct ForceFunctionAttrsLegacyPass : public ModulePass { } bool runOnModule(Module &M) override { - if (ForceAttributes.empty()) + if (!hasForceAttributes()) return false; for (Function &F : M.functions()) - addForcedAttributes(F); + forceAttributes(F); // Conservatively assume we changed something. return true; diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionAttrs.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionAttrs.cpp index 4baeaa6e1630..30a1f81ad0e1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionAttrs.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionAttrs.cpp @@ -13,15 +13,16 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/FunctionAttrs.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CGSCCPassManager.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraphSCCPass.h" @@ -63,7 +64,7 @@ using namespace llvm; -#define DEBUG_TYPE "functionattrs" +#define DEBUG_TYPE "function-attrs" STATISTIC(NumReadNone, "Number of functions marked readnone"); STATISTIC(NumReadOnly, "Number of functions marked readonly"); @@ -77,6 +78,7 @@ STATISTIC(NumNonNullReturn, "Number of function returns marked nonnull"); STATISTIC(NumNoRecurse, "Number of functions marked as norecurse"); STATISTIC(NumNoUnwind, "Number of functions marked as nounwind"); STATISTIC(NumNoFree, "Number of functions marked as nofree"); +STATISTIC(NumWillReturn, "Number of functions marked as willreturn"); static cl::opt<bool> EnableNonnullArgPropagation( "enable-nonnull-arg-prop", cl::init(true), cl::Hidden, @@ -166,7 +168,7 @@ static MemoryAccessKind checkFunctionMemoryAccess(Function &F, bool ThisBody, AAMDNodes AAInfo; I->getAAMetadata(AAInfo); - MemoryLocation Loc(Arg, LocationSize::unknown(), AAInfo); + MemoryLocation Loc = MemoryLocation::getBeforeOrAfter(Arg, AAInfo); // Skip accesses to local or constant memory as they don't impact the // externally visible mod/ref behavior. @@ -281,16 +283,18 @@ static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT &&AARGetter) { MadeChange = true; // Clear out any existing attributes. - F->removeFnAttr(Attribute::ReadOnly); - F->removeFnAttr(Attribute::ReadNone); - F->removeFnAttr(Attribute::WriteOnly); + AttrBuilder AttrsToRemove; + AttrsToRemove.addAttribute(Attribute::ReadOnly); + AttrsToRemove.addAttribute(Attribute::ReadNone); + AttrsToRemove.addAttribute(Attribute::WriteOnly); if (!WritesMemory && !ReadsMemory) { // Clear out any "access range attributes" if readnone was deduced. - F->removeFnAttr(Attribute::ArgMemOnly); - F->removeFnAttr(Attribute::InaccessibleMemOnly); - F->removeFnAttr(Attribute::InaccessibleMemOrArgMemOnly); + AttrsToRemove.addAttribute(Attribute::ArgMemOnly); + AttrsToRemove.addAttribute(Attribute::InaccessibleMemOnly); + AttrsToRemove.addAttribute(Attribute::InaccessibleMemOrArgMemOnly); } + F->removeAttributes(AttributeList::FunctionIndex, AttrsToRemove); // Add in the new attribute. if (WritesMemory && !ReadsMemory) @@ -639,7 +643,7 @@ static bool addArgumentAttrsFromCallsites(Function &F) { if (auto *CB = dyn_cast<CallBase>(&I)) { if (auto *CalledFunc = CB->getCalledFunction()) { for (auto &CSArg : CalledFunc->args()) { - if (!CSArg.hasNonNullAttr()) + if (!CSArg.hasNonNullAttr(/* AllowUndefOrPoison */ false)) continue; // If the non-null callsite argument operand is an argument to 'F' @@ -1216,6 +1220,11 @@ bool AttributeInferer::run(const SCCNodeSet &SCCNodes) { return Changed; } +struct SCCNodesResult { + SCCNodeSet SCCNodes; + bool HasUnknownCall; +}; + } // end anonymous namespace /// Helper for non-Convergent inference predicate InstrBreaksAttribute. @@ -1237,7 +1246,7 @@ static bool InstrBreaksNonThrowing(Instruction &I, const SCCNodeSet &SCCNodes) { // I is a may-throw call to a function inside our SCC. This doesn't // invalidate our current working assumption that the SCC is no-throw; we // just have to scan that other function. - if (SCCNodes.count(Callee) > 0) + if (SCCNodes.contains(Callee)) return false; } } @@ -1257,21 +1266,16 @@ static bool InstrBreaksNoFree(Instruction &I, const SCCNodeSet &SCCNodes) { if (Callee->doesNotFreeMemory()) return false; - if (SCCNodes.count(Callee) > 0) + if (SCCNodes.contains(Callee)) return false; return true; } -/// Infer attributes from all functions in the SCC by scanning every -/// instruction for compliance to the attribute assumptions. Currently it -/// does: -/// - removal of Convergent attribute -/// - addition of NoUnwind attribute +/// Attempt to remove convergent function attribute when possible. /// /// Returns true if any changes to function attributes were made. -static bool inferAttrsFromFunctionBodies(const SCCNodeSet &SCCNodes) { - +static bool inferConvergent(const SCCNodeSet &SCCNodes) { AttributeInferer AI; // Request to remove the convergent attribute from all functions in the SCC @@ -1293,6 +1297,18 @@ static bool inferAttrsFromFunctionBodies(const SCCNodeSet &SCCNodes) { F.setNotConvergent(); }, /* RequiresExactDefinition= */ false}); + // Perform all the requested attribute inference actions. + return AI.run(SCCNodes); +} + +/// Infer attributes from all functions in the SCC by scanning every +/// instruction for compliance to the attribute assumptions. Currently it +/// does: +/// - addition of NoUnwind attribute +/// +/// Returns true if any changes to function attributes were made. +static bool inferAttrsFromFunctionBodies(const SCCNodeSet &SCCNodes) { + AttributeInferer AI; if (!DisableNoUnwindInference) // Request to infer nounwind attribute for all the functions in the SCC if @@ -1343,14 +1359,6 @@ static bool inferAttrsFromFunctionBodies(const SCCNodeSet &SCCNodes) { return AI.run(SCCNodes); } -static bool setDoesNotRecurse(Function &F) { - if (F.doesNotRecurse()) - return false; - F.setDoesNotRecurse(); - ++NumNoRecurse; - return true; -} - static bool addNoRecurseAttrs(const SCCNodeSet &SCCNodes) { // Try and identify functions that do not recurse. @@ -1377,30 +1385,140 @@ static bool addNoRecurseAttrs(const SCCNodeSet &SCCNodes) { // Every call was to a non-recursive function other than this function, and // we have no indirect recursion as the SCC size is one. This function cannot // recurse. - return setDoesNotRecurse(*F); + F->setDoesNotRecurse(); + ++NumNoRecurse; + return true; +} + +static bool instructionDoesNotReturn(Instruction &I) { + if (auto *CB = dyn_cast<CallBase>(&I)) { + Function *Callee = CB->getCalledFunction(); + return Callee && Callee->doesNotReturn(); + } + return false; +} + +// A basic block can only return if it terminates with a ReturnInst and does not +// contain calls to noreturn functions. +static bool basicBlockCanReturn(BasicBlock &BB) { + if (!isa<ReturnInst>(BB.getTerminator())) + return false; + return none_of(BB, instructionDoesNotReturn); +} + +// Set the noreturn function attribute if possible. +static bool addNoReturnAttrs(const SCCNodeSet &SCCNodes) { + bool Changed = false; + + for (Function *F : SCCNodes) { + if (!F || !F->hasExactDefinition() || F->hasFnAttribute(Attribute::Naked) || + F->doesNotReturn()) + continue; + + // The function can return if any basic blocks can return. + // FIXME: this doesn't handle recursion or unreachable blocks. + if (none_of(*F, basicBlockCanReturn)) { + F->setDoesNotReturn(); + Changed = true; + } + } + + return Changed; +} + +static bool functionWillReturn(const Function &F) { + // Must-progress function without side-effects must return. + if (F.mustProgress() && F.onlyReadsMemory()) + return true; + + // Can only analyze functions with a definition. + if (F.isDeclaration()) + return false; + + // Functions with loops require more sophisticated analysis, as the loop + // may be infinite. For now, don't try to handle them. + SmallVector<std::pair<const BasicBlock *, const BasicBlock *>> Backedges; + FindFunctionBackedges(F, Backedges); + if (!Backedges.empty()) + return false; + + // If there are no loops, then the function is willreturn if all calls in + // it are willreturn. + return all_of(instructions(F), [](const Instruction &I) { + const auto *CB = dyn_cast<CallBase>(&I); + return !CB || CB->hasFnAttr(Attribute::WillReturn); + }); +} + +// Set the willreturn function attribute if possible. +static bool addWillReturn(const SCCNodeSet &SCCNodes) { + bool Changed = false; + + for (Function *F : SCCNodes) { + if (!F || F->willReturn() || !functionWillReturn(*F)) + continue; + + F->setWillReturn(); + NumWillReturn++; + Changed = true; + } + + return Changed; +} + +static SCCNodesResult createSCCNodeSet(ArrayRef<Function *> Functions) { + SCCNodesResult Res; + Res.HasUnknownCall = false; + for (Function *F : Functions) { + if (!F || F->hasOptNone() || F->hasFnAttribute(Attribute::Naked)) { + // Treat any function we're trying not to optimize as if it were an + // indirect call and omit it from the node set used below. + Res.HasUnknownCall = true; + continue; + } + // Track whether any functions in this SCC have an unknown call edge. + // Note: if this is ever a performance hit, we can common it with + // subsequent routines which also do scans over the instructions of the + // function. + if (!Res.HasUnknownCall) { + for (Instruction &I : instructions(*F)) { + if (auto *CB = dyn_cast<CallBase>(&I)) { + if (!CB->getCalledFunction()) { + Res.HasUnknownCall = true; + break; + } + } + } + } + Res.SCCNodes.insert(F); + } + return Res; } template <typename AARGetterT> -static bool deriveAttrsInPostOrder(SCCNodeSet &SCCNodes, - AARGetterT &&AARGetter, - bool HasUnknownCall) { +static bool deriveAttrsInPostOrder(ArrayRef<Function *> Functions, + AARGetterT &&AARGetter) { + SCCNodesResult Nodes = createSCCNodeSet(Functions); bool Changed = false; // Bail if the SCC only contains optnone functions. - if (SCCNodes.empty()) + if (Nodes.SCCNodes.empty()) return Changed; - Changed |= addArgumentReturnedAttrs(SCCNodes); - Changed |= addReadAttrs(SCCNodes, AARGetter); - Changed |= addArgumentAttrs(SCCNodes); + Changed |= addArgumentReturnedAttrs(Nodes.SCCNodes); + Changed |= addReadAttrs(Nodes.SCCNodes, AARGetter); + Changed |= addArgumentAttrs(Nodes.SCCNodes); + Changed |= inferConvergent(Nodes.SCCNodes); + Changed |= addNoReturnAttrs(Nodes.SCCNodes); + Changed |= addWillReturn(Nodes.SCCNodes); // If we have no external nodes participating in the SCC, we can deduce some // more precise attributes as well. - if (!HasUnknownCall) { - Changed |= addNoAliasAttrs(SCCNodes); - Changed |= addNonNullAttrs(SCCNodes); - Changed |= inferAttrsFromFunctionBodies(SCCNodes); - Changed |= addNoRecurseAttrs(SCCNodes); + if (!Nodes.HasUnknownCall) { + Changed |= addNoAliasAttrs(Nodes.SCCNodes); + Changed |= addNonNullAttrs(Nodes.SCCNodes); + Changed |= inferAttrsFromFunctionBodies(Nodes.SCCNodes); + Changed |= addNoRecurseAttrs(Nodes.SCCNodes); } return Changed; @@ -1419,35 +1537,12 @@ PreservedAnalyses PostOrderFunctionAttrsPass::run(LazyCallGraph::SCC &C, return FAM.getResult<AAManager>(F); }; - // Fill SCCNodes with the elements of the SCC. Also track whether there are - // any external or opt-none nodes that will prevent us from optimizing any - // part of the SCC. - SCCNodeSet SCCNodes; - bool HasUnknownCall = false; + SmallVector<Function *, 8> Functions; for (LazyCallGraph::Node &N : C) { - Function &F = N.getFunction(); - if (F.hasOptNone() || F.hasFnAttribute(Attribute::Naked)) { - // Treat any function we're trying not to optimize as if it were an - // indirect call and omit it from the node set used below. - HasUnknownCall = true; - continue; - } - // Track whether any functions in this SCC have an unknown call edge. - // Note: if this is ever a performance hit, we can common it with - // subsequent routines which also do scans over the instructions of the - // function. - if (!HasUnknownCall) - for (Instruction &I : instructions(F)) - if (auto *CB = dyn_cast<CallBase>(&I)) - if (!CB->getCalledFunction()) { - HasUnknownCall = true; - break; - } - - SCCNodes.insert(&F); + Functions.push_back(&N.getFunction()); } - if (deriveAttrsInPostOrder(SCCNodes, AARGetter, HasUnknownCall)) + if (deriveAttrsInPostOrder(Functions, AARGetter)) return PreservedAnalyses::none(); return PreservedAnalyses::all(); @@ -1477,11 +1572,11 @@ struct PostOrderFunctionAttrsLegacyPass : public CallGraphSCCPass { } // end anonymous namespace char PostOrderFunctionAttrsLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(PostOrderFunctionAttrsLegacyPass, "functionattrs", +INITIALIZE_PASS_BEGIN(PostOrderFunctionAttrsLegacyPass, "function-attrs", "Deduce function attributes", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) -INITIALIZE_PASS_END(PostOrderFunctionAttrsLegacyPass, "functionattrs", +INITIALIZE_PASS_END(PostOrderFunctionAttrsLegacyPass, "function-attrs", "Deduce function attributes", false, false) Pass *llvm::createPostOrderFunctionAttrsLegacyPass() { @@ -1490,26 +1585,12 @@ Pass *llvm::createPostOrderFunctionAttrsLegacyPass() { template <typename AARGetterT> static bool runImpl(CallGraphSCC &SCC, AARGetterT AARGetter) { - - // Fill SCCNodes with the elements of the SCC. Used for quickly looking up - // whether a given CallGraphNode is in this SCC. Also track whether there are - // any external or opt-none nodes that will prevent us from optimizing any - // part of the SCC. - SCCNodeSet SCCNodes; - bool ExternalNode = false; + SmallVector<Function *, 8> Functions; for (CallGraphNode *I : SCC) { - Function *F = I->getFunction(); - if (!F || F->hasOptNone() || F->hasFnAttribute(Attribute::Naked)) { - // External node or function we're trying not to optimize - we both avoid - // transform them and avoid leveraging information they provide. - ExternalNode = true; - continue; - } - - SCCNodes.insert(F); + Functions.push_back(I->getFunction()); } - return deriveAttrsInPostOrder(SCCNodes, AARGetter, ExternalNode); + return deriveAttrsInPostOrder(Functions, AARGetter); } bool PostOrderFunctionAttrsLegacyPass::runOnSCC(CallGraphSCC &SCC) { @@ -1542,11 +1623,13 @@ struct ReversePostOrderFunctionAttrsLegacyPass : public ModulePass { char ReversePostOrderFunctionAttrsLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(ReversePostOrderFunctionAttrsLegacyPass, "rpo-functionattrs", - "Deduce function attributes in RPO", false, false) +INITIALIZE_PASS_BEGIN(ReversePostOrderFunctionAttrsLegacyPass, + "rpo-function-attrs", "Deduce function attributes in RPO", + false, false) INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) -INITIALIZE_PASS_END(ReversePostOrderFunctionAttrsLegacyPass, "rpo-functionattrs", - "Deduce function attributes in RPO", false, false) +INITIALIZE_PASS_END(ReversePostOrderFunctionAttrsLegacyPass, + "rpo-function-attrs", "Deduce function attributes in RPO", + false, false) Pass *llvm::createReversePostOrderFunctionAttrsPass() { return new ReversePostOrderFunctionAttrsLegacyPass(); @@ -1578,7 +1661,9 @@ static bool addNoRecurseAttrsTopDown(Function &F) { if (!CB || !CB->getParent()->getParent()->doesNotRecurse()) return false; } - return setDoesNotRecurse(F); + F.setDoesNotRecurse(); + ++NumNoRecurse; + return true; } static bool deduceFunctionAttributeInRPO(Module &M, CallGraph &CG) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionImport.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionImport.cpp index 468bf19f2e48..18343030bc6a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionImport.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionImport.cpp @@ -124,14 +124,8 @@ static cl::opt<bool> ComputeDead("compute-dead", cl::init(true), cl::Hidden, cl::desc("Compute dead symbols")); static cl::opt<bool> EnableImportMetadata( - "enable-import-metadata", cl::init( -#if !defined(NDEBUG) - true /*Enabled with asserts.*/ -#else - false -#endif - ), - cl::Hidden, cl::desc("Enable import metadata like 'thinlto_src_module'")); + "enable-import-metadata", cl::init(false), cl::Hidden, + cl::desc("Enable import metadata like 'thinlto_src_module'")); /// Summary file to use for function importing when using -function-import from /// the command line. @@ -261,8 +255,8 @@ selectCallee(const ModuleSummaryIndex &Index, namespace { -using EdgeInfo = std::tuple<const FunctionSummary *, unsigned /* Threshold */, - GlobalValue::GUID>; +using EdgeInfo = + std::tuple<const GlobalValueSummary *, unsigned /* Threshold */>; } // anonymous namespace @@ -282,8 +276,9 @@ updateValueInfoForIndirectCalls(const ModuleSummaryIndex &Index, ValueInfo VI) { } static void computeImportForReferencedGlobals( - const FunctionSummary &Summary, const ModuleSummaryIndex &Index, + const GlobalValueSummary &Summary, const ModuleSummaryIndex &Index, const GVSummaryMapTy &DefinedGVSummaries, + SmallVectorImpl<EdgeInfo> &Worklist, FunctionImporter::ImportMapTy &ImportList, StringMap<FunctionImporter::ExportSetTy> *ExportLists) { for (auto &VI : Summary.refs()) { @@ -321,6 +316,11 @@ static void computeImportForReferencedGlobals( // which is more efficient than adding them here. if (ExportLists) (*ExportLists)[RefSummary->modulePath()].insert(VI); + + // If variable is not writeonly we attempt to recursively analyze + // its references in order to import referenced constants. + if (!Index.isWriteOnly(cast<GlobalVarSummary>(RefSummary.get()))) + Worklist.emplace_back(RefSummary.get(), 0); break; } } @@ -360,7 +360,7 @@ static void computeImportForFunction( StringMap<FunctionImporter::ExportSetTy> *ExportLists, FunctionImporter::ImportThresholdsTy &ImportThresholds) { computeImportForReferencedGlobals(Summary, Index, DefinedGVSummaries, - ImportList, ExportLists); + Worklist, ImportList, ExportLists); static int ImportCount = 0; for (auto &Edge : Summary.calls()) { ValueInfo VI = Edge.first; @@ -508,7 +508,7 @@ static void computeImportForFunction( ImportCount++; // Insert the newly imported function to the worklist. - Worklist.emplace_back(ResolvedCalleeSummary, AdjThreshold, VI.getGUID()); + Worklist.emplace_back(ResolvedCalleeSummary, AdjThreshold); } } @@ -549,13 +549,17 @@ static void ComputeImportForModule( // Process the newly imported functions and add callees to the worklist. while (!Worklist.empty()) { - auto FuncInfo = Worklist.pop_back_val(); - auto *Summary = std::get<0>(FuncInfo); - auto Threshold = std::get<1>(FuncInfo); - - computeImportForFunction(*Summary, Index, Threshold, DefinedGVSummaries, - Worklist, ImportList, ExportLists, - ImportThresholds); + auto GVInfo = Worklist.pop_back_val(); + auto *Summary = std::get<0>(GVInfo); + auto Threshold = std::get<1>(GVInfo); + + if (auto *FS = dyn_cast<FunctionSummary>(Summary)) + computeImportForFunction(*FS, Index, Threshold, DefinedGVSummaries, + Worklist, ImportList, ExportLists, + ImportThresholds); + else + computeImportForReferencedGlobals(*Summary, Index, DefinedGVSummaries, + Worklist, ImportList, ExportLists); } // Print stats about functions considered but rejected for importing @@ -884,6 +888,7 @@ void llvm::computeDeadSymbols( while (!Worklist.empty()) { auto VI = Worklist.pop_back_val(); for (auto &Summary : VI.getSummaryList()) { + Summary->setLive(true); if (auto *AS = dyn_cast<AliasSummary>(Summary.get())) { // If this is an alias, visit the aliasee VI to ensure that all copies // are marked live and it is added to the worklist for further @@ -891,8 +896,6 @@ void llvm::computeDeadSymbols( visit(AS->getAliaseeVI(), true); continue; } - - Summary->setLive(true); for (auto Ref : Summary->refs()) visit(Ref, false); if (auto *FS = dyn_cast<FunctionSummary>(Summary.get())) @@ -1311,7 +1314,7 @@ static bool doImportingForModule(Module &M) { // Next we need to promote to global scope and rename any local values that // are potentially exported to other modules. - if (renameModuleForThinLTO(M, *Index, /*clearDSOOnDeclarations=*/false, + if (renameModuleForThinLTO(M, *Index, /*ClearDSOLocalOnDeclarations=*/false, /*GlobalsToImport=*/nullptr)) { errs() << "Error renaming module\n"; return false; diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalOpt.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalOpt.cpp index 9524d9a36204..223a05e8ea02 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalOpt.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalOpt.cpp @@ -268,6 +268,7 @@ CleanupPointerRootUsers(GlobalVariable *GV, I = J; } while (true); I->eraseFromParent(); + Changed = true; } } @@ -285,7 +286,7 @@ static bool CleanupConstantGlobalUsers( // we delete a constant array, we may also be holding pointer to one of its // elements (or an element of one of its elements if we're dealing with an // array of arrays) in the worklist. - SmallVector<WeakTrackingVH, 8> WorkList(V->user_begin(), V->user_end()); + SmallVector<WeakTrackingVH, 8> WorkList(V->users()); while (!WorkList.empty()) { Value *UV = WorkList.pop_back_val(); if (!UV) @@ -1879,7 +1880,8 @@ static bool isPointerValueDeadOnEntryToFunction( // and the number of bits loaded in L is less than or equal to // the number of bits stored in S. return DT.dominates(S, L) && - DL.getTypeStoreSize(LTy) <= DL.getTypeStoreSize(STy); + DL.getTypeStoreSize(LTy).getFixedSize() <= + DL.getTypeStoreSize(STy).getFixedSize(); })) return false; } @@ -1931,8 +1933,7 @@ static void makeAllConstantUsesInstructions(Constant *C) { SmallVector<Value*,4> UUsers; for (auto *U : Users) { UUsers.clear(); - for (auto *UU : U->users()) - UUsers.push_back(UU); + append_range(UUsers, U->users()); for (auto *UU : UUsers) { Instruction *UI = cast<Instruction>(UU); Instruction *NewU = U->getAsInstruction(); @@ -1989,12 +1990,13 @@ processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS, return true; } + bool Changed = false; + // If the global is never loaded (but may be stored to), it is dead. // Delete it now. if (!GS.IsLoaded) { LLVM_DEBUG(dbgs() << "GLOBAL NEVER LOADED: " << *GV << "\n"); - bool Changed; if (isLeakCheckerRoot(GV)) { // Delete any constant stores to the global. Changed = CleanupPointerRootUsers(GV, GetTLI); @@ -2020,11 +2022,14 @@ processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS, // Don't actually mark a global constant if it's atomic because atomic loads // are implemented by a trivial cmpxchg in some edge-cases and that usually // requires write access to the variable even if it's not actually changed. - if (GS.Ordering == AtomicOrdering::NotAtomic) + if (GS.Ordering == AtomicOrdering::NotAtomic) { + assert(!GV->isConstant() && "Expected a non-constant global"); GV->setConstant(true); + Changed = true; + } // Clean up any obviously simplifiable users now. - CleanupConstantGlobalUsers(GV, GV->getInitializer(), DL, GetTLI); + Changed |= CleanupConstantGlobalUsers(GV, GV->getInitializer(), DL, GetTLI); // If the global is dead now, just nuke it. if (GV->use_empty()) { @@ -2084,7 +2089,7 @@ processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS, } } - return false; + return Changed; } /// Analyze the specified global variable and optimize it if possible. If we @@ -2219,8 +2224,7 @@ isValidCandidateForColdCC(Function &F, BlockFrequencyInfo &CallerBFI = GetBFI(*CallerFunc); if (!isColdCallSite(CB, CallerBFI)) return false; - auto It = std::find(AllCallsCold.begin(), AllCallsCold.end(), CallerFunc); - if (It == AllCallsCold.end()) + if (!llvm::is_contained(AllCallsCold, CallerFunc)) return false; } return true; diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/HotColdSplitting.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/HotColdSplitting.cpp index d0bd0166534a..aa708ee520b1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/HotColdSplitting.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/HotColdSplitting.cpp @@ -29,7 +29,6 @@ #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/CFG.h" @@ -68,7 +67,9 @@ #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> +#include <limits> #include <cassert> +#include <string> #define DEBUG_TYPE "hotcoldsplit" @@ -77,14 +78,29 @@ STATISTIC(NumColdRegionsOutlined, "Number of cold regions outlined."); using namespace llvm; -static cl::opt<bool> EnableStaticAnalyis("hot-cold-static-analysis", - cl::init(true), cl::Hidden); +static cl::opt<bool> EnableStaticAnalysis("hot-cold-static-analysis", + cl::init(true), cl::Hidden); static cl::opt<int> SplittingThreshold("hotcoldsplit-threshold", cl::init(2), cl::Hidden, cl::desc("Base penalty for splitting cold code (as a " "multiple of TCC_Basic)")); +static cl::opt<bool> EnableColdSection( + "enable-cold-section", cl::init(false), cl::Hidden, + cl::desc("Enable placement of extracted cold functions" + " into a separate section after hot-cold splitting.")); + +static cl::opt<std::string> + ColdSectionName("hotcoldsplit-cold-section-name", cl::init("__llvm_cold"), + cl::Hidden, + cl::desc("Name for the section containing cold functions " + "extracted by hot-cold splitting.")); + +static cl::opt<int> MaxParametersForSplit( + "hotcoldsplit-max-params", cl::init(4), cl::Hidden, + cl::desc("Maximum number of parameters for a split function")); + namespace { // Same as blockEndsInUnreachable in CodeGen/BranchFolding.cpp. Do not modify // this function unless you modify the MBB version as well. @@ -221,11 +237,11 @@ bool HotColdSplitting::shouldOutlineFrom(const Function &F) const { } /// Get the benefit score of outlining \p Region. -static int getOutliningBenefit(ArrayRef<BasicBlock *> Region, - TargetTransformInfo &TTI) { +static InstructionCost getOutliningBenefit(ArrayRef<BasicBlock *> Region, + TargetTransformInfo &TTI) { // Sum up the code size costs of non-terminator instructions. Tight coupling // with \ref getOutliningPenalty is needed to model the costs of terminators. - int Benefit = 0; + InstructionCost Benefit = 0; for (BasicBlock *BB : Region) for (Instruction &I : BB->instructionsWithoutDebug()) if (&I != BB->getTerminator()) @@ -246,18 +262,6 @@ static int getOutliningPenalty(ArrayRef<BasicBlock *> Region, if (SplittingThreshold <= 0) return Penalty; - // The typical code size cost for materializing an argument for the outlined - // call. - LLVM_DEBUG(dbgs() << "Applying penalty for: " << NumInputs << " inputs\n"); - const int CostForArgMaterialization = TargetTransformInfo::TCC_Basic; - Penalty += CostForArgMaterialization * NumInputs; - - // The typical code size cost for an output alloca, its associated store, and - // its associated reload. - LLVM_DEBUG(dbgs() << "Applying penalty for: " << NumOutputs << " outputs\n"); - const int CostForRegionOutput = 3 * TargetTransformInfo::TCC_Basic; - Penalty += CostForRegionOutput * NumOutputs; - // Find the number of distinct exit blocks for the region. Use a conservative // check to determine whether control returns from the region. bool NoBlocksReturn = true; @@ -271,13 +275,55 @@ static int getOutliningPenalty(ArrayRef<BasicBlock *> Region, } for (BasicBlock *SuccBB : successors(BB)) { - if (find(Region, SuccBB) == Region.end()) { + if (!is_contained(Region, SuccBB)) { NoBlocksReturn = false; SuccsOutsideRegion.insert(SuccBB); } } } + // Count the number of phis in exit blocks with >= 2 incoming values from the + // outlining region. These phis are split (\ref severSplitPHINodesOfExits), + // and new outputs are created to supply the split phis. CodeExtractor can't + // report these new outputs until extraction begins, but it's important to + // factor the cost of the outputs into the cost calculation. + unsigned NumSplitExitPhis = 0; + for (BasicBlock *ExitBB : SuccsOutsideRegion) { + for (PHINode &PN : ExitBB->phis()) { + // Find all incoming values from the outlining region. + int NumIncomingVals = 0; + for (unsigned i = 0; i < PN.getNumIncomingValues(); ++i) + if (find(Region, PN.getIncomingBlock(i)) != Region.end()) { + ++NumIncomingVals; + if (NumIncomingVals > 1) { + ++NumSplitExitPhis; + break; + } + } + } + } + + // Apply a penalty for calling the split function. Factor in the cost of + // materializing all of the parameters. + int NumOutputsAndSplitPhis = NumOutputs + NumSplitExitPhis; + int NumParams = NumInputs + NumOutputsAndSplitPhis; + if (NumParams > MaxParametersForSplit) { + LLVM_DEBUG(dbgs() << NumInputs << " inputs and " << NumOutputsAndSplitPhis + << " outputs exceeds parameter limit (" + << MaxParametersForSplit << ")\n"); + return std::numeric_limits<int>::max(); + } + const int CostForArgMaterialization = 2 * TargetTransformInfo::TCC_Basic; + LLVM_DEBUG(dbgs() << "Applying penalty for: " << NumParams << " params\n"); + Penalty += CostForArgMaterialization * NumParams; + + // Apply the typical code size cost for an output alloca and its associated + // reload in the caller. Also penalize the associated store in the callee. + LLVM_DEBUG(dbgs() << "Applying penalty for: " << NumOutputsAndSplitPhis + << " outputs/split phis\n"); + const int CostForRegionOutput = 3 * TargetTransformInfo::TCC_Basic; + Penalty += CostForRegionOutput * NumOutputsAndSplitPhis; + // Apply a `noreturn` bonus. if (NoBlocksReturn) { LLVM_DEBUG(dbgs() << "Applying bonus for: " << Region.size() @@ -287,7 +333,7 @@ static int getOutliningPenalty(ArrayRef<BasicBlock *> Region, // Apply a penalty for having more than one successor outside of the region. // This penalty accounts for the switch needed in the caller. - if (!SuccsOutsideRegion.empty()) { + if (SuccsOutsideRegion.size() > 1) { LLVM_DEBUG(dbgs() << "Applying penalty for: " << SuccsOutsideRegion.size() << " non-region successors\n"); Penalty += (SuccsOutsideRegion.size() - 1) * TargetTransformInfo::TCC_Basic; @@ -312,12 +358,12 @@ Function *HotColdSplitting::extractColdRegion( // splitting. SetVector<Value *> Inputs, Outputs, Sinks; CE.findInputsOutputs(Inputs, Outputs, Sinks); - int OutliningBenefit = getOutliningBenefit(Region, TTI); + InstructionCost OutliningBenefit = getOutliningBenefit(Region, TTI); int OutliningPenalty = getOutliningPenalty(Region, Inputs.size(), Outputs.size()); LLVM_DEBUG(dbgs() << "Split profitability: benefit = " << OutliningBenefit << ", penalty = " << OutliningPenalty << "\n"); - if (OutliningBenefit <= OutliningPenalty) + if (!OutliningBenefit.isValid() || OutliningBenefit <= OutliningPenalty) return nullptr; Function *OrigF = Region[0]->getParent(); @@ -331,8 +377,12 @@ Function *HotColdSplitting::extractColdRegion( } CI->setIsNoInline(); - if (OrigF->hasSection()) - OutF->setSection(OrigF->getSection()); + if (EnableColdSection) + OutF->setSection(ColdSectionName); + else { + if (OrigF->hasSection()) + OutF->setSection(OrigF->getSection()); + } markFunctionCold(*OutF, BFI != nullptr); @@ -575,7 +625,7 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) { continue; bool Cold = (BFI && PSI->isColdBlock(BB, BFI)) || - (EnableStaticAnalyis && unlikelyExecuted(*BB)); + (EnableStaticAnalysis && unlikelyExecuted(*BB)); if (!Cold) continue; diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/IPConstantPropagation.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/IPConstantPropagation.cpp deleted file mode 100644 index 8d05a72d68da..000000000000 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/IPConstantPropagation.cpp +++ /dev/null @@ -1,308 +0,0 @@ -//===-- IPConstantPropagation.cpp - Propagate constants through calls -----===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This pass implements an _extremely_ simple interprocedural constant -// propagation pass. It could certainly be improved in many different ways, -// like using a worklist. This pass makes arguments dead, but does not remove -// them. The existing dead argument elimination pass should be run after this -// to clean up the mess. -// -//===----------------------------------------------------------------------===// - -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/AbstractCallSite.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/Module.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" -#include "llvm/Transforms/IPO.h" -using namespace llvm; - -#define DEBUG_TYPE "ipconstprop" - -STATISTIC(NumArgumentsProped, "Number of args turned into constants"); -STATISTIC(NumReturnValProped, "Number of return values turned into constants"); - -namespace { - /// IPCP - The interprocedural constant propagation pass - /// - struct IPCP : public ModulePass { - static char ID; // Pass identification, replacement for typeid - IPCP() : ModulePass(ID) { - initializeIPCPPass(*PassRegistry::getPassRegistry()); - } - - bool runOnModule(Module &M) override; - }; -} - -/// PropagateConstantsIntoArguments - Look at all uses of the specified -/// function. If all uses are direct call sites, and all pass a particular -/// constant in for an argument, propagate that constant in as the argument. -/// -static bool PropagateConstantsIntoArguments(Function &F) { - if (F.arg_empty() || F.use_empty()) return false; // No arguments? Early exit. - - // For each argument, keep track of its constant value and whether it is a - // constant or not. The bool is driven to true when found to be non-constant. - SmallVector<PointerIntPair<Constant *, 1, bool>, 16> ArgumentConstants; - ArgumentConstants.resize(F.arg_size()); - - unsigned NumNonconstant = 0; - for (Use &U : F.uses()) { - User *UR = U.getUser(); - // Ignore blockaddress uses. - if (isa<BlockAddress>(UR)) continue; - - // If no abstract call site was created we did not understand the use, bail. - AbstractCallSite ACS(&U); - if (!ACS) - return false; - - // Mismatched argument count is undefined behavior. Simply bail out to avoid - // handling of such situations below (avoiding asserts/crashes). - unsigned NumActualArgs = ACS.getNumArgOperands(); - if (F.isVarArg() ? ArgumentConstants.size() > NumActualArgs - : ArgumentConstants.size() != NumActualArgs) - return false; - - // Check out all of the potentially constant arguments. Note that we don't - // inspect varargs here. - Function::arg_iterator Arg = F.arg_begin(); - for (unsigned i = 0, e = ArgumentConstants.size(); i != e; ++i, ++Arg) { - - // If this argument is known non-constant, ignore it. - if (ArgumentConstants[i].getInt()) - continue; - - Value *V = ACS.getCallArgOperand(i); - Constant *C = dyn_cast_or_null<Constant>(V); - - // Mismatched argument type is undefined behavior. Simply bail out to avoid - // handling of such situations below (avoiding asserts/crashes). - if (C && Arg->getType() != C->getType()) - return false; - - // We can only propagate thread independent values through callbacks. - // This is different to direct/indirect call sites because for them we - // know the thread executing the caller and callee is the same. For - // callbacks this is not guaranteed, thus a thread dependent value could - // be different for the caller and callee, making it invalid to propagate. - if (C && ACS.isCallbackCall() && C->isThreadDependent()) { - // Argument became non-constant. If all arguments are non-constant now, - // give up on this function. - if (++NumNonconstant == ArgumentConstants.size()) - return false; - - ArgumentConstants[i].setInt(true); - continue; - } - - if (C && ArgumentConstants[i].getPointer() == nullptr) { - ArgumentConstants[i].setPointer(C); // First constant seen. - } else if (C && ArgumentConstants[i].getPointer() == C) { - // Still the constant value we think it is. - } else if (V == &*Arg) { - // Ignore recursive calls passing argument down. - } else { - // Argument became non-constant. If all arguments are non-constant now, - // give up on this function. - if (++NumNonconstant == ArgumentConstants.size()) - return false; - ArgumentConstants[i].setInt(true); - } - } - } - - // If we got to this point, there is a constant argument! - assert(NumNonconstant != ArgumentConstants.size()); - bool MadeChange = false; - Function::arg_iterator AI = F.arg_begin(); - for (unsigned i = 0, e = ArgumentConstants.size(); i != e; ++i, ++AI) { - // Do we have a constant argument? - if (ArgumentConstants[i].getInt() || AI->use_empty() || - (AI->hasByValAttr() && !F.onlyReadsMemory())) - continue; - - Value *V = ArgumentConstants[i].getPointer(); - if (!V) V = UndefValue::get(AI->getType()); - AI->replaceAllUsesWith(V); - ++NumArgumentsProped; - MadeChange = true; - } - return MadeChange; -} - - -// Check to see if this function returns one or more constants. If so, replace -// all callers that use those return values with the constant value. This will -// leave in the actual return values and instructions, but deadargelim will -// clean that up. -// -// Additionally if a function always returns one of its arguments directly, -// callers will be updated to use the value they pass in directly instead of -// using the return value. -static bool PropagateConstantReturn(Function &F) { - if (F.getReturnType()->isVoidTy()) - return false; // No return value. - - // We can infer and propagate the return value only when we know that the - // definition we'll get at link time is *exactly* the definition we see now. - // For more details, see GlobalValue::mayBeDerefined. - if (!F.isDefinitionExact()) - return false; - - // Don't touch naked functions. The may contain asm returning - // value we don't see, so we may end up interprocedurally propagating - // the return value incorrectly. - if (F.hasFnAttribute(Attribute::Naked)) - return false; - - // Check to see if this function returns a constant. - SmallVector<Value *,4> RetVals; - StructType *STy = dyn_cast<StructType>(F.getReturnType()); - if (STy) - for (unsigned i = 0, e = STy->getNumElements(); i < e; ++i) - RetVals.push_back(UndefValue::get(STy->getElementType(i))); - else - RetVals.push_back(UndefValue::get(F.getReturnType())); - - unsigned NumNonConstant = 0; - for (BasicBlock &BB : F) - if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) { - for (unsigned i = 0, e = RetVals.size(); i != e; ++i) { - // Already found conflicting return values? - Value *RV = RetVals[i]; - if (!RV) - continue; - - // Find the returned value - Value *V; - if (!STy) - V = RI->getOperand(0); - else - V = FindInsertedValue(RI->getOperand(0), i); - - if (V) { - // Ignore undefs, we can change them into anything - if (isa<UndefValue>(V)) - continue; - - // Try to see if all the rets return the same constant or argument. - if (isa<Constant>(V) || isa<Argument>(V)) { - if (isa<UndefValue>(RV)) { - // No value found yet? Try the current one. - RetVals[i] = V; - continue; - } - // Returning the same value? Good. - if (RV == V) - continue; - } - } - // Different or no known return value? Don't propagate this return - // value. - RetVals[i] = nullptr; - // All values non-constant? Stop looking. - if (++NumNonConstant == RetVals.size()) - return false; - } - } - - // If we got here, the function returns at least one constant value. Loop - // over all users, replacing any uses of the return value with the returned - // constant. - bool MadeChange = false; - for (Use &U : F.uses()) { - CallBase *CB = dyn_cast<CallBase>(U.getUser()); - - // Not a call instruction or a call instruction that's not calling F - // directly? - if (!CB || !CB->isCallee(&U)) - continue; - - // Call result not used? - if (CB->use_empty()) - continue; - - MadeChange = true; - - if (!STy) { - Value* New = RetVals[0]; - if (Argument *A = dyn_cast<Argument>(New)) - // Was an argument returned? Then find the corresponding argument in - // the call instruction and use that. - New = CB->getArgOperand(A->getArgNo()); - CB->replaceAllUsesWith(New); - continue; - } - - for (auto I = CB->user_begin(), E = CB->user_end(); I != E;) { - Instruction *Ins = cast<Instruction>(*I); - - // Increment now, so we can remove the use - ++I; - - // Find the index of the retval to replace with - int index = -1; - if (ExtractValueInst *EV = dyn_cast<ExtractValueInst>(Ins)) - if (EV->getNumIndices() == 1) - index = *EV->idx_begin(); - - // If this use uses a specific return value, and we have a replacement, - // replace it. - if (index != -1) { - Value *New = RetVals[index]; - if (New) { - if (Argument *A = dyn_cast<Argument>(New)) - // Was an argument returned? Then find the corresponding argument in - // the call instruction and use that. - New = CB->getArgOperand(A->getArgNo()); - Ins->replaceAllUsesWith(New); - Ins->eraseFromParent(); - } - } - } - } - - if (MadeChange) ++NumReturnValProped; - return MadeChange; -} - -char IPCP::ID = 0; -INITIALIZE_PASS(IPCP, "ipconstprop", - "Interprocedural constant propagation", false, false) - -ModulePass *llvm::createIPConstantPropagationPass() { return new IPCP(); } - -bool IPCP::runOnModule(Module &M) { - if (skipModule(M)) - return false; - - bool Changed = false; - bool LocalChange = true; - - // FIXME: instead of using smart algorithms, we just iterate until we stop - // making changes. - while (LocalChange) { - LocalChange = false; - for (Function &F : M) - if (!F.isDeclaration()) { - // Delete any klingons. - F.removeDeadConstantUsers(); - if (F.hasLocalLinkage()) - LocalChange |= PropagateConstantsIntoArguments(F); - Changed |= PropagateConstantReturn(F); - } - Changed |= LocalChange; - } - return Changed; -} diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/IPO.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/IPO.cpp index d37b9236380d..f4c12dd7f4cd 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/IPO.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/IPO.cpp @@ -25,6 +25,7 @@ using namespace llvm; void llvm::initializeIPO(PassRegistry &Registry) { initializeOpenMPOptLegacyPassPass(Registry); initializeArgPromotionPass(Registry); + initializeAnnotation2MetadataLegacyPass(Registry); initializeCalledValuePropagationLegacyPassPass(Registry); initializeConstantMergeLegacyPassPass(Registry); initializeCrossDSOCFIPass(Registry); @@ -35,13 +36,13 @@ void llvm::initializeIPO(PassRegistry &Registry) { initializeGlobalOptLegacyPassPass(Registry); initializeGlobalSplitPass(Registry); initializeHotColdSplittingLegacyPassPass(Registry); - initializeIPCPPass(Registry); + initializeIROutlinerLegacyPassPass(Registry); initializeAlwaysInlinerLegacyPassPass(Registry); initializeSimpleInlinerPass(Registry); initializeInferFunctionAttrsLegacyPassPass(Registry); initializeInternalizeLegacyPassPass(Registry); - initializeLoopExtractorPass(Registry); - initializeBlockExtractorPass(Registry); + initializeLoopExtractorLegacyPassPass(Registry); + initializeBlockExtractorLegacyPassPass(Registry); initializeSingleLoopExtractorPass(Registry); initializeLowerTypeTestsPass(Registry); initializeMergeFunctionsLegacyPassPass(Registry); @@ -104,10 +105,6 @@ void LLVMAddGlobalOptimizerPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createGlobalOptimizerPass()); } -void LLVMAddIPConstantPropagationPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createIPConstantPropagationPass()); -} - void LLVMAddPruneEHPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createPruneEHPass()); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/IROutliner.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/IROutliner.cpp new file mode 100644 index 000000000000..4b6a4f3d8fc4 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/IROutliner.cpp @@ -0,0 +1,1764 @@ +//===- IROutliner.cpp -- Outline Similar Regions ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +// Implementation for the IROutliner which is used by the IROutliner Pass. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO/IROutliner.h" +#include "llvm/Analysis/IRSimilarityIdentifier.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/PassManager.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Transforms/IPO.h" +#include <map> +#include <set> +#include <vector> + +#define DEBUG_TYPE "iroutliner" + +using namespace llvm; +using namespace IRSimilarity; + +// Set to true if the user wants the ir outliner to run on linkonceodr linkage +// functions. This is false by default because the linker can dedupe linkonceodr +// functions. Since the outliner is confined to a single module (modulo LTO), +// this is off by default. It should, however, be the default behavior in +// LTO. +static cl::opt<bool> EnableLinkOnceODRIROutlining( + "enable-linkonceodr-ir-outlining", cl::Hidden, + cl::desc("Enable the IR outliner on linkonceodr functions"), + cl::init(false)); + +// This is a debug option to test small pieces of code to ensure that outlining +// works correctly. +static cl::opt<bool> NoCostModel( + "ir-outlining-no-cost", cl::init(false), cl::ReallyHidden, + cl::desc("Debug option to outline greedily, without restriction that " + "calculated benefit outweighs cost")); + +/// The OutlinableGroup holds all the overarching information for outlining +/// a set of regions that are structurally similar to one another, such as the +/// types of the overall function, the output blocks, the sets of stores needed +/// and a list of the different regions. This information is used in the +/// deduplication of extracted regions with the same structure. +struct OutlinableGroup { + /// The sections that could be outlined + std::vector<OutlinableRegion *> Regions; + + /// The argument types for the function created as the overall function to + /// replace the extracted function for each region. + std::vector<Type *> ArgumentTypes; + /// The FunctionType for the overall function. + FunctionType *OutlinedFunctionType = nullptr; + /// The Function for the collective overall function. + Function *OutlinedFunction = nullptr; + + /// Flag for whether we should not consider this group of OutlinableRegions + /// for extraction. + bool IgnoreGroup = false; + + /// The return block for the overall function. + BasicBlock *EndBB = nullptr; + + /// A set containing the different GVN store sets needed. Each array contains + /// a sorted list of the different values that need to be stored into output + /// registers. + DenseSet<ArrayRef<unsigned>> OutputGVNCombinations; + + /// Flag for whether the \ref ArgumentTypes have been defined after the + /// extraction of the first region. + bool InputTypesSet = false; + + /// The number of input values in \ref ArgumentTypes. Anything after this + /// index in ArgumentTypes is an output argument. + unsigned NumAggregateInputs = 0; + + /// The number of instructions that will be outlined by extracting \ref + /// Regions. + InstructionCost Benefit = 0; + /// The number of added instructions needed for the outlining of the \ref + /// Regions. + InstructionCost Cost = 0; + + /// The argument that needs to be marked with the swifterr attribute. If not + /// needed, there is no value. + Optional<unsigned> SwiftErrorArgument; + + /// For the \ref Regions, we look at every Value. If it is a constant, + /// we check whether it is the same in Region. + /// + /// \param [in,out] NotSame contains the global value numbers where the + /// constant is not always the same, and must be passed in as an argument. + void findSameConstants(DenseSet<unsigned> &NotSame); + + /// For the regions, look at each set of GVN stores needed and account for + /// each combination. Add an argument to the argument types if there is + /// more than one combination. + /// + /// \param [in] M - The module we are outlining from. + void collectGVNStoreSets(Module &M); +}; + +/// Move the contents of \p SourceBB to before the last instruction of \p +/// TargetBB. +/// \param SourceBB - the BasicBlock to pull Instructions from. +/// \param TargetBB - the BasicBlock to put Instruction into. +static void moveBBContents(BasicBlock &SourceBB, BasicBlock &TargetBB) { + BasicBlock::iterator BBCurr, BBEnd, BBNext; + for (BBCurr = SourceBB.begin(), BBEnd = SourceBB.end(); BBCurr != BBEnd; + BBCurr = BBNext) { + BBNext = std::next(BBCurr); + BBCurr->moveBefore(TargetBB, TargetBB.end()); + } +} + +void OutlinableRegion::splitCandidate() { + assert(!CandidateSplit && "Candidate already split!"); + + Instruction *StartInst = (*Candidate->begin()).Inst; + Instruction *EndInst = (*Candidate->end()).Inst; + assert(StartInst && EndInst && "Expected a start and end instruction?"); + StartBB = StartInst->getParent(); + PrevBB = StartBB; + + // The basic block gets split like so: + // block: block: + // inst1 inst1 + // inst2 inst2 + // region1 br block_to_outline + // region2 block_to_outline: + // region3 -> region1 + // region4 region2 + // inst3 region3 + // inst4 region4 + // br block_after_outline + // block_after_outline: + // inst3 + // inst4 + + std::string OriginalName = PrevBB->getName().str(); + + StartBB = PrevBB->splitBasicBlock(StartInst, OriginalName + "_to_outline"); + + // This is the case for the inner block since we do not have to include + // multiple blocks. + EndBB = StartBB; + FollowBB = EndBB->splitBasicBlock(EndInst, OriginalName + "_after_outline"); + + CandidateSplit = true; +} + +void OutlinableRegion::reattachCandidate() { + assert(CandidateSplit && "Candidate is not split!"); + + // The basic block gets reattached like so: + // block: block: + // inst1 inst1 + // inst2 inst2 + // br block_to_outline region1 + // block_to_outline: -> region2 + // region1 region3 + // region2 region4 + // region3 inst3 + // region4 inst4 + // br block_after_outline + // block_after_outline: + // inst3 + // inst4 + assert(StartBB != nullptr && "StartBB for Candidate is not defined!"); + assert(FollowBB != nullptr && "StartBB for Candidate is not defined!"); + + // StartBB should only have one predecessor since we put an unconditional + // branch at the end of PrevBB when we split the BasicBlock. + PrevBB = StartBB->getSinglePredecessor(); + assert(PrevBB != nullptr && + "No Predecessor for the region start basic block!"); + + assert(PrevBB->getTerminator() && "Terminator removed from PrevBB!"); + assert(EndBB->getTerminator() && "Terminator removed from EndBB!"); + PrevBB->getTerminator()->eraseFromParent(); + EndBB->getTerminator()->eraseFromParent(); + + moveBBContents(*StartBB, *PrevBB); + + BasicBlock *PlacementBB = PrevBB; + if (StartBB != EndBB) + PlacementBB = EndBB; + moveBBContents(*FollowBB, *PlacementBB); + + PrevBB->replaceSuccessorsPhiUsesWith(StartBB, PrevBB); + PrevBB->replaceSuccessorsPhiUsesWith(FollowBB, PlacementBB); + StartBB->eraseFromParent(); + FollowBB->eraseFromParent(); + + // Make sure to save changes back to the StartBB. + StartBB = PrevBB; + EndBB = nullptr; + PrevBB = nullptr; + FollowBB = nullptr; + + CandidateSplit = false; +} + +/// Find whether \p V matches the Constants previously found for the \p GVN. +/// +/// \param V - The value to check for consistency. +/// \param GVN - The global value number assigned to \p V. +/// \param GVNToConstant - The mapping of global value number to Constants. +/// \returns true if the Value matches the Constant mapped to by V and false if +/// it \p V is a Constant but does not match. +/// \returns None if \p V is not a Constant. +static Optional<bool> +constantMatches(Value *V, unsigned GVN, + DenseMap<unsigned, Constant *> &GVNToConstant) { + // See if we have a constants + Constant *CST = dyn_cast<Constant>(V); + if (!CST) + return None; + + // Holds a mapping from a global value number to a Constant. + DenseMap<unsigned, Constant *>::iterator GVNToConstantIt; + bool Inserted; + + + // If we have a constant, try to make a new entry in the GVNToConstant. + std::tie(GVNToConstantIt, Inserted) = + GVNToConstant.insert(std::make_pair(GVN, CST)); + // If it was found and is not equal, it is not the same. We do not + // handle this case yet, and exit early. + if (Inserted || (GVNToConstantIt->second == CST)) + return true; + + return false; +} + +InstructionCost OutlinableRegion::getBenefit(TargetTransformInfo &TTI) { + InstructionCost Benefit = 0; + + // Estimate the benefit of outlining a specific sections of the program. We + // delegate mostly this task to the TargetTransformInfo so that if the target + // has specific changes, we can have a more accurate estimate. + + // However, getInstructionCost delegates the code size calculation for + // arithmetic instructions to getArithmeticInstrCost in + // include/Analysis/TargetTransformImpl.h, where it always estimates that the + // code size for a division and remainder instruction to be equal to 4, and + // everything else to 1. This is not an accurate representation of the + // division instruction for targets that have a native division instruction. + // To be overly conservative, we only add 1 to the number of instructions for + // each division instruction. + for (Instruction &I : *StartBB) { + switch (I.getOpcode()) { + case Instruction::FDiv: + case Instruction::FRem: + case Instruction::SDiv: + case Instruction::SRem: + case Instruction::UDiv: + case Instruction::URem: + Benefit += 1; + break; + default: + Benefit += TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize); + break; + } + } + + return Benefit; +} + +/// Find whether \p Region matches the global value numbering to Constant +/// mapping found so far. +/// +/// \param Region - The OutlinableRegion we are checking for constants +/// \param GVNToConstant - The mapping of global value number to Constants. +/// \param NotSame - The set of global value numbers that do not have the same +/// constant in each region. +/// \returns true if all Constants are the same in every use of a Constant in \p +/// Region and false if not +static bool +collectRegionsConstants(OutlinableRegion &Region, + DenseMap<unsigned, Constant *> &GVNToConstant, + DenseSet<unsigned> &NotSame) { + bool ConstantsTheSame = true; + + IRSimilarityCandidate &C = *Region.Candidate; + for (IRInstructionData &ID : C) { + + // Iterate over the operands in an instruction. If the global value number, + // assigned by the IRSimilarityCandidate, has been seen before, we check if + // the the number has been found to be not the same value in each instance. + for (Value *V : ID.OperVals) { + Optional<unsigned> GVNOpt = C.getGVN(V); + assert(GVNOpt.hasValue() && "Expected a GVN for operand?"); + unsigned GVN = GVNOpt.getValue(); + + // Check if this global value has been found to not be the same already. + if (NotSame.contains(GVN)) { + if (isa<Constant>(V)) + ConstantsTheSame = false; + continue; + } + + // If it has been the same so far, we check the value for if the + // associated Constant value match the previous instances of the same + // global value number. If the global value does not map to a Constant, + // it is considered to not be the same value. + Optional<bool> ConstantMatches = constantMatches(V, GVN, GVNToConstant); + if (ConstantMatches.hasValue()) { + if (ConstantMatches.getValue()) + continue; + else + ConstantsTheSame = false; + } + + // While this value is a register, it might not have been previously, + // make sure we don't already have a constant mapped to this global value + // number. + if (GVNToConstant.find(GVN) != GVNToConstant.end()) + ConstantsTheSame = false; + + NotSame.insert(GVN); + } + } + + return ConstantsTheSame; +} + +void OutlinableGroup::findSameConstants(DenseSet<unsigned> &NotSame) { + DenseMap<unsigned, Constant *> GVNToConstant; + + for (OutlinableRegion *Region : Regions) + collectRegionsConstants(*Region, GVNToConstant, NotSame); +} + +void OutlinableGroup::collectGVNStoreSets(Module &M) { + for (OutlinableRegion *OS : Regions) + OutputGVNCombinations.insert(OS->GVNStores); + + // We are adding an extracted argument to decide between which output path + // to use in the basic block. It is used in a switch statement and only + // needs to be an integer. + if (OutputGVNCombinations.size() > 1) + ArgumentTypes.push_back(Type::getInt32Ty(M.getContext())); +} + +Function *IROutliner::createFunction(Module &M, OutlinableGroup &Group, + unsigned FunctionNameSuffix) { + assert(!Group.OutlinedFunction && "Function is already defined!"); + + Group.OutlinedFunctionType = FunctionType::get( + Type::getVoidTy(M.getContext()), Group.ArgumentTypes, false); + + // These functions will only be called from within the same module, so + // we can set an internal linkage. + Group.OutlinedFunction = Function::Create( + Group.OutlinedFunctionType, GlobalValue::InternalLinkage, + "outlined_ir_func_" + std::to_string(FunctionNameSuffix), M); + + // Transfer the swifterr attribute to the correct function parameter. + if (Group.SwiftErrorArgument.hasValue()) + Group.OutlinedFunction->addParamAttr(Group.SwiftErrorArgument.getValue(), + Attribute::SwiftError); + + Group.OutlinedFunction->addFnAttr(Attribute::OptimizeForSize); + Group.OutlinedFunction->addFnAttr(Attribute::MinSize); + + return Group.OutlinedFunction; +} + +/// Move each BasicBlock in \p Old to \p New. +/// +/// \param [in] Old - the function to move the basic blocks from. +/// \param [in] New - The function to move the basic blocks to. +/// \returns the first return block for the function in New. +static BasicBlock *moveFunctionData(Function &Old, Function &New) { + Function::iterator CurrBB, NextBB, FinalBB; + BasicBlock *NewEnd = nullptr; + std::vector<Instruction *> DebugInsts; + for (CurrBB = Old.begin(), FinalBB = Old.end(); CurrBB != FinalBB; + CurrBB = NextBB) { + NextBB = std::next(CurrBB); + CurrBB->removeFromParent(); + CurrBB->insertInto(&New); + Instruction *I = CurrBB->getTerminator(); + if (isa<ReturnInst>(I)) + NewEnd = &(*CurrBB); + } + + assert(NewEnd && "No return instruction for new function?"); + return NewEnd; +} + +/// Find the the constants that will need to be lifted into arguments +/// as they are not the same in each instance of the region. +/// +/// \param [in] C - The IRSimilarityCandidate containing the region we are +/// analyzing. +/// \param [in] NotSame - The set of global value numbers that do not have a +/// single Constant across all OutlinableRegions similar to \p C. +/// \param [out] Inputs - The list containing the global value numbers of the +/// arguments needed for the region of code. +static void findConstants(IRSimilarityCandidate &C, DenseSet<unsigned> &NotSame, + std::vector<unsigned> &Inputs) { + DenseSet<unsigned> Seen; + // Iterate over the instructions, and find what constants will need to be + // extracted into arguments. + for (IRInstructionDataList::iterator IDIt = C.begin(), EndIDIt = C.end(); + IDIt != EndIDIt; IDIt++) { + for (Value *V : (*IDIt).OperVals) { + // Since these are stored before any outlining, they will be in the + // global value numbering. + unsigned GVN = C.getGVN(V).getValue(); + if (isa<Constant>(V)) + if (NotSame.contains(GVN) && !Seen.contains(GVN)) { + Inputs.push_back(GVN); + Seen.insert(GVN); + } + } + } +} + +/// Find the GVN for the inputs that have been found by the CodeExtractor. +/// +/// \param [in] C - The IRSimilarityCandidate containing the region we are +/// analyzing. +/// \param [in] CurrentInputs - The set of inputs found by the +/// CodeExtractor. +/// \param [out] EndInputNumbers - The global value numbers for the extracted +/// arguments. +/// \param [in] OutputMappings - The mapping of values that have been replaced +/// by a new output value. +/// \param [out] EndInputs - The global value numbers for the extracted +/// arguments. +static void mapInputsToGVNs(IRSimilarityCandidate &C, + SetVector<Value *> &CurrentInputs, + const DenseMap<Value *, Value *> &OutputMappings, + std::vector<unsigned> &EndInputNumbers) { + // Get the Global Value Number for each input. We check if the Value has been + // replaced by a different value at output, and use the original value before + // replacement. + for (Value *Input : CurrentInputs) { + assert(Input && "Have a nullptr as an input"); + if (OutputMappings.find(Input) != OutputMappings.end()) + Input = OutputMappings.find(Input)->second; + assert(C.getGVN(Input).hasValue() && + "Could not find a numbering for the given input"); + EndInputNumbers.push_back(C.getGVN(Input).getValue()); + } +} + +/// Find the original value for the \p ArgInput values if any one of them was +/// replaced during a previous extraction. +/// +/// \param [in] ArgInputs - The inputs to be extracted by the code extractor. +/// \param [in] OutputMappings - The mapping of values that have been replaced +/// by a new output value. +/// \param [out] RemappedArgInputs - The remapped values according to +/// \p OutputMappings that will be extracted. +static void +remapExtractedInputs(const ArrayRef<Value *> ArgInputs, + const DenseMap<Value *, Value *> &OutputMappings, + SetVector<Value *> &RemappedArgInputs) { + // Get the global value number for each input that will be extracted as an + // argument by the code extractor, remapping if needed for reloaded values. + for (Value *Input : ArgInputs) { + if (OutputMappings.find(Input) != OutputMappings.end()) + Input = OutputMappings.find(Input)->second; + RemappedArgInputs.insert(Input); + } +} + +/// Find the input GVNs and the output values for a region of Instructions. +/// Using the code extractor, we collect the inputs to the extracted function. +/// +/// The \p Region can be identified as needing to be ignored in this function. +/// It should be checked whether it should be ignored after a call to this +/// function. +/// +/// \param [in,out] Region - The region of code to be analyzed. +/// \param [out] InputGVNs - The global value numbers for the extracted +/// arguments. +/// \param [in] NotSame - The global value numbers in the region that do not +/// have the same constant value in the regions structurally similar to +/// \p Region. +/// \param [in] OutputMappings - The mapping of values that have been replaced +/// by a new output value after extraction. +/// \param [out] ArgInputs - The values of the inputs to the extracted function. +/// \param [out] Outputs - The set of values extracted by the CodeExtractor +/// as outputs. +static void getCodeExtractorArguments( + OutlinableRegion &Region, std::vector<unsigned> &InputGVNs, + DenseSet<unsigned> &NotSame, DenseMap<Value *, Value *> &OutputMappings, + SetVector<Value *> &ArgInputs, SetVector<Value *> &Outputs) { + IRSimilarityCandidate &C = *Region.Candidate; + + // OverallInputs are the inputs to the region found by the CodeExtractor, + // SinkCands and HoistCands are used by the CodeExtractor to find sunken + // allocas of values whose lifetimes are contained completely within the + // outlined region. PremappedInputs are the arguments found by the + // CodeExtractor, removing conditions such as sunken allocas, but that + // may need to be remapped due to the extracted output values replacing + // the original values. We use DummyOutputs for this first run of finding + // inputs and outputs since the outputs could change during findAllocas, + // the correct set of extracted outputs will be in the final Outputs ValueSet. + SetVector<Value *> OverallInputs, PremappedInputs, SinkCands, HoistCands, + DummyOutputs; + + // Use the code extractor to get the inputs and outputs, without sunken + // allocas or removing llvm.assumes. + CodeExtractor *CE = Region.CE; + CE->findInputsOutputs(OverallInputs, DummyOutputs, SinkCands); + assert(Region.StartBB && "Region must have a start BasicBlock!"); + Function *OrigF = Region.StartBB->getParent(); + CodeExtractorAnalysisCache CEAC(*OrigF); + BasicBlock *Dummy = nullptr; + + // The region may be ineligible due to VarArgs in the parent function. In this + // case we ignore the region. + if (!CE->isEligible()) { + Region.IgnoreRegion = true; + return; + } + + // Find if any values are going to be sunk into the function when extracted + CE->findAllocas(CEAC, SinkCands, HoistCands, Dummy); + CE->findInputsOutputs(PremappedInputs, Outputs, SinkCands); + + // TODO: Support regions with sunken allocas: values whose lifetimes are + // contained completely within the outlined region. These are not guaranteed + // to be the same in every region, so we must elevate them all to arguments + // when they appear. If these values are not equal, it means there is some + // Input in OverallInputs that was removed for ArgInputs. + if (OverallInputs.size() != PremappedInputs.size()) { + Region.IgnoreRegion = true; + return; + } + + findConstants(C, NotSame, InputGVNs); + + mapInputsToGVNs(C, OverallInputs, OutputMappings, InputGVNs); + + remapExtractedInputs(PremappedInputs.getArrayRef(), OutputMappings, + ArgInputs); + + // Sort the GVNs, since we now have constants included in the \ref InputGVNs + // we need to make sure they are in a deterministic order. + stable_sort(InputGVNs); +} + +/// Look over the inputs and map each input argument to an argument in the +/// overall function for the OutlinableRegions. This creates a way to replace +/// the arguments of the extracted function with the arguments of the new +/// overall function. +/// +/// \param [in,out] Region - The region of code to be analyzed. +/// \param [in] InputsGVNs - The global value numbering of the input values +/// collected. +/// \param [in] ArgInputs - The values of the arguments to the extracted +/// function. +static void +findExtractedInputToOverallInputMapping(OutlinableRegion &Region, + std::vector<unsigned> &InputGVNs, + SetVector<Value *> &ArgInputs) { + + IRSimilarityCandidate &C = *Region.Candidate; + OutlinableGroup &Group = *Region.Parent; + + // This counts the argument number in the overall function. + unsigned TypeIndex = 0; + + // This counts the argument number in the extracted function. + unsigned OriginalIndex = 0; + + // Find the mapping of the extracted arguments to the arguments for the + // overall function. Since there may be extra arguments in the overall + // function to account for the extracted constants, we have two different + // counters as we find extracted arguments, and as we come across overall + // arguments. + for (unsigned InputVal : InputGVNs) { + Optional<Value *> InputOpt = C.fromGVN(InputVal); + assert(InputOpt.hasValue() && "Global value number not found?"); + Value *Input = InputOpt.getValue(); + + if (!Group.InputTypesSet) { + Group.ArgumentTypes.push_back(Input->getType()); + // If the input value has a swifterr attribute, make sure to mark the + // argument in the overall function. + if (Input->isSwiftError()) { + assert( + !Group.SwiftErrorArgument.hasValue() && + "Argument already marked with swifterr for this OutlinableGroup!"); + Group.SwiftErrorArgument = TypeIndex; + } + } + + // Check if we have a constant. If we do add it to the overall argument + // number to Constant map for the region, and continue to the next input. + if (Constant *CST = dyn_cast<Constant>(Input)) { + Region.AggArgToConstant.insert(std::make_pair(TypeIndex, CST)); + TypeIndex++; + continue; + } + + // It is not a constant, we create the mapping from extracted argument list + // to the overall argument list. + assert(ArgInputs.count(Input) && "Input cannot be found!"); + + Region.ExtractedArgToAgg.insert(std::make_pair(OriginalIndex, TypeIndex)); + Region.AggArgToExtracted.insert(std::make_pair(TypeIndex, OriginalIndex)); + OriginalIndex++; + TypeIndex++; + } + + // If the function type definitions for the OutlinableGroup holding the region + // have not been set, set the length of the inputs here. We should have the + // same inputs for all of the different regions contained in the + // OutlinableGroup since they are all structurally similar to one another. + if (!Group.InputTypesSet) { + Group.NumAggregateInputs = TypeIndex; + Group.InputTypesSet = true; + } + + Region.NumExtractedInputs = OriginalIndex; +} + +/// Create a mapping of the output arguments for the \p Region to the output +/// arguments of the overall outlined function. +/// +/// \param [in,out] Region - The region of code to be analyzed. +/// \param [in] Outputs - The values found by the code extractor. +static void +findExtractedOutputToOverallOutputMapping(OutlinableRegion &Region, + ArrayRef<Value *> Outputs) { + OutlinableGroup &Group = *Region.Parent; + IRSimilarityCandidate &C = *Region.Candidate; + + // This counts the argument number in the extracted function. + unsigned OriginalIndex = Region.NumExtractedInputs; + + // This counts the argument number in the overall function. + unsigned TypeIndex = Group.NumAggregateInputs; + bool TypeFound; + DenseSet<unsigned> AggArgsUsed; + + // Iterate over the output types and identify if there is an aggregate pointer + // type whose base type matches the current output type. If there is, we mark + // that we will use this output register for this value. If not we add another + // type to the overall argument type list. We also store the GVNs used for + // stores to identify which values will need to be moved into an special + // block that holds the stores to the output registers. + for (Value *Output : Outputs) { + TypeFound = false; + // We can do this since it is a result value, and will have a number + // that is necessarily the same. BUT if in the future, the instructions + // do not have to be in same order, but are functionally the same, we will + // have to use a different scheme, as one-to-one correspondence is not + // guaranteed. + unsigned GlobalValue = C.getGVN(Output).getValue(); + unsigned ArgumentSize = Group.ArgumentTypes.size(); + + for (unsigned Jdx = TypeIndex; Jdx < ArgumentSize; Jdx++) { + if (Group.ArgumentTypes[Jdx] != PointerType::getUnqual(Output->getType())) + continue; + + if (AggArgsUsed.contains(Jdx)) + continue; + + TypeFound = true; + AggArgsUsed.insert(Jdx); + Region.ExtractedArgToAgg.insert(std::make_pair(OriginalIndex, Jdx)); + Region.AggArgToExtracted.insert(std::make_pair(Jdx, OriginalIndex)); + Region.GVNStores.push_back(GlobalValue); + break; + } + + // We were unable to find an unused type in the output type set that matches + // the output, so we add a pointer type to the argument types of the overall + // function to handle this output and create a mapping to it. + if (!TypeFound) { + Group.ArgumentTypes.push_back(PointerType::getUnqual(Output->getType())); + AggArgsUsed.insert(Group.ArgumentTypes.size() - 1); + Region.ExtractedArgToAgg.insert( + std::make_pair(OriginalIndex, Group.ArgumentTypes.size() - 1)); + Region.AggArgToExtracted.insert( + std::make_pair(Group.ArgumentTypes.size() - 1, OriginalIndex)); + Region.GVNStores.push_back(GlobalValue); + } + + stable_sort(Region.GVNStores); + OriginalIndex++; + TypeIndex++; + } +} + +void IROutliner::findAddInputsOutputs(Module &M, OutlinableRegion &Region, + DenseSet<unsigned> &NotSame) { + std::vector<unsigned> Inputs; + SetVector<Value *> ArgInputs, Outputs; + + getCodeExtractorArguments(Region, Inputs, NotSame, OutputMappings, ArgInputs, + Outputs); + + if (Region.IgnoreRegion) + return; + + // Map the inputs found by the CodeExtractor to the arguments found for + // the overall function. + findExtractedInputToOverallInputMapping(Region, Inputs, ArgInputs); + + // Map the outputs found by the CodeExtractor to the arguments found for + // the overall function. + findExtractedOutputToOverallOutputMapping(Region, Outputs.getArrayRef()); +} + +/// Replace the extracted function in the Region with a call to the overall +/// function constructed from the deduplicated similar regions, replacing and +/// remapping the values passed to the extracted function as arguments to the +/// new arguments of the overall function. +/// +/// \param [in] M - The module to outline from. +/// \param [in] Region - The regions of extracted code to be replaced with a new +/// function. +/// \returns a call instruction with the replaced function. +CallInst *replaceCalledFunction(Module &M, OutlinableRegion &Region) { + std::vector<Value *> NewCallArgs; + DenseMap<unsigned, unsigned>::iterator ArgPair; + + OutlinableGroup &Group = *Region.Parent; + CallInst *Call = Region.Call; + assert(Call && "Call to replace is nullptr?"); + Function *AggFunc = Group.OutlinedFunction; + assert(AggFunc && "Function to replace with is nullptr?"); + + // If the arguments are the same size, there are not values that need to be + // made argument, or different output registers to handle. We can simply + // replace the called function in this case. + if (AggFunc->arg_size() == Call->arg_size()) { + LLVM_DEBUG(dbgs() << "Replace call to " << *Call << " with call to " + << *AggFunc << " with same number of arguments\n"); + Call->setCalledFunction(AggFunc); + return Call; + } + + // We have a different number of arguments than the new function, so + // we need to use our previously mappings off extracted argument to overall + // function argument, and constants to overall function argument to create the + // new argument list. + for (unsigned AggArgIdx = 0; AggArgIdx < AggFunc->arg_size(); AggArgIdx++) { + + if (AggArgIdx == AggFunc->arg_size() - 1 && + Group.OutputGVNCombinations.size() > 1) { + // If we are on the last argument, and we need to differentiate between + // output blocks, add an integer to the argument list to determine + // what block to take + LLVM_DEBUG(dbgs() << "Set switch block argument to " + << Region.OutputBlockNum << "\n"); + NewCallArgs.push_back(ConstantInt::get(Type::getInt32Ty(M.getContext()), + Region.OutputBlockNum)); + continue; + } + + ArgPair = Region.AggArgToExtracted.find(AggArgIdx); + if (ArgPair != Region.AggArgToExtracted.end()) { + Value *ArgumentValue = Call->getArgOperand(ArgPair->second); + // If we found the mapping from the extracted function to the overall + // function, we simply add it to the argument list. We use the same + // value, it just needs to honor the new order of arguments. + LLVM_DEBUG(dbgs() << "Setting argument " << AggArgIdx << " to value " + << *ArgumentValue << "\n"); + NewCallArgs.push_back(ArgumentValue); + continue; + } + + // If it is a constant, we simply add it to the argument list as a value. + if (Region.AggArgToConstant.find(AggArgIdx) != + Region.AggArgToConstant.end()) { + Constant *CST = Region.AggArgToConstant.find(AggArgIdx)->second; + LLVM_DEBUG(dbgs() << "Setting argument " << AggArgIdx << " to value " + << *CST << "\n"); + NewCallArgs.push_back(CST); + continue; + } + + // Add a nullptr value if the argument is not found in the extracted + // function. If we cannot find a value, it means it is not in use + // for the region, so we should not pass anything to it. + LLVM_DEBUG(dbgs() << "Setting argument " << AggArgIdx << " to nullptr\n"); + NewCallArgs.push_back(ConstantPointerNull::get( + static_cast<PointerType *>(AggFunc->getArg(AggArgIdx)->getType()))); + } + + LLVM_DEBUG(dbgs() << "Replace call to " << *Call << " with call to " + << *AggFunc << " with new set of arguments\n"); + // Create the new call instruction and erase the old one. + Call = CallInst::Create(AggFunc->getFunctionType(), AggFunc, NewCallArgs, "", + Call); + + // It is possible that the call to the outlined function is either the first + // instruction is in the new block, the last instruction, or both. If either + // of these is the case, we need to make sure that we replace the instruction + // in the IRInstructionData struct with the new call. + CallInst *OldCall = Region.Call; + if (Region.NewFront->Inst == OldCall) + Region.NewFront->Inst = Call; + if (Region.NewBack->Inst == OldCall) + Region.NewBack->Inst = Call; + + // Transfer any debug information. + Call->setDebugLoc(Region.Call->getDebugLoc()); + + // Remove the old instruction. + OldCall->eraseFromParent(); + Region.Call = Call; + + // Make sure that the argument in the new function has the SwiftError + // argument. + if (Group.SwiftErrorArgument.hasValue()) + Call->addParamAttr(Group.SwiftErrorArgument.getValue(), + Attribute::SwiftError); + + return Call; +} + +// Within an extracted function, replace the argument uses of the extracted +// region with the arguments of the function for an OutlinableGroup. +// +/// \param [in] Region - The region of extracted code to be changed. +/// \param [in,out] OutputBB - The BasicBlock for the output stores for this +/// region. +static void replaceArgumentUses(OutlinableRegion &Region, + BasicBlock *OutputBB) { + OutlinableGroup &Group = *Region.Parent; + assert(Region.ExtractedFunction && "Region has no extracted function?"); + + for (unsigned ArgIdx = 0; ArgIdx < Region.ExtractedFunction->arg_size(); + ArgIdx++) { + assert(Region.ExtractedArgToAgg.find(ArgIdx) != + Region.ExtractedArgToAgg.end() && + "No mapping from extracted to outlined?"); + unsigned AggArgIdx = Region.ExtractedArgToAgg.find(ArgIdx)->second; + Argument *AggArg = Group.OutlinedFunction->getArg(AggArgIdx); + Argument *Arg = Region.ExtractedFunction->getArg(ArgIdx); + // The argument is an input, so we can simply replace it with the overall + // argument value + if (ArgIdx < Region.NumExtractedInputs) { + LLVM_DEBUG(dbgs() << "Replacing uses of input " << *Arg << " in function " + << *Region.ExtractedFunction << " with " << *AggArg + << " in function " << *Group.OutlinedFunction << "\n"); + Arg->replaceAllUsesWith(AggArg); + continue; + } + + // If we are replacing an output, we place the store value in its own + // block inside the overall function before replacing the use of the output + // in the function. + assert(Arg->hasOneUse() && "Output argument can only have one use"); + User *InstAsUser = Arg->user_back(); + assert(InstAsUser && "User is nullptr!"); + + Instruction *I = cast<Instruction>(InstAsUser); + I->setDebugLoc(DebugLoc()); + LLVM_DEBUG(dbgs() << "Move store for instruction " << *I << " to " + << *OutputBB << "\n"); + + I->moveBefore(*OutputBB, OutputBB->end()); + + LLVM_DEBUG(dbgs() << "Replacing uses of output " << *Arg << " in function " + << *Region.ExtractedFunction << " with " << *AggArg + << " in function " << *Group.OutlinedFunction << "\n"); + Arg->replaceAllUsesWith(AggArg); + } +} + +/// Within an extracted function, replace the constants that need to be lifted +/// into arguments with the actual argument. +/// +/// \param Region [in] - The region of extracted code to be changed. +void replaceConstants(OutlinableRegion &Region) { + OutlinableGroup &Group = *Region.Parent; + // Iterate over the constants that need to be elevated into arguments + for (std::pair<unsigned, Constant *> &Const : Region.AggArgToConstant) { + unsigned AggArgIdx = Const.first; + Function *OutlinedFunction = Group.OutlinedFunction; + assert(OutlinedFunction && "Overall Function is not defined?"); + Constant *CST = Const.second; + Argument *Arg = Group.OutlinedFunction->getArg(AggArgIdx); + // Identify the argument it will be elevated to, and replace instances of + // that constant in the function. + + // TODO: If in the future constants do not have one global value number, + // i.e. a constant 1 could be mapped to several values, this check will + // have to be more strict. It cannot be using only replaceUsesWithIf. + + LLVM_DEBUG(dbgs() << "Replacing uses of constant " << *CST + << " in function " << *OutlinedFunction << " with " + << *Arg << "\n"); + CST->replaceUsesWithIf(Arg, [OutlinedFunction](Use &U) { + if (Instruction *I = dyn_cast<Instruction>(U.getUser())) + return I->getFunction() == OutlinedFunction; + return false; + }); + } +} + +/// For the given function, find all the nondebug or lifetime instructions, +/// and return them as a vector. Exclude any blocks in \p ExludeBlocks. +/// +/// \param [in] F - The function we collect the instructions from. +/// \param [in] ExcludeBlocks - BasicBlocks to ignore. +/// \returns the list of instructions extracted. +static std::vector<Instruction *> +collectRelevantInstructions(Function &F, + DenseSet<BasicBlock *> &ExcludeBlocks) { + std::vector<Instruction *> RelevantInstructions; + + for (BasicBlock &BB : F) { + if (ExcludeBlocks.contains(&BB)) + continue; + + for (Instruction &Inst : BB) { + if (Inst.isLifetimeStartOrEnd()) + continue; + if (isa<DbgInfoIntrinsic>(Inst)) + continue; + + RelevantInstructions.push_back(&Inst); + } + } + + return RelevantInstructions; +} + +/// It is possible that there is a basic block that already performs the same +/// stores. This returns a duplicate block, if it exists +/// +/// \param OutputBB [in] the block we are looking for a duplicate of. +/// \param OutputStoreBBs [in] The existing output blocks. +/// \returns an optional value with the number output block if there is a match. +Optional<unsigned> +findDuplicateOutputBlock(BasicBlock *OutputBB, + ArrayRef<BasicBlock *> OutputStoreBBs) { + + bool WrongInst = false; + bool WrongSize = false; + unsigned MatchingNum = 0; + for (BasicBlock *CompBB : OutputStoreBBs) { + WrongInst = false; + if (CompBB->size() - 1 != OutputBB->size()) { + WrongSize = true; + MatchingNum++; + continue; + } + + WrongSize = false; + BasicBlock::iterator NIt = OutputBB->begin(); + for (Instruction &I : *CompBB) { + if (isa<BranchInst>(&I)) + continue; + + if (!I.isIdenticalTo(&(*NIt))) { + WrongInst = true; + break; + } + + NIt++; + } + if (!WrongInst && !WrongSize) + return MatchingNum; + + MatchingNum++; + } + + return None; +} + +/// For the outlined section, move needed the StoreInsts for the output +/// registers into their own block. Then, determine if there is a duplicate +/// output block already created. +/// +/// \param [in] OG - The OutlinableGroup of regions to be outlined. +/// \param [in] Region - The OutlinableRegion that is being analyzed. +/// \param [in,out] OutputBB - the block that stores for this region will be +/// placed in. +/// \param [in] EndBB - the final block of the extracted function. +/// \param [in] OutputMappings - OutputMappings the mapping of values that have +/// been replaced by a new output value. +/// \param [in,out] OutputStoreBBs - The existing output blocks. +static void +alignOutputBlockWithAggFunc(OutlinableGroup &OG, OutlinableRegion &Region, + BasicBlock *OutputBB, BasicBlock *EndBB, + const DenseMap<Value *, Value *> &OutputMappings, + std::vector<BasicBlock *> &OutputStoreBBs) { + DenseSet<unsigned> ValuesToFind(Region.GVNStores.begin(), + Region.GVNStores.end()); + + // We iterate over the instructions in the extracted function, and find the + // global value number of the instructions. If we find a value that should + // be contained in a store, we replace the uses of the value with the value + // from the overall function, so that the store is storing the correct + // value from the overall function. + DenseSet<BasicBlock *> ExcludeBBs(OutputStoreBBs.begin(), + OutputStoreBBs.end()); + ExcludeBBs.insert(OutputBB); + std::vector<Instruction *> ExtractedFunctionInsts = + collectRelevantInstructions(*(Region.ExtractedFunction), ExcludeBBs); + std::vector<Instruction *> OverallFunctionInsts = + collectRelevantInstructions(*OG.OutlinedFunction, ExcludeBBs); + + assert(ExtractedFunctionInsts.size() == OverallFunctionInsts.size() && + "Number of relevant instructions not equal!"); + + unsigned NumInstructions = ExtractedFunctionInsts.size(); + for (unsigned Idx = 0; Idx < NumInstructions; Idx++) { + Value *V = ExtractedFunctionInsts[Idx]; + + if (OutputMappings.find(V) != OutputMappings.end()) + V = OutputMappings.find(V)->second; + Optional<unsigned> GVN = Region.Candidate->getGVN(V); + + // If we have found one of the stored values for output, replace the value + // with the corresponding one from the overall function. + if (GVN.hasValue() && ValuesToFind.erase(GVN.getValue())) { + V->replaceAllUsesWith(OverallFunctionInsts[Idx]); + if (ValuesToFind.size() == 0) + break; + } + + if (ValuesToFind.size() == 0) + break; + } + + assert(ValuesToFind.size() == 0 && "Not all store values were handled!"); + + // If the size of the block is 0, then there are no stores, and we do not + // need to save this block. + if (OutputBB->size() == 0) { + Region.OutputBlockNum = -1; + OutputBB->eraseFromParent(); + return; + } + + // Determine is there is a duplicate block. + Optional<unsigned> MatchingBB = + findDuplicateOutputBlock(OutputBB, OutputStoreBBs); + + // If there is, we remove the new output block. If it does not, + // we add it to our list of output blocks. + if (MatchingBB.hasValue()) { + LLVM_DEBUG(dbgs() << "Set output block for region in function" + << Region.ExtractedFunction << " to " + << MatchingBB.getValue()); + + Region.OutputBlockNum = MatchingBB.getValue(); + OutputBB->eraseFromParent(); + return; + } + + Region.OutputBlockNum = OutputStoreBBs.size(); + + LLVM_DEBUG(dbgs() << "Create output block for region in" + << Region.ExtractedFunction << " to " + << *OutputBB); + OutputStoreBBs.push_back(OutputBB); + BranchInst::Create(EndBB, OutputBB); +} + +/// Create the switch statement for outlined function to differentiate between +/// all the output blocks. +/// +/// For the outlined section, determine if an outlined block already exists that +/// matches the needed stores for the extracted section. +/// \param [in] M - The module we are outlining from. +/// \param [in] OG - The group of regions to be outlined. +/// \param [in] OS - The region that is being analyzed. +/// \param [in] EndBB - The final block of the extracted function. +/// \param [in,out] OutputStoreBBs - The existing output blocks. +void createSwitchStatement(Module &M, OutlinableGroup &OG, BasicBlock *EndBB, + ArrayRef<BasicBlock *> OutputStoreBBs) { + // We only need the switch statement if there is more than one store + // combination. + if (OG.OutputGVNCombinations.size() > 1) { + Function *AggFunc = OG.OutlinedFunction; + // Create a final block + BasicBlock *ReturnBlock = + BasicBlock::Create(M.getContext(), "final_block", AggFunc); + Instruction *Term = EndBB->getTerminator(); + Term->moveBefore(*ReturnBlock, ReturnBlock->end()); + // Put the switch statement in the old end basic block for the function with + // a fall through to the new return block + LLVM_DEBUG(dbgs() << "Create switch statement in " << *AggFunc << " for " + << OutputStoreBBs.size() << "\n"); + SwitchInst *SwitchI = + SwitchInst::Create(AggFunc->getArg(AggFunc->arg_size() - 1), + ReturnBlock, OutputStoreBBs.size(), EndBB); + + unsigned Idx = 0; + for (BasicBlock *BB : OutputStoreBBs) { + SwitchI->addCase(ConstantInt::get(Type::getInt32Ty(M.getContext()), Idx), + BB); + Term = BB->getTerminator(); + Term->setSuccessor(0, ReturnBlock); + Idx++; + } + return; + } + + // If there needs to be stores, move them from the output block to the end + // block to save on branching instructions. + if (OutputStoreBBs.size() == 1) { + LLVM_DEBUG(dbgs() << "Move store instructions to the end block in " + << *OG.OutlinedFunction << "\n"); + BasicBlock *OutputBlock = OutputStoreBBs[0]; + Instruction *Term = OutputBlock->getTerminator(); + Term->eraseFromParent(); + Term = EndBB->getTerminator(); + moveBBContents(*OutputBlock, *EndBB); + Term->moveBefore(*EndBB, EndBB->end()); + OutputBlock->eraseFromParent(); + } +} + +/// Fill the new function that will serve as the replacement function for all of +/// the extracted regions of a certain structure from the first region in the +/// list of regions. Replace this first region's extracted function with the +/// new overall function. +/// +/// \param [in] M - The module we are outlining from. +/// \param [in] CurrentGroup - The group of regions to be outlined. +/// \param [in,out] OutputStoreBBs - The output blocks for each different +/// set of stores needed for the different functions. +/// \param [in,out] FuncsToRemove - Extracted functions to erase from module +/// once outlining is complete. +static void fillOverallFunction(Module &M, OutlinableGroup &CurrentGroup, + std::vector<BasicBlock *> &OutputStoreBBs, + std::vector<Function *> &FuncsToRemove) { + OutlinableRegion *CurrentOS = CurrentGroup.Regions[0]; + + // Move first extracted function's instructions into new function. + LLVM_DEBUG(dbgs() << "Move instructions from " + << *CurrentOS->ExtractedFunction << " to instruction " + << *CurrentGroup.OutlinedFunction << "\n"); + + CurrentGroup.EndBB = moveFunctionData(*CurrentOS->ExtractedFunction, + *CurrentGroup.OutlinedFunction); + + // Transfer the attributes from the function to the new function. + for (Attribute A : + CurrentOS->ExtractedFunction->getAttributes().getFnAttributes()) + CurrentGroup.OutlinedFunction->addFnAttr(A); + + // Create an output block for the first extracted function. + BasicBlock *NewBB = BasicBlock::Create( + M.getContext(), Twine("output_block_") + Twine(static_cast<unsigned>(0)), + CurrentGroup.OutlinedFunction); + CurrentOS->OutputBlockNum = 0; + + replaceArgumentUses(*CurrentOS, NewBB); + replaceConstants(*CurrentOS); + + // If the new basic block has no new stores, we can erase it from the module. + // It it does, we create a branch instruction to the last basic block from the + // new one. + if (NewBB->size() == 0) { + CurrentOS->OutputBlockNum = -1; + NewBB->eraseFromParent(); + } else { + BranchInst::Create(CurrentGroup.EndBB, NewBB); + OutputStoreBBs.push_back(NewBB); + } + + // Replace the call to the extracted function with the outlined function. + CurrentOS->Call = replaceCalledFunction(M, *CurrentOS); + + // We only delete the extracted functions at the end since we may need to + // reference instructions contained in them for mapping purposes. + FuncsToRemove.push_back(CurrentOS->ExtractedFunction); +} + +void IROutliner::deduplicateExtractedSections( + Module &M, OutlinableGroup &CurrentGroup, + std::vector<Function *> &FuncsToRemove, unsigned &OutlinedFunctionNum) { + createFunction(M, CurrentGroup, OutlinedFunctionNum); + + std::vector<BasicBlock *> OutputStoreBBs; + + OutlinableRegion *CurrentOS; + + fillOverallFunction(M, CurrentGroup, OutputStoreBBs, FuncsToRemove); + + for (unsigned Idx = 1; Idx < CurrentGroup.Regions.size(); Idx++) { + CurrentOS = CurrentGroup.Regions[Idx]; + AttributeFuncs::mergeAttributesForOutlining(*CurrentGroup.OutlinedFunction, + *CurrentOS->ExtractedFunction); + + // Create a new BasicBlock to hold the needed store instructions. + BasicBlock *NewBB = BasicBlock::Create( + M.getContext(), "output_block_" + std::to_string(Idx), + CurrentGroup.OutlinedFunction); + replaceArgumentUses(*CurrentOS, NewBB); + + alignOutputBlockWithAggFunc(CurrentGroup, *CurrentOS, NewBB, + CurrentGroup.EndBB, OutputMappings, + OutputStoreBBs); + + CurrentOS->Call = replaceCalledFunction(M, *CurrentOS); + FuncsToRemove.push_back(CurrentOS->ExtractedFunction); + } + + // Create a switch statement to handle the different output schemes. + createSwitchStatement(M, CurrentGroup, CurrentGroup.EndBB, OutputStoreBBs); + + OutlinedFunctionNum++; +} + +void IROutliner::pruneIncompatibleRegions( + std::vector<IRSimilarityCandidate> &CandidateVec, + OutlinableGroup &CurrentGroup) { + bool PreviouslyOutlined; + + // Sort from beginning to end, so the IRSimilarityCandidates are in order. + stable_sort(CandidateVec, [](const IRSimilarityCandidate &LHS, + const IRSimilarityCandidate &RHS) { + return LHS.getStartIdx() < RHS.getStartIdx(); + }); + + unsigned CurrentEndIdx = 0; + for (IRSimilarityCandidate &IRSC : CandidateVec) { + PreviouslyOutlined = false; + unsigned StartIdx = IRSC.getStartIdx(); + unsigned EndIdx = IRSC.getEndIdx(); + + for (unsigned Idx = StartIdx; Idx <= EndIdx; Idx++) + if (Outlined.contains(Idx)) { + PreviouslyOutlined = true; + break; + } + + if (PreviouslyOutlined) + continue; + + // TODO: If in the future we can outline across BasicBlocks, we will need to + // check all BasicBlocks contained in the region. + if (IRSC.getStartBB()->hasAddressTaken()) + continue; + + if (IRSC.front()->Inst->getFunction()->hasLinkOnceODRLinkage() && + !OutlineFromLinkODRs) + continue; + + // Greedily prune out any regions that will overlap with already chosen + // regions. + if (CurrentEndIdx != 0 && StartIdx <= CurrentEndIdx) + continue; + + bool BadInst = any_of(IRSC, [this](IRInstructionData &ID) { + // We check if there is a discrepancy between the InstructionDataList + // and the actual next instruction in the module. If there is, it means + // that an extra instruction was added, likely by the CodeExtractor. + + // Since we do not have any similarity data about this particular + // instruction, we cannot confidently outline it, and must discard this + // candidate. + if (std::next(ID.getIterator())->Inst != + ID.Inst->getNextNonDebugInstruction()) + return true; + return !this->InstructionClassifier.visit(ID.Inst); + }); + + if (BadInst) + continue; + + OutlinableRegion *OS = new (RegionAllocator.Allocate()) + OutlinableRegion(IRSC, CurrentGroup); + CurrentGroup.Regions.push_back(OS); + + CurrentEndIdx = EndIdx; + } +} + +InstructionCost +IROutliner::findBenefitFromAllRegions(OutlinableGroup &CurrentGroup) { + InstructionCost RegionBenefit = 0; + for (OutlinableRegion *Region : CurrentGroup.Regions) { + TargetTransformInfo &TTI = getTTI(*Region->StartBB->getParent()); + // We add the number of instructions in the region to the benefit as an + // estimate as to how much will be removed. + RegionBenefit += Region->getBenefit(TTI); + LLVM_DEBUG(dbgs() << "Adding: " << RegionBenefit + << " saved instructions to overfall benefit.\n"); + } + + return RegionBenefit; +} + +InstructionCost +IROutliner::findCostOutputReloads(OutlinableGroup &CurrentGroup) { + InstructionCost OverallCost = 0; + for (OutlinableRegion *Region : CurrentGroup.Regions) { + TargetTransformInfo &TTI = getTTI(*Region->StartBB->getParent()); + + // Each output incurs a load after the call, so we add that to the cost. + for (unsigned OutputGVN : Region->GVNStores) { + Optional<Value *> OV = Region->Candidate->fromGVN(OutputGVN); + assert(OV.hasValue() && "Could not find value for GVN?"); + Value *V = OV.getValue(); + InstructionCost LoadCost = + TTI.getMemoryOpCost(Instruction::Load, V->getType(), Align(1), 0, + TargetTransformInfo::TCK_CodeSize); + + LLVM_DEBUG(dbgs() << "Adding: " << LoadCost + << " instructions to cost for output of type " + << *V->getType() << "\n"); + OverallCost += LoadCost; + } + } + + return OverallCost; +} + +/// Find the extra instructions needed to handle any output values for the +/// region. +/// +/// \param [in] M - The Module to outline from. +/// \param [in] CurrentGroup - The collection of OutlinableRegions to analyze. +/// \param [in] TTI - The TargetTransformInfo used to collect information for +/// new instruction costs. +/// \returns the additional cost to handle the outputs. +static InstructionCost findCostForOutputBlocks(Module &M, + OutlinableGroup &CurrentGroup, + TargetTransformInfo &TTI) { + InstructionCost OutputCost = 0; + + for (const ArrayRef<unsigned> &OutputUse : + CurrentGroup.OutputGVNCombinations) { + IRSimilarityCandidate &Candidate = *CurrentGroup.Regions[0]->Candidate; + for (unsigned GVN : OutputUse) { + Optional<Value *> OV = Candidate.fromGVN(GVN); + assert(OV.hasValue() && "Could not find value for GVN?"); + Value *V = OV.getValue(); + InstructionCost StoreCost = + TTI.getMemoryOpCost(Instruction::Load, V->getType(), Align(1), 0, + TargetTransformInfo::TCK_CodeSize); + + // An instruction cost is added for each store set that needs to occur for + // various output combinations inside the function, plus a branch to + // return to the exit block. + LLVM_DEBUG(dbgs() << "Adding: " << StoreCost + << " instructions to cost for output of type " + << *V->getType() << "\n"); + OutputCost += StoreCost; + } + + InstructionCost BranchCost = + TTI.getCFInstrCost(Instruction::Br, TargetTransformInfo::TCK_CodeSize); + LLVM_DEBUG(dbgs() << "Adding " << BranchCost << " to the current cost for" + << " a branch instruction\n"); + OutputCost += BranchCost; + } + + // If there is more than one output scheme, we must have a comparison and + // branch for each different item in the switch statement. + if (CurrentGroup.OutputGVNCombinations.size() > 1) { + InstructionCost ComparisonCost = TTI.getCmpSelInstrCost( + Instruction::ICmp, Type::getInt32Ty(M.getContext()), + Type::getInt32Ty(M.getContext()), CmpInst::BAD_ICMP_PREDICATE, + TargetTransformInfo::TCK_CodeSize); + InstructionCost BranchCost = + TTI.getCFInstrCost(Instruction::Br, TargetTransformInfo::TCK_CodeSize); + + unsigned DifferentBlocks = CurrentGroup.OutputGVNCombinations.size(); + InstructionCost TotalCost = ComparisonCost * BranchCost * DifferentBlocks; + + LLVM_DEBUG(dbgs() << "Adding: " << TotalCost + << " instructions for each switch case for each different" + << " output path in a function\n"); + OutputCost += TotalCost; + } + + return OutputCost; +} + +void IROutliner::findCostBenefit(Module &M, OutlinableGroup &CurrentGroup) { + InstructionCost RegionBenefit = findBenefitFromAllRegions(CurrentGroup); + CurrentGroup.Benefit += RegionBenefit; + LLVM_DEBUG(dbgs() << "Current Benefit: " << CurrentGroup.Benefit << "\n"); + + InstructionCost OutputReloadCost = findCostOutputReloads(CurrentGroup); + CurrentGroup.Cost += OutputReloadCost; + LLVM_DEBUG(dbgs() << "Current Cost: " << CurrentGroup.Cost << "\n"); + + InstructionCost AverageRegionBenefit = + RegionBenefit / CurrentGroup.Regions.size(); + unsigned OverallArgumentNum = CurrentGroup.ArgumentTypes.size(); + unsigned NumRegions = CurrentGroup.Regions.size(); + TargetTransformInfo &TTI = + getTTI(*CurrentGroup.Regions[0]->Candidate->getFunction()); + + // We add one region to the cost once, to account for the instructions added + // inside of the newly created function. + LLVM_DEBUG(dbgs() << "Adding: " << AverageRegionBenefit + << " instructions to cost for body of new function.\n"); + CurrentGroup.Cost += AverageRegionBenefit; + LLVM_DEBUG(dbgs() << "Current Cost: " << CurrentGroup.Cost << "\n"); + + // For each argument, we must add an instruction for loading the argument + // out of the register and into a value inside of the newly outlined function. + LLVM_DEBUG(dbgs() << "Adding: " << OverallArgumentNum + << " instructions to cost for each argument in the new" + << " function.\n"); + CurrentGroup.Cost += + OverallArgumentNum * TargetTransformInfo::TCC_Basic; + LLVM_DEBUG(dbgs() << "Current Cost: " << CurrentGroup.Cost << "\n"); + + // Each argument needs to either be loaded into a register or onto the stack. + // Some arguments will only be loaded into the stack once the argument + // registers are filled. + LLVM_DEBUG(dbgs() << "Adding: " << OverallArgumentNum + << " instructions to cost for each argument in the new" + << " function " << NumRegions << " times for the " + << "needed argument handling at the call site.\n"); + CurrentGroup.Cost += + 2 * OverallArgumentNum * TargetTransformInfo::TCC_Basic * NumRegions; + LLVM_DEBUG(dbgs() << "Current Cost: " << CurrentGroup.Cost << "\n"); + + CurrentGroup.Cost += findCostForOutputBlocks(M, CurrentGroup, TTI); + LLVM_DEBUG(dbgs() << "Current Cost: " << CurrentGroup.Cost << "\n"); +} + +void IROutliner::updateOutputMapping(OutlinableRegion &Region, + ArrayRef<Value *> Outputs, + LoadInst *LI) { + // For and load instructions following the call + Value *Operand = LI->getPointerOperand(); + Optional<unsigned> OutputIdx = None; + // Find if the operand it is an output register. + for (unsigned ArgIdx = Region.NumExtractedInputs; + ArgIdx < Region.Call->arg_size(); ArgIdx++) { + if (Operand == Region.Call->getArgOperand(ArgIdx)) { + OutputIdx = ArgIdx - Region.NumExtractedInputs; + break; + } + } + + // If we found an output register, place a mapping of the new value + // to the original in the mapping. + if (!OutputIdx.hasValue()) + return; + + if (OutputMappings.find(Outputs[OutputIdx.getValue()]) == + OutputMappings.end()) { + LLVM_DEBUG(dbgs() << "Mapping extracted output " << *LI << " to " + << *Outputs[OutputIdx.getValue()] << "\n"); + OutputMappings.insert(std::make_pair(LI, Outputs[OutputIdx.getValue()])); + } else { + Value *Orig = OutputMappings.find(Outputs[OutputIdx.getValue()])->second; + LLVM_DEBUG(dbgs() << "Mapping extracted output " << *Orig << " to " + << *Outputs[OutputIdx.getValue()] << "\n"); + OutputMappings.insert(std::make_pair(LI, Orig)); + } +} + +bool IROutliner::extractSection(OutlinableRegion &Region) { + SetVector<Value *> ArgInputs, Outputs, SinkCands; + Region.CE->findInputsOutputs(ArgInputs, Outputs, SinkCands); + + assert(Region.StartBB && "StartBB for the OutlinableRegion is nullptr!"); + assert(Region.FollowBB && "FollowBB for the OutlinableRegion is nullptr!"); + Function *OrigF = Region.StartBB->getParent(); + CodeExtractorAnalysisCache CEAC(*OrigF); + Region.ExtractedFunction = Region.CE->extractCodeRegion(CEAC); + + // If the extraction was successful, find the BasicBlock, and reassign the + // OutlinableRegion blocks + if (!Region.ExtractedFunction) { + LLVM_DEBUG(dbgs() << "CodeExtractor failed to outline " << Region.StartBB + << "\n"); + Region.reattachCandidate(); + return false; + } + + BasicBlock *RewrittenBB = Region.FollowBB->getSinglePredecessor(); + Region.StartBB = RewrittenBB; + Region.EndBB = RewrittenBB; + + // The sequences of outlinable regions has now changed. We must fix the + // IRInstructionDataList for consistency. Although they may not be illegal + // instructions, they should not be compared with anything else as they + // should not be outlined in this round. So marking these as illegal is + // allowed. + IRInstructionDataList *IDL = Region.Candidate->front()->IDL; + Instruction *BeginRewritten = &*RewrittenBB->begin(); + Instruction *EndRewritten = &*RewrittenBB->begin(); + Region.NewFront = new (InstDataAllocator.Allocate()) IRInstructionData( + *BeginRewritten, InstructionClassifier.visit(*BeginRewritten), *IDL); + Region.NewBack = new (InstDataAllocator.Allocate()) IRInstructionData( + *EndRewritten, InstructionClassifier.visit(*EndRewritten), *IDL); + + // Insert the first IRInstructionData of the new region in front of the + // first IRInstructionData of the IRSimilarityCandidate. + IDL->insert(Region.Candidate->begin(), *Region.NewFront); + // Insert the first IRInstructionData of the new region after the + // last IRInstructionData of the IRSimilarityCandidate. + IDL->insert(Region.Candidate->end(), *Region.NewBack); + // Remove the IRInstructionData from the IRSimilarityCandidate. + IDL->erase(Region.Candidate->begin(), std::prev(Region.Candidate->end())); + + assert(RewrittenBB != nullptr && + "Could not find a predecessor after extraction!"); + + // Iterate over the new set of instructions to find the new call + // instruction. + for (Instruction &I : *RewrittenBB) + if (CallInst *CI = dyn_cast<CallInst>(&I)) { + if (Region.ExtractedFunction == CI->getCalledFunction()) + Region.Call = CI; + } else if (LoadInst *LI = dyn_cast<LoadInst>(&I)) + updateOutputMapping(Region, Outputs.getArrayRef(), LI); + Region.reattachCandidate(); + return true; +} + +unsigned IROutliner::doOutline(Module &M) { + // Find the possible similarity sections. + IRSimilarityIdentifier &Identifier = getIRSI(M); + SimilarityGroupList &SimilarityCandidates = *Identifier.getSimilarity(); + + // Sort them by size of extracted sections + unsigned OutlinedFunctionNum = 0; + // If we only have one SimilarityGroup in SimilarityCandidates, we do not have + // to sort them by the potential number of instructions to be outlined + if (SimilarityCandidates.size() > 1) + llvm::stable_sort(SimilarityCandidates, + [](const std::vector<IRSimilarityCandidate> &LHS, + const std::vector<IRSimilarityCandidate> &RHS) { + return LHS[0].getLength() * LHS.size() > + RHS[0].getLength() * RHS.size(); + }); + + DenseSet<unsigned> NotSame; + std::vector<Function *> FuncsToRemove; + // Iterate over the possible sets of similarity. + for (SimilarityGroup &CandidateVec : SimilarityCandidates) { + OutlinableGroup CurrentGroup; + + // Remove entries that were previously outlined + pruneIncompatibleRegions(CandidateVec, CurrentGroup); + + // We pruned the number of regions to 0 to 1, meaning that it's not worth + // trying to outlined since there is no compatible similar instance of this + // code. + if (CurrentGroup.Regions.size() < 2) + continue; + + // Determine if there are any values that are the same constant throughout + // each section in the set. + NotSame.clear(); + CurrentGroup.findSameConstants(NotSame); + + if (CurrentGroup.IgnoreGroup) + continue; + + // Create a CodeExtractor for each outlinable region. Identify inputs and + // outputs for each section using the code extractor and create the argument + // types for the Aggregate Outlining Function. + std::vector<OutlinableRegion *> OutlinedRegions; + for (OutlinableRegion *OS : CurrentGroup.Regions) { + // Break the outlinable region out of its parent BasicBlock into its own + // BasicBlocks (see function implementation). + OS->splitCandidate(); + std::vector<BasicBlock *> BE = {OS->StartBB}; + OS->CE = new (ExtractorAllocator.Allocate()) + CodeExtractor(BE, nullptr, false, nullptr, nullptr, nullptr, false, + false, "outlined"); + findAddInputsOutputs(M, *OS, NotSame); + if (!OS->IgnoreRegion) + OutlinedRegions.push_back(OS); + else + OS->reattachCandidate(); + } + + CurrentGroup.Regions = std::move(OutlinedRegions); + + if (CurrentGroup.Regions.empty()) + continue; + + CurrentGroup.collectGVNStoreSets(M); + + if (CostModel) + findCostBenefit(M, CurrentGroup); + + // If we are adhering to the cost model, reattach all the candidates + if (CurrentGroup.Cost >= CurrentGroup.Benefit && CostModel) { + for (OutlinableRegion *OS : CurrentGroup.Regions) + OS->reattachCandidate(); + OptimizationRemarkEmitter &ORE = getORE( + *CurrentGroup.Regions[0]->Candidate->getFunction()); + ORE.emit([&]() { + IRSimilarityCandidate *C = CurrentGroup.Regions[0]->Candidate; + OptimizationRemarkMissed R(DEBUG_TYPE, "WouldNotDecreaseSize", + C->frontInstruction()); + R << "did not outline " + << ore::NV(std::to_string(CurrentGroup.Regions.size())) + << " regions due to estimated increase of " + << ore::NV("InstructionIncrease", + CurrentGroup.Cost - CurrentGroup.Benefit) + << " instructions at locations "; + interleave( + CurrentGroup.Regions.begin(), CurrentGroup.Regions.end(), + [&R](OutlinableRegion *Region) { + R << ore::NV( + "DebugLoc", + Region->Candidate->frontInstruction()->getDebugLoc()); + }, + [&R]() { R << " "; }); + return R; + }); + continue; + } + + LLVM_DEBUG(dbgs() << "Outlining regions with cost " << CurrentGroup.Cost + << " and benefit " << CurrentGroup.Benefit << "\n"); + + // Create functions out of all the sections, and mark them as outlined. + OutlinedRegions.clear(); + for (OutlinableRegion *OS : CurrentGroup.Regions) { + bool FunctionOutlined = extractSection(*OS); + if (FunctionOutlined) { + unsigned StartIdx = OS->Candidate->getStartIdx(); + unsigned EndIdx = OS->Candidate->getEndIdx(); + for (unsigned Idx = StartIdx; Idx <= EndIdx; Idx++) + Outlined.insert(Idx); + + OutlinedRegions.push_back(OS); + } + } + + LLVM_DEBUG(dbgs() << "Outlined " << OutlinedRegions.size() + << " with benefit " << CurrentGroup.Benefit + << " and cost " << CurrentGroup.Cost << "\n"); + + CurrentGroup.Regions = std::move(OutlinedRegions); + + if (CurrentGroup.Regions.empty()) + continue; + + OptimizationRemarkEmitter &ORE = + getORE(*CurrentGroup.Regions[0]->Call->getFunction()); + ORE.emit([&]() { + IRSimilarityCandidate *C = CurrentGroup.Regions[0]->Candidate; + OptimizationRemark R(DEBUG_TYPE, "Outlined", C->front()->Inst); + R << "outlined " << ore::NV(std::to_string(CurrentGroup.Regions.size())) + << " regions with decrease of " + << ore::NV("Benefit", CurrentGroup.Benefit - CurrentGroup.Cost) + << " instructions at locations "; + interleave( + CurrentGroup.Regions.begin(), CurrentGroup.Regions.end(), + [&R](OutlinableRegion *Region) { + R << ore::NV("DebugLoc", + Region->Candidate->frontInstruction()->getDebugLoc()); + }, + [&R]() { R << " "; }); + return R; + }); + + deduplicateExtractedSections(M, CurrentGroup, FuncsToRemove, + OutlinedFunctionNum); + } + + for (Function *F : FuncsToRemove) + F->eraseFromParent(); + + return OutlinedFunctionNum; +} + +bool IROutliner::run(Module &M) { + CostModel = !NoCostModel; + OutlineFromLinkODRs = EnableLinkOnceODRIROutlining; + + return doOutline(M) > 0; +} + +// Pass Manager Boilerplate +class IROutlinerLegacyPass : public ModulePass { +public: + static char ID; + IROutlinerLegacyPass() : ModulePass(ID) { + initializeIROutlinerLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.addRequired<IRSimilarityIdentifierWrapperPass>(); + } + + bool runOnModule(Module &M) override; +}; + +bool IROutlinerLegacyPass::runOnModule(Module &M) { + if (skipModule(M)) + return false; + + std::unique_ptr<OptimizationRemarkEmitter> ORE; + auto GORE = [&ORE](Function &F) -> OptimizationRemarkEmitter & { + ORE.reset(new OptimizationRemarkEmitter(&F)); + return *ORE.get(); + }; + + auto GTTI = [this](Function &F) -> TargetTransformInfo & { + return this->getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + }; + + auto GIRSI = [this](Module &) -> IRSimilarityIdentifier & { + return this->getAnalysis<IRSimilarityIdentifierWrapperPass>().getIRSI(); + }; + + return IROutliner(GTTI, GIRSI, GORE).run(M); +} + +PreservedAnalyses IROutlinerPass::run(Module &M, ModuleAnalysisManager &AM) { + auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + + std::function<TargetTransformInfo &(Function &)> GTTI = + [&FAM](Function &F) -> TargetTransformInfo & { + return FAM.getResult<TargetIRAnalysis>(F); + }; + + std::function<IRSimilarityIdentifier &(Module &)> GIRSI = + [&AM](Module &M) -> IRSimilarityIdentifier & { + return AM.getResult<IRSimilarityAnalysis>(M); + }; + + std::unique_ptr<OptimizationRemarkEmitter> ORE; + std::function<OptimizationRemarkEmitter &(Function &)> GORE = + [&ORE](Function &F) -> OptimizationRemarkEmitter & { + ORE.reset(new OptimizationRemarkEmitter(&F)); + return *ORE.get(); + }; + + if (IROutliner(GTTI, GIRSI, GORE).run(M)) + return PreservedAnalyses::none(); + return PreservedAnalyses::all(); +} + +char IROutlinerLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(IROutlinerLegacyPass, "iroutliner", "IR Outliner", false, + false) +INITIALIZE_PASS_DEPENDENCY(IRSimilarityIdentifierWrapperPass) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END(IROutlinerLegacyPass, "iroutliner", "IR Outliner", false, + false) + +ModulePass *llvm::createIROutlinerPass() { return new IROutlinerLegacyPass(); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/Inliner.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/Inliner.cpp index 7d2260f4c169..e91b6c9b1d26 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/Inliner.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/Inliner.cpp @@ -23,7 +23,6 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" -#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/BlockFrequencyInfo.h" @@ -37,6 +36,7 @@ #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/Utils/ImportedFunctionsInliningStatistics.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/DataLayout.h" @@ -60,7 +60,6 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/CallPromotionUtils.h" #include "llvm/Transforms/Utils/Cloning.h" -#include "llvm/Transforms/Utils/ImportedFunctionsInliningStatistics.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include <algorithm> @@ -91,24 +90,14 @@ static cl::opt<bool> DisableInlinedAllocaMerging("disable-inlined-alloca-merging", cl::init(false), cl::Hidden); -namespace { +extern cl::opt<InlinerFunctionImportStatsOpts> InlinerFunctionImportStats; -enum class InlinerFunctionImportStatsOpts { - No = 0, - Basic = 1, - Verbose = 2, -}; - -} // end anonymous namespace - -static cl::opt<InlinerFunctionImportStatsOpts> InlinerFunctionImportStats( - "inliner-function-import-stats", - cl::init(InlinerFunctionImportStatsOpts::No), - cl::values(clEnumValN(InlinerFunctionImportStatsOpts::Basic, "basic", - "basic statistics"), - clEnumValN(InlinerFunctionImportStatsOpts::Verbose, "verbose", - "printing of statistics for each inlined function")), - cl::Hidden, cl::desc("Enable inliner stats for imported functions")); +static cl::opt<std::string> CGSCCInlineReplayFile( + "cgscc-inline-replay", cl::init(""), cl::value_desc("filename"), + cl::desc( + "Optimization remarks file containing inline remarks to be replayed " + "by inlining from cgscc inline remarks."), + cl::Hidden); LegacyInlinerBase::LegacyInlinerBase(char &ID) : CallGraphSCCPass(ID) {} @@ -648,17 +637,12 @@ bool LegacyInlinerBase::removeDeadFunctions(CallGraph &CG, return true; } -InlinerPass::~InlinerPass() { - if (ImportedFunctionsStats) { - assert(InlinerFunctionImportStats != InlinerFunctionImportStatsOpts::No); - ImportedFunctionsStats->dump(InlinerFunctionImportStats == - InlinerFunctionImportStatsOpts::Verbose); - } -} - InlineAdvisor & InlinerPass::getAdvisor(const ModuleAnalysisManagerCGSCCProxy::Result &MAM, FunctionAnalysisManager &FAM, Module &M) { + if (OwnedAdvisor) + return *OwnedAdvisor; + auto *IAA = MAM.getCachedResult<InlineAdvisorAnalysis>(M); if (!IAA) { // It should still be possible to run the inliner as a stand-alone SCC pass, @@ -669,8 +653,16 @@ InlinerPass::getAdvisor(const ModuleAnalysisManagerCGSCCProxy::Result &MAM, // duration of the inliner pass, and thus the lifetime of the owned advisor. // The one we would get from the MAM can be invalidated as a result of the // inliner's activity. - OwnedDefaultAdvisor.emplace(FAM, getInlineParams()); - return *OwnedDefaultAdvisor; + OwnedAdvisor = + std::make_unique<DefaultInlineAdvisor>(M, FAM, getInlineParams()); + + if (!CGSCCInlineReplayFile.empty()) + OwnedAdvisor = std::make_unique<ReplayInlineAdvisor>( + M, FAM, M.getContext(), std::move(OwnedAdvisor), + CGSCCInlineReplayFile, + /*EmitRemarks=*/true); + + return *OwnedAdvisor; } assert(IAA->getAdvisor() && "Expected a present InlineAdvisorAnalysis also have an " @@ -698,20 +690,13 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, auto AdvisorOnExit = make_scope_exit([&] { Advisor.onPassExit(); }); - if (!ImportedFunctionsStats && - InlinerFunctionImportStats != InlinerFunctionImportStatsOpts::No) { - ImportedFunctionsStats = - std::make_unique<ImportedFunctionsInliningStatistics>(); - ImportedFunctionsStats->setModuleInfo(M); - } - // We use a single common worklist for calls across the entire SCC. We // process these in-order and append new calls introduced during inlining to // the end. // // Note that this particular order of processing is actually critical to // avoid very bad behaviors. Consider *highly connected* call graphs where - // each function contains a small amonut of code and a couple of calls to + // each function contains a small amount of code and a couple of calls to // other functions. Because the LLVM inliner is fundamentally a bottom-up // inliner, it can handle gracefully the fact that these all appear to be // reasonable inlining candidates as it will flatten things until they become @@ -761,9 +746,8 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, if (Calls.empty()) return PreservedAnalyses::all(); - // Capture updatable variables for the current SCC and RefSCC. + // Capture updatable variable for the current SCC. auto *C = &InitialC; - auto *RC = &C->getOuterRefSCC(); // When inlining a callee produces new call sites, we want to keep track of // the fact that they were inlined from the callee. This allows us to avoid @@ -791,12 +775,6 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, LazyCallGraph::Node &N = *CG.lookup(F); if (CG.lookupSCC(N) != C) continue; - if (!Calls[I].first->getCalledFunction()->hasFnAttribute( - Attribute::AlwaysInline) && - F.hasOptNone()) { - setInlineRemark(*Calls[I].first, "optnone attribute"); - continue; - } LLVM_DEBUG(dbgs() << "Inlining calls in: " << F.getName() << "\n"); @@ -834,7 +812,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, continue; } - auto Advice = Advisor.getAdvice(*CB); + auto Advice = Advisor.getAdvice(*CB, OnlyMandatory); // Check whether we want to inline this callsite. if (!Advice->isInliningRecommended()) { Advice->recordUnattemptedInlining(); @@ -848,7 +826,8 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, &FAM.getResult<BlockFrequencyAnalysis>(*(CB->getCaller())), &FAM.getResult<BlockFrequencyAnalysis>(Callee)); - InlineResult IR = InlineFunction(*CB, IFI); + InlineResult IR = + InlineFunction(*CB, IFI, &FAM.getResult<AAManager>(*CB->getCaller())); if (!IR.isSuccess()) { Advice->recordUnsuccessfulInlining(IR); continue; @@ -879,9 +858,6 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, } } - if (InlinerFunctionImportStats != InlinerFunctionImportStatsOpts::No) - ImportedFunctionsStats->recordInline(F, Callee); - // Merge the attributes based on the inlining. AttributeFuncs::mergeAttributesForInlining(F, Callee); @@ -906,7 +882,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, // Note that after this point, it is an error to do anything other // than use the callee's address or delete it. Callee.dropAllReferences(); - assert(find(DeadFunctions, &Callee) == DeadFunctions.end() && + assert(!is_contained(DeadFunctions, &Callee) && "Cannot put cause a function to become dead twice!"); DeadFunctions.push_back(&Callee); CalleeWasDeleted = true; @@ -926,20 +902,6 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, continue; Changed = true; - // Add all the inlined callees' edges as ref edges to the caller. These are - // by definition trivial edges as we always have *some* transitive ref edge - // chain. While in some cases these edges are direct calls inside the - // callee, they have to be modeled in the inliner as reference edges as - // there may be a reference edge anywhere along the chain from the current - // caller to the callee that causes the whole thing to appear like - // a (transitive) reference edge that will require promotion to a call edge - // below. - for (Function *InlinedCallee : InlinedCallees) { - LazyCallGraph::Node &CalleeN = *CG.lookup(*InlinedCallee); - for (LazyCallGraph::Edge &E : *CalleeN) - RC->insertTrivialRefEdge(N, E.getNode()); - } - // At this point, since we have made changes we have at least removed // a call instruction. However, in the process we do some incremental // simplification of the surrounding code. This simplification can @@ -952,9 +914,8 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, // as we're going to mutate this particular function we want to make sure // the proxy is in place to forward any invalidation events. LazyCallGraph::SCC *OldC = C; - C = &updateCGAndAnalysisManagerForFunctionPass(CG, *C, N, AM, UR, FAM); + C = &updateCGAndAnalysisManagerForCGSCCPass(CG, *C, N, AM, UR, FAM); LLVM_DEBUG(dbgs() << "Updated inlining SCC: " << *C << "\n"); - RC = &C->getOuterRefSCC(); // If this causes an SCC to split apart into multiple smaller SCCs, there // is a subtle risk we need to prepare for. Other transformations may @@ -1033,6 +994,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, ModuleInlinerWrapperPass::ModuleInlinerWrapperPass(InlineParams Params, bool Debugging, + bool MandatoryFirst, InliningAdvisorMode Mode, unsigned MaxDevirtIterations) : Params(Params), Mode(Mode), MaxDevirtIterations(MaxDevirtIterations), @@ -1042,13 +1004,15 @@ ModuleInlinerWrapperPass::ModuleInlinerWrapperPass(InlineParams Params, // into the callers so that our optimizations can reflect that. // For PreLinkThinLTO pass, we disable hot-caller heuristic for sample PGO // because it makes profile annotation in the backend inaccurate. + if (MandatoryFirst) + PM.addPass(InlinerPass(/*OnlyMandatory*/ true)); PM.addPass(InlinerPass()); } PreservedAnalyses ModuleInlinerWrapperPass::run(Module &M, ModuleAnalysisManager &MAM) { auto &IAA = MAM.getResult<InlineAdvisorAnalysis>(M); - if (!IAA.tryCreate(Params, Mode)) { + if (!IAA.tryCreate(Params, Mode, CGSCCInlineReplayFile)) { M.getContext().emitError( "Could not setup Inlining Advisor for the requested " "mode and/or options"); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/LoopExtractor.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/LoopExtractor.cpp index f7f5b4cf6704..a497c0390bce 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/LoopExtractor.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/LoopExtractor.cpp @@ -13,12 +13,14 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/IPO/LoopExtractor.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" @@ -36,51 +38,71 @@ using namespace llvm; STATISTIC(NumExtracted, "Number of loops extracted"); namespace { - struct LoopExtractor : public ModulePass { - static char ID; // Pass identification, replacement for typeid +struct LoopExtractorLegacyPass : public ModulePass { + static char ID; // Pass identification, replacement for typeid - // The number of natural loops to extract from the program into functions. - unsigned NumLoops; + unsigned NumLoops; - explicit LoopExtractor(unsigned numLoops = ~0) - : ModulePass(ID), NumLoops(numLoops) { - initializeLoopExtractorPass(*PassRegistry::getPassRegistry()); - } - - bool runOnModule(Module &M) override; - bool runOnFunction(Function &F); + explicit LoopExtractorLegacyPass(unsigned NumLoops = ~0) + : ModulePass(ID), NumLoops(NumLoops) { + initializeLoopExtractorLegacyPassPass(*PassRegistry::getPassRegistry()); + } - bool extractLoops(Loop::iterator From, Loop::iterator To, LoopInfo &LI, - DominatorTree &DT); - bool extractLoop(Loop *L, LoopInfo &LI, DominatorTree &DT); + bool runOnModule(Module &M) override; - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequiredID(BreakCriticalEdgesID); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); - AU.addRequiredID(LoopSimplifyID); - AU.addUsedIfAvailable<AssumptionCacheTracker>(); - } - }; -} - -char LoopExtractor::ID = 0; -INITIALIZE_PASS_BEGIN(LoopExtractor, "loop-extract", + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequiredID(BreakCriticalEdgesID); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addPreserved<LoopInfoWrapperPass>(); + AU.addRequiredID(LoopSimplifyID); + AU.addUsedIfAvailable<AssumptionCacheTracker>(); + } +}; + +struct LoopExtractor { + explicit LoopExtractor( + unsigned NumLoops, + function_ref<DominatorTree &(Function &)> LookupDomTree, + function_ref<LoopInfo &(Function &)> LookupLoopInfo, + function_ref<AssumptionCache *(Function &)> LookupAssumptionCache) + : NumLoops(NumLoops), LookupDomTree(LookupDomTree), + LookupLoopInfo(LookupLoopInfo), + LookupAssumptionCache(LookupAssumptionCache) {} + bool runOnModule(Module &M); + +private: + // The number of natural loops to extract from the program into functions. + unsigned NumLoops; + + function_ref<DominatorTree &(Function &)> LookupDomTree; + function_ref<LoopInfo &(Function &)> LookupLoopInfo; + function_ref<AssumptionCache *(Function &)> LookupAssumptionCache; + + bool runOnFunction(Function &F); + + bool extractLoops(Loop::iterator From, Loop::iterator To, LoopInfo &LI, + DominatorTree &DT); + bool extractLoop(Loop *L, LoopInfo &LI, DominatorTree &DT); +}; +} // namespace + +char LoopExtractorLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(LoopExtractorLegacyPass, "loop-extract", "Extract loops into new functions", false, false) INITIALIZE_PASS_DEPENDENCY(BreakCriticalEdges) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopSimplify) -INITIALIZE_PASS_END(LoopExtractor, "loop-extract", +INITIALIZE_PASS_END(LoopExtractorLegacyPass, "loop-extract", "Extract loops into new functions", false, false) namespace { /// SingleLoopExtractor - For bugpoint. - struct SingleLoopExtractor : public LoopExtractor { - static char ID; // Pass identification, replacement for typeid - SingleLoopExtractor() : LoopExtractor(1) {} - }; +struct SingleLoopExtractor : public LoopExtractorLegacyPass { + static char ID; // Pass identification, replacement for typeid + SingleLoopExtractor() : LoopExtractorLegacyPass(1) {} +}; } // End anonymous namespace char SingleLoopExtractor::ID = 0; @@ -90,12 +112,30 @@ INITIALIZE_PASS(SingleLoopExtractor, "loop-extract-single", // createLoopExtractorPass - This pass extracts all natural loops from the // program into a function if it can. // -Pass *llvm::createLoopExtractorPass() { return new LoopExtractor(); } +Pass *llvm::createLoopExtractorPass() { return new LoopExtractorLegacyPass(); } -bool LoopExtractor::runOnModule(Module &M) { +bool LoopExtractorLegacyPass::runOnModule(Module &M) { if (skipModule(M)) return false; + bool Changed = false; + auto LookupDomTree = [this](Function &F) -> DominatorTree & { + return this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); + }; + auto LookupLoopInfo = [this, &Changed](Function &F) -> LoopInfo & { + return this->getAnalysis<LoopInfoWrapperPass>(F, &Changed).getLoopInfo(); + }; + auto LookupACT = [this](Function &F) -> AssumptionCache * { + if (auto *ACT = this->getAnalysisIfAvailable<AssumptionCacheTracker>()) + return ACT->lookupAssumptionCache(F); + return nullptr; + }; + return LoopExtractor(NumLoops, LookupDomTree, LookupLoopInfo, LookupACT) + .runOnModule(M) || + Changed; +} + +bool LoopExtractor::runOnModule(Module &M) { if (M.empty()) return false; @@ -132,13 +172,13 @@ bool LoopExtractor::runOnFunction(Function &F) { return false; bool Changed = false; - LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>(F, &Changed).getLoopInfo(); + LoopInfo &LI = LookupLoopInfo(F); // If there are no loops in the function. if (LI.empty()) return Changed; - DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); + DominatorTree &DT = LookupDomTree(F); // If there is more than one top-level loop in this function, extract all of // the loops. @@ -203,10 +243,8 @@ bool LoopExtractor::extractLoops(Loop::iterator From, Loop::iterator To, bool LoopExtractor::extractLoop(Loop *L, LoopInfo &LI, DominatorTree &DT) { assert(NumLoops != 0); - AssumptionCache *AC = nullptr; Function &Func = *L->getHeader()->getParent(); - if (auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>()) - AC = ACT->lookupAssumptionCache(Func); + AssumptionCache *AC = LookupAssumptionCache(Func); CodeExtractorAnalysisCache CEAC(Func); CodeExtractor Extractor(DT, *L, false, nullptr, nullptr, AC); if (Extractor.extractCodeRegion(CEAC)) { @@ -224,3 +262,24 @@ bool LoopExtractor::extractLoop(Loop *L, LoopInfo &LI, DominatorTree &DT) { Pass *llvm::createSingleLoopExtractorPass() { return new SingleLoopExtractor(); } + +PreservedAnalyses LoopExtractorPass::run(Module &M, ModuleAnalysisManager &AM) { + auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + auto LookupDomTree = [&FAM](Function &F) -> DominatorTree & { + return FAM.getResult<DominatorTreeAnalysis>(F); + }; + auto LookupLoopInfo = [&FAM](Function &F) -> LoopInfo & { + return FAM.getResult<LoopAnalysis>(F); + }; + auto LookupAssumptionCache = [&FAM](Function &F) -> AssumptionCache * { + return FAM.getCachedResult<AssumptionAnalysis>(F); + }; + if (!LoopExtractor(NumLoops, LookupDomTree, LookupLoopInfo, + LookupAssumptionCache) + .runOnModule(M)) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserve<LoopAnalysis>(); + return PA; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/LowerTypeTests.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/LowerTypeTests.cpp index 8eef7e3e7e99..8bd3036f1fc3 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/LowerTypeTests.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/LowerTypeTests.cpp @@ -198,7 +198,7 @@ void GlobalLayoutBuilder::addFragment(const std::set<uint64_t> &F) { // indices from the old fragment in this fragment do not insert any more // indices. std::vector<uint64_t> &OldFragment = Fragments[OldFragmentIndex]; - Fragment.insert(Fragment.end(), OldFragment.begin(), OldFragment.end()); + llvm::append_range(Fragment, OldFragment); OldFragment.clear(); } } @@ -1205,6 +1205,7 @@ void LowerTypeTestsModule::verifyTypeMDNode(GlobalObject *GO, MDNode *Type) { static const unsigned kX86JumpTableEntrySize = 8; static const unsigned kARMJumpTableEntrySize = 4; +static const unsigned kARMBTIJumpTableEntrySize = 8; unsigned LowerTypeTestsModule::getJumpTableEntrySize() { switch (Arch) { @@ -1213,7 +1214,12 @@ unsigned LowerTypeTestsModule::getJumpTableEntrySize() { return kX86JumpTableEntrySize; case Triple::arm: case Triple::thumb: + return kARMJumpTableEntrySize; case Triple::aarch64: + if (const auto *BTE = mdconst::extract_or_null<ConstantInt>( + M.getModuleFlag("branch-target-enforcement"))) + if (BTE->getZExtValue()) + return kARMBTIJumpTableEntrySize; return kARMJumpTableEntrySize; default: report_fatal_error("Unsupported architecture for jump tables"); @@ -1232,7 +1238,13 @@ void LowerTypeTestsModule::createJumpTableEntry( if (JumpTableArch == Triple::x86 || JumpTableArch == Triple::x86_64) { AsmOS << "jmp ${" << ArgIndex << ":c}@plt\n"; AsmOS << "int3\nint3\nint3\n"; - } else if (JumpTableArch == Triple::arm || JumpTableArch == Triple::aarch64) { + } else if (JumpTableArch == Triple::arm) { + AsmOS << "b $" << ArgIndex << "\n"; + } else if (JumpTableArch == Triple::aarch64) { + if (const auto *BTE = mdconst::extract_or_null<ConstantInt>( + Dest->getParent()->getModuleFlag("branch-target-enforcement"))) + if (BTE->getZExtValue()) + AsmOS << "bti c\n"; AsmOS << "b $" << ArgIndex << "\n"; } else if (JumpTableArch == Triple::thumb) { AsmOS << "b.w $" << ArgIndex << "\n"; @@ -1326,7 +1338,7 @@ void LowerTypeTestsModule::replaceWeakDeclarationWithJumpTablePtr( static bool isThumbFunction(Function *F, Triple::ArchType ModuleArch) { Attribute TFAttr = F->getFnAttribute("target-features"); - if (!TFAttr.hasAttribute(Attribute::None)) { + if (TFAttr.isValid()) { SmallVector<StringRef, 6> Features; TFAttr.getValueAsString().split(Features, ','); for (StringRef Feature : Features) { @@ -1394,6 +1406,10 @@ void LowerTypeTestsModule::createJumpTable( // by Clang for -march=armv7. F->addFnAttr("target-cpu", "cortex-a8"); } + if (JumpTableArch == Triple::aarch64) { + F->addFnAttr("branch-target-enforcement", "false"); + F->addFnAttr("sign-return-address", "none"); + } // Make sure we don't emit .eh_frame for this function. F->addFnAttr(Attribute::NoUnwind); @@ -2239,9 +2255,13 @@ bool LowerTypeTestsModule::lower() { PreservedAnalyses LowerTypeTestsPass::run(Module &M, ModuleAnalysisManager &AM) { - bool Changed = - LowerTypeTestsModule(M, ExportSummary, ImportSummary, DropTypeTests) - .lower(); + bool Changed; + if (UseCommandLine) + Changed = LowerTypeTestsModule::runForTesting(M); + else + Changed = + LowerTypeTestsModule(M, ExportSummary, ImportSummary, DropTypeTests) + .lower(); if (!Changed) return PreservedAnalyses::all(); return PreservedAnalyses::none(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/MergeFunctions.cpp index 8cc19515f3db..ec5d86b72a1f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/MergeFunctions.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/MergeFunctions.cpp @@ -725,8 +725,10 @@ void MergeFunctions::writeThunk(Function *F, Function *G) { if (MergeFunctionsPDI) { DISubprogram *DIS = G->getSubprogram(); if (DIS) { - DebugLoc CIDbgLoc = DebugLoc::get(DIS->getScopeLine(), 0, DIS); - DebugLoc RIDbgLoc = DebugLoc::get(DIS->getScopeLine(), 0, DIS); + DebugLoc CIDbgLoc = + DILocation::get(DIS->getContext(), DIS->getScopeLine(), 0, DIS); + DebugLoc RIDbgLoc = + DILocation::get(DIS->getContext(), DIS->getScopeLine(), 0, DIS); CI->setDebugLoc(CIDbgLoc); RI->setDebugLoc(RIDbgLoc); } else { diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp index f664a2417374..a5ba6edb9a00 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -19,13 +19,16 @@ #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraphSCCPass.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" #include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/Attributor.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/CallGraphUpdater.h" +#include "llvm/Transforms/Utils/CodeExtractor.h" using namespace llvm; using namespace omp; @@ -37,11 +40,22 @@ static cl::opt<bool> DisableOpenMPOptimizations( cl::desc("Disable OpenMP specific optimizations."), cl::Hidden, cl::init(false)); +static cl::opt<bool> EnableParallelRegionMerging( + "openmp-opt-enable-merging", cl::ZeroOrMore, + cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden, + cl::init(false)); + static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false), cl::Hidden); static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels", cl::init(false), cl::Hidden); +static cl::opt<bool> HideMemoryTransferLatency( + "openmp-hide-memory-transfer-latency", + cl::desc("[WIP] Tries to hide the latency of host to device memory" + " transfers"), + cl::Hidden, cl::init(false)); + STATISTIC(NumOpenMPRuntimeCallsDeduplicated, "Number of OpenMP runtime calls deduplicated"); STATISTIC(NumOpenMPParallelRegionsDeleted, @@ -55,70 +69,13 @@ STATISTIC(NumOpenMPTargetRegionKernels, STATISTIC( NumOpenMPParallelRegionsReplacedInGPUStateMachine, "Number of OpenMP parallel regions replaced with ID in GPU state machines"); +STATISTIC(NumOpenMPParallelRegionsMerged, + "Number of OpenMP parallel regions merged"); #if !defined(NDEBUG) static constexpr auto TAG = "[" DEBUG_TYPE "]"; #endif -/// Apply \p CB to all uses of \p F. If \p LookThroughConstantExprUses is -/// true, constant expression users are not given to \p CB but their uses are -/// traversed transitively. -template <typename CBTy> -static void foreachUse(Function &F, CBTy CB, - bool LookThroughConstantExprUses = true) { - SmallVector<Use *, 8> Worklist(make_pointer_range(F.uses())); - - for (unsigned idx = 0; idx < Worklist.size(); ++idx) { - Use &U = *Worklist[idx]; - - // Allow use in constant bitcasts and simply look through them. - if (LookThroughConstantExprUses && isa<ConstantExpr>(U.getUser())) { - for (Use &CEU : cast<ConstantExpr>(U.getUser())->uses()) - Worklist.push_back(&CEU); - continue; - } - - CB(U); - } -} - -/// Helper struct to store tracked ICV values at specif instructions. -struct ICVValue { - Instruction *Inst; - Value *TrackedValue; - - ICVValue(Instruction *I, Value *Val) : Inst(I), TrackedValue(Val) {} -}; - -namespace llvm { - -// Provide DenseMapInfo for ICVValue -template <> struct DenseMapInfo<ICVValue> { - using InstInfo = DenseMapInfo<Instruction *>; - using ValueInfo = DenseMapInfo<Value *>; - - static inline ICVValue getEmptyKey() { - return ICVValue(InstInfo::getEmptyKey(), ValueInfo::getEmptyKey()); - }; - - static inline ICVValue getTombstoneKey() { - return ICVValue(InstInfo::getTombstoneKey(), ValueInfo::getTombstoneKey()); - }; - - static unsigned getHashValue(const ICVValue &ICVVal) { - return detail::combineHashValue( - InstInfo::getHashValue(ICVVal.Inst), - ValueInfo::getHashValue(ICVVal.TrackedValue)); - } - - static bool isEqual(const ICVValue &LHS, const ICVValue &RHS) { - return InstInfo::isEqual(LHS.Inst, RHS.Inst) && - ValueInfo::isEqual(LHS.TrackedValue, RHS.TrackedValue); - } -}; - -} // end namespace llvm - namespace { struct AAICVTracker; @@ -131,7 +88,6 @@ struct OMPInformationCache : public InformationCache { SmallPtrSetImpl<Kernel> &Kernels) : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M), Kernels(Kernels) { - initializeModuleSlice(CGSCC); OMPBuilder.initialize(); initializeRuntimeFunctions(); @@ -258,46 +214,6 @@ struct OMPInformationCache : public InformationCache { DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap; }; - /// Initialize the ModuleSlice member based on \p SCC. ModuleSlices contains - /// (a subset of) all functions that we can look at during this SCC traversal. - /// This includes functions (transitively) called from the SCC and the - /// (transitive) callers of SCC functions. We also can look at a function if - /// there is a "reference edge", i.a., if the function somehow uses (!=calls) - /// a function in the SCC or a caller of a function in the SCC. - void initializeModuleSlice(SetVector<Function *> &SCC) { - ModuleSlice.insert(SCC.begin(), SCC.end()); - - SmallPtrSet<Function *, 16> Seen; - SmallVector<Function *, 16> Worklist(SCC.begin(), SCC.end()); - while (!Worklist.empty()) { - Function *F = Worklist.pop_back_val(); - ModuleSlice.insert(F); - - for (Instruction &I : instructions(*F)) - if (auto *CB = dyn_cast<CallBase>(&I)) - if (Function *Callee = CB->getCalledFunction()) - if (Seen.insert(Callee).second) - Worklist.push_back(Callee); - } - - Seen.clear(); - Worklist.append(SCC.begin(), SCC.end()); - while (!Worklist.empty()) { - Function *F = Worklist.pop_back_val(); - ModuleSlice.insert(F); - - // Traverse all transitive uses. - foreachUse(*F, [&](Use &U) { - if (auto *UsrI = dyn_cast<Instruction>(U.getUser())) - if (Seen.insert(UsrI->getFunction()).second) - Worklist.push_back(UsrI->getFunction()); - }); - } - } - - /// The slice of the module we are allowed to look at. - SmallPtrSet<Function *, 8> ModuleSlice; - /// An OpenMP-IR-Builder instance OpenMPIRBuilder OMPBuilder; @@ -402,13 +318,17 @@ struct OMPInformationCache : public InformationCache { return NumUses; } + // Helper function to recollect uses of a runtime function. + void recollectUsesForFunction(RuntimeFunction RTF) { + auto &RFI = RFIs[RTF]; + RFI.clearUsesMap(); + collectUses(RFI, /*CollectStats*/ false); + } + // Helper function to recollect uses of all runtime functions. void recollectUses() { - for (int Idx = 0; Idx < RFIs.size(); ++Idx) { - auto &RFI = RFIs[static_cast<RuntimeFunction>(Idx)]; - RFI.clearUsesMap(); - collectUses(RFI, /*CollectStats*/ false); - } + for (int Idx = 0; Idx < RFIs.size(); ++Idx) + recollectUsesForFunction(static_cast<RuntimeFunction>(Idx)); } /// Helper to initialize all runtime function information for those defined @@ -472,6 +392,91 @@ struct OMPInformationCache : public InformationCache { SmallPtrSetImpl<Kernel> &Kernels; }; +/// Used to map the values physically (in the IR) stored in an offload +/// array, to a vector in memory. +struct OffloadArray { + /// Physical array (in the IR). + AllocaInst *Array = nullptr; + /// Mapped values. + SmallVector<Value *, 8> StoredValues; + /// Last stores made in the offload array. + SmallVector<StoreInst *, 8> LastAccesses; + + OffloadArray() = default; + + /// Initializes the OffloadArray with the values stored in \p Array before + /// instruction \p Before is reached. Returns false if the initialization + /// fails. + /// This MUST be used immediately after the construction of the object. + bool initialize(AllocaInst &Array, Instruction &Before) { + if (!Array.getAllocatedType()->isArrayTy()) + return false; + + if (!getValues(Array, Before)) + return false; + + this->Array = &Array; + return true; + } + + static const unsigned DeviceIDArgNum = 1; + static const unsigned BasePtrsArgNum = 3; + static const unsigned PtrsArgNum = 4; + static const unsigned SizesArgNum = 5; + +private: + /// Traverses the BasicBlock where \p Array is, collecting the stores made to + /// \p Array, leaving StoredValues with the values stored before the + /// instruction \p Before is reached. + bool getValues(AllocaInst &Array, Instruction &Before) { + // Initialize container. + const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements(); + StoredValues.assign(NumValues, nullptr); + LastAccesses.assign(NumValues, nullptr); + + // TODO: This assumes the instruction \p Before is in the same + // BasicBlock as Array. Make it general, for any control flow graph. + BasicBlock *BB = Array.getParent(); + if (BB != Before.getParent()) + return false; + + const DataLayout &DL = Array.getModule()->getDataLayout(); + const unsigned int PointerSize = DL.getPointerSize(); + + for (Instruction &I : *BB) { + if (&I == &Before) + break; + + if (!isa<StoreInst>(&I)) + continue; + + auto *S = cast<StoreInst>(&I); + int64_t Offset = -1; + auto *Dst = + GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL); + if (Dst == &Array) { + int64_t Idx = Offset / PointerSize; + StoredValues[Idx] = getUnderlyingObject(S->getValueOperand()); + LastAccesses[Idx] = S; + } + } + + return isFilled(); + } + + /// Returns true if all values in StoredValues and + /// LastAccesses are not nullptrs. + bool isFilled() { + const unsigned NumValues = StoredValues.size(); + for (unsigned I = 0; I < NumValues; ++I) { + if (!StoredValues[I] || !LastAccesses[I]) + return false; + } + + return true; + } +}; + struct OpenMPOpt { using OptimizationRemarkGetter = @@ -483,6 +488,12 @@ struct OpenMPOpt { : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater), OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {} + /// Check if any remarks are enabled for openmp-opt + bool remarksEnabled() { + auto &Ctx = M.getContext(); + return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE); + } + /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice. bool run() { if (SCC.empty()) @@ -506,8 +517,18 @@ struct OpenMPOpt { // Recollect uses, in case Attributor deleted any. OMPInfoCache.recollectUses(); - Changed |= deduplicateRuntimeCalls(); Changed |= deleteParallelRegions(); + if (HideMemoryTransferLatency) + Changed |= hideMemTransfersLatency(); + if (remarksEnabled()) + analysisGlobalization(); + Changed |= deduplicateRuntimeCalls(); + if (EnableParallelRegionMerging) { + if (mergeParallelRegions()) { + deduplicateRuntimeCalls(); + Changed = true; + } + } return Changed; } @@ -515,7 +536,8 @@ struct OpenMPOpt { /// Print initial ICV values for testing. /// FIXME: This should be done from the Attributor once it is added. void printICVs() const { - InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel}; + InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel, + ICV_proc_bind}; for (Function *F : OMPInfoCache.ModuleSlice) { for (auto ICV : ICVs) { @@ -571,6 +593,394 @@ struct OpenMPOpt { } private: + /// Merge parallel regions when it is safe. + bool mergeParallelRegions() { + const unsigned CallbackCalleeOperand = 2; + const unsigned CallbackFirstArgOperand = 3; + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + + // Check if there are any __kmpc_fork_call calls to merge. + OMPInformationCache::RuntimeFunctionInfo &RFI = + OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call]; + + if (!RFI.Declaration) + return false; + + // Unmergable calls that prevent merging a parallel region. + OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = { + OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind], + OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads], + }; + + bool Changed = false; + LoopInfo *LI = nullptr; + DominatorTree *DT = nullptr; + + SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap; + + BasicBlock *StartBB = nullptr, *EndBB = nullptr; + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, + BasicBlock &ContinuationIP) { + BasicBlock *CGStartBB = CodeGenIP.getBlock(); + BasicBlock *CGEndBB = + SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI); + assert(StartBB != nullptr && "StartBB should not be null"); + CGStartBB->getTerminator()->setSuccessor(0, StartBB); + assert(EndBB != nullptr && "EndBB should not be null"); + EndBB->getTerminator()->setSuccessor(0, CGEndBB); + }; + + auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &, + Value &Inner, Value *&ReplacementValue) -> InsertPointTy { + ReplacementValue = &Inner; + return CodeGenIP; + }; + + auto FiniCB = [&](InsertPointTy CodeGenIP) {}; + + /// Create a sequential execution region within a merged parallel region, + /// encapsulated in a master construct with a barrier for synchronization. + auto CreateSequentialRegion = [&](Function *OuterFn, + BasicBlock *OuterPredBB, + Instruction *SeqStartI, + Instruction *SeqEndI) { + // Isolate the instructions of the sequential region to a separate + // block. + BasicBlock *ParentBB = SeqStartI->getParent(); + BasicBlock *SeqEndBB = + SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI); + BasicBlock *SeqAfterBB = + SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI); + BasicBlock *SeqStartBB = + SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged"); + + assert(ParentBB->getUniqueSuccessor() == SeqStartBB && + "Expected a different CFG"); + const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc(); + ParentBB->getTerminator()->eraseFromParent(); + + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, + BasicBlock &ContinuationIP) { + BasicBlock *CGStartBB = CodeGenIP.getBlock(); + BasicBlock *CGEndBB = + SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI); + assert(SeqStartBB != nullptr && "SeqStartBB should not be null"); + CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB); + assert(SeqEndBB != nullptr && "SeqEndBB should not be null"); + SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB); + }; + auto FiniCB = [&](InsertPointTy CodeGenIP) {}; + + // Find outputs from the sequential region to outside users and + // broadcast their values to them. + for (Instruction &I : *SeqStartBB) { + SmallPtrSet<Instruction *, 4> OutsideUsers; + for (User *Usr : I.users()) { + Instruction &UsrI = *cast<Instruction>(Usr); + // Ignore outputs to LT intrinsics, code extraction for the merged + // parallel region will fix them. + if (UsrI.isLifetimeStartOrEnd()) + continue; + + if (UsrI.getParent() != SeqStartBB) + OutsideUsers.insert(&UsrI); + } + + if (OutsideUsers.empty()) + continue; + + // Emit an alloca in the outer region to store the broadcasted + // value. + const DataLayout &DL = M.getDataLayout(); + AllocaInst *AllocaI = new AllocaInst( + I.getType(), DL.getAllocaAddrSpace(), nullptr, + I.getName() + ".seq.output.alloc", &OuterFn->front().front()); + + // Emit a store instruction in the sequential BB to update the + // value. + new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()); + + // Emit a load instruction and replace the use of the output value + // with it. + for (Instruction *UsrI : OutsideUsers) { + LoadInst *LoadI = new LoadInst(I.getType(), AllocaI, + I.getName() + ".seq.output.load", UsrI); + UsrI->replaceUsesOfWith(&I, LoadI); + } + } + + OpenMPIRBuilder::LocationDescription Loc( + InsertPointTy(ParentBB, ParentBB->end()), DL); + InsertPointTy SeqAfterIP = + OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB); + + OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel); + + BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock()); + + LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn + << "\n"); + }; + + // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all + // contained in BB and only separated by instructions that can be + // redundantly executed in parallel. The block BB is split before the first + // call (in MergableCIs) and after the last so the entire region we merge + // into a single parallel region is contained in a single basic block + // without any other instructions. We use the OpenMPIRBuilder to outline + // that block and call the resulting function via __kmpc_fork_call. + auto Merge = [&](SmallVectorImpl<CallInst *> &MergableCIs, BasicBlock *BB) { + // TODO: Change the interface to allow single CIs expanded, e.g, to + // include an outer loop. + assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs"); + + auto Remark = [&](OptimizationRemark OR) { + OR << "Parallel region at " + << ore::NV("OpenMPParallelMergeFront", + MergableCIs.front()->getDebugLoc()) + << " merged with parallel regions at "; + for (auto *CI : llvm::drop_begin(MergableCIs)) { + OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc()); + if (CI != MergableCIs.back()) + OR << ", "; + } + return OR; + }; + + emitRemark<OptimizationRemark>(MergableCIs.front(), + "OpenMPParallelRegionMerging", Remark); + + Function *OriginalFn = BB->getParent(); + LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size() + << " parallel regions in " << OriginalFn->getName() + << "\n"); + + // Isolate the calls to merge in a separate block. + EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI); + BasicBlock *AfterBB = + SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI); + StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr, + "omp.par.merged"); + + assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG"); + const DebugLoc DL = BB->getTerminator()->getDebugLoc(); + BB->getTerminator()->eraseFromParent(); + + // Create sequential regions for sequential instructions that are + // in-between mergable parallel regions. + for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1; + It != End; ++It) { + Instruction *ForkCI = *It; + Instruction *NextForkCI = *(It + 1); + + // Continue if there are not in-between instructions. + if (ForkCI->getNextNode() == NextForkCI) + continue; + + CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(), + NextForkCI->getPrevNode()); + } + + OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()), + DL); + IRBuilder<>::InsertPoint AllocaIP( + &OriginalFn->getEntryBlock(), + OriginalFn->getEntryBlock().getFirstInsertionPt()); + // Create the merged parallel region with default proc binding, to + // avoid overriding binding settings, and without explicit cancellation. + InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel( + Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr, + OMP_PROC_BIND_default, /* IsCancellable */ false); + BranchInst::Create(AfterBB, AfterIP.getBlock()); + + // Perform the actual outlining. + OMPInfoCache.OMPBuilder.finalize(/* AllowExtractorSinking */ true); + + Function *OutlinedFn = MergableCIs.front()->getCaller(); + + // Replace the __kmpc_fork_call calls with direct calls to the outlined + // callbacks. + SmallVector<Value *, 8> Args; + for (auto *CI : MergableCIs) { + Value *Callee = + CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts(); + FunctionType *FT = + cast<FunctionType>(Callee->getType()->getPointerElementType()); + Args.clear(); + Args.push_back(OutlinedFn->getArg(0)); + Args.push_back(OutlinedFn->getArg(1)); + for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands(); + U < E; ++U) + Args.push_back(CI->getArgOperand(U)); + + CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI); + if (CI->getDebugLoc()) + NewCI->setDebugLoc(CI->getDebugLoc()); + + // Forward parameter attributes from the callback to the callee. + for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands(); + U < E; ++U) + for (const Attribute &A : CI->getAttributes().getParamAttributes(U)) + NewCI->addParamAttr( + U - (CallbackFirstArgOperand - CallbackCalleeOperand), A); + + // Emit an explicit barrier to replace the implicit fork-join barrier. + if (CI != MergableCIs.back()) { + // TODO: Remove barrier if the merged parallel region includes the + // 'nowait' clause. + OMPInfoCache.OMPBuilder.createBarrier( + InsertPointTy(NewCI->getParent(), + NewCI->getNextNode()->getIterator()), + OMPD_parallel); + } + + auto Remark = [&](OptimizationRemark OR) { + return OR << "Parallel region at " + << ore::NV("OpenMPParallelMerge", CI->getDebugLoc()) + << " merged with " + << ore::NV("OpenMPParallelMergeFront", + MergableCIs.front()->getDebugLoc()); + }; + if (CI != MergableCIs.front()) + emitRemark<OptimizationRemark>(CI, "OpenMPParallelRegionMerging", + Remark); + + CI->eraseFromParent(); + } + + assert(OutlinedFn != OriginalFn && "Outlining failed"); + CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn); + CGUpdater.reanalyzeFunction(*OriginalFn); + + NumOpenMPParallelRegionsMerged += MergableCIs.size(); + + return true; + }; + + // Helper function that identifes sequences of + // __kmpc_fork_call uses in a basic block. + auto DetectPRsCB = [&](Use &U, Function &F) { + CallInst *CI = getCallIfRegularCall(U, &RFI); + BB2PRMap[CI->getParent()].insert(CI); + + return false; + }; + + BB2PRMap.clear(); + RFI.foreachUse(SCC, DetectPRsCB); + SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector; + // Find mergable parallel regions within a basic block that are + // safe to merge, that is any in-between instructions can safely + // execute in parallel after merging. + // TODO: support merging across basic-blocks. + for (auto &It : BB2PRMap) { + auto &CIs = It.getSecond(); + if (CIs.size() < 2) + continue; + + BasicBlock *BB = It.getFirst(); + SmallVector<CallInst *, 4> MergableCIs; + + /// Returns true if the instruction is mergable, false otherwise. + /// A terminator instruction is unmergable by definition since merging + /// works within a BB. Instructions before the mergable region are + /// mergable if they are not calls to OpenMP runtime functions that may + /// set different execution parameters for subsequent parallel regions. + /// Instructions in-between parallel regions are mergable if they are not + /// calls to any non-intrinsic function since that may call a non-mergable + /// OpenMP runtime function. + auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) { + // We do not merge across BBs, hence return false (unmergable) if the + // instruction is a terminator. + if (I.isTerminator()) + return false; + + if (!isa<CallInst>(&I)) + return true; + + CallInst *CI = cast<CallInst>(&I); + if (IsBeforeMergableRegion) { + Function *CalledFunction = CI->getCalledFunction(); + if (!CalledFunction) + return false; + // Return false (unmergable) if the call before the parallel + // region calls an explicit affinity (proc_bind) or number of + // threads (num_threads) compiler-generated function. Those settings + // may be incompatible with following parallel regions. + // TODO: ICV tracking to detect compatibility. + for (const auto &RFI : UnmergableCallsInfo) { + if (CalledFunction == RFI.Declaration) + return false; + } + } else { + // Return false (unmergable) if there is a call instruction + // in-between parallel regions when it is not an intrinsic. It + // may call an unmergable OpenMP runtime function in its callpath. + // TODO: Keep track of possible OpenMP calls in the callpath. + if (!isa<IntrinsicInst>(CI)) + return false; + } + + return true; + }; + // Find maximal number of parallel region CIs that are safe to merge. + for (auto It = BB->begin(), End = BB->end(); It != End;) { + Instruction &I = *It; + ++It; + + if (CIs.count(&I)) { + MergableCIs.push_back(cast<CallInst>(&I)); + continue; + } + + // Continue expanding if the instruction is mergable. + if (IsMergable(I, MergableCIs.empty())) + continue; + + // Forward the instruction iterator to skip the next parallel region + // since there is an unmergable instruction which can affect it. + for (; It != End; ++It) { + Instruction &SkipI = *It; + if (CIs.count(&SkipI)) { + LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI + << " due to " << I << "\n"); + ++It; + break; + } + } + + // Store mergable regions found. + if (MergableCIs.size() > 1) { + MergableCIsVector.push_back(MergableCIs); + LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size() + << " parallel regions in block " << BB->getName() + << " of function " << BB->getParent()->getName() + << "\n";); + } + + MergableCIs.clear(); + } + + if (!MergableCIsVector.empty()) { + Changed = true; + + for (auto &MergableCIs : MergableCIsVector) + Merge(MergableCIs, BB); + } + } + + if (Changed) { + /// Re-collect use for fork calls, emitted barrier calls, and + /// any emitted master/end_master calls. + OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call); + OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier); + OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master); + OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master); + } + + return Changed; + } + /// Try to delete parallel regions if possible. bool deleteParallelRegions() { const unsigned CallbackCalleeOperand = 2; @@ -648,8 +1058,8 @@ private: for (Function *F : SCC) { for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs) - deduplicateRuntimeCalls(*F, - OMPInfoCache.RFIs[DeduplicableRuntimeCallID]); + Changed |= deduplicateRuntimeCalls( + *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]); // __kmpc_global_thread_num is special as we can replace it with an // argument in enough cases to make it worth trying. @@ -666,6 +1076,223 @@ private: return Changed; } + /// Tries to hide the latency of runtime calls that involve host to + /// device memory transfers by splitting them into their "issue" and "wait" + /// versions. The "issue" is moved upwards as much as possible. The "wait" is + /// moved downards as much as possible. The "issue" issues the memory transfer + /// asynchronously, returning a handle. The "wait" waits in the returned + /// handle for the memory transfer to finish. + bool hideMemTransfersLatency() { + auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper]; + bool Changed = false; + auto SplitMemTransfers = [&](Use &U, Function &Decl) { + auto *RTCall = getCallIfRegularCall(U, &RFI); + if (!RTCall) + return false; + + OffloadArray OffloadArrays[3]; + if (!getValuesInOffloadArrays(*RTCall, OffloadArrays)) + return false; + + LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays)); + + // TODO: Check if can be moved upwards. + bool WasSplit = false; + Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall); + if (WaitMovementPoint) + WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint); + + Changed |= WasSplit; + return WasSplit; + }; + RFI.foreachUse(SCC, SplitMemTransfers); + + return Changed; + } + + void analysisGlobalization() { + RuntimeFunction GlobalizationRuntimeIDs[] = { + OMPRTL___kmpc_data_sharing_coalesced_push_stack, + OMPRTL___kmpc_data_sharing_push_stack}; + + for (const auto GlobalizationCallID : GlobalizationRuntimeIDs) { + auto &RFI = OMPInfoCache.RFIs[GlobalizationCallID]; + + auto CheckGlobalization = [&](Use &U, Function &Decl) { + if (CallInst *CI = getCallIfRegularCall(U, &RFI)) { + auto Remark = [&](OptimizationRemarkAnalysis ORA) { + return ORA + << "Found thread data sharing on the GPU. " + << "Expect degraded performance due to data globalization."; + }; + emitRemark<OptimizationRemarkAnalysis>(CI, "OpenMPGlobalization", + Remark); + } + + return false; + }; + + RFI.foreachUse(SCC, CheckGlobalization); + } + } + + /// Maps the values stored in the offload arrays passed as arguments to + /// \p RuntimeCall into the offload arrays in \p OAs. + bool getValuesInOffloadArrays(CallInst &RuntimeCall, + MutableArrayRef<OffloadArray> OAs) { + assert(OAs.size() == 3 && "Need space for three offload arrays!"); + + // A runtime call that involves memory offloading looks something like: + // call void @__tgt_target_data_begin_mapper(arg0, arg1, + // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes, + // ...) + // So, the idea is to access the allocas that allocate space for these + // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes. + // Therefore: + // i8** %offload_baseptrs. + Value *BasePtrsArg = + RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum); + // i8** %offload_ptrs. + Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum); + // i8** %offload_sizes. + Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum); + + // Get values stored in **offload_baseptrs. + auto *V = getUnderlyingObject(BasePtrsArg); + if (!isa<AllocaInst>(V)) + return false; + auto *BasePtrsArray = cast<AllocaInst>(V); + if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall)) + return false; + + // Get values stored in **offload_baseptrs. + V = getUnderlyingObject(PtrsArg); + if (!isa<AllocaInst>(V)) + return false; + auto *PtrsArray = cast<AllocaInst>(V); + if (!OAs[1].initialize(*PtrsArray, RuntimeCall)) + return false; + + // Get values stored in **offload_sizes. + V = getUnderlyingObject(SizesArg); + // If it's a [constant] global array don't analyze it. + if (isa<GlobalValue>(V)) + return isa<Constant>(V); + if (!isa<AllocaInst>(V)) + return false; + + auto *SizesArray = cast<AllocaInst>(V); + if (!OAs[2].initialize(*SizesArray, RuntimeCall)) + return false; + + return true; + } + + /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG. + /// For now this is a way to test that the function getValuesInOffloadArrays + /// is working properly. + /// TODO: Move this to a unittest when unittests are available for OpenMPOpt. + void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) { + assert(OAs.size() == 3 && "There are three offload arrays to debug!"); + + LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n"); + std::string ValuesStr; + raw_string_ostream Printer(ValuesStr); + std::string Separator = " --- "; + + for (auto *BP : OAs[0].StoredValues) { + BP->print(Printer); + Printer << Separator; + } + LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n"); + ValuesStr.clear(); + + for (auto *P : OAs[1].StoredValues) { + P->print(Printer); + Printer << Separator; + } + LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n"); + ValuesStr.clear(); + + for (auto *S : OAs[2].StoredValues) { + S->print(Printer); + Printer << Separator; + } + LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n"); + } + + /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be + /// moved. Returns nullptr if the movement is not possible, or not worth it. + Instruction *canBeMovedDownwards(CallInst &RuntimeCall) { + // FIXME: This traverses only the BasicBlock where RuntimeCall is. + // Make it traverse the CFG. + + Instruction *CurrentI = &RuntimeCall; + bool IsWorthIt = false; + while ((CurrentI = CurrentI->getNextNode())) { + + // TODO: Once we detect the regions to be offloaded we should use the + // alias analysis manager to check if CurrentI may modify one of + // the offloaded regions. + if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) { + if (IsWorthIt) + return CurrentI; + + return nullptr; + } + + // FIXME: For now if we move it over anything without side effect + // is worth it. + IsWorthIt = true; + } + + // Return end of BasicBlock. + return RuntimeCall.getParent()->getTerminator(); + } + + /// Splits \p RuntimeCall into its "issue" and "wait" counterparts. + bool splitTargetDataBeginRTC(CallInst &RuntimeCall, + Instruction &WaitMovementPoint) { + // Create stack allocated handle (__tgt_async_info) at the beginning of the + // function. Used for storing information of the async transfer, allowing to + // wait on it later. + auto &IRBuilder = OMPInfoCache.OMPBuilder; + auto *F = RuntimeCall.getCaller(); + Instruction *FirstInst = &(F->getEntryBlock().front()); + AllocaInst *Handle = new AllocaInst( + IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst); + + // Add "issue" runtime call declaration: + // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32, + // i8**, i8**, i64*, i64*) + FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction( + M, OMPRTL___tgt_target_data_begin_mapper_issue); + + // Change RuntimeCall call site for its asynchronous version. + SmallVector<Value *, 16> Args; + for (auto &Arg : RuntimeCall.args()) + Args.push_back(Arg.get()); + Args.push_back(Handle); + + CallInst *IssueCallsite = + CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall); + RuntimeCall.eraseFromParent(); + + // Add "wait" runtime call declaration: + // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info) + FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction( + M, OMPRTL___tgt_target_data_begin_mapper_wait); + + Value *WaitParams[2] = { + IssueCallsite->getArgOperand( + OffloadArray::DeviceIDArgNum), // device_id. + Handle // handle to wait on. + }; + CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint); + + return true; + } + static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent, bool GlobalOnly, bool &SingleChoice) { if (CurrentIdent == NextIdent) @@ -951,11 +1578,28 @@ private: /// Populate the Attributor with abstract attribute opportunities in the /// function. void registerAAs() { - for (Function *F : SCC) { - if (F->isDeclaration()) - continue; + if (SCC.empty()) + return; + + // Create CallSite AA for all Getters. + for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) { + auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)]; + + auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter]; + + auto CreateAA = [&](Use &U, Function &Caller) { + CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI); + if (!CI) + return false; + + auto &CB = cast<CallBase>(*CI); + + IRPosition CBPos = IRPosition::callsite_function(CB); + A.getOrCreateAAFor<AAICVTracker>(CBPos); + return false; + }; - A.getOrCreateAAFor<AAICVTracker>(IRPosition::function(*F)); + GetterRFI.foreachUse(SCC, CreateAA); } } }; @@ -979,8 +1623,16 @@ Kernel OpenMPOpt::getUniqueKernelFor(Function &F) { } CachedKernel = nullptr; - if (!F.hasLocalLinkage()) + if (!F.hasLocalLinkage()) { + + // See https://openmp.llvm.org/remarks/OptimizationRemarks.html + auto Remark = [&](OptimizationRemark OR) { + return OR << "[OMP100] Potentially unknown OpenMP target region caller"; + }; + emitRemarkOnFunction(&F, "OMP100", Remark); + return nullptr; + } } auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel { @@ -1006,7 +1658,7 @@ Kernel OpenMPOpt::getUniqueKernelFor(Function &F) { // TODO: In the future we want to track more than just a unique kernel. SmallPtrSet<Kernel, 2> PotentialKernels; - foreachUse(F, [&](const Use &U) { + OMPInformationCache::foreachUse(F, [&](const Use &U) { PotentialKernels.insert(GetUniqueKernelForUse(U)); }); @@ -1037,7 +1689,7 @@ bool OpenMPOpt::rewriteDeviceCodeStateMachine() { unsigned NumDirectCalls = 0; SmallVector<Use *, 2> ToBeReplacedStateMachineUses; - foreachUse(*F, [&](Use &U) { + OMPInformationCache::foreachUse(*F, [&](Use &U) { if (auto *CB = dyn_cast<CallBase>(U.getUser())) if (CB->isCallee(&U)) { ++NumDirectCalls; @@ -1157,6 +1809,12 @@ struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> { using Base = StateWrapper<BooleanState, AbstractAttribute>; AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {} + void initialize(Attributor &A) override { + Function *F = getAnchorScope(); + if (!F || !A.isFunctionIPOAmendable(*F)) + indicatePessimisticFixpoint(); + } + /// Returns true if value is assumed to be tracked. bool isAssumedTracked() const { return getAssumed(); } @@ -1167,8 +1825,21 @@ struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> { static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A); /// Return the value with which \p I can be replaced for specific \p ICV. - virtual Value *getReplacementValue(InternalControlVar ICV, - const Instruction *I, Attributor &A) = 0; + virtual Optional<Value *> getReplacementValue(InternalControlVar ICV, + const Instruction *I, + Attributor &A) const { + return None; + } + + /// Return an assumed unique ICV value if a single candidate is found. If + /// there cannot be one, return a nullptr. If it is not clear yet, return the + /// Optional::NoneType. + virtual Optional<Value *> + getUniqueReplacementValue(InternalControlVar ICV) const = 0; + + // Currently only nthreads is being tracked. + // this array will only grow with time. + InternalControlVar TrackableICVs[1] = {ICV_nthreads}; /// See AbstractAttribute::getName() const std::string getName() const override { return "AAICVTracker"; } @@ -1189,57 +1860,20 @@ struct AAICVTrackerFunction : public AAICVTracker { : AAICVTracker(IRP, A) {} // FIXME: come up with better string. - const std::string getAsStr() const override { return "ICVTracker"; } + const std::string getAsStr() const override { return "ICVTrackerFunction"; } // FIXME: come up with some stats. void trackStatistics() const override {} - /// TODO: decide whether to deduplicate here, or use current - /// deduplicateRuntimeCalls function. + /// We don't manifest anything for this AA. ChangeStatus manifest(Attributor &A) override { - ChangeStatus Changed = ChangeStatus::UNCHANGED; - - for (InternalControlVar &ICV : TrackableICVs) - if (deduplicateICVGetters(ICV, A)) - Changed = ChangeStatus::CHANGED; - - return Changed; - } - - bool deduplicateICVGetters(InternalControlVar &ICV, Attributor &A) { - auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); - auto &ICVInfo = OMPInfoCache.ICVs[ICV]; - auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter]; - - bool Changed = false; - - auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) { - CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI); - Instruction *UserI = cast<Instruction>(U.getUser()); - Value *ReplVal = getReplacementValue(ICV, UserI, A); - - if (!ReplVal || !CI) - return false; - - A.removeCallSite(CI); - CI->replaceAllUsesWith(ReplVal); - CI->eraseFromParent(); - Changed = true; - return true; - }; - - GetterRFI.foreachUse(ReplaceAndDeleteCB, getAnchorScope()); - return Changed; + return ChangeStatus::UNCHANGED; } // Map of ICV to their values at specific program point. - EnumeratedArray<SmallSetVector<ICVValue, 4>, InternalControlVar, + EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar, InternalControlVar::ICV___last> - ICVValuesMap; - - // Currently only nthreads is being tracked. - // this array will only grow with time. - InternalControlVar TrackableICVs[1] = {ICV_nthreads}; + ICVReplacementValuesMap; ChangeStatus updateImpl(Attributor &A) override { ChangeStatus HasChanged = ChangeStatus::UNCHANGED; @@ -1251,6 +1885,7 @@ struct AAICVTrackerFunction : public AAICVTracker { for (InternalControlVar ICV : TrackableICVs) { auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter]; + auto &ValuesMap = ICVReplacementValuesMap[ICV]; auto TrackValues = [&](Use &U, Function &) { CallInst *CI = OpenMPOpt::getCallIfRegularCall(U); if (!CI) @@ -1258,51 +1893,342 @@ struct AAICVTrackerFunction : public AAICVTracker { // FIXME: handle setters with more that 1 arguments. /// Track new value. - if (ICVValuesMap[ICV].insert(ICVValue(CI, CI->getArgOperand(0)))) + if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second) HasChanged = ChangeStatus::CHANGED; return false; }; + auto CallCheck = [&](Instruction &I) { + Optional<Value *> ReplVal = getValueForCall(A, &I, ICV); + if (ReplVal.hasValue() && + ValuesMap.insert(std::make_pair(&I, *ReplVal)).second) + HasChanged = ChangeStatus::CHANGED; + + return true; + }; + + // Track all changes of an ICV. SetterRFI.foreachUse(TrackValues, F); + + A.checkForAllInstructions(CallCheck, *this, {Instruction::Call}, + /* CheckBBLivenessOnly */ true); + + /// TODO: Figure out a way to avoid adding entry in + /// ICVReplacementValuesMap + Instruction *Entry = &F->getEntryBlock().front(); + if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry)) + ValuesMap.insert(std::make_pair(Entry, nullptr)); } return HasChanged; } - /// Return the value with which \p I can be replaced for specific \p ICV. - Value *getReplacementValue(InternalControlVar ICV, const Instruction *I, - Attributor &A) override { - const BasicBlock *CurrBB = I->getParent(); + /// Hepler to check if \p I is a call and get the value for it if it is + /// unique. + Optional<Value *> getValueForCall(Attributor &A, const Instruction *I, + InternalControlVar &ICV) const { + + const auto *CB = dyn_cast<CallBase>(I); + if (!CB || CB->hasFnAttr("no_openmp") || + CB->hasFnAttr("no_openmp_routines")) + return None; - auto &ValuesSet = ICVValuesMap[ICV]; auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter]; + auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter]; + Function *CalledFunction = CB->getCalledFunction(); - for (const auto &ICVVal : ValuesSet) { - if (CurrBB == ICVVal.Inst->getParent()) { - if (!ICVVal.Inst->comesBefore(I)) - continue; + // Indirect call, assume ICV changes. + if (CalledFunction == nullptr) + return nullptr; + if (CalledFunction == GetterRFI.Declaration) + return None; + if (CalledFunction == SetterRFI.Declaration) { + if (ICVReplacementValuesMap[ICV].count(I)) + return ICVReplacementValuesMap[ICV].lookup(I); - // both instructions are in the same BB and at \p I we know the ICV - // value. - while (I != ICVVal.Inst) { - // we don't yet know if a call might update an ICV. - // TODO: check callsite AA for value. - if (const auto *CB = dyn_cast<CallBase>(I)) - if (CB->getCalledFunction() != GetterRFI.Declaration) + return nullptr; + } + + // Since we don't know, assume it changes the ICV. + if (CalledFunction->isDeclaration()) + return nullptr; + + const auto &ICVTrackingAA = + A.getAAFor<AAICVTracker>(*this, IRPosition::callsite_returned(*CB)); + + if (ICVTrackingAA.isAssumedTracked()) + return ICVTrackingAA.getUniqueReplacementValue(ICV); + + // If we don't know, assume it changes. + return nullptr; + } + + // We don't check unique value for a function, so return None. + Optional<Value *> + getUniqueReplacementValue(InternalControlVar ICV) const override { + return None; + } + + /// Return the value with which \p I can be replaced for specific \p ICV. + Optional<Value *> getReplacementValue(InternalControlVar ICV, + const Instruction *I, + Attributor &A) const override { + const auto &ValuesMap = ICVReplacementValuesMap[ICV]; + if (ValuesMap.count(I)) + return ValuesMap.lookup(I); + + SmallVector<const Instruction *, 16> Worklist; + SmallPtrSet<const Instruction *, 16> Visited; + Worklist.push_back(I); + + Optional<Value *> ReplVal; + + while (!Worklist.empty()) { + const Instruction *CurrInst = Worklist.pop_back_val(); + if (!Visited.insert(CurrInst).second) + continue; + + const BasicBlock *CurrBB = CurrInst->getParent(); + + // Go up and look for all potential setters/calls that might change the + // ICV. + while ((CurrInst = CurrInst->getPrevNode())) { + if (ValuesMap.count(CurrInst)) { + Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst); + // Unknown value, track new. + if (!ReplVal.hasValue()) { + ReplVal = NewReplVal; + break; + } + + // If we found a new value, we can't know the icv value anymore. + if (NewReplVal.hasValue()) + if (ReplVal != NewReplVal) return nullptr; - I = I->getPrevNode(); + break; } - // No call in between, return the value. - return ICVVal.TrackedValue; + Optional<Value *> NewReplVal = getValueForCall(A, CurrInst, ICV); + if (!NewReplVal.hasValue()) + continue; + + // Unknown value, track new. + if (!ReplVal.hasValue()) { + ReplVal = NewReplVal; + break; + } + + // if (NewReplVal.hasValue()) + // We found a new value, we can't know the icv value anymore. + if (ReplVal != NewReplVal) + return nullptr; } + + // If we are in the same BB and we have a value, we are done. + if (CurrBB == I->getParent() && ReplVal.hasValue()) + return ReplVal; + + // Go through all predecessors and add terminators for analysis. + for (const BasicBlock *Pred : predecessors(CurrBB)) + if (const Instruction *Terminator = Pred->getTerminator()) + Worklist.push_back(Terminator); } - // No value was tracked. - return nullptr; + return ReplVal; + } +}; + +struct AAICVTrackerFunctionReturned : AAICVTracker { + AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A) + : AAICVTracker(IRP, A) {} + + // FIXME: come up with better string. + const std::string getAsStr() const override { + return "ICVTrackerFunctionReturned"; + } + + // FIXME: come up with some stats. + void trackStatistics() const override {} + + /// We don't manifest anything for this AA. + ChangeStatus manifest(Attributor &A) override { + return ChangeStatus::UNCHANGED; + } + + // Map of ICV to their values at specific program point. + EnumeratedArray<Optional<Value *>, InternalControlVar, + InternalControlVar::ICV___last> + ICVReplacementValuesMap; + + /// Return the value with which \p I can be replaced for specific \p ICV. + Optional<Value *> + getUniqueReplacementValue(InternalControlVar ICV) const override { + return ICVReplacementValuesMap[ICV]; + } + + ChangeStatus updateImpl(Attributor &A) override { + ChangeStatus Changed = ChangeStatus::UNCHANGED; + const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( + *this, IRPosition::function(*getAnchorScope())); + + if (!ICVTrackingAA.isAssumedTracked()) + return indicatePessimisticFixpoint(); + + for (InternalControlVar ICV : TrackableICVs) { + Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV]; + Optional<Value *> UniqueICVValue; + + auto CheckReturnInst = [&](Instruction &I) { + Optional<Value *> NewReplVal = + ICVTrackingAA.getReplacementValue(ICV, &I, A); + + // If we found a second ICV value there is no unique returned value. + if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal) + return false; + + UniqueICVValue = NewReplVal; + + return true; + }; + + if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret}, + /* CheckBBLivenessOnly */ true)) + UniqueICVValue = nullptr; + + if (UniqueICVValue == ReplVal) + continue; + + ReplVal = UniqueICVValue; + Changed = ChangeStatus::CHANGED; + } + + return Changed; + } +}; + +struct AAICVTrackerCallSite : AAICVTracker { + AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A) + : AAICVTracker(IRP, A) {} + + void initialize(Attributor &A) override { + Function *F = getAnchorScope(); + if (!F || !A.isFunctionIPOAmendable(*F)) + indicatePessimisticFixpoint(); + + // We only initialize this AA for getters, so we need to know which ICV it + // gets. + auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); + for (InternalControlVar ICV : TrackableICVs) { + auto ICVInfo = OMPInfoCache.ICVs[ICV]; + auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter]; + if (Getter.Declaration == getAssociatedFunction()) { + AssociatedICV = ICVInfo.Kind; + return; + } + } + + /// Unknown ICV. + indicatePessimisticFixpoint(); + } + + ChangeStatus manifest(Attributor &A) override { + if (!ReplVal.hasValue() || !ReplVal.getValue()) + return ChangeStatus::UNCHANGED; + + A.changeValueAfterManifest(*getCtxI(), **ReplVal); + A.deleteAfterManifest(*getCtxI()); + + return ChangeStatus::CHANGED; + } + + // FIXME: come up with better string. + const std::string getAsStr() const override { return "ICVTrackerCallSite"; } + + // FIXME: come up with some stats. + void trackStatistics() const override {} + + InternalControlVar AssociatedICV; + Optional<Value *> ReplVal; + + ChangeStatus updateImpl(Attributor &A) override { + const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( + *this, IRPosition::function(*getAnchorScope())); + + // We don't have any information, so we assume it changes the ICV. + if (!ICVTrackingAA.isAssumedTracked()) + return indicatePessimisticFixpoint(); + + Optional<Value *> NewReplVal = + ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A); + + if (ReplVal == NewReplVal) + return ChangeStatus::UNCHANGED; + + ReplVal = NewReplVal; + return ChangeStatus::CHANGED; + } + + // Return the value with which associated value can be replaced for specific + // \p ICV. + Optional<Value *> + getUniqueReplacementValue(InternalControlVar ICV) const override { + return ReplVal; + } +}; + +struct AAICVTrackerCallSiteReturned : AAICVTracker { + AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A) + : AAICVTracker(IRP, A) {} + + // FIXME: come up with better string. + const std::string getAsStr() const override { + return "ICVTrackerCallSiteReturned"; + } + + // FIXME: come up with some stats. + void trackStatistics() const override {} + + /// We don't manifest anything for this AA. + ChangeStatus manifest(Attributor &A) override { + return ChangeStatus::UNCHANGED; + } + + // Map of ICV to their values at specific program point. + EnumeratedArray<Optional<Value *>, InternalControlVar, + InternalControlVar::ICV___last> + ICVReplacementValuesMap; + + /// Return the value with which associated value can be replaced for specific + /// \p ICV. + Optional<Value *> + getUniqueReplacementValue(InternalControlVar ICV) const override { + return ICVReplacementValuesMap[ICV]; + } + + ChangeStatus updateImpl(Attributor &A) override { + ChangeStatus Changed = ChangeStatus::UNCHANGED; + const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( + *this, IRPosition::returned(*getAssociatedFunction())); + + // We don't have any information, so we assume it changes the ICV. + if (!ICVTrackingAA.isAssumedTracked()) + return indicatePessimisticFixpoint(); + + for (InternalControlVar ICV : TrackableICVs) { + Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV]; + Optional<Value *> NewReplVal = + ICVTrackingAA.getUniqueReplacementValue(ICV); + + if (ReplVal == NewReplVal) + continue; + + ReplVal = NewReplVal; + Changed = ChangeStatus::CHANGED; + } + return Changed; } }; } // namespace @@ -1316,11 +2242,17 @@ AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP, case IRPosition::IRP_INVALID: case IRPosition::IRP_FLOAT: case IRPosition::IRP_ARGUMENT: + case IRPosition::IRP_CALL_SITE_ARGUMENT: + llvm_unreachable("ICVTracker can only be created for function position!"); case IRPosition::IRP_RETURNED: + AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A); + break; case IRPosition::IRP_CALL_SITE_RETURNED: - case IRPosition::IRP_CALL_SITE_ARGUMENT: + AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A); + break; case IRPosition::IRP_CALL_SITE: - llvm_unreachable("ICVTracker can only be created for function position!"); + AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A); + break; case IRPosition::IRP_FUNCTION: AA = new (A.Allocator) AAICVTrackerFunction(IRP, A); break; @@ -1339,10 +2271,21 @@ PreservedAnalyses OpenMPOptPass::run(LazyCallGraph::SCC &C, return PreservedAnalyses::all(); SmallVector<Function *, 16> SCC; - for (LazyCallGraph::Node &N : C) - SCC.push_back(&N.getFunction()); + // If there are kernels in the module, we have to run on all SCC's. + bool SCCIsInteresting = !OMPInModule.getKernels().empty(); + for (LazyCallGraph::Node &N : C) { + Function *Fn = &N.getFunction(); + SCC.push_back(Fn); + + // Do we already know that the SCC contains kernels, + // or that OpenMP functions are called from this SCC? + if (SCCIsInteresting) + continue; + // If not, let's check that. + SCCIsInteresting |= OMPInModule.containsOMPRuntimeCalls(Fn); + } - if (SCC.empty()) + if (!SCCIsInteresting || SCC.empty()) return PreservedAnalyses::all(); FunctionAnalysisManager &FAM = @@ -1364,7 +2307,6 @@ PreservedAnalyses OpenMPOptPass::run(LazyCallGraph::SCC &C, Attributor A(Functions, InfoCache, CGUpdater); - // TODO: Compute the module slice we are allowed to look at. OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); bool Changed = OMPOpt.run(); if (Changed) @@ -1401,12 +2343,23 @@ struct OpenMPOptLegacyPass : public CallGraphSCCPass { return false; SmallVector<Function *, 16> SCC; - for (CallGraphNode *CGN : CGSCC) - if (Function *Fn = CGN->getFunction()) - if (!Fn->isDeclaration()) - SCC.push_back(Fn); + // If there are kernels in the module, we have to run on all SCC's. + bool SCCIsInteresting = !OMPInModule.getKernels().empty(); + for (CallGraphNode *CGN : CGSCC) { + Function *Fn = CGN->getFunction(); + if (!Fn || Fn->isDeclaration()) + continue; + SCC.push_back(Fn); - if (SCC.empty()) + // Do we already know that the SCC contains kernels, + // or that OpenMP functions are called from this SCC? + if (SCCIsInteresting) + continue; + // If not, let's check that. + SCCIsInteresting |= OMPInModule.containsOMPRuntimeCalls(Fn); + } + + if (!SCCIsInteresting || SCC.empty()) return false; CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); @@ -1430,7 +2383,6 @@ struct OpenMPOptLegacyPass : public CallGraphSCCPass { Attributor A(Functions, InfoCache, CGUpdater); - // TODO: Compute the module slice we are allowed to look at. OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); return OMPOpt.run(); } @@ -1468,13 +2420,19 @@ bool llvm::omp::containsOpenMP(Module &M, OpenMPInModule &OMPInModule) { if (OMPInModule.isKnown()) return OMPInModule; + auto RecordFunctionsContainingUsesOf = [&](Function *F) { + for (User *U : F->users()) + if (auto *I = dyn_cast<Instruction>(U)) + OMPInModule.FuncsWithOMPRuntimeCalls.insert(I->getFunction()); + }; + // MSVC doesn't like long if-else chains for some reason and instead just // issues an error. Work around it.. do { #define OMP_RTL(_Enum, _Name, ...) \ - if (M.getFunction(_Name)) { \ + if (Function *F = M.getFunction(_Name)) { \ + RecordFunctionsContainingUsesOf(F); \ OMPInModule = true; \ - break; \ } #include "llvm/Frontend/OpenMP/OMPKinds.def" } while (false); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/PartialInlining.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/PartialInlining.cpp index 5d863f1330a4..2bbf4bf110ae 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/PartialInlining.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/PartialInlining.cpp @@ -97,13 +97,6 @@ static cl::opt<bool> MarkOutlinedColdCC("pi-mark-coldcc", cl::init(false), cl::Hidden, cl::desc("Mark outline function calls with ColdCC")); -#ifndef NDEBUG -// Command line option to debug partial-inlining. The default is none: -static cl::opt<bool> TracePartialInlining("trace-partial-inlining", - cl::init(false), cl::Hidden, - cl::desc("Trace partial inlining.")); -#endif - // This is an option used by testing: static cl::opt<bool> SkipCostAnalysis("skip-partial-inlining-cost-analysis", cl::init(false), cl::ZeroOrMore, @@ -159,7 +152,7 @@ struct FunctionOutliningInfo { // Returns the number of blocks to be inlined including all blocks // in Entries and one return block. - unsigned GetNumInlinedBlocks() const { return Entries.size() + 1; } + unsigned getNumInlinedBlocks() const { return Entries.size() + 1; } // A set of blocks including the function entry that guard // the region to be outlined. @@ -215,7 +208,7 @@ struct PartialInlinerImpl { // function (only if we partially inlined early returns) as there is a // possibility to further "peel" early return statements that were left in the // outline function due to code size. - std::pair<bool, Function *> unswitchFunction(Function *F); + std::pair<bool, Function *> unswitchFunction(Function &F); // This class speculatively clones the function to be partial inlined. // At the end of partial inlining, the remaining callsites to the cloned @@ -226,16 +219,19 @@ struct PartialInlinerImpl { // multi-region outlining. FunctionCloner(Function *F, FunctionOutliningInfo *OI, OptimizationRemarkEmitter &ORE, - function_ref<AssumptionCache *(Function &)> LookupAC); + function_ref<AssumptionCache *(Function &)> LookupAC, + function_ref<TargetTransformInfo &(Function &)> GetTTI); FunctionCloner(Function *F, FunctionOutliningMultiRegionInfo *OMRI, OptimizationRemarkEmitter &ORE, - function_ref<AssumptionCache *(Function &)> LookupAC); + function_ref<AssumptionCache *(Function &)> LookupAC, + function_ref<TargetTransformInfo &(Function &)> GetTTI); + ~FunctionCloner(); // Prepare for function outlining: making sure there is only // one incoming edge from the extracted/outlined region to // the return block. - void NormalizeReturnBlock(); + void normalizeReturnBlock() const; // Do function outlining for cold regions. bool doMultiRegionFunctionOutlining(); @@ -266,6 +262,7 @@ struct PartialInlinerImpl { std::unique_ptr<BlockFrequencyInfo> ClonedFuncBFI = nullptr; OptimizationRemarkEmitter &ORE; function_ref<AssumptionCache *(Function &)> LookupAC; + function_ref<TargetTransformInfo &(Function &)> GetTTI; }; private: @@ -281,13 +278,14 @@ private: // The result is no larger than 1 and is represented using BP. // (Note that the outlined region's 'head' block can only have incoming // edges from the guarding entry blocks). - BranchProbability getOutliningCallBBRelativeFreq(FunctionCloner &Cloner); + BranchProbability + getOutliningCallBBRelativeFreq(FunctionCloner &Cloner) const; // Return true if the callee of CB should be partially inlined with // profit. bool shouldPartialInline(CallBase &CB, FunctionCloner &Cloner, BlockFrequency WeightedOutliningRcost, - OptimizationRemarkEmitter &ORE); + OptimizationRemarkEmitter &ORE) const; // Try to inline DuplicateFunction (cloned from F with call to // the OutlinedFunction into its callers. Return true @@ -296,10 +294,11 @@ private: // Compute the mapping from use site of DuplicationFunction to the enclosing // BB's profile count. - void computeCallsiteToProfCountMap(Function *DuplicateFunction, - DenseMap<User *, uint64_t> &SiteCountMap); + void + computeCallsiteToProfCountMap(Function *DuplicateFunction, + DenseMap<User *, uint64_t> &SiteCountMap) const; - bool IsLimitReached() { + bool isLimitReached() const { return (MaxNumPartialInlining != -1 && NumPartialInlining >= MaxNumPartialInlining); } @@ -311,12 +310,12 @@ private: return nullptr; } - static CallBase *getOneCallSiteTo(Function *F) { - User *User = *F->user_begin(); + static CallBase *getOneCallSiteTo(Function &F) { + User *User = *F.user_begin(); return getSupportedCallBase(User); } - std::tuple<DebugLoc, BasicBlock *> getOneDebugLoc(Function *F) { + std::tuple<DebugLoc, BasicBlock *> getOneDebugLoc(Function &F) const { CallBase *CB = getOneCallSiteTo(F); DebugLoc DLoc = CB->getDebugLoc(); BasicBlock *Block = CB->getParent(); @@ -329,16 +328,19 @@ private: // outlined function itself; // - The second value is the estimated size of the new call sequence in // basic block Cloner.OutliningCallBB; - std::tuple<int, int> computeOutliningCosts(FunctionCloner &Cloner); + std::tuple<int, int> computeOutliningCosts(FunctionCloner &Cloner) const; // Compute the 'InlineCost' of block BB. InlineCost is a proxy used to // approximate both the size and runtime cost (Note that in the current // inline cost analysis, there is no clear distinction there either). - static int computeBBInlineCost(BasicBlock *BB); + static int computeBBInlineCost(BasicBlock *BB, TargetTransformInfo *TTI); + + std::unique_ptr<FunctionOutliningInfo> + computeOutliningInfo(Function &F) const; - std::unique_ptr<FunctionOutliningInfo> computeOutliningInfo(Function *F); std::unique_ptr<FunctionOutliningMultiRegionInfo> - computeOutliningColdRegionsInfo(Function *F, OptimizationRemarkEmitter &ORE); + computeOutliningColdRegionsInfo(Function &F, + OptimizationRemarkEmitter &ORE) const; }; struct PartialInlinerLegacyPass : public ModulePass { @@ -390,20 +392,20 @@ struct PartialInlinerLegacyPass : public ModulePass { } // end anonymous namespace std::unique_ptr<FunctionOutliningMultiRegionInfo> -PartialInlinerImpl::computeOutliningColdRegionsInfo(Function *F, - OptimizationRemarkEmitter &ORE) { - BasicBlock *EntryBlock = &F->front(); +PartialInlinerImpl::computeOutliningColdRegionsInfo( + Function &F, OptimizationRemarkEmitter &ORE) const { + BasicBlock *EntryBlock = &F.front(); - DominatorTree DT(*F); + DominatorTree DT(F); LoopInfo LI(DT); - BranchProbabilityInfo BPI(*F, LI); + BranchProbabilityInfo BPI(F, LI); std::unique_ptr<BlockFrequencyInfo> ScopedBFI; BlockFrequencyInfo *BFI; if (!GetBFI) { - ScopedBFI.reset(new BlockFrequencyInfo(*F, BPI, LI)); + ScopedBFI.reset(new BlockFrequencyInfo(F, BPI, LI)); BFI = ScopedBFI.get(); } else - BFI = &(GetBFI(*F)); + BFI = &(GetBFI(F)); // Return if we don't have profiling information. if (!PSI.hasInstrumentationProfile()) @@ -412,11 +414,6 @@ PartialInlinerImpl::computeOutliningColdRegionsInfo(Function *F, std::unique_ptr<FunctionOutliningMultiRegionInfo> OutliningInfo = std::make_unique<FunctionOutliningMultiRegionInfo>(); - auto IsSingleEntry = [](SmallVectorImpl<BasicBlock *> &BlockList) { - BasicBlock *Dom = BlockList.front(); - return BlockList.size() > 1 && Dom->hasNPredecessors(1); - }; - auto IsSingleExit = [&ORE](SmallVectorImpl<BasicBlock *> &BlockList) -> BasicBlock * { BasicBlock *ExitBlock = nullptr; @@ -432,8 +429,9 @@ PartialInlinerImpl::computeOutliningColdRegionsInfo(Function *F, << " has more than one region exit edge."; }); return nullptr; - } else - ExitBlock = Block; + } + + ExitBlock = Block; } } } @@ -448,14 +446,14 @@ PartialInlinerImpl::computeOutliningColdRegionsInfo(Function *F, // Use the same computeBBInlineCost function to compute the cost savings of // the outlining the candidate region. + TargetTransformInfo *FTTI = &GetTTI(F); int OverallFunctionCost = 0; - for (auto &BB : *F) - OverallFunctionCost += computeBBInlineCost(&BB); + for (auto &BB : F) + OverallFunctionCost += computeBBInlineCost(&BB, FTTI); + + LLVM_DEBUG(dbgs() << "OverallFunctionCost = " << OverallFunctionCost + << "\n";); -#ifndef NDEBUG - if (TracePartialInlining) - dbgs() << "OverallFunctionCost = " << OverallFunctionCost << "\n"; -#endif int MinOutlineRegionCost = static_cast<int>(OverallFunctionCost * MinRegionSizeRatio); BranchProbability MinBranchProbability( @@ -467,6 +465,7 @@ PartialInlinerImpl::computeOutliningColdRegionsInfo(Function *F, DenseMap<BasicBlock *, bool> VisitedMap; DFS.push_back(CurrEntry); VisitedMap[CurrEntry] = true; + // Use Depth First Search on the basic blocks to find CFG edges that are // considered cold. // Cold regions considered must also have its inline cost compared to the @@ -474,88 +473,98 @@ PartialInlinerImpl::computeOutliningColdRegionsInfo(Function *F, // if it reduced the inline cost of the function by 'MinOutlineRegionCost' or // more. while (!DFS.empty()) { - auto *thisBB = DFS.back(); + auto *ThisBB = DFS.back(); DFS.pop_back(); // Only consider regions with predecessor blocks that are considered // not-cold (default: part of the top 99.99% of all block counters) // AND greater than our minimum block execution count (default: 100). - if (PSI.isColdBlock(thisBB, BFI) || - BBProfileCount(thisBB) < MinBlockCounterExecution) + if (PSI.isColdBlock(ThisBB, BFI) || + BBProfileCount(ThisBB) < MinBlockCounterExecution) continue; - for (auto SI = succ_begin(thisBB); SI != succ_end(thisBB); ++SI) { + for (auto SI = succ_begin(ThisBB); SI != succ_end(ThisBB); ++SI) { if (VisitedMap[*SI]) continue; VisitedMap[*SI] = true; DFS.push_back(*SI); // If branch isn't cold, we skip to the next one. - BranchProbability SuccProb = BPI.getEdgeProbability(thisBB, *SI); + BranchProbability SuccProb = BPI.getEdgeProbability(ThisBB, *SI); if (SuccProb > MinBranchProbability) continue; -#ifndef NDEBUG - if (TracePartialInlining) { - dbgs() << "Found cold edge: " << thisBB->getName() << "->" - << (*SI)->getName() << "\nBranch Probability = " << SuccProb - << "\n"; - } -#endif + + LLVM_DEBUG(dbgs() << "Found cold edge: " << ThisBB->getName() << "->" + << SI->getName() + << "\nBranch Probability = " << SuccProb << "\n";); + SmallVector<BasicBlock *, 8> DominateVector; DT.getDescendants(*SI, DominateVector); + assert(!DominateVector.empty() && + "SI should be reachable and have at least itself as descendant"); + // We can only outline single entry regions (for now). - if (!IsSingleEntry(DominateVector)) + if (!DominateVector.front()->hasNPredecessors(1)) { + LLVM_DEBUG(dbgs() << "ABORT: Block " << SI->getName() + << " doesn't have a single predecessor in the " + "dominator tree\n";); continue; + } + BasicBlock *ExitBlock = nullptr; // We can only outline single exit regions (for now). - if (!(ExitBlock = IsSingleExit(DominateVector))) + if (!(ExitBlock = IsSingleExit(DominateVector))) { + LLVM_DEBUG(dbgs() << "ABORT: Block " << SI->getName() + << " doesn't have a unique successor\n";); continue; + } + int OutlineRegionCost = 0; for (auto *BB : DominateVector) - OutlineRegionCost += computeBBInlineCost(BB); + OutlineRegionCost += computeBBInlineCost(BB, &GetTTI(*BB->getParent())); -#ifndef NDEBUG - if (TracePartialInlining) - dbgs() << "OutlineRegionCost = " << OutlineRegionCost << "\n"; -#endif + LLVM_DEBUG(dbgs() << "OutlineRegionCost = " << OutlineRegionCost + << "\n";); - if (OutlineRegionCost < MinOutlineRegionCost) { + if (!SkipCostAnalysis && OutlineRegionCost < MinOutlineRegionCost) { ORE.emit([&]() { return OptimizationRemarkAnalysis(DEBUG_TYPE, "TooCostly", &SI->front()) - << ore::NV("Callee", F) << " inline cost-savings smaller than " + << ore::NV("Callee", &F) + << " inline cost-savings smaller than " << ore::NV("Cost", MinOutlineRegionCost); }); + + LLVM_DEBUG(dbgs() << "ABORT: Outline region cost is smaller than " + << MinOutlineRegionCost << "\n";); continue; } + // For now, ignore blocks that belong to a SISE region that is a // candidate for outlining. In the future, we may want to look // at inner regions because the outer region may have live-exit // variables. for (auto *BB : DominateVector) VisitedMap[BB] = true; + // ReturnBlock here means the block after the outline call BasicBlock *ReturnBlock = ExitBlock->getSingleSuccessor(); - // assert(ReturnBlock && "ReturnBlock is NULL somehow!"); FunctionOutliningMultiRegionInfo::OutlineRegionInfo RegInfo( DominateVector, DominateVector.front(), ExitBlock, ReturnBlock); OutliningInfo->ORI.push_back(RegInfo); -#ifndef NDEBUG - if (TracePartialInlining) { - dbgs() << "Found Cold Candidate starting at block: " - << DominateVector.front()->getName() << "\n"; - } -#endif + LLVM_DEBUG(dbgs() << "Found Cold Candidate starting at block: " + << DominateVector.front()->getName() << "\n";); ColdCandidateFound = true; NumColdRegionsFound++; } } + if (ColdCandidateFound) return OutliningInfo; - else - return std::unique_ptr<FunctionOutliningMultiRegionInfo>(); + + return std::unique_ptr<FunctionOutliningMultiRegionInfo>(); } std::unique_ptr<FunctionOutliningInfo> -PartialInlinerImpl::computeOutliningInfo(Function *F) { - BasicBlock *EntryBlock = &F->front(); +PartialInlinerImpl::computeOutliningInfo(Function &F) const { + BasicBlock *EntryBlock = &F.front(); BranchInst *BR = dyn_cast<BranchInst>(EntryBlock->getTerminator()); if (!BR || BR->isUnconditional()) return std::unique_ptr<FunctionOutliningInfo>(); @@ -598,7 +607,7 @@ PartialInlinerImpl::computeOutliningInfo(Function *F) { // The number of blocks to be inlined has already reached // the limit. When MaxNumInlineBlocks is set to 0 or 1, this // disables partial inlining for the function. - if (OutliningInfo->GetNumInlinedBlocks() >= MaxNumInlineBlocks) + if (OutliningInfo->getNumInlinedBlocks() >= MaxNumInlineBlocks) break; if (succ_size(CurrEntry) != 2) @@ -618,8 +627,7 @@ PartialInlinerImpl::computeOutliningInfo(Function *F) { break; } - BasicBlock *CommSucc; - BasicBlock *OtherSucc; + BasicBlock *CommSucc, *OtherSucc; std::tie(CommSucc, OtherSucc) = GetCommonSucc(Succ1, Succ2); if (!CommSucc) @@ -635,7 +643,7 @@ PartialInlinerImpl::computeOutliningInfo(Function *F) { // Do sanity check of the entries: threre should not // be any successors (not in the entry set) other than // {ReturnBlock, NonReturnBlock} - assert(OutliningInfo->Entries[0] == &F->front() && + assert(OutliningInfo->Entries[0] == &F.front() && "Function Entry must be the first in Entries vector"); DenseSet<BasicBlock *> Entries; for (BasicBlock *E : OutliningInfo->Entries) @@ -644,7 +652,7 @@ PartialInlinerImpl::computeOutliningInfo(Function *F) { // Returns true of BB has Predecessor which is not // in Entries set. auto HasNonEntryPred = [Entries](BasicBlock *BB) { - for (auto Pred : predecessors(BB)) { + for (auto *Pred : predecessors(BB)) { if (!Entries.count(Pred)) return true; } @@ -653,7 +661,7 @@ PartialInlinerImpl::computeOutliningInfo(Function *F) { auto CheckAndNormalizeCandidate = [Entries, HasNonEntryPred](FunctionOutliningInfo *OutliningInfo) { for (BasicBlock *E : OutliningInfo->Entries) { - for (auto Succ : successors(E)) { + for (auto *Succ : successors(E)) { if (Entries.count(Succ)) continue; if (Succ == OutliningInfo->ReturnBlock) @@ -673,7 +681,7 @@ PartialInlinerImpl::computeOutliningInfo(Function *F) { // Now further growing the candidate's inlining region by // peeling off dominating blocks from the outlining region: - while (OutliningInfo->GetNumInlinedBlocks() < MaxNumInlineBlocks) { + while (OutliningInfo->getNumInlinedBlocks() < MaxNumInlineBlocks) { BasicBlock *Cand = OutliningInfo->NonReturnBlock; if (succ_size(Cand) != 2) break; @@ -703,11 +711,11 @@ PartialInlinerImpl::computeOutliningInfo(Function *F) { } // Check if there is PGO data or user annotated branch data: -static bool hasProfileData(Function *F, FunctionOutliningInfo *OI) { - if (F->hasProfileData()) +static bool hasProfileData(const Function &F, const FunctionOutliningInfo &OI) { + if (F.hasProfileData()) return true; // Now check if any of the entry block has MD_prof data: - for (auto *E : OI->Entries) { + for (auto *E : OI.Entries) { BranchInst *BR = dyn_cast<BranchInst>(E->getTerminator()); if (!BR || BR->isUnconditional()) continue; @@ -718,8 +726,8 @@ static bool hasProfileData(Function *F, FunctionOutliningInfo *OI) { return false; } -BranchProbability -PartialInlinerImpl::getOutliningCallBBRelativeFreq(FunctionCloner &Cloner) { +BranchProbability PartialInlinerImpl::getOutliningCallBBRelativeFreq( + FunctionCloner &Cloner) const { BasicBlock *OutliningCallBB = Cloner.OutlinedFunctions.back().second; auto EntryFreq = Cloner.ClonedFuncBFI->getBlockFreq(&Cloner.ClonedFunc->getEntryBlock()); @@ -728,13 +736,13 @@ PartialInlinerImpl::getOutliningCallBBRelativeFreq(FunctionCloner &Cloner) { // FIXME Hackery needed because ClonedFuncBFI is based on the function BEFORE // we outlined any regions, so we may encounter situations where the // OutliningCallFreq is *slightly* bigger than the EntryFreq. - if (OutliningCallFreq.getFrequency() > EntryFreq.getFrequency()) { + if (OutliningCallFreq.getFrequency() > EntryFreq.getFrequency()) OutliningCallFreq = EntryFreq; - } + auto OutlineRegionRelFreq = BranchProbability::getBranchProbability( OutliningCallFreq.getFrequency(), EntryFreq.getFrequency()); - if (hasProfileData(Cloner.OrigFunc, Cloner.ClonedOI.get())) + if (hasProfileData(*Cloner.OrigFunc, *Cloner.ClonedOI.get())) return OutlineRegionRelFreq; // When profile data is not available, we need to be conservative in @@ -760,7 +768,7 @@ PartialInlinerImpl::getOutliningCallBBRelativeFreq(FunctionCloner &Cloner) { bool PartialInlinerImpl::shouldPartialInline( CallBase &CB, FunctionCloner &Cloner, BlockFrequency WeightedOutliningRcost, - OptimizationRemarkEmitter &ORE) { + OptimizationRemarkEmitter &ORE) const { using namespace ore; Function *Callee = CB.getCalledFunction(); @@ -843,7 +851,8 @@ bool PartialInlinerImpl::shouldPartialInline( // TODO: Ideally we should share Inliner's InlineCost Analysis code. // For now use a simplified version. The returned 'InlineCost' will be used // to esimate the size cost as well as runtime cost of the BB. -int PartialInlinerImpl::computeBBInlineCost(BasicBlock *BB) { +int PartialInlinerImpl::computeBBInlineCost(BasicBlock *BB, + TargetTransformInfo *TTI) { int InlineCost = 0; const DataLayout &DL = BB->getParent()->getParent()->getDataLayout(); for (Instruction &I : BB->instructionsWithoutDebug()) { @@ -866,6 +875,21 @@ int PartialInlinerImpl::computeBBInlineCost(BasicBlock *BB) { if (I.isLifetimeStartOrEnd()) continue; + if (auto *II = dyn_cast<IntrinsicInst>(&I)) { + Intrinsic::ID IID = II->getIntrinsicID(); + SmallVector<Type *, 4> Tys; + FastMathFlags FMF; + for (Value *Val : II->args()) + Tys.push_back(Val->getType()); + + if (auto *FPMO = dyn_cast<FPMathOperator>(II)) + FMF = FPMO->getFastMathFlags(); + + IntrinsicCostAttributes ICA(IID, II->getType(), Tys, FMF); + InlineCost += TTI->getIntrinsicInstrCost(ICA, TTI::TCK_SizeAndLatency); + continue; + } + if (CallInst *CI = dyn_cast<CallInst>(&I)) { InlineCost += getCallsiteCost(*CI, DL); continue; @@ -886,18 +910,20 @@ int PartialInlinerImpl::computeBBInlineCost(BasicBlock *BB) { } std::tuple<int, int> -PartialInlinerImpl::computeOutliningCosts(FunctionCloner &Cloner) { +PartialInlinerImpl::computeOutliningCosts(FunctionCloner &Cloner) const { int OutliningFuncCallCost = 0, OutlinedFunctionCost = 0; for (auto FuncBBPair : Cloner.OutlinedFunctions) { Function *OutlinedFunc = FuncBBPair.first; BasicBlock* OutliningCallBB = FuncBBPair.second; // Now compute the cost of the call sequence to the outlined function // 'OutlinedFunction' in BB 'OutliningCallBB': - OutliningFuncCallCost += computeBBInlineCost(OutliningCallBB); + auto *OutlinedFuncTTI = &GetTTI(*OutlinedFunc); + OutliningFuncCallCost += + computeBBInlineCost(OutliningCallBB, OutlinedFuncTTI); // Now compute the cost of the extracted/outlined function itself: for (BasicBlock &BB : *OutlinedFunc) - OutlinedFunctionCost += computeBBInlineCost(&BB); + OutlinedFunctionCost += computeBBInlineCost(&BB, OutlinedFuncTTI); } assert(OutlinedFunctionCost >= Cloner.OutlinedRegionCost && "Outlined function cost should be no less than the outlined region"); @@ -921,7 +947,7 @@ PartialInlinerImpl::computeOutliningCosts(FunctionCloner &Cloner) { // after the function is partially inlined into the callsite. void PartialInlinerImpl::computeCallsiteToProfCountMap( Function *DuplicateFunction, - DenseMap<User *, uint64_t> &CallSiteToProfCountMap) { + DenseMap<User *, uint64_t> &CallSiteToProfCountMap) const { std::vector<User *> Users(DuplicateFunction->user_begin(), DuplicateFunction->user_end()); Function *CurrentCaller = nullptr; @@ -962,8 +988,9 @@ void PartialInlinerImpl::computeCallsiteToProfCountMap( PartialInlinerImpl::FunctionCloner::FunctionCloner( Function *F, FunctionOutliningInfo *OI, OptimizationRemarkEmitter &ORE, - function_ref<AssumptionCache *(Function &)> LookupAC) - : OrigFunc(F), ORE(ORE), LookupAC(LookupAC) { + function_ref<AssumptionCache *(Function &)> LookupAC, + function_ref<TargetTransformInfo &(Function &)> GetTTI) + : OrigFunc(F), ORE(ORE), LookupAC(LookupAC), GetTTI(GetTTI) { ClonedOI = std::make_unique<FunctionOutliningInfo>(); // Clone the function, so that we can hack away on it. @@ -972,9 +999,9 @@ PartialInlinerImpl::FunctionCloner::FunctionCloner( ClonedOI->ReturnBlock = cast<BasicBlock>(VMap[OI->ReturnBlock]); ClonedOI->NonReturnBlock = cast<BasicBlock>(VMap[OI->NonReturnBlock]); - for (BasicBlock *BB : OI->Entries) { + for (BasicBlock *BB : OI->Entries) ClonedOI->Entries.push_back(cast<BasicBlock>(VMap[BB])); - } + for (BasicBlock *E : OI->ReturnBlockPreds) { BasicBlock *NewE = cast<BasicBlock>(VMap[E]); ClonedOI->ReturnBlockPreds.push_back(NewE); @@ -987,8 +1014,9 @@ PartialInlinerImpl::FunctionCloner::FunctionCloner( PartialInlinerImpl::FunctionCloner::FunctionCloner( Function *F, FunctionOutliningMultiRegionInfo *OI, OptimizationRemarkEmitter &ORE, - function_ref<AssumptionCache *(Function &)> LookupAC) - : OrigFunc(F), ORE(ORE), LookupAC(LookupAC) { + function_ref<AssumptionCache *(Function &)> LookupAC, + function_ref<TargetTransformInfo &(Function &)> GetTTI) + : OrigFunc(F), ORE(ORE), LookupAC(LookupAC), GetTTI(GetTTI) { ClonedOMRI = std::make_unique<FunctionOutliningMultiRegionInfo>(); // Clone the function, so that we can hack away on it. @@ -1000,9 +1028,9 @@ PartialInlinerImpl::FunctionCloner::FunctionCloner( for (FunctionOutliningMultiRegionInfo::OutlineRegionInfo RegionInfo : OI->ORI) { SmallVector<BasicBlock *, 8> Region; - for (BasicBlock *BB : RegionInfo.Region) { + for (BasicBlock *BB : RegionInfo.Region) Region.push_back(cast<BasicBlock>(VMap[BB])); - } + BasicBlock *NewEntryBlock = cast<BasicBlock>(VMap[RegionInfo.EntryBlock]); BasicBlock *NewExitBlock = cast<BasicBlock>(VMap[RegionInfo.ExitBlock]); BasicBlock *NewReturnBlock = nullptr; @@ -1017,8 +1045,8 @@ PartialInlinerImpl::FunctionCloner::FunctionCloner( F->replaceAllUsesWith(ClonedFunc); } -void PartialInlinerImpl::FunctionCloner::NormalizeReturnBlock() { - auto getFirstPHI = [](BasicBlock *BB) { +void PartialInlinerImpl::FunctionCloner::normalizeReturnBlock() const { + auto GetFirstPHI = [](BasicBlock *BB) { BasicBlock::iterator I = BB->begin(); PHINode *FirstPhi = nullptr; while (I != BB->end()) { @@ -1044,7 +1072,7 @@ void PartialInlinerImpl::FunctionCloner::NormalizeReturnBlock() { // of which will go outside. BasicBlock *PreReturn = ClonedOI->ReturnBlock; // only split block when necessary: - PHINode *FirstPhi = getFirstPHI(PreReturn); + PHINode *FirstPhi = GetFirstPHI(PreReturn); unsigned NumPredsFromEntries = ClonedOI->ReturnBlockPreds.size(); if (!FirstPhi || FirstPhi->getNumIncomingValues() <= NumPredsFromEntries + 1) @@ -1092,17 +1120,16 @@ void PartialInlinerImpl::FunctionCloner::NormalizeReturnBlock() { for (auto *DP : DeadPhis) DP->eraseFromParent(); - for (auto E : ClonedOI->ReturnBlockPreds) { + for (auto *E : ClonedOI->ReturnBlockPreds) E->getTerminator()->replaceUsesOfWith(PreReturn, ClonedOI->ReturnBlock); - } } bool PartialInlinerImpl::FunctionCloner::doMultiRegionFunctionOutlining() { - auto ComputeRegionCost = [](SmallVectorImpl<BasicBlock *> &Region) { + auto ComputeRegionCost = [&](SmallVectorImpl<BasicBlock *> &Region) { int Cost = 0; for (BasicBlock* BB : Region) - Cost += computeBBInlineCost(BB); + Cost += computeBBInlineCost(BB, &GetTTI(*BB->getParent())); return Cost; }; @@ -1135,24 +1162,21 @@ bool PartialInlinerImpl::FunctionCloner::doMultiRegionFunctionOutlining() { CE.findInputsOutputs(Inputs, Outputs, Sinks); -#ifndef NDEBUG - if (TracePartialInlining) { + LLVM_DEBUG({ dbgs() << "inputs: " << Inputs.size() << "\n"; dbgs() << "outputs: " << Outputs.size() << "\n"; for (Value *value : Inputs) dbgs() << "value used in func: " << *value << "\n"; for (Value *output : Outputs) dbgs() << "instr used in func: " << *output << "\n"; - } -#endif + }); + // Do not extract regions that have live exit variables. if (Outputs.size() > 0 && !ForceLiveExit) continue; - Function *OutlinedFunc = CE.extractCodeRegion(CEAC); - - if (OutlinedFunc) { - CallBase *OCS = PartialInlinerImpl::getOneCallSiteTo(OutlinedFunc); + if (Function *OutlinedFunc = CE.extractCodeRegion(CEAC)) { + CallBase *OCS = PartialInlinerImpl::getOneCallSiteTo(*OutlinedFunc); BasicBlock *OutliningCallBB = OCS->getParent(); assert(OutliningCallBB->getParent() == ClonedFunc); OutlinedFunctions.push_back(std::make_pair(OutlinedFunc,OutliningCallBB)); @@ -1181,8 +1205,7 @@ PartialInlinerImpl::FunctionCloner::doSingleRegionFunctionOutlining() { // (i.e. not to be extracted to the out of line function) auto ToBeInlined = [&, this](BasicBlock *BB) { return BB == ClonedOI->ReturnBlock || - (std::find(ClonedOI->Entries.begin(), ClonedOI->Entries.end(), BB) != - ClonedOI->Entries.end()); + llvm::is_contained(ClonedOI->Entries, BB); }; assert(ClonedOI && "Expecting OutlineInfo for single region outline"); @@ -1197,9 +1220,10 @@ PartialInlinerImpl::FunctionCloner::doSingleRegionFunctionOutlining() { // Gather up the blocks that we're going to extract. std::vector<BasicBlock *> ToExtract; + auto *ClonedFuncTTI = &GetTTI(*ClonedFunc); ToExtract.push_back(ClonedOI->NonReturnBlock); - OutlinedRegionCost += - PartialInlinerImpl::computeBBInlineCost(ClonedOI->NonReturnBlock); + OutlinedRegionCost += PartialInlinerImpl::computeBBInlineCost( + ClonedOI->NonReturnBlock, ClonedFuncTTI); for (BasicBlock &BB : *ClonedFunc) if (!ToBeInlined(&BB) && &BB != ClonedOI->NonReturnBlock) { ToExtract.push_back(&BB); @@ -1207,7 +1231,7 @@ PartialInlinerImpl::FunctionCloner::doSingleRegionFunctionOutlining() { // into the outlined function which may make the outlining // overhead (the difference of the outlined function cost // and OutliningRegionCost) look larger. - OutlinedRegionCost += computeBBInlineCost(&BB); + OutlinedRegionCost += computeBBInlineCost(&BB, ClonedFuncTTI); } // Extract the body of the if. @@ -1220,8 +1244,7 @@ PartialInlinerImpl::FunctionCloner::doSingleRegionFunctionOutlining() { if (OutlinedFunc) { BasicBlock *OutliningCallBB = - PartialInlinerImpl::getOneCallSiteTo(OutlinedFunc) - ->getParent(); + PartialInlinerImpl::getOneCallSiteTo(*OutlinedFunc)->getParent(); assert(OutliningCallBB->getParent() == ClonedFunc); OutlinedFunctions.push_back(std::make_pair(OutlinedFunc, OutliningCallBB)); } else @@ -1250,52 +1273,48 @@ PartialInlinerImpl::FunctionCloner::~FunctionCloner() { } } -std::pair<bool, Function *> PartialInlinerImpl::unswitchFunction(Function *F) { - - if (F->hasAddressTaken()) +std::pair<bool, Function *> PartialInlinerImpl::unswitchFunction(Function &F) { + if (F.hasAddressTaken()) return {false, nullptr}; // Let inliner handle it - if (F->hasFnAttribute(Attribute::AlwaysInline)) + if (F.hasFnAttribute(Attribute::AlwaysInline)) return {false, nullptr}; - if (F->hasFnAttribute(Attribute::NoInline)) + if (F.hasFnAttribute(Attribute::NoInline)) return {false, nullptr}; - if (PSI.isFunctionEntryCold(F)) + if (PSI.isFunctionEntryCold(&F)) return {false, nullptr}; - if (F->users().empty()) + if (F.users().empty()) return {false, nullptr}; - OptimizationRemarkEmitter ORE(F); + OptimizationRemarkEmitter ORE(&F); // Only try to outline cold regions if we have a profile summary, which // implies we have profiling information. - if (PSI.hasProfileSummary() && F->hasProfileData() && + if (PSI.hasProfileSummary() && F.hasProfileData() && !DisableMultiRegionPartialInline) { std::unique_ptr<FunctionOutliningMultiRegionInfo> OMRI = computeOutliningColdRegionsInfo(F, ORE); if (OMRI) { - FunctionCloner Cloner(F, OMRI.get(), ORE, LookupAssumptionCache); + FunctionCloner Cloner(&F, OMRI.get(), ORE, LookupAssumptionCache, GetTTI); -#ifndef NDEBUG - if (TracePartialInlining) { + LLVM_DEBUG({ dbgs() << "HotCountThreshold = " << PSI.getHotCountThreshold() << "\n"; dbgs() << "ColdCountThreshold = " << PSI.getColdCountThreshold() << "\n"; - } -#endif + }); + bool DidOutline = Cloner.doMultiRegionFunctionOutlining(); if (DidOutline) { -#ifndef NDEBUG - if (TracePartialInlining) { + LLVM_DEBUG({ dbgs() << ">>>>>> Outlined (Cloned) Function >>>>>>\n"; Cloner.ClonedFunc->print(dbgs()); dbgs() << "<<<<<< Outlined (Cloned) Function <<<<<<\n"; - } -#endif + }); if (tryPartialInline(Cloner)) return {true, nullptr}; @@ -1310,17 +1329,15 @@ std::pair<bool, Function *> PartialInlinerImpl::unswitchFunction(Function *F) { if (!OI) return {false, nullptr}; - FunctionCloner Cloner(F, OI.get(), ORE, LookupAssumptionCache); - Cloner.NormalizeReturnBlock(); + FunctionCloner Cloner(&F, OI.get(), ORE, LookupAssumptionCache, GetTTI); + Cloner.normalizeReturnBlock(); Function *OutlinedFunction = Cloner.doSingleRegionFunctionOutlining(); if (!OutlinedFunction) return {false, nullptr}; - bool AnyInline = tryPartialInline(Cloner); - - if (AnyInline) + if (tryPartialInline(Cloner)) return {true, OutlinedFunction}; return {false, nullptr}; @@ -1338,9 +1355,9 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) { // Only calculate RelativeToEntryFreq when we are doing single region // outlining. BranchProbability RelativeToEntryFreq; - if (Cloner.ClonedOI) { + if (Cloner.ClonedOI) RelativeToEntryFreq = getOutliningCallBBRelativeFreq(Cloner); - } else + else // RelativeToEntryFreq doesn't make sense when we have more than one // outlined call because each call will have a different relative frequency // to the entry block. We can consider using the average, but the @@ -1358,7 +1375,7 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) { OptimizationRemarkEmitter OrigFuncORE(Cloner.OrigFunc); DebugLoc DLoc; BasicBlock *Block; - std::tie(DLoc, Block) = getOneDebugLoc(Cloner.ClonedFunc); + std::tie(DLoc, Block) = getOneDebugLoc(*Cloner.ClonedFunc); OrigFuncORE.emit([&]() { return OptimizationRemarkAnalysis(DEBUG_TYPE, "OutlineRegionTooSmall", DLoc, Block) @@ -1389,7 +1406,7 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) { for (User *User : Users) { CallBase *CB = getSupportedCallBase(User); - if (IsLimitReached()) + if (isLimitReached()) continue; OptimizationRemarkEmitter CallerORE(CB->getCaller()); @@ -1426,7 +1443,6 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) { NumPartialInlined++; else NumColdOutlinePartialInlined++; - } if (AnyInline) { @@ -1439,7 +1455,6 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) { return OptimizationRemark(DEBUG_TYPE, "PartiallyInlined", Cloner.OrigFunc) << "Partially inlined into at least one caller"; }); - } return AnyInline; @@ -1473,7 +1488,7 @@ bool PartialInlinerImpl::run(Module &M) { if (Recursive) continue; - std::pair<bool, Function * > Result = unswitchFunction(CurrFunc); + std::pair<bool, Function *> Result = unswitchFunction(*CurrFunc); if (Result.second) Worklist.push_back(Result.second); Changed |= Result.first; diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp index d73d42c52074..068328391dff 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -51,16 +51,16 @@ using namespace llvm; -static cl::opt<bool> - RunPartialInlining("enable-partial-inlining", cl::init(false), cl::Hidden, - cl::ZeroOrMore, cl::desc("Run Partial inlinining pass")); +cl::opt<bool> RunPartialInlining("enable-partial-inlining", cl::init(false), + cl::Hidden, cl::ZeroOrMore, + cl::desc("Run Partial inlinining pass")); static cl::opt<bool> UseGVNAfterVectorization("use-gvn-after-vectorization", cl::init(false), cl::Hidden, cl::desc("Run GVN instead of Early CSE after vectorization passes")); -static cl::opt<bool> ExtraVectorizerPasses( +cl::opt<bool> ExtraVectorizerPasses( "extra-vectorizer-passes", cl::init(false), cl::Hidden, cl::desc("Run cleanup optimization passes after vectorization.")); @@ -68,29 +68,33 @@ static cl::opt<bool> RunLoopRerolling("reroll-loops", cl::Hidden, cl::desc("Run the loop rerolling pass")); -static cl::opt<bool> RunNewGVN("enable-newgvn", cl::init(false), cl::Hidden, - cl::desc("Run the NewGVN pass")); +cl::opt<bool> RunNewGVN("enable-newgvn", cl::init(false), cl::Hidden, + cl::desc("Run the NewGVN pass")); // Experimental option to use CFL-AA enum class CFLAAType { None, Steensgaard, Andersen, Both }; -static cl::opt<CFLAAType> - UseCFLAA("use-cfl-aa", cl::init(CFLAAType::None), cl::Hidden, +static cl::opt<::CFLAAType> + UseCFLAA("use-cfl-aa", cl::init(::CFLAAType::None), cl::Hidden, cl::desc("Enable the new, experimental CFL alias analysis"), - cl::values(clEnumValN(CFLAAType::None, "none", "Disable CFL-AA"), - clEnumValN(CFLAAType::Steensgaard, "steens", + cl::values(clEnumValN(::CFLAAType::None, "none", "Disable CFL-AA"), + clEnumValN(::CFLAAType::Steensgaard, "steens", "Enable unification-based CFL-AA"), - clEnumValN(CFLAAType::Andersen, "anders", + clEnumValN(::CFLAAType::Andersen, "anders", "Enable inclusion-based CFL-AA"), - clEnumValN(CFLAAType::Both, "both", + clEnumValN(::CFLAAType::Both, "both", "Enable both variants of CFL-AA"))); static cl::opt<bool> EnableLoopInterchange( "enable-loopinterchange", cl::init(false), cl::Hidden, cl::desc("Enable the new, experimental LoopInterchange Pass")); -static cl::opt<bool> EnableUnrollAndJam("enable-unroll-and-jam", - cl::init(false), cl::Hidden, - cl::desc("Enable Unroll And Jam Pass")); +cl::opt<bool> EnableUnrollAndJam("enable-unroll-and-jam", cl::init(false), + cl::Hidden, + cl::desc("Enable Unroll And Jam Pass")); + +cl::opt<bool> EnableLoopFlatten("enable-loop-flatten", cl::init(false), + cl::Hidden, + cl::desc("Enable the LoopFlatten Pass")); static cl::opt<bool> EnablePrepareForThinLTO("prepare-for-thinlto", cl::init(false), cl::Hidden, @@ -103,22 +107,25 @@ static cl::opt<bool> cl::opt<bool> EnableHotColdSplit("hot-cold-split", cl::init(false), cl::ZeroOrMore, cl::desc("Enable hot-cold splitting pass")); +cl::opt<bool> EnableIROutliner("ir-outliner", cl::init(false), cl::Hidden, + cl::desc("Enable ir outliner pass")); + static cl::opt<bool> UseLoopVersioningLICM( "enable-loop-versioning-licm", cl::init(false), cl::Hidden, cl::desc("Enable the experimental Loop Versioning LICM pass")); -static cl::opt<bool> +cl::opt<bool> DisablePreInliner("disable-preinline", cl::init(false), cl::Hidden, cl::desc("Disable pre-instrumentation inliner")); -static cl::opt<int> PreInlineThreshold( +cl::opt<int> PreInlineThreshold( "preinline-threshold", cl::Hidden, cl::init(75), cl::ZeroOrMore, cl::desc("Control the amount of inlining in pre-instrumentation inliner " "(default = 75)")); -static cl::opt<bool> EnableGVNHoist( - "enable-gvn-hoist", cl::init(false), cl::ZeroOrMore, - cl::desc("Enable the GVN hoisting pass (default = off)")); +cl::opt<bool> + EnableGVNHoist("enable-gvn-hoist", cl::init(false), cl::ZeroOrMore, + cl::desc("Enable the GVN hoisting pass (default = off)")); static cl::opt<bool> DisableLibCallsShrinkWrap("disable-libcalls-shrinkwrap", cl::init(false), @@ -130,13 +137,13 @@ static cl::opt<bool> EnableSimpleLoopUnswitch( cl::desc("Enable the simple loop unswitch pass. Also enables independent " "cleanup passes integrated into the loop pass manager pipeline.")); -static cl::opt<bool> EnableGVNSink( - "enable-gvn-sink", cl::init(false), cl::ZeroOrMore, - cl::desc("Enable the GVN sinking pass (default = off)")); +cl::opt<bool> + EnableGVNSink("enable-gvn-sink", cl::init(false), cl::ZeroOrMore, + cl::desc("Enable the GVN sinking pass (default = off)")); // This option is used in simplifying testing SampleFDO optimizations for // profile loading. -static cl::opt<bool> +cl::opt<bool> EnableCHR("enable-chr", cl::init(true), cl::Hidden, cl::desc("Enable control height reduction optimization (CHR)")); @@ -149,9 +156,14 @@ cl::opt<bool> EnableOrderFileInstrumentation( "enable-order-file-instrumentation", cl::init(false), cl::Hidden, cl::desc("Enable order file instrumentation (default = off)")); -static cl::opt<bool> - EnableMatrix("enable-matrix", cl::init(false), cl::Hidden, - cl::desc("Enable lowering of the matrix intrinsics")); +cl::opt<bool> EnableMatrix( + "enable-matrix", cl::init(false), cl::Hidden, + cl::desc("Enable lowering of the matrix intrinsics")); + +cl::opt<bool> EnableConstraintElimination( + "enable-constraint-elimination", cl::init(false), cl::Hidden, + cl::desc( + "Enable pass to eliminate conditions based on linear constraints.")); cl::opt<AttributorRunOption> AttributorRun( "attributor-enable", cl::Hidden, cl::init(AttributorRunOption::NONE), @@ -264,13 +276,13 @@ void PassManagerBuilder::addExtensionsToPM(ExtensionPointTy ETy, void PassManagerBuilder::addInitialAliasAnalysisPasses( legacy::PassManagerBase &PM) const { switch (UseCFLAA) { - case CFLAAType::Steensgaard: + case ::CFLAAType::Steensgaard: PM.add(createCFLSteensAAWrapperPass()); break; - case CFLAAType::Andersen: + case ::CFLAAType::Andersen: PM.add(createCFLAndersAAWrapperPass()); break; - case CFLAAType::Both: + case ::CFLAAType::Both: PM.add(createCFLSteensAAWrapperPass()); PM.add(createCFLAndersAAWrapperPass()); break; @@ -294,6 +306,13 @@ void PassManagerBuilder::populateFunctionPassManager( if (LibraryInfo) FPM.add(new TargetLibraryInfoWrapperPass(*LibraryInfo)); + // The backends do not handle matrix intrinsics currently. + // Make sure they are also lowered in O0. + // FIXME: A lightweight version of the pass should run in the backend + // pipeline on demand. + if (EnableMatrix && OptLevel == 0) + FPM.add(createLowerMatrixIntrinsicsMinimalPass()); + if (OptLevel == 0) return; addInitialAliasAnalysisPasses(FPM); @@ -314,19 +333,21 @@ void PassManagerBuilder::addPGOInstrPasses(legacy::PassManagerBase &MPM, return; // Perform the preinline and cleanup passes for O1 and above. - // And avoid doing them if optimizing for size. // We will not do this inline for context sensitive PGO (when IsCS is true). - if (OptLevel > 0 && SizeLevel == 0 && !DisablePreInliner && - PGOSampleUse.empty() && !IsCS) { + if (OptLevel > 0 && !DisablePreInliner && PGOSampleUse.empty() && !IsCS) { // Create preinline pass. We construct an InlineParams object and specify // the threshold here to avoid the command line options of the regular // inliner to influence pre-inlining. The only fields of InlineParams we // care about are DefaultThreshold and HintThreshold. InlineParams IP; IP.DefaultThreshold = PreInlineThreshold; - // FIXME: The hint threshold has the same value used by the regular inliner. - // This should probably be lowered after performance testing. - IP.HintThreshold = 325; + // FIXME: The hint threshold has the same value used by the regular inliner + // when not optimzing for size. This should probably be lowered after + // performance testing. + // Use PreInlineThreshold for both -Os and -Oz. Not running preinliner makes + // the instrumented binary unusably large. Even if PreInlineThreshold is not + // correct thresold for -Oz, it is better than not running preinliner. + IP.HintThreshold = SizeLevel > 0 ? PreInlineThreshold : 325; MPM.add(createFunctionInliningPass(IP)); MPM.add(createSROAPass()); @@ -374,6 +395,9 @@ void PassManagerBuilder::addFunctionSimplificationPasses( } } + if (EnableConstraintElimination) + MPM.add(createConstraintEliminationPass()); + if (OptLevel > 1) { // Speculative execution if the target has divergent branches; otherwise nop. MPM.add(createSpeculativeExecutionIfHasBranchDivergencePass()); @@ -409,7 +433,7 @@ void PassManagerBuilder::addFunctionSimplificationPasses( MPM.add(createLoopSimplifyCFGPass()); } // Rotate Loop - disable header duplication at -Oz - MPM.add(createLoopRotatePass(SizeLevel == 2 ? 0 : -1)); + MPM.add(createLoopRotatePass(SizeLevel == 2 ? 0 : -1, PrepareForLTO)); // TODO: Investigate promotion cap for O1. MPM.add(createLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap)); if (EnableSimpleLoopUnswitch) @@ -422,20 +446,27 @@ void PassManagerBuilder::addFunctionSimplificationPasses( MPM.add(createCFGSimplificationPass()); MPM.add(createInstructionCombiningPass()); // We resume loop passes creating a second loop pipeline here. - MPM.add(createIndVarSimplifyPass()); // Canonicalize indvars + if (EnableLoopFlatten) { + MPM.add(createLoopFlattenPass()); // Flatten loops + MPM.add(createLoopSimplifyCFGPass()); + } MPM.add(createLoopIdiomPass()); // Recognize idioms like memset. + MPM.add(createIndVarSimplifyPass()); // Canonicalize indvars addExtensionsToPM(EP_LateLoopOptimizations, MPM); MPM.add(createLoopDeletionPass()); // Delete dead loops if (EnableLoopInterchange) MPM.add(createLoopInterchangePass()); // Interchange loops - // Unroll small loops + // Unroll small loops and perform peeling. MPM.add(createSimpleLoopUnrollPass(OptLevel, DisableUnrollLoops, ForgetAllSCEVInLoopUnroll)); addExtensionsToPM(EP_LoopOptimizerEnd, MPM); // This ends the loop pass pipelines. + // Break up allocas that may now be splittable after loop unrolling. + MPM.add(createSROAPass()); + if (OptLevel > 1) { MPM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds MPM.add(NewGVN ? createNewGVNPass() @@ -444,6 +475,9 @@ void PassManagerBuilder::addFunctionSimplificationPasses( MPM.add(createMemCpyOptPass()); // Remove memcpy / form memset MPM.add(createSCCPPass()); // Constant prop with SCCP + if (EnableConstraintElimination) + MPM.add(createConstraintEliminationPass()); + // Delete dead bit computations (instcombine runs after to fold away the dead // computations, and then ADCE will run later to exploit any new DCE // opportunities that creates). @@ -456,6 +490,11 @@ void PassManagerBuilder::addFunctionSimplificationPasses( if (OptLevel > 1) { MPM.add(createJumpThreadingPass()); // Thread jumps MPM.add(createCorrelatedValuePropagationPass()); + } + MPM.add(createAggressiveDCEPass()); // Delete dead instructions + + // TODO: Investigate if this is too expensive at O1. + if (OptLevel > 1) { MPM.add(createDeadStoreEliminationPass()); // Delete dead stores MPM.add(createLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap)); } @@ -465,8 +504,6 @@ void PassManagerBuilder::addFunctionSimplificationPasses( if (RerollLoops) MPM.add(createLoopRerollPass()); - // TODO: Investigate if this is too expensive at O1. - MPM.add(createAggressiveDCEPass()); // Delete dead instructions MPM.add(createCFGSimplificationPass()); // Merge & remove BBs // Clean up after everything. MPM.add(createInstructionCombiningPass()); @@ -483,6 +520,8 @@ void PassManagerBuilder::populateModulePassManager( // is handled separately, so just check this is not the ThinLTO post-link. bool DefaultOrPreLinkPipeline = !PerformThinLTO; + MPM.add(createAnnotation2MetadataLegacyPass()); + if (!PGOSampleUse.empty()) { MPM.add(createPruneEHPass()); // In ThinLTO mode, when flattened profile is used, all the available @@ -533,6 +572,8 @@ void PassManagerBuilder::populateModulePassManager( // new unnamed globals. MPM.add(createNameAnonGlobalPass()); } + + MPM.add(createAnnotationRemarksLegacyPass()); return; } @@ -736,7 +777,7 @@ void PassManagerBuilder::populateModulePassManager( // Re-rotate loops in all our loop nests. These may have fallout out of // rotated form due to GVN or other transformations, and the vectorizer relies // on the rotated form. Disable header duplication at -Oz. - MPM.add(createLoopRotatePass(SizeLevel == 2 ? 0 : -1)); + MPM.add(createLoopRotatePass(SizeLevel == 2 ? 0 : -1, PrepareForLTO)); // Distribute loops to allow partial vectorization. I.e. isolate dependences // into separate loop that would otherwise inhibit vectorization. This is @@ -777,7 +818,14 @@ void PassManagerBuilder::populateModulePassManager( // convert to more optimized IR using more aggressive simplify CFG options. // The extra sinking transform can create larger basic blocks, so do this // before SLP vectorization. - MPM.add(createCFGSimplificationPass(1, true, true, false, true)); + // FIXME: study whether hoisting and/or sinking of common instructions should + // be delayed until after SLP vectorizer. + MPM.add(createCFGSimplificationPass(SimplifyCFGOptions() + .forwardSwitchCondToPhi(true) + .convertSwitchToLookupTable(true) + .needCanonicalLoops(false) + .hoistCommonInsts(true) + .sinkCommonInsts(true))); if (SLPVectorize) { MPM.add(createSLPVectorizerPass()); // Vectorize parallel scalar chains. @@ -835,6 +883,9 @@ void PassManagerBuilder::populateModulePassManager( if (EnableHotColdSplit && !(PrepareForLTO || PrepareForThinLTO)) MPM.add(createHotColdSplittingPass()); + if (EnableIROutliner) + MPM.add(createIROutlinerPass()); + if (MergeFunctions) MPM.add(createMergeFunctionsPass()); @@ -866,6 +917,8 @@ void PassManagerBuilder::populateModulePassManager( // Rename anon globals to be able to handle them in the summary MPM.add(createNameAnonGlobalPass()); } + + MPM.add(createAnnotationRemarksLegacyPass()); } void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { @@ -984,7 +1037,7 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { // The IPO passes may leave cruft around. Clean up after them. PM.add(createInstructionCombiningPass()); addExtensionsToPM(EP_Peephole, PM); - PM.add(createJumpThreadingPass()); + PM.add(createJumpThreadingPass(/*FreezeSelectCond*/ true)); // Break up allocas PM.add(createSROAPass()); @@ -1000,23 +1053,29 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { PM.add(createGlobalsAAWrapperPass()); // IP alias analysis. PM.add(createLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap)); - PM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds. PM.add(NewGVN ? createNewGVNPass() : createGVNPass(DisableGVNLoadPRE)); // Remove redundancies. PM.add(createMemCpyOptPass()); // Remove dead memcpys. // Nuke dead stores. PM.add(createDeadStoreEliminationPass()); + PM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds. // More loops are countable; try to optimize them. + if (EnableLoopFlatten) + PM.add(createLoopFlattenPass()); PM.add(createIndVarSimplifyPass()); PM.add(createLoopDeletionPass()); if (EnableLoopInterchange) PM.add(createLoopInterchangePass()); - // Unroll small loops + if (EnableConstraintElimination) + PM.add(createConstraintEliminationPass()); + + // Unroll small loops and perform peeling. PM.add(createSimpleLoopUnrollPass(OptLevel, DisableUnrollLoops, ForgetAllSCEVInLoopUnroll)); + PM.add(createLoopDistributePass()); PM.add(createLoopVectorizePass(true, !LoopVectorize)); // The vectorizer may have significantly shortened a loop body; unroll again. PM.add(createLoopUnrollPass(OptLevel, DisableUnrollLoops, @@ -1028,7 +1087,8 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { // we may have exposed more scalar opportunities. Run parts of the scalar // optimizer again at this point. PM.add(createInstructionCombiningPass()); // Initial cleanup - PM.add(createCFGSimplificationPass()); // if-convert + PM.add(createCFGSimplificationPass(SimplifyCFGOptions() // if-convert + .hoistCommonInsts(true))); PM.add(createSCCPPass()); // Propagate exposed constants PM.add(createInstructionCombiningPass()); // Clean up again PM.add(createBitTrackingDCEPass()); @@ -1047,7 +1107,7 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { PM.add(createInstructionCombiningPass()); addExtensionsToPM(EP_Peephole, PM); - PM.add(createJumpThreadingPass()); + PM.add(createJumpThreadingPass(/*FreezeSelectCond*/ true)); } void PassManagerBuilder::addLateLTOOptimizationPasses( @@ -1058,7 +1118,8 @@ void PassManagerBuilder::addLateLTOOptimizationPasses( PM.add(createHotColdSplittingPass()); // Delete basic blocks, which optimization passes may have killed. - PM.add(createCFGSimplificationPass()); + PM.add( + createCFGSimplificationPass(SimplifyCFGOptions().hoistCommonInsts(true))); // Drop bodies of available externally objects to improve GlobalDCE. PM.add(createEliminateAvailableExternallyPass()); @@ -1140,6 +1201,8 @@ void PassManagerBuilder::populateLTOPassManager(legacy::PassManagerBase &PM) { addExtensionsToPM(EP_FullLinkTimeOptimizationLast, PM); + PM.add(createAnnotationRemarksLegacyPass()); + if (VerifyOutput) PM.add(createVerifierPass()); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/PruneEH.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/PruneEH.cpp index a16dc664db64..3f3b18771cd5 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/PruneEH.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/PruneEH.cpp @@ -13,6 +13,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/CallGraph.h" @@ -27,8 +28,10 @@ #include "llvm/InitializePasses.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/Utils/CallGraphUpdater.h" #include "llvm/Transforms/Utils/Local.h" #include <algorithm> + using namespace llvm; #define DEBUG_TYPE "prune-eh" @@ -45,11 +48,10 @@ namespace { // runOnSCC - Analyze the SCC, performing the transformation if possible. bool runOnSCC(CallGraphSCC &SCC) override; - }; } -static bool SimplifyFunction(Function *F, CallGraph &CG); -static void DeleteBasicBlock(BasicBlock *BB, CallGraph &CG); +static bool SimplifyFunction(Function *F, CallGraphUpdater &CGU); +static void DeleteBasicBlock(BasicBlock *BB, CallGraphUpdater &CGU); char PruneEH::ID = 0; INITIALIZE_PASS_BEGIN(PruneEH, "prune-eh", @@ -60,20 +62,17 @@ INITIALIZE_PASS_END(PruneEH, "prune-eh", Pass *llvm::createPruneEHPass() { return new PruneEH(); } -static bool runImpl(CallGraphSCC &SCC, CallGraph &CG) { - SmallPtrSet<CallGraphNode *, 8> SCCNodes; +static bool runImpl(CallGraphUpdater &CGU, SetVector<Function *> &Functions) { +#ifndef NDEBUG + for (auto *F : Functions) + assert(F && "null Function"); +#endif bool MadeChange = false; - // Fill SCCNodes with the elements of the SCC. Used for quickly - // looking up whether a given CallGraphNode is in this SCC. - for (CallGraphNode *I : SCC) - SCCNodes.insert(I); - // First pass, scan all of the functions in the SCC, simplifying them // according to what we know. - for (CallGraphNode *I : SCC) - if (Function *F = I->getFunction()) - MadeChange |= SimplifyFunction(F, CG); + for (Function *F : Functions) + MadeChange |= SimplifyFunction(F, CGU); // Next, check to see if any callees might throw or if there are any external // functions in this SCC: if so, we cannot prune any functions in this SCC. @@ -83,13 +82,8 @@ static bool runImpl(CallGraphSCC &SCC, CallGraph &CG) { // obviously the SCC might throw. // bool SCCMightUnwind = false, SCCMightReturn = false; - for (CallGraphSCC::iterator I = SCC.begin(), E = SCC.end(); - (!SCCMightUnwind || !SCCMightReturn) && I != E; ++I) { - Function *F = (*I)->getFunction(); - if (!F) { - SCCMightUnwind = true; - SCCMightReturn = true; - } else if (!F->hasExactDefinition()) { + for (Function *F : Functions) { + if (!F->hasExactDefinition()) { SCCMightUnwind |= !F->doesNotThrow(); SCCMightReturn |= !F->doesNotReturn(); } else { @@ -125,10 +119,9 @@ static bool runImpl(CallGraphSCC &SCC, CallGraph &CG) { bool InstMightUnwind = true; if (const auto *CI = dyn_cast<CallInst>(&I)) { if (Function *Callee = CI->getCalledFunction()) { - CallGraphNode *CalleeNode = CG[Callee]; // If the callee is outside our current SCC then we may throw // because it might. If it is inside, do nothing. - if (SCCNodes.count(CalleeNode) > 0) + if (Functions.contains(Callee)) InstMightUnwind = false; } } @@ -140,18 +133,15 @@ static bool runImpl(CallGraphSCC &SCC, CallGraph &CG) { if (IA->hasSideEffects()) SCCMightReturn = true; } - + } if (SCCMightUnwind && SCCMightReturn) break; - } } } // If the SCC doesn't unwind or doesn't throw, note this fact. if (!SCCMightUnwind || !SCCMightReturn) - for (CallGraphNode *I : SCC) { - Function *F = I->getFunction(); - + for (Function *F : Functions) { if (!SCCMightUnwind && !F->hasFnAttribute(Attribute::NoUnwind)) { F->addFnAttr(Attribute::NoUnwind); MadeChange = true; @@ -163,30 +153,35 @@ static bool runImpl(CallGraphSCC &SCC, CallGraph &CG) { } } - for (CallGraphNode *I : SCC) { + for (Function *F : Functions) { // Convert any invoke instructions to non-throwing functions in this node // into call instructions with a branch. This makes the exception blocks // dead. - if (Function *F = I->getFunction()) - MadeChange |= SimplifyFunction(F, CG); + MadeChange |= SimplifyFunction(F, CGU); } return MadeChange; } - bool PruneEH::runOnSCC(CallGraphSCC &SCC) { if (skipSCC(SCC)) return false; + SetVector<Function *> Functions; + for (auto &N : SCC) { + if (auto *F = N->getFunction()) + Functions.insert(F); + } CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); - return runImpl(SCC, CG); + CallGraphUpdater CGU; + CGU.initialize(CG, SCC); + return runImpl(CGU, Functions); } // SimplifyFunction - Given information about callees, simplify the specified // function if we have invokes to non-unwinding functions or code after calls to // no-return functions. -static bool SimplifyFunction(Function *F, CallGraph &CG) { +static bool SimplifyFunction(Function *F, CallGraphUpdater &CGU) { bool MadeChange = false; for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) { if (InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator())) @@ -196,7 +191,7 @@ static bool SimplifyFunction(Function *F, CallGraph &CG) { // If the unwind block is now dead, nuke it. if (pred_empty(UnwindBlock)) - DeleteBasicBlock(UnwindBlock, CG); // Delete the new BB. + DeleteBasicBlock(UnwindBlock, CGU); // Delete the new BB. ++NumRemoved; MadeChange = true; @@ -216,7 +211,7 @@ static bool SimplifyFunction(Function *F, CallGraph &CG) { BB->getInstList().pop_back(); new UnreachableInst(BB->getContext(), &*BB); - DeleteBasicBlock(New, CG); // Delete the new BB. + DeleteBasicBlock(New, CGU); // Delete the new BB. MadeChange = true; ++NumUnreach; break; @@ -229,12 +224,11 @@ static bool SimplifyFunction(Function *F, CallGraph &CG) { /// DeleteBasicBlock - remove the specified basic block from the program, /// updating the callgraph to reflect any now-obsolete edges due to calls that /// exist in the BB. -static void DeleteBasicBlock(BasicBlock *BB, CallGraph &CG) { +static void DeleteBasicBlock(BasicBlock *BB, CallGraphUpdater &CGU) { assert(pred_empty(BB) && "BB is not dead!"); Instruction *TokenInst = nullptr; - CallGraphNode *CGN = CG[BB->getParent()]; for (BasicBlock::iterator I = BB->end(), E = BB->begin(); I != E; ) { --I; @@ -246,9 +240,9 @@ static void DeleteBasicBlock(BasicBlock *BB, CallGraph &CG) { if (auto *Call = dyn_cast<CallBase>(&*I)) { const Function *Callee = Call->getCalledFunction(); if (!Callee || !Intrinsic::isLeaf(Callee->getIntrinsicID())) - CGN->removeCallEdgeFor(*Call); + CGU.removeCallSite(*Call); else if (!Callee->isIntrinsic()) - CGN->removeCallEdgeFor(*Call); + CGU.removeCallSite(*Call); } if (!I->use_empty()) diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleContextTracker.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleContextTracker.cpp new file mode 100644 index 000000000000..37fc27e91100 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleContextTracker.cpp @@ -0,0 +1,521 @@ +//===- SampleContextTracker.cpp - Context-sensitive Profile Tracker -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the SampleContextTracker used by CSSPGO. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO/SampleContextTracker.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/Instructions.h" +#include "llvm/ProfileData/SampleProf.h" +#include <map> +#include <queue> +#include <vector> + +using namespace llvm; +using namespace sampleprof; + +#define DEBUG_TYPE "sample-context-tracker" + +namespace llvm { + +ContextTrieNode *ContextTrieNode::getChildContext(const LineLocation &CallSite, + StringRef CalleeName) { + if (CalleeName.empty()) + return getChildContext(CallSite); + + uint32_t Hash = nodeHash(CalleeName, CallSite); + auto It = AllChildContext.find(Hash); + if (It != AllChildContext.end()) + return &It->second; + return nullptr; +} + +ContextTrieNode * +ContextTrieNode::getChildContext(const LineLocation &CallSite) { + // CSFDO-TODO: This could be slow, change AllChildContext so we can + // do point look up for child node by call site alone. + // CSFDO-TODO: Return the child with max count for indirect call + ContextTrieNode *ChildNodeRet = nullptr; + for (auto &It : AllChildContext) { + ContextTrieNode &ChildNode = It.second; + if (ChildNode.CallSiteLoc == CallSite) { + if (ChildNodeRet) + return nullptr; + else + ChildNodeRet = &ChildNode; + } + } + + return ChildNodeRet; +} + +ContextTrieNode &ContextTrieNode::moveToChildContext( + const LineLocation &CallSite, ContextTrieNode &&NodeToMove, + StringRef ContextStrToRemove, bool DeleteNode) { + uint32_t Hash = nodeHash(NodeToMove.getFuncName(), CallSite); + assert(!AllChildContext.count(Hash) && "Node to remove must exist"); + LineLocation OldCallSite = NodeToMove.CallSiteLoc; + ContextTrieNode &OldParentContext = *NodeToMove.getParentContext(); + AllChildContext[Hash] = NodeToMove; + ContextTrieNode &NewNode = AllChildContext[Hash]; + NewNode.CallSiteLoc = CallSite; + + // Walk through nodes in the moved the subtree, and update + // FunctionSamples' context as for the context promotion. + // We also need to set new parant link for all children. + std::queue<ContextTrieNode *> NodeToUpdate; + NewNode.setParentContext(this); + NodeToUpdate.push(&NewNode); + + while (!NodeToUpdate.empty()) { + ContextTrieNode *Node = NodeToUpdate.front(); + NodeToUpdate.pop(); + FunctionSamples *FSamples = Node->getFunctionSamples(); + + if (FSamples) { + FSamples->getContext().promoteOnPath(ContextStrToRemove); + FSamples->getContext().setState(SyntheticContext); + LLVM_DEBUG(dbgs() << " Context promoted to: " << FSamples->getContext() + << "\n"); + } + + for (auto &It : Node->getAllChildContext()) { + ContextTrieNode *ChildNode = &It.second; + ChildNode->setParentContext(Node); + NodeToUpdate.push(ChildNode); + } + } + + // Original context no longer needed, destroy if requested. + if (DeleteNode) + OldParentContext.removeChildContext(OldCallSite, NewNode.getFuncName()); + + return NewNode; +} + +void ContextTrieNode::removeChildContext(const LineLocation &CallSite, + StringRef CalleeName) { + uint32_t Hash = nodeHash(CalleeName, CallSite); + // Note this essentially calls dtor and destroys that child context + AllChildContext.erase(Hash); +} + +std::map<uint32_t, ContextTrieNode> &ContextTrieNode::getAllChildContext() { + return AllChildContext; +} + +const StringRef ContextTrieNode::getFuncName() const { return FuncName; } + +FunctionSamples *ContextTrieNode::getFunctionSamples() const { + return FuncSamples; +} + +void ContextTrieNode::setFunctionSamples(FunctionSamples *FSamples) { + FuncSamples = FSamples; +} + +LineLocation ContextTrieNode::getCallSiteLoc() const { return CallSiteLoc; } + +ContextTrieNode *ContextTrieNode::getParentContext() const { + return ParentContext; +} + +void ContextTrieNode::setParentContext(ContextTrieNode *Parent) { + ParentContext = Parent; +} + +void ContextTrieNode::dump() { + dbgs() << "Node: " << FuncName << "\n" + << " Callsite: " << CallSiteLoc << "\n" + << " Children:\n"; + + for (auto &It : AllChildContext) { + dbgs() << " Node: " << It.second.getFuncName() << "\n"; + } +} + +uint32_t ContextTrieNode::nodeHash(StringRef ChildName, + const LineLocation &Callsite) { + // We still use child's name for child hash, this is + // because for children of root node, we don't have + // different line/discriminator, and we'll rely on name + // to differentiate children. + uint32_t NameHash = std::hash<std::string>{}(ChildName.str()); + uint32_t LocId = (Callsite.LineOffset << 16) | Callsite.Discriminator; + return NameHash + (LocId << 5) + LocId; +} + +ContextTrieNode *ContextTrieNode::getOrCreateChildContext( + const LineLocation &CallSite, StringRef CalleeName, bool AllowCreate) { + uint32_t Hash = nodeHash(CalleeName, CallSite); + auto It = AllChildContext.find(Hash); + if (It != AllChildContext.end()) { + assert(It->second.getFuncName() == CalleeName && + "Hash collision for child context node"); + return &It->second; + } + + if (!AllowCreate) + return nullptr; + + AllChildContext[Hash] = ContextTrieNode(this, CalleeName, nullptr, CallSite); + return &AllChildContext[Hash]; +} + +// Profiler tracker than manages profiles and its associated context +SampleContextTracker::SampleContextTracker( + StringMap<FunctionSamples> &Profiles) { + for (auto &FuncSample : Profiles) { + FunctionSamples *FSamples = &FuncSample.second; + SampleContext Context(FuncSample.first(), RawContext); + LLVM_DEBUG(dbgs() << "Tracking Context for function: " << Context << "\n"); + if (!Context.isBaseContext()) + FuncToCtxtProfileSet[Context.getName()].insert(FSamples); + ContextTrieNode *NewNode = getOrCreateContextPath(Context, true); + assert(!NewNode->getFunctionSamples() && + "New node can't have sample profile"); + NewNode->setFunctionSamples(FSamples); + } +} + +FunctionSamples * +SampleContextTracker::getCalleeContextSamplesFor(const CallBase &Inst, + StringRef CalleeName) { + LLVM_DEBUG(dbgs() << "Getting callee context for instr: " << Inst << "\n"); + // CSFDO-TODO: We use CalleeName to differentiate indirect call + // We need to get sample for indirect callee too. + DILocation *DIL = Inst.getDebugLoc(); + if (!DIL) + return nullptr; + + ContextTrieNode *CalleeContext = getCalleeContextFor(DIL, CalleeName); + if (CalleeContext) { + FunctionSamples *FSamples = CalleeContext->getFunctionSamples(); + LLVM_DEBUG(if (FSamples) { + dbgs() << " Callee context found: " << FSamples->getContext() << "\n"; + }); + return FSamples; + } + + return nullptr; +} + +FunctionSamples * +SampleContextTracker::getContextSamplesFor(const DILocation *DIL) { + assert(DIL && "Expect non-null location"); + + ContextTrieNode *ContextNode = getContextFor(DIL); + if (!ContextNode) + return nullptr; + + // We may have inlined callees during pre-LTO compilation, in which case + // we need to rely on the inline stack from !dbg to mark context profile + // as inlined, instead of `MarkContextSamplesInlined` during inlining. + // Sample profile loader walks through all instructions to get profile, + // which calls this function. So once that is done, all previously inlined + // context profile should be marked properly. + FunctionSamples *Samples = ContextNode->getFunctionSamples(); + if (Samples && ContextNode->getParentContext() != &RootContext) + Samples->getContext().setState(InlinedContext); + + return Samples; +} + +FunctionSamples * +SampleContextTracker::getContextSamplesFor(const SampleContext &Context) { + ContextTrieNode *Node = getContextFor(Context); + if (!Node) + return nullptr; + + return Node->getFunctionSamples(); +} + +FunctionSamples *SampleContextTracker::getBaseSamplesFor(const Function &Func, + bool MergeContext) { + StringRef CanonName = FunctionSamples::getCanonicalFnName(Func); + return getBaseSamplesFor(CanonName, MergeContext); +} + +FunctionSamples *SampleContextTracker::getBaseSamplesFor(StringRef Name, + bool MergeContext) { + LLVM_DEBUG(dbgs() << "Getting base profile for function: " << Name << "\n"); + // Base profile is top-level node (child of root node), so try to retrieve + // existing top-level node for given function first. If it exists, it could be + // that we've merged base profile before, or there's actually context-less + // profile from the input (e.g. due to unreliable stack walking). + ContextTrieNode *Node = getTopLevelContextNode(Name); + if (MergeContext) { + LLVM_DEBUG(dbgs() << " Merging context profile into base profile: " << Name + << "\n"); + + // We have profile for function under different contexts, + // create synthetic base profile and merge context profiles + // into base profile. + for (auto *CSamples : FuncToCtxtProfileSet[Name]) { + SampleContext &Context = CSamples->getContext(); + ContextTrieNode *FromNode = getContextFor(Context); + if (FromNode == Node) + continue; + + // Skip inlined context profile and also don't re-merge any context + if (Context.hasState(InlinedContext) || Context.hasState(MergedContext)) + continue; + + ContextTrieNode &ToNode = promoteMergeContextSamplesTree(*FromNode); + assert((!Node || Node == &ToNode) && "Expect only one base profile"); + Node = &ToNode; + } + } + + // Still no profile even after merge/promotion (if allowed) + if (!Node) + return nullptr; + + return Node->getFunctionSamples(); +} + +void SampleContextTracker::markContextSamplesInlined( + const FunctionSamples *InlinedSamples) { + assert(InlinedSamples && "Expect non-null inlined samples"); + LLVM_DEBUG(dbgs() << "Marking context profile as inlined: " + << InlinedSamples->getContext() << "\n"); + InlinedSamples->getContext().setState(InlinedContext); +} + +void SampleContextTracker::promoteMergeContextSamplesTree( + const Instruction &Inst, StringRef CalleeName) { + LLVM_DEBUG(dbgs() << "Promoting and merging context tree for instr: \n" + << Inst << "\n"); + // CSFDO-TODO: We also need to promote context profile from indirect + // calls. We won't have callee names from those from call instr. + if (CalleeName.empty()) + return; + + // Get the caller context for the call instruction, we don't use callee + // name from call because there can be context from indirect calls too. + DILocation *DIL = Inst.getDebugLoc(); + ContextTrieNode *CallerNode = getContextFor(DIL); + if (!CallerNode) + return; + + // Get the context that needs to be promoted + LineLocation CallSite(FunctionSamples::getOffset(DIL), + DIL->getBaseDiscriminator()); + ContextTrieNode *NodeToPromo = + CallerNode->getChildContext(CallSite, CalleeName); + if (!NodeToPromo) + return; + + promoteMergeContextSamplesTree(*NodeToPromo); +} + +ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree( + ContextTrieNode &NodeToPromo) { + // Promote the input node to be directly under root. This can happen + // when we decided to not inline a function under context represented + // by the input node. The promote and merge is then needed to reflect + // the context profile in the base (context-less) profile. + FunctionSamples *FromSamples = NodeToPromo.getFunctionSamples(); + assert(FromSamples && "Shouldn't promote a context without profile"); + LLVM_DEBUG(dbgs() << " Found context tree root to promote: " + << FromSamples->getContext() << "\n"); + + StringRef ContextStrToRemove = FromSamples->getContext().getCallingContext(); + return promoteMergeContextSamplesTree(NodeToPromo, RootContext, + ContextStrToRemove); +} + +void SampleContextTracker::dump() { + dbgs() << "Context Profile Tree:\n"; + std::queue<ContextTrieNode *> NodeQueue; + NodeQueue.push(&RootContext); + + while (!NodeQueue.empty()) { + ContextTrieNode *Node = NodeQueue.front(); + NodeQueue.pop(); + Node->dump(); + + for (auto &It : Node->getAllChildContext()) { + ContextTrieNode *ChildNode = &It.second; + NodeQueue.push(ChildNode); + } + } +} + +ContextTrieNode * +SampleContextTracker::getContextFor(const SampleContext &Context) { + return getOrCreateContextPath(Context, false); +} + +ContextTrieNode * +SampleContextTracker::getCalleeContextFor(const DILocation *DIL, + StringRef CalleeName) { + assert(DIL && "Expect non-null location"); + + // CSSPGO-TODO: need to support indirect callee + if (CalleeName.empty()) + return nullptr; + + ContextTrieNode *CallContext = getContextFor(DIL); + if (!CallContext) + return nullptr; + + return CallContext->getChildContext( + LineLocation(FunctionSamples::getOffset(DIL), + DIL->getBaseDiscriminator()), + CalleeName); +} + +ContextTrieNode *SampleContextTracker::getContextFor(const DILocation *DIL) { + assert(DIL && "Expect non-null location"); + SmallVector<std::pair<LineLocation, StringRef>, 10> S; + + // Use C++ linkage name if possible. + const DILocation *PrevDIL = DIL; + for (DIL = DIL->getInlinedAt(); DIL; DIL = DIL->getInlinedAt()) { + StringRef Name = PrevDIL->getScope()->getSubprogram()->getLinkageName(); + if (Name.empty()) + Name = PrevDIL->getScope()->getSubprogram()->getName(); + S.push_back( + std::make_pair(LineLocation(FunctionSamples::getOffset(DIL), + DIL->getBaseDiscriminator()), Name)); + PrevDIL = DIL; + } + + // Push root node, note that root node like main may only + // a name, but not linkage name. + StringRef RootName = PrevDIL->getScope()->getSubprogram()->getLinkageName(); + if (RootName.empty()) + RootName = PrevDIL->getScope()->getSubprogram()->getName(); + S.push_back(std::make_pair(LineLocation(0, 0), RootName)); + + ContextTrieNode *ContextNode = &RootContext; + int I = S.size(); + while (--I >= 0 && ContextNode) { + LineLocation &CallSite = S[I].first; + StringRef &CalleeName = S[I].second; + ContextNode = ContextNode->getChildContext(CallSite, CalleeName); + } + + if (I < 0) + return ContextNode; + + return nullptr; +} + +ContextTrieNode * +SampleContextTracker::getOrCreateContextPath(const SampleContext &Context, + bool AllowCreate) { + ContextTrieNode *ContextNode = &RootContext; + StringRef ContextRemain = Context; + StringRef ChildContext; + StringRef CalleeName; + LineLocation CallSiteLoc(0, 0); + + while (ContextNode && !ContextRemain.empty()) { + auto ContextSplit = SampleContext::splitContextString(ContextRemain); + ChildContext = ContextSplit.first; + ContextRemain = ContextSplit.second; + LineLocation NextCallSiteLoc(0, 0); + SampleContext::decodeContextString(ChildContext, CalleeName, + NextCallSiteLoc); + + // Create child node at parent line/disc location + if (AllowCreate) { + ContextNode = + ContextNode->getOrCreateChildContext(CallSiteLoc, CalleeName); + } else { + ContextNode = ContextNode->getChildContext(CallSiteLoc, CalleeName); + } + CallSiteLoc = NextCallSiteLoc; + } + + assert((!AllowCreate || ContextNode) && + "Node must exist if creation is allowed"); + return ContextNode; +} + +ContextTrieNode *SampleContextTracker::getTopLevelContextNode(StringRef FName) { + return RootContext.getChildContext(LineLocation(0, 0), FName); +} + +ContextTrieNode &SampleContextTracker::addTopLevelContextNode(StringRef FName) { + assert(!getTopLevelContextNode(FName) && "Node to add must not exist"); + return *RootContext.getOrCreateChildContext(LineLocation(0, 0), FName); +} + +void SampleContextTracker::mergeContextNode(ContextTrieNode &FromNode, + ContextTrieNode &ToNode, + StringRef ContextStrToRemove) { + FunctionSamples *FromSamples = FromNode.getFunctionSamples(); + FunctionSamples *ToSamples = ToNode.getFunctionSamples(); + if (FromSamples && ToSamples) { + // Merge/duplicate FromSamples into ToSamples + ToSamples->merge(*FromSamples); + ToSamples->getContext().setState(SyntheticContext); + FromSamples->getContext().setState(MergedContext); + } else if (FromSamples) { + // Transfer FromSamples from FromNode to ToNode + ToNode.setFunctionSamples(FromSamples); + FromSamples->getContext().setState(SyntheticContext); + FromSamples->getContext().promoteOnPath(ContextStrToRemove); + FromNode.setFunctionSamples(nullptr); + } +} + +ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree( + ContextTrieNode &FromNode, ContextTrieNode &ToNodeParent, + StringRef ContextStrToRemove) { + assert(!ContextStrToRemove.empty() && "Context to remove can't be empty"); + + // Ignore call site location if destination is top level under root + LineLocation NewCallSiteLoc = LineLocation(0, 0); + LineLocation OldCallSiteLoc = FromNode.getCallSiteLoc(); + ContextTrieNode &FromNodeParent = *FromNode.getParentContext(); + ContextTrieNode *ToNode = nullptr; + bool MoveToRoot = (&ToNodeParent == &RootContext); + if (!MoveToRoot) { + NewCallSiteLoc = OldCallSiteLoc; + } + + // Locate destination node, create/move if not existing + ToNode = ToNodeParent.getChildContext(NewCallSiteLoc, FromNode.getFuncName()); + if (!ToNode) { + // Do not delete node to move from its parent here because + // caller is iterating over children of that parent node. + ToNode = &ToNodeParent.moveToChildContext( + NewCallSiteLoc, std::move(FromNode), ContextStrToRemove, false); + } else { + // Destination node exists, merge samples for the context tree + mergeContextNode(FromNode, *ToNode, ContextStrToRemove); + LLVM_DEBUG(dbgs() << " Context promoted and merged to: " + << ToNode->getFunctionSamples()->getContext() << "\n"); + + // Recursively promote and merge children + for (auto &It : FromNode.getAllChildContext()) { + ContextTrieNode &FromChildNode = It.second; + promoteMergeContextSamplesTree(FromChildNode, *ToNode, + ContextStrToRemove); + } + + // Remove children once they're all merged + FromNode.getAllChildContext().clear(); + } + + // For root of subtree, remove itself from old parent too + if (MoveToRoot) + FromNodeParent.removeChildContext(OldCallSiteLoc, ToNode->getFuncName()); + + return *ToNode; +} + +} // namespace llvm diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfile.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfile.cpp index b6871e260532..264ac4065e8c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfile.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfile.cpp @@ -43,6 +43,7 @@ #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/ReplayInlineAdvisor.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/BasicBlock.h" @@ -75,10 +76,11 @@ #include "llvm/Support/GenericDomTree.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/SampleContextTracker.h" +#include "llvm/Transforms/IPO/SampleProfileProbe.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/CallPromotionUtils.h" #include "llvm/Transforms/Utils/Cloning.h" -#include "llvm/Transforms/Utils/MisExpect.h" #include <algorithm> #include <cassert> #include <cstdint> @@ -102,6 +104,9 @@ STATISTIC(NumCSInlined, "Number of functions inlined with context sensitive profile"); STATISTIC(NumCSNotInlined, "Number of functions not inlined with context sensitive profile"); +STATISTIC(NumMismatchedProfile, + "Number of functions with CFG mismatched profile"); +STATISTIC(NumMatchedProfile, "Number of functions with CFG matched profile"); // Command line option to specify the file to read samples from. This is // mainly used for debugging. @@ -170,6 +175,13 @@ static cl::opt<int> SampleColdCallSiteThreshold( "sample-profile-cold-inline-threshold", cl::Hidden, cl::init(45), cl::desc("Threshold for inlining cold callsites")); +static cl::opt<std::string> ProfileInlineReplayFile( + "sample-profile-inline-replay", cl::init(""), cl::value_desc("filename"), + cl::desc( + "Optimization remarks file containing inline remarks to be replayed " + "by inlining from sample profile loader."), + cl::Hidden); + namespace { using BlockWeightMap = DenseMap<const BasicBlock *, uint64_t>; @@ -309,17 +321,16 @@ private: class SampleProfileLoader { public: SampleProfileLoader( - StringRef Name, StringRef RemapName, bool IsThinLTOPreLink, + StringRef Name, StringRef RemapName, ThinOrFullLTOPhase LTOPhase, std::function<AssumptionCache &(Function &)> GetAssumptionCache, std::function<TargetTransformInfo &(Function &)> GetTargetTransformInfo, std::function<const TargetLibraryInfo &(Function &)> GetTLI) : GetAC(std::move(GetAssumptionCache)), GetTTI(std::move(GetTargetTransformInfo)), GetTLI(std::move(GetTLI)), CoverageTracker(*this), Filename(std::string(Name)), - RemappingFilename(std::string(RemapName)), - IsThinLTOPreLink(IsThinLTOPreLink) {} + RemappingFilename(std::string(RemapName)), LTOPhase(LTOPhase) {} - bool doInitialization(Module &M); + bool doInitialization(Module &M, FunctionAnalysisManager *FAM = nullptr); bool runOnModule(Module &M, ModuleAnalysisManager *AM, ProfileSummaryInfo *_PSI, CallGraph *CG); @@ -332,6 +343,7 @@ protected: unsigned getFunctionLoc(Function &F); bool emitAnnotations(Function &F); ErrorOr<uint64_t> getInstWeight(const Instruction &I); + ErrorOr<uint64_t> getProbeWeight(const Instruction &I); ErrorOr<uint64_t> getBlockWeight(const BasicBlock *BB); const FunctionSamples *findCalleeFunctionSamples(const CallBase &I) const; std::vector<const FunctionSamples *> @@ -417,6 +429,9 @@ protected: /// Profile reader object. std::unique_ptr<SampleProfileReader> Reader; + /// Profile tracker for different context. + std::unique_ptr<SampleContextTracker> ContextTracker; + /// Samples collected for the body of this function. FunctionSamples *Samples = nullptr; @@ -429,11 +444,15 @@ protected: /// Flag indicating whether the profile input loaded successfully. bool ProfileIsValid = false; - /// Flag indicating if the pass is invoked in ThinLTO compile phase. + /// Flag indicating whether input profile is context-sensitive + bool ProfileIsCS = false; + + /// Flag indicating which LTO/ThinLTO phase the pass is invoked in. /// - /// In this phase, in annotation, we should not promote indirect calls. - /// Instead, we will mark GUIDs that needs to be annotated to the function. - bool IsThinLTOPreLink; + /// We need to know the LTO phase because for example in ThinLTOPrelink + /// phase, in annotation, we should not promote indirect calls. Instead, + /// we will mark GUIDs that needs to be annotated to the function. + ThinOrFullLTOPhase LTOPhase; /// Profile Summary Info computed from sample profile. ProfileSummaryInfo *PSI = nullptr; @@ -473,6 +492,12 @@ protected: // overriden by -profile-sample-accurate or profile-sample-accurate // attribute. bool ProfAccForSymsInList; + + // External inline advisor used to replay inline decision from remarks. + std::unique_ptr<ReplayInlineAdvisor> ExternalInlineAdvisor; + + // A pseudo probe helper to correlate the imported sample counts. + std::unique_ptr<PseudoProbeManager> ProbeManager; }; class SampleProfileLoaderLegacyPass : public ModulePass { @@ -480,10 +505,11 @@ public: // Class identification, replacement for typeinfo static char ID; - SampleProfileLoaderLegacyPass(StringRef Name = SampleProfileFile, - bool IsThinLTOPreLink = false) + SampleProfileLoaderLegacyPass( + StringRef Name = SampleProfileFile, + ThinOrFullLTOPhase LTOPhase = ThinOrFullLTOPhase::None) : ModulePass(ID), SampleLoader( - Name, SampleProfileRemappingFile, IsThinLTOPreLink, + Name, SampleProfileRemappingFile, LTOPhase, [&](Function &F) -> AssumptionCache & { return ACT->getAssumptionCache(F); }, @@ -705,6 +731,9 @@ void SampleProfileLoader::printBlockWeight(raw_ostream &OS, /// /// \returns the weight of \p Inst. ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) { + if (FunctionSamples::ProfileIsProbeBased) + return getProbeWeight(Inst); + const DebugLoc &DLoc = Inst.getDebugLoc(); if (!DLoc) return std::error_code(); @@ -723,9 +752,10 @@ ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) { // (findCalleeFunctionSamples returns non-empty result), but not inlined here, // it means that the inlined callsite has no sample, thus the call // instruction should have 0 count. - if (auto *CB = dyn_cast<CallBase>(&Inst)) - if (!CB->isIndirectCall() && findCalleeFunctionSamples(*CB)) - return 0; + if (!ProfileIsCS) + if (const auto *CB = dyn_cast<CallBase>(&Inst)) + if (!CB->isIndirectCall() && findCalleeFunctionSamples(*CB)) + return 0; const DILocation *DIL = DLoc; uint32_t LineOffset = FunctionSamples::getOffset(DIL); @@ -757,6 +787,47 @@ ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) { return R; } +ErrorOr<uint64_t> SampleProfileLoader::getProbeWeight(const Instruction &Inst) { + assert(FunctionSamples::ProfileIsProbeBased && + "Profile is not pseudo probe based"); + Optional<PseudoProbe> Probe = extractProbe(Inst); + if (!Probe) + return std::error_code(); + + const FunctionSamples *FS = findFunctionSamples(Inst); + if (!FS) + return std::error_code(); + + // If a direct call/invoke instruction is inlined in profile + // (findCalleeFunctionSamples returns non-empty result), but not inlined here, + // it means that the inlined callsite has no sample, thus the call + // instruction should have 0 count. + if (const auto *CB = dyn_cast<CallBase>(&Inst)) + if (!CB->isIndirectCall() && findCalleeFunctionSamples(*CB)) + return 0; + + const ErrorOr<uint64_t> &R = FS->findSamplesAt(Probe->Id, 0); + if (R) { + uint64_t Samples = R.get(); + bool FirstMark = CoverageTracker.markSamplesUsed(FS, Probe->Id, 0, Samples); + if (FirstMark) { + ORE->emit([&]() { + OptimizationRemarkAnalysis Remark(DEBUG_TYPE, "AppliedSamples", &Inst); + Remark << "Applied " << ore::NV("NumSamples", Samples); + Remark << " samples from profile (ProbeId="; + Remark << ore::NV("ProbeId", Probe->Id); + Remark << ")"; + return Remark; + }); + } + + LLVM_DEBUG(dbgs() << " " << Probe->Id << ":" << Inst + << " - weight: " << R.get() << ")\n"); + return Samples; + } + return R; +} + /// Compute the weight of a basic block. /// /// The weight of basic block \p BB is the maximum weight of all the @@ -820,17 +891,18 @@ SampleProfileLoader::findCalleeFunctionSamples(const CallBase &Inst) const { } StringRef CalleeName; - if (const CallInst *CI = dyn_cast<CallInst>(&Inst)) - if (Function *Callee = CI->getCalledFunction()) - CalleeName = Callee->getName(); + if (Function *Callee = Inst.getCalledFunction()) + CalleeName = FunctionSamples::getCanonicalFnName(*Callee); + + if (ProfileIsCS) + return ContextTracker->getCalleeContextSamplesFor(Inst, CalleeName); const FunctionSamples *FS = findFunctionSamples(Inst); if (FS == nullptr) return nullptr; - return FS->findFunctionSamplesAt(LineLocation(FunctionSamples::getOffset(DIL), - DIL->getBaseDiscriminator()), - CalleeName); + return FS->findFunctionSamplesAt(FunctionSamples::getCallSiteIdentifier(DIL), + CalleeName, Reader->getRemapper()); } /// Returns a vector of FunctionSamples that are the indirect call targets @@ -850,16 +922,13 @@ SampleProfileLoader::findIndirectCallFunctionSamples( if (FS == nullptr) return R; - uint32_t LineOffset = FunctionSamples::getOffset(DIL); - uint32_t Discriminator = DIL->getBaseDiscriminator(); - - auto T = FS->findCallTargetMapAt(LineOffset, Discriminator); + auto CallSite = FunctionSamples::getCallSiteIdentifier(DIL); + auto T = FS->findCallTargetMapAt(CallSite); Sum = 0; if (T) for (const auto &T_C : T.get()) Sum += T_C.second; - if (const FunctionSamplesMap *M = FS->findFunctionSamplesMapAt(LineLocation( - FunctionSamples::getOffset(DIL), DIL->getBaseDiscriminator()))) { + if (const FunctionSamplesMap *M = FS->findFunctionSamplesMapAt(CallSite)) { if (M->empty()) return R; for (const auto &NameFS : *M) { @@ -887,17 +956,38 @@ SampleProfileLoader::findIndirectCallFunctionSamples( /// \returns the FunctionSamples pointer to the inlined instance. const FunctionSamples * SampleProfileLoader::findFunctionSamples(const Instruction &Inst) const { + if (FunctionSamples::ProfileIsProbeBased) { + Optional<PseudoProbe> Probe = extractProbe(Inst); + if (!Probe) + return nullptr; + } + const DILocation *DIL = Inst.getDebugLoc(); if (!DIL) return Samples; auto it = DILocation2SampleMap.try_emplace(DIL,nullptr); - if (it.second) - it.first->second = Samples->findFunctionSamples(DIL); + if (it.second) { + if (ProfileIsCS) + it.first->second = ContextTracker->getContextSamplesFor(DIL); + else + it.first->second = + Samples->findFunctionSamples(DIL, Reader->getRemapper()); + } return it.first->second; } bool SampleProfileLoader::inlineCallInstruction(CallBase &CB) { + if (ExternalInlineAdvisor) { + auto Advice = ExternalInlineAdvisor->getAdvice(CB); + if (!Advice->isInliningRecommended()) { + Advice->recordUnattemptedInlining(); + return false; + } + // Dummy record, we don't use it for replay. + Advice->recordInlining(); + } + Function *CalledFunction = CB.getCalledFunction(); assert(CalledFunction); DebugLoc DLoc = CB.getDebugLoc(); @@ -938,6 +1028,12 @@ bool SampleProfileLoader::shouldInlineColdCallee(CallBase &CallInst) { InlineCost Cost = getInlineCost(CallInst, getInlineParams(), GetTTI(*Callee), GetAC, GetTLI); + if (Cost.isNever()) + return false; + + if (Cost.isAlways()) + return true; + return Cost.getCost() <= SampleColdCallSiteThreshold; } @@ -995,8 +1091,10 @@ bool SampleProfileLoader::inlineHotFunctions( const FunctionSamples *FS = nullptr; if (auto *CB = dyn_cast<CallBase>(&I)) { if (!isa<IntrinsicInst>(I) && (FS = findCalleeFunctionSamples(*CB))) { + assert((!FunctionSamples::UseMD5 || FS->GUIDToFuncNameMap) && + "GUIDToFuncNameMap has to be populated"); AllCandidates.push_back(CB); - if (FS->getEntrySamples() > 0) + if (FS->getEntrySamples() > 0 || ProfileIsCS) localNotInlinedCallSites.try_emplace(CB, FS); if (callsiteIsHot(FS, PSI)) Hot = true; @@ -1005,7 +1103,7 @@ bool SampleProfileLoader::inlineHotFunctions( } } } - if (Hot) { + if (Hot || ExternalInlineAdvisor) { CIS.insert(CIS.begin(), AllCandidates.begin(), AllCandidates.end()); emitOptimizationRemarksForInlineCandidates(AllCandidates, F, true); } else { @@ -1023,29 +1121,28 @@ bool SampleProfileLoader::inlineHotFunctions( continue; uint64_t Sum; for (const auto *FS : findIndirectCallFunctionSamples(*I, Sum)) { - if (IsThinLTOPreLink) { + if (LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink) { FS->findInlinedFunctions(InlinedGUIDs, F.getParent(), PSI->getOrCompHotCountThreshold()); continue; } - auto CalleeFunctionName = FS->getFuncName(); + if (!callsiteIsHot(FS, PSI)) + continue; + + const char *Reason = "Callee function not available"; + // R->getValue() != &F is to prevent promoting a recursive call. // If it is a recursive call, we do not inline it as it could bloat // the code exponentially. There is way to better handle this, e.g. // clone the caller first, and inline the cloned caller if it is // recursive. As llvm does not inline recursive calls, we will // simply ignore it instead of handling it explicitly. - if (CalleeFunctionName == F.getName()) - continue; - - if (!callsiteIsHot(FS, PSI)) - continue; - - const char *Reason = "Callee function not available"; + auto CalleeFunctionName = FS->getFuncName(); auto R = SymbolMap.find(CalleeFunctionName); if (R != SymbolMap.end() && R->getValue() && !R->getValue()->isDeclaration() && R->getValue()->getSubprogram() && R->getValue()->hasFnAttribute("use-sample-profile") && + R->getValue() != &F && isLegalToPromote(*I, R->getValue(), &Reason)) { uint64_t C = FS->getEntrySamples(); auto &DI = @@ -1055,6 +1152,8 @@ bool SampleProfileLoader::inlineHotFunctions( // If profile mismatches, we should not attempt to inline DI. if ((isa<CallInst>(DI) || isa<InvokeInst>(DI)) && inlineCallInstruction(cast<CallBase>(DI))) { + if (ProfileIsCS) + ContextTracker->markContextSamplesInlined(FS); localNotInlinedCallSites.erase(I); LocalChanged = true; ++NumCSInlined; @@ -1068,11 +1167,14 @@ bool SampleProfileLoader::inlineHotFunctions( } else if (CalledFunction && CalledFunction->getSubprogram() && !CalledFunction->isDeclaration()) { if (inlineCallInstruction(*I)) { + if (ProfileIsCS) + ContextTracker->markContextSamplesInlined( + localNotInlinedCallSites[I]); localNotInlinedCallSites.erase(I); LocalChanged = true; ++NumCSInlined; } - } else if (IsThinLTOPreLink) { + } else if (LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink) { findCalleeFunctionSamples(*I)->findInlinedFunctions( InlinedGUIDs, F.getParent(), PSI->getOrCompHotCountThreshold()); } @@ -1104,16 +1206,23 @@ bool SampleProfileLoader::inlineHotFunctions( } if (ProfileMergeInlinee) { - // Use entry samples as head samples during the merge, as inlinees - // don't have head samples. - assert(FS->getHeadSamples() == 0 && "Expect 0 head sample for inlinee"); - const_cast<FunctionSamples *>(FS)->addHeadSamples(FS->getEntrySamples()); - - // Note that we have to do the merge right after processing function. - // This allows OutlineFS's profile to be used for annotation during - // top-down processing of functions' annotation. - FunctionSamples *OutlineFS = Reader->getOrCreateSamplesFor(*Callee); - OutlineFS->merge(*FS); + // A function call can be replicated by optimizations like callsite + // splitting or jump threading and the replicates end up sharing the + // sample nested callee profile instead of slicing the original inlinee's + // profile. We want to do merge exactly once by filtering out callee + // profiles with a non-zero head sample count. + if (FS->getHeadSamples() == 0) { + // Use entry samples as head samples during the merge, as inlinees + // don't have head samples. + const_cast<FunctionSamples *>(FS)->addHeadSamples( + FS->getEntrySamples()); + + // Note that we have to do the merge right after processing function. + // This allows OutlineFS's profile to be used for annotation during + // top-down processing of functions' annotation. + FunctionSamples *OutlineFS = Reader->getOrCreateSamplesFor(*Callee); + OutlineFS->merge(*FS); + } } else { auto pair = notInlinedCallInfo.try_emplace(Callee, NotInlinedProfileInfo{0}); @@ -1538,13 +1647,11 @@ void SampleProfileLoader::propagateWeights(Function &F) { if (!DLoc) continue; const DILocation *DIL = DLoc; - uint32_t LineOffset = FunctionSamples::getOffset(DIL); - uint32_t Discriminator = DIL->getBaseDiscriminator(); - const FunctionSamples *FS = findFunctionSamples(I); if (!FS) continue; - auto T = FS->findCallTargetMapAt(LineOffset, Discriminator); + auto CallSite = FunctionSamples::getCallSiteIdentifier(DIL); + auto T = FS->findCallTargetMapAt(CallSite); if (!T || T.get().empty()) continue; SmallVector<InstrProfValueData, 2> SortedCallTargets = @@ -1598,8 +1705,6 @@ void SampleProfileLoader::propagateWeights(Function &F) { } } - misexpect::verifyMisExpect(TI, Weights, TI->getContext()); - uint64_t TempWeight; // Only set weights if there is at least one non-zero weight. // In any other case, let the analyzer set weights. @@ -1710,11 +1815,22 @@ void SampleProfileLoader::computeDominanceAndLoopInfo(Function &F) { bool SampleProfileLoader::emitAnnotations(Function &F) { bool Changed = false; - if (getFunctionLoc(F) == 0) - return false; + if (FunctionSamples::ProfileIsProbeBased) { + if (!ProbeManager->profileIsValid(F, *Samples)) { + LLVM_DEBUG( + dbgs() << "Profile is invalid due to CFG mismatch for Function " + << F.getName()); + ++NumMismatchedProfile; + return false; + } + ++NumMatchedProfile; + } else { + if (getFunctionLoc(F) == 0) + return false; - LLVM_DEBUG(dbgs() << "Line number for the first instruction in " - << F.getName() << ": " << getFunctionLoc(F) << "\n"); + LLVM_DEBUG(dbgs() << "Line number for the first instruction in " + << F.getName() << ": " << getFunctionLoc(F) << "\n"); + } DenseSet<GlobalValue::GUID> InlinedGUIDs; Changed |= inlineHotFunctions(F, InlinedGUIDs); @@ -1818,10 +1934,10 @@ SampleProfileLoader::buildFunctionOrder(Module &M, CallGraph *CG) { return FunctionOrderList; } -bool SampleProfileLoader::doInitialization(Module &M) { +bool SampleProfileLoader::doInitialization(Module &M, + FunctionAnalysisManager *FAM) { auto &Ctx = M.getContext(); - std::unique_ptr<SampleProfileReaderItaniumRemapper> RemapReader; auto ReaderOrErr = SampleProfileReader::create(Filename, Ctx, RemappingFilename); if (std::error_code EC = ReaderOrErr.getError()) { @@ -1830,8 +1946,14 @@ bool SampleProfileLoader::doInitialization(Module &M) { return false; } Reader = std::move(ReaderOrErr.get()); + Reader->setSkipFlatProf(LTOPhase == ThinOrFullLTOPhase::ThinLTOPostLink); Reader->collectFuncsFrom(M); - ProfileIsValid = (Reader->read() == sampleprof_error::success); + if (std::error_code EC = Reader->read()) { + std::string Msg = "profile reading failed: " + EC.message(); + Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg)); + return false; + } + PSL = Reader->getProfileSymbolList(); // While profile-sample-accurate is on, ignore symbol list. @@ -1843,6 +1965,35 @@ bool SampleProfileLoader::doInitialization(Module &M) { NamesInProfile.insert(NameTable->begin(), NameTable->end()); } + if (FAM && !ProfileInlineReplayFile.empty()) { + ExternalInlineAdvisor = std::make_unique<ReplayInlineAdvisor>( + M, *FAM, Ctx, /*OriginalAdvisor=*/nullptr, ProfileInlineReplayFile, + /*EmitRemarks=*/false); + if (!ExternalInlineAdvisor->areReplayRemarksLoaded()) + ExternalInlineAdvisor.reset(); + } + + // Apply tweaks if context-sensitive profile is available. + if (Reader->profileIsCS()) { + ProfileIsCS = true; + FunctionSamples::ProfileIsCS = true; + + // Tracker for profiles under different context + ContextTracker = + std::make_unique<SampleContextTracker>(Reader->getProfiles()); + } + + // Load pseudo probe descriptors for probe-based function samples. + if (Reader->profileIsProbeBased()) { + ProbeManager = std::make_unique<PseudoProbeManager>(M); + if (!ProbeManager->moduleIsProbed(M)) { + const char *Msg = + "Pseudo-probe-based profile requires SampleProfileProbePass"; + Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg)); + return false; + } + } + return true; } @@ -1856,8 +2007,6 @@ ModulePass *llvm::createSampleProfileLoaderPass(StringRef Name) { bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM, ProfileSummaryInfo *_PSI, CallGraph *CG) { - if (!ProfileIsValid) - return false; GUIDToFuncNameMapper Mapper(M, *Reader, GUIDToFuncNameMap); PSI = _PSI; @@ -1870,6 +2019,7 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM, for (const auto &I : Reader->getProfiles()) TotalCollectedSamples += I.second.getTotalSamples(); + auto Remapper = Reader->getRemapper(); // Populate the symbol map. for (const auto &N_F : M.getValueSymbolTable()) { StringRef OrigName = N_F.getKey(); @@ -1887,6 +2037,15 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM, // to nullptr to avoid confusion. if (!r.second) r.first->second = nullptr; + OrigName = NewName; + } + // Insert the remapped names into SymbolMap. + if (Remapper) { + if (auto MapName = Remapper->lookUpNameInProfile(OrigName)) { + if (*MapName == OrigName) + continue; + SymbolMap.insert(std::make_pair(*MapName, F)); + } } } @@ -1898,9 +2057,10 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM, } // Account for cold calls not inlined.... - for (const std::pair<Function *, NotInlinedProfileInfo> &pair : - notInlinedCallInfo) - updateProfileCallee(pair.first, pair.second.entryCount); + if (!ProfileIsCS) + for (const std::pair<Function *, NotInlinedProfileInfo> &pair : + notInlinedCallInfo) + updateProfileCallee(pair.first, pair.second.entryCount); return retval; } @@ -1915,7 +2075,6 @@ bool SampleProfileLoaderLegacyPass::runOnModule(Module &M) { } bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM) { - DILocation2SampleMap.clear(); // By default the entry count is initialized to -1, which will be treated // conservatively by getEntryCount as the same as unknown (None). This is @@ -1957,7 +2116,10 @@ bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM) initialEntryCount = -1; } - F.setEntryCount(ProfileCount(initialEntryCount, Function::PCT_Real)); + // Initialize entry count when the function has no existing entry + // count value. + if (!F.getEntryCount().hasValue()) + F.setEntryCount(ProfileCount(initialEntryCount, Function::PCT_Real)); std::unique_ptr<OptimizationRemarkEmitter> OwnedORE; if (AM) { auto &FAM = @@ -1968,7 +2130,12 @@ bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM) OwnedORE = std::make_unique<OptimizationRemarkEmitter>(&F); ORE = OwnedORE.get(); } - Samples = Reader->getSamplesFor(F); + + if (ProfileIsCS) + Samples = ContextTracker->getBaseSamplesFor(F); + else + Samples = Reader->getSamplesFor(F); + if (Samples && !Samples->empty()) return emitAnnotations(F); return false; @@ -1993,9 +2160,9 @@ PreservedAnalyses SampleProfileLoaderPass::run(Module &M, ProfileFileName.empty() ? SampleProfileFile : ProfileFileName, ProfileRemappingFileName.empty() ? SampleProfileRemappingFile : ProfileRemappingFileName, - IsThinLTOPreLink, GetAssumptionCache, GetTTI, GetTLI); + LTOPhase, GetAssumptionCache, GetTTI, GetTLI); - if (!SampleLoader.doInitialization(M)) + if (!SampleLoader.doInitialization(M, &FAM)) return PreservedAnalyses::all(); ProfileSummaryInfo *PSI = &AM.getResult<ProfileSummaryAnalysis>(M); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp new file mode 100644 index 000000000000..7cecd20b78d8 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp @@ -0,0 +1,276 @@ +//===- SampleProfileProbe.cpp - Pseudo probe Instrumentation -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the SampleProfileProber transformation. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO/SampleProfileProbe.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/ProfileData/SampleProf.h" +#include "llvm/Support/CRC.h" +#include "llvm/Transforms/Instrumentation.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" +#include <vector> + +using namespace llvm; +#define DEBUG_TYPE "sample-profile-probe" + +STATISTIC(ArtificialDbgLine, + "Number of probes that have an artificial debug line"); + +PseudoProbeManager::PseudoProbeManager(const Module &M) { + if (NamedMDNode *FuncInfo = M.getNamedMetadata(PseudoProbeDescMetadataName)) { + for (const auto *Operand : FuncInfo->operands()) { + const auto *MD = cast<MDNode>(Operand); + auto GUID = + mdconst::dyn_extract<ConstantInt>(MD->getOperand(0))->getZExtValue(); + auto Hash = + mdconst::dyn_extract<ConstantInt>(MD->getOperand(1))->getZExtValue(); + GUIDToProbeDescMap.try_emplace(GUID, PseudoProbeDescriptor(GUID, Hash)); + } + } +} + +const PseudoProbeDescriptor * +PseudoProbeManager::getDesc(const Function &F) const { + auto I = GUIDToProbeDescMap.find( + Function::getGUID(FunctionSamples::getCanonicalFnName(F))); + return I == GUIDToProbeDescMap.end() ? nullptr : &I->second; +} + +bool PseudoProbeManager::moduleIsProbed(const Module &M) const { + return M.getNamedMetadata(PseudoProbeDescMetadataName); +} + +bool PseudoProbeManager::profileIsValid(const Function &F, + const FunctionSamples &Samples) const { + const auto *Desc = getDesc(F); + if (!Desc) { + LLVM_DEBUG(dbgs() << "Probe descriptor missing for Function " << F.getName() + << "\n"); + return false; + } else { + if (Desc->getFunctionHash() != Samples.getFunctionHash()) { + LLVM_DEBUG(dbgs() << "Hash mismatch for Function " << F.getName() + << "\n"); + return false; + } + } + return true; +} + +SampleProfileProber::SampleProfileProber(Function &Func, + const std::string &CurModuleUniqueId) + : F(&Func), CurModuleUniqueId(CurModuleUniqueId) { + BlockProbeIds.clear(); + CallProbeIds.clear(); + LastProbeId = (uint32_t)PseudoProbeReservedId::Last; + computeProbeIdForBlocks(); + computeProbeIdForCallsites(); + computeCFGHash(); +} + +// Compute Hash value for the CFG: the lower 32 bits are CRC32 of the index +// value of each BB in the CFG. The higher 32 bits record the number of edges +// preceded by the number of indirect calls. +// This is derived from FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash(). +void SampleProfileProber::computeCFGHash() { + std::vector<uint8_t> Indexes; + JamCRC JC; + for (auto &BB : *F) { + auto *TI = BB.getTerminator(); + for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I) { + auto *Succ = TI->getSuccessor(I); + auto Index = getBlockId(Succ); + for (int J = 0; J < 4; J++) + Indexes.push_back((uint8_t)(Index >> (J * 8))); + } + } + + JC.update(Indexes); + + FunctionHash = (uint64_t)CallProbeIds.size() << 48 | + (uint64_t)Indexes.size() << 32 | JC.getCRC(); + // Reserve bit 60-63 for other information purpose. + FunctionHash &= 0x0FFFFFFFFFFFFFFF; + assert(FunctionHash && "Function checksum should not be zero"); + LLVM_DEBUG(dbgs() << "\nFunction Hash Computation for " << F->getName() + << ":\n" + << " CRC = " << JC.getCRC() << ", Edges = " + << Indexes.size() << ", ICSites = " << CallProbeIds.size() + << ", Hash = " << FunctionHash << "\n"); +} + +void SampleProfileProber::computeProbeIdForBlocks() { + for (auto &BB : *F) { + BlockProbeIds[&BB] = ++LastProbeId; + } +} + +void SampleProfileProber::computeProbeIdForCallsites() { + for (auto &BB : *F) { + for (auto &I : BB) { + if (!isa<CallBase>(I)) + continue; + if (isa<IntrinsicInst>(&I)) + continue; + CallProbeIds[&I] = ++LastProbeId; + } + } +} + +uint32_t SampleProfileProber::getBlockId(const BasicBlock *BB) const { + auto I = BlockProbeIds.find(const_cast<BasicBlock *>(BB)); + return I == BlockProbeIds.end() ? 0 : I->second; +} + +uint32_t SampleProfileProber::getCallsiteId(const Instruction *Call) const { + auto Iter = CallProbeIds.find(const_cast<Instruction *>(Call)); + return Iter == CallProbeIds.end() ? 0 : Iter->second; +} + +void SampleProfileProber::instrumentOneFunc(Function &F, TargetMachine *TM) { + Module *M = F.getParent(); + MDBuilder MDB(F.getContext()); + // Compute a GUID without considering the function's linkage type. This is + // fine since function name is the only key in the profile database. + uint64_t Guid = Function::getGUID(F.getName()); + + // Assign an artificial debug line to a probe that doesn't come with a real + // line. A probe not having a debug line will get an incomplete inline + // context. This will cause samples collected on the probe to be counted + // into the base profile instead of a context profile. The line number + // itself is not important though. + auto AssignDebugLoc = [&](Instruction *I) { + assert((isa<PseudoProbeInst>(I) || isa<CallBase>(I)) && + "Expecting pseudo probe or call instructions"); + if (!I->getDebugLoc()) { + if (auto *SP = F.getSubprogram()) { + auto DIL = DILocation::get(SP->getContext(), 0, 0, SP); + I->setDebugLoc(DIL); + ArtificialDbgLine++; + LLVM_DEBUG({ + dbgs() << "\nIn Function " << F.getName() + << " Probe gets an artificial debug line\n"; + I->dump(); + }); + } + } + }; + + // Probe basic blocks. + for (auto &I : BlockProbeIds) { + BasicBlock *BB = I.first; + uint32_t Index = I.second; + // Insert a probe before an instruction with a valid debug line number which + // will be assigned to the probe. The line number will be used later to + // model the inline context when the probe is inlined into other functions. + // Debug instructions, phi nodes and lifetime markers do not have an valid + // line number. Real instructions generated by optimizations may not come + // with a line number either. + auto HasValidDbgLine = [](Instruction *J) { + return !isa<PHINode>(J) && !isa<DbgInfoIntrinsic>(J) && + !J->isLifetimeStartOrEnd() && J->getDebugLoc(); + }; + + Instruction *J = &*BB->getFirstInsertionPt(); + while (J != BB->getTerminator() && !HasValidDbgLine(J)) { + J = J->getNextNode(); + } + + IRBuilder<> Builder(J); + assert(Builder.GetInsertPoint() != BB->end() && + "Cannot get the probing point"); + Function *ProbeFn = + llvm::Intrinsic::getDeclaration(M, Intrinsic::pseudoprobe); + Value *Args[] = {Builder.getInt64(Guid), Builder.getInt64(Index), + Builder.getInt32(0)}; + auto *Probe = Builder.CreateCall(ProbeFn, Args); + AssignDebugLoc(Probe); + } + + // Probe both direct calls and indirect calls. Direct calls are probed so that + // their probe ID can be used as an call site identifier to represent a + // calling context. + for (auto &I : CallProbeIds) { + auto *Call = I.first; + uint32_t Index = I.second; + uint32_t Type = cast<CallBase>(Call)->getCalledFunction() + ? (uint32_t)PseudoProbeType::DirectCall + : (uint32_t)PseudoProbeType::IndirectCall; + AssignDebugLoc(Call); + // Levarge the 32-bit discriminator field of debug data to store the ID and + // type of a callsite probe. This gets rid of the dependency on plumbing a + // customized metadata through the codegen pipeline. + uint32_t V = PseudoProbeDwarfDiscriminator::packProbeData(Index, Type); + if (auto DIL = Call->getDebugLoc()) { + DIL = DIL->cloneWithDiscriminator(V); + Call->setDebugLoc(DIL); + } + } + + // Create module-level metadata that contains function info necessary to + // synthesize probe-based sample counts, which are + // - FunctionGUID + // - FunctionHash. + // - FunctionName + auto Hash = getFunctionHash(); + auto *MD = MDB.createPseudoProbeDesc(Guid, Hash, &F); + auto *NMD = M->getNamedMetadata(PseudoProbeDescMetadataName); + assert(NMD && "llvm.pseudo_probe_desc should be pre-created"); + NMD->addOperand(MD); + + // Preserve a comdat group to hold all probes materialized later. This + // allows that when the function is considered dead and removed, the + // materialized probes are disposed too. + // Imported functions are defined in another module. They do not need + // the following handling since same care will be taken for them in their + // original module. The pseudo probes inserted into an imported functions + // above will naturally not be emitted since the imported function is free + // from object emission. However they will be emitted together with the + // inliner functions that the imported function is inlined into. We are not + // creating a comdat group for an import function since it's useless anyway. + if (!F.isDeclarationForLinker()) { + if (TM) { + auto Triple = TM->getTargetTriple(); + if (Triple.supportsCOMDAT() && TM->getFunctionSections()) { + GetOrCreateFunctionComdat(F, Triple, CurModuleUniqueId); + } + } + } +} + +PreservedAnalyses SampleProfileProbePass::run(Module &M, + ModuleAnalysisManager &AM) { + auto ModuleId = getUniqueModuleId(&M); + // Create the pseudo probe desc metadata beforehand. + // Note that modules with only data but no functions will require this to + // be set up so that they will be known as probed later. + M.getOrInsertNamedMetadata(PseudoProbeDescMetadataName); + + for (auto &F : M) { + if (F.isDeclaration()) + continue; + SampleProfileProber ProbeManager(F, ModuleId); + ProbeManager.instrumentOneFunc(F, TM); + } + + return PreservedAnalyses::none(); +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/StripSymbols.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/StripSymbols.cpp index 088091df770f..4fc71847a070 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/StripSymbols.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/StripSymbols.cpp @@ -19,18 +19,21 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/IPO/StripSymbols.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" #include "llvm/IR/TypeFinder.h" #include "llvm/IR/ValueSymbolTable.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/Local.h" + using namespace llvm; namespace { @@ -249,9 +252,7 @@ bool StripNonDebugSymbols::runOnModule(Module &M) { return StripSymbolNames(M, true); } -bool StripDebugDeclare::runOnModule(Module &M) { - if (skipModule(M)) - return false; +static bool stripDebugDeclareImpl(Module &M) { Function *Declare = M.getFunction("llvm.dbg.declare"); std::vector<Constant*> DeadConstants; @@ -289,17 +290,13 @@ bool StripDebugDeclare::runOnModule(Module &M) { return true; } -/// Remove any debug info for global variables/functions in the given module for -/// which said global variable/function no longer exists (i.e. is null). -/// -/// Debugging information is encoded in llvm IR using metadata. This is designed -/// such a way that debug info for symbols preserved even if symbols are -/// optimized away by the optimizer. This special pass removes debug info for -/// such symbols. -bool StripDeadDebugInfo::runOnModule(Module &M) { +bool StripDebugDeclare::runOnModule(Module &M) { if (skipModule(M)) return false; + return stripDebugDeclareImpl(M); +} +static bool stripDeadDebugInfoImpl(Module &M) { bool Changed = false; LLVMContext &C = M.getContext(); @@ -380,3 +377,40 @@ bool StripDeadDebugInfo::runOnModule(Module &M) { return Changed; } + +/// Remove any debug info for global variables/functions in the given module for +/// which said global variable/function no longer exists (i.e. is null). +/// +/// Debugging information is encoded in llvm IR using metadata. This is designed +/// such a way that debug info for symbols preserved even if symbols are +/// optimized away by the optimizer. This special pass removes debug info for +/// such symbols. +bool StripDeadDebugInfo::runOnModule(Module &M) { + if (skipModule(M)) + return false; + return stripDeadDebugInfoImpl(M); +} + +PreservedAnalyses StripSymbolsPass::run(Module &M, ModuleAnalysisManager &AM) { + StripDebugInfo(M); + StripSymbolNames(M, false); + return PreservedAnalyses::all(); +} + +PreservedAnalyses StripNonDebugSymbolsPass::run(Module &M, + ModuleAnalysisManager &AM) { + StripSymbolNames(M, true); + return PreservedAnalyses::all(); +} + +PreservedAnalyses StripDebugDeclarePass::run(Module &M, + ModuleAnalysisManager &AM) { + stripDebugDeclareImpl(M); + return PreservedAnalyses::all(); +} + +PreservedAnalyses StripDeadDebugInfoPass::run(Module &M, + ModuleAnalysisManager &AM) { + stripDeadDebugInfoImpl(M); + return PreservedAnalyses::all(); +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp index 87a18171787f..225b4fe95f67 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp @@ -14,6 +14,7 @@ #include "llvm/Bitcode/BitcodeWriter.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DebugInfo.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" @@ -260,7 +261,7 @@ void splitAndWriteThinLTOBitcode( if (!RT || RT->getBitWidth() > 64 || F->arg_empty() || !F->arg_begin()->use_empty()) return; - for (auto &Arg : make_range(std::next(F->arg_begin()), F->arg_end())) { + for (auto &Arg : drop_begin(F->args())) { auto *ArgT = dyn_cast<IntegerType>(Arg.getType()); if (!ArgT || ArgT->getBitWidth() > 64) return; @@ -333,8 +334,7 @@ void splitAndWriteThinLTOBitcode( Linkage = CFL_Declaration; Elts.push_back(ConstantAsMetadata::get( llvm::ConstantInt::get(Type::getInt8Ty(Ctx), Linkage))); - for (auto Type : Types) - Elts.push_back(Type); + append_range(Elts, Types); CfiFunctionMDs.push_back(MDTuple::get(Ctx, Elts)); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index 5a25f9857665..cf1ff405c493 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -59,7 +59,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Triple.h" #include "llvm/ADT/iterator_range.h" -#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TypeMetadataUtils.h" @@ -470,7 +470,7 @@ CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallBase &CB) { auto *CBType = dyn_cast<IntegerType>(CB.getType()); if (!CBType || CBType->getBitWidth() > 64 || CB.arg_empty()) return CSInfo; - for (auto &&Arg : make_range(CB.arg_begin() + 1, CB.arg_end())) { + for (auto &&Arg : drop_begin(CB.args())) { auto *CI = dyn_cast<ConstantInt>(Arg); if (!CI || CI->getBitWidth() > 64) return CSInfo; @@ -753,6 +753,11 @@ PreservedAnalyses WholeProgramDevirtPass::run(Module &M, auto LookupDomTree = [&FAM](Function &F) -> DominatorTree & { return FAM.getResult<DominatorTreeAnalysis>(F); }; + if (UseCommandLine) { + if (DevirtModule::runForTesting(M, AARGetter, OREGetter, LookupDomTree)) + return PreservedAnalyses::all(); + return PreservedAnalyses::none(); + } if (!DevirtModule(M, AARGetter, OREGetter, LookupDomTree, ExportSummary, ImportSummary) .run()) @@ -1025,6 +1030,10 @@ bool DevirtIndex::tryFindVirtualCallTargets( void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn, bool &IsExported) { + // Don't devirtualize function if we're told to skip it + // in -wholeprogramdevirt-skip. + if (FunctionsToSkip.match(TheFn->stripPointerCasts()->getName())) + return; auto Apply = [&](CallSiteInfo &CSInfo) { for (auto &&VCallSite : CSInfo.CallSites) { if (RemarksEnabled) @@ -1258,7 +1267,7 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, // Jump tables are only profitable if the retpoline mitigation is enabled. Attribute FSAttr = CB.getCaller()->getFnAttribute("target-features"); - if (FSAttr.hasAttribute(Attribute::None) || + if (!FSAttr.isValid() || !FSAttr.getValueAsString().contains("+retpoline")) continue; @@ -1270,8 +1279,7 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, // x86_64. std::vector<Type *> NewArgs; NewArgs.push_back(Int8PtrTy); - for (Type *T : CB.getFunctionType()->params()) - NewArgs.push_back(T); + append_range(NewArgs, CB.getFunctionType()->params()); FunctionType *NewFT = FunctionType::get(CB.getFunctionType()->getReturnType(), NewArgs, CB.getFunctionType()->isVarArg()); @@ -1280,7 +1288,7 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, IRBuilder<> IRB(&CB); std::vector<Value *> Args; Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy)); - Args.insert(Args.end(), CB.arg_begin(), CB.arg_end()); + llvm::append_range(Args, CB.args()); CallBase *NewCS = nullptr; if (isa<CallInst>(CB)) @@ -2205,6 +2213,4 @@ void DevirtIndex::run() { if (PrintSummaryDevirt) for (const auto &DT : DevirtTargets) errs() << "Devirtualized call to " << DT << "\n"; - - return; } |