diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp | 1175 |
1 files changed, 1175 insertions, 0 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp new file mode 100644 index 000000000000..7998a1ae5c6e --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -0,0 +1,1175 @@ +//===- ArgumentPromotion.cpp - Promote by-reference arguments -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass promotes "by reference" arguments to be "by value" arguments. In +// practice, this means looking for internal functions that have pointer +// arguments. If it can prove, through the use of alias analysis, that an +// argument is *only* loaded, then it can pass the value into the function +// instead of the address of the value. This can cause recursive simplification +// of code and lead to the elimination of allocas (especially in C++ template +// code like the STL). +// +// This pass also handles aggregate arguments that are passed into a function, +// scalarizing them if the elements of the aggregate are only loaded. Note that +// by default it refuses to scalarize aggregates which would require passing in +// more than three operands to the function, because passing thousands of +// operands for a large array or structure is unprofitable! This limit can be +// configured or disabled, however. +// +// Note that this transformation could also be done for arguments that are only +// stored to (returning the value instead), but does not currently. This case +// would be best handled when and if LLVM begins supporting multiple return +// values from functions. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO/ArgumentPromotion.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/CGSCCPassManager.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/Analysis/CallGraphSCCPass.h" +#include "llvm/Analysis/LazyCallGraph.h" +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/NoFolder.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <functional> +#include <iterator> +#include <map> +#include <set> +#include <string> +#include <utility> +#include <vector> + +using namespace llvm; + +#define DEBUG_TYPE "argpromotion" + +STATISTIC(NumArgumentsPromoted, "Number of pointer arguments promoted"); +STATISTIC(NumAggregatesPromoted, "Number of aggregate arguments promoted"); +STATISTIC(NumByValArgsPromoted, "Number of byval arguments promoted"); +STATISTIC(NumArgumentsDead, "Number of dead pointer args eliminated"); + +/// A vector used to hold the indices of a single GEP instruction +using IndicesVector = std::vector<uint64_t>; + +/// DoPromotion - This method actually performs the promotion of the specified +/// arguments, and returns the new function. At this point, we know that it's +/// safe to do so. +static Function * +doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, + SmallPtrSetImpl<Argument *> &ByValArgsToTransform, + Optional<function_ref<void(CallBase &OldCS, CallBase &NewCS)>> + ReplaceCallSite) { + // Start by computing a new prototype for the function, which is the same as + // the old function, but has modified arguments. + FunctionType *FTy = F->getFunctionType(); + std::vector<Type *> Params; + + using ScalarizeTable = std::set<std::pair<Type *, IndicesVector>>; + + // ScalarizedElements - If we are promoting a pointer that has elements + // accessed out of it, keep track of which elements are accessed so that we + // can add one argument for each. + // + // Arguments that are directly loaded will have a zero element value here, to + // handle cases where there are both a direct load and GEP accesses. + std::map<Argument *, ScalarizeTable> ScalarizedElements; + + // OriginalLoads - Keep track of a representative load instruction from the + // original function so that we can tell the alias analysis implementation + // what the new GEP/Load instructions we are inserting look like. + // We need to keep the original loads for each argument and the elements + // of the argument that are accessed. + std::map<std::pair<Argument *, IndicesVector>, LoadInst *> OriginalLoads; + + // Attribute - Keep track of the parameter attributes for the arguments + // that we are *not* promoting. For the ones that we do promote, the parameter + // attributes are lost + SmallVector<AttributeSet, 8> ArgAttrVec; + AttributeList PAL = F->getAttributes(); + + // First, determine the new argument list + unsigned ArgNo = 0; + for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E; + ++I, ++ArgNo) { + if (ByValArgsToTransform.count(&*I)) { + // Simple byval argument? Just add all the struct element types. + Type *AgTy = cast<PointerType>(I->getType())->getElementType(); + StructType *STy = cast<StructType>(AgTy); + llvm::append_range(Params, STy->elements()); + ArgAttrVec.insert(ArgAttrVec.end(), STy->getNumElements(), + AttributeSet()); + ++NumByValArgsPromoted; + } else if (!ArgsToPromote.count(&*I)) { + // Unchanged argument + Params.push_back(I->getType()); + ArgAttrVec.push_back(PAL.getParamAttributes(ArgNo)); + } else if (I->use_empty()) { + // Dead argument (which are always marked as promotable) + ++NumArgumentsDead; + } else { + // Okay, this is being promoted. This means that the only uses are loads + // or GEPs which are only used by loads + + // In this table, we will track which indices are loaded from the argument + // (where direct loads are tracked as no indices). + ScalarizeTable &ArgIndices = ScalarizedElements[&*I]; + for (User *U : make_early_inc_range(I->users())) { + Instruction *UI = cast<Instruction>(U); + Type *SrcTy; + if (LoadInst *L = dyn_cast<LoadInst>(UI)) + SrcTy = L->getType(); + else + SrcTy = cast<GetElementPtrInst>(UI)->getSourceElementType(); + // Skip dead GEPs and remove them. + if (isa<GetElementPtrInst>(UI) && UI->use_empty()) { + UI->eraseFromParent(); + continue; + } + + IndicesVector Indices; + Indices.reserve(UI->getNumOperands() - 1); + // Since loads will only have a single operand, and GEPs only a single + // non-index operand, this will record direct loads without any indices, + // and gep+loads with the GEP indices. + for (User::op_iterator II = UI->op_begin() + 1, IE = UI->op_end(); + II != IE; ++II) + Indices.push_back(cast<ConstantInt>(*II)->getSExtValue()); + // GEPs with a single 0 index can be merged with direct loads + if (Indices.size() == 1 && Indices.front() == 0) + Indices.clear(); + ArgIndices.insert(std::make_pair(SrcTy, Indices)); + LoadInst *OrigLoad; + if (LoadInst *L = dyn_cast<LoadInst>(UI)) + OrigLoad = L; + else + // Take any load, we will use it only to update Alias Analysis + OrigLoad = cast<LoadInst>(UI->user_back()); + OriginalLoads[std::make_pair(&*I, Indices)] = OrigLoad; + } + + // Add a parameter to the function for each element passed in. + for (const auto &ArgIndex : ArgIndices) { + // not allowed to dereference ->begin() if size() is 0 + Params.push_back(GetElementPtrInst::getIndexedType( + cast<PointerType>(I->getType())->getElementType(), + ArgIndex.second)); + ArgAttrVec.push_back(AttributeSet()); + assert(Params.back()); + } + + if (ArgIndices.size() == 1 && ArgIndices.begin()->second.empty()) + ++NumArgumentsPromoted; + else + ++NumAggregatesPromoted; + } + } + + Type *RetTy = FTy->getReturnType(); + + // Construct the new function type using the new arguments. + FunctionType *NFTy = FunctionType::get(RetTy, Params, FTy->isVarArg()); + + // Create the new function body and insert it into the module. + Function *NF = Function::Create(NFTy, F->getLinkage(), F->getAddressSpace(), + F->getName()); + NF->copyAttributesFrom(F); + NF->copyMetadata(F, 0); + + // The new function will have the !dbg metadata copied from the original + // function. The original function may not be deleted, and dbg metadata need + // to be unique so we need to drop it. + F->setSubprogram(nullptr); + + LLVM_DEBUG(dbgs() << "ARG PROMOTION: Promoting to:" << *NF << "\n" + << "From: " << *F); + + // Recompute the parameter attributes list based on the new arguments for + // the function. + NF->setAttributes(AttributeList::get(F->getContext(), PAL.getFnAttributes(), + PAL.getRetAttributes(), ArgAttrVec)); + ArgAttrVec.clear(); + + F->getParent()->getFunctionList().insert(F->getIterator(), NF); + NF->takeName(F); + + // Loop over all of the callers of the function, transforming the call sites + // to pass in the loaded pointers. + // + SmallVector<Value *, 16> Args; + while (!F->use_empty()) { + CallBase &CB = cast<CallBase>(*F->user_back()); + assert(CB.getCalledFunction() == F); + const AttributeList &CallPAL = CB.getAttributes(); + IRBuilder<NoFolder> IRB(&CB); + + // Loop over the operands, inserting GEP and loads in the caller as + // appropriate. + auto AI = CB.arg_begin(); + ArgNo = 0; + for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E; + ++I, ++AI, ++ArgNo) + if (!ArgsToPromote.count(&*I) && !ByValArgsToTransform.count(&*I)) { + Args.push_back(*AI); // Unmodified argument + ArgAttrVec.push_back(CallPAL.getParamAttributes(ArgNo)); + } else if (ByValArgsToTransform.count(&*I)) { + // Emit a GEP and load for each element of the struct. + Type *AgTy = cast<PointerType>(I->getType())->getElementType(); + StructType *STy = cast<StructType>(AgTy); + Value *Idxs[2] = { + ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), nullptr}; + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { + Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i); + auto *Idx = + IRB.CreateGEP(STy, *AI, Idxs, (*AI)->getName() + "." + Twine(i)); + // TODO: Tell AA about the new values? + Args.push_back(IRB.CreateLoad(STy->getElementType(i), Idx, + Idx->getName() + ".val")); + ArgAttrVec.push_back(AttributeSet()); + } + } else if (!I->use_empty()) { + // Non-dead argument: insert GEPs and loads as appropriate. + ScalarizeTable &ArgIndices = ScalarizedElements[&*I]; + // Store the Value* version of the indices in here, but declare it now + // for reuse. + std::vector<Value *> Ops; + for (const auto &ArgIndex : ArgIndices) { + Value *V = *AI; + LoadInst *OrigLoad = + OriginalLoads[std::make_pair(&*I, ArgIndex.second)]; + if (!ArgIndex.second.empty()) { + Ops.reserve(ArgIndex.second.size()); + Type *ElTy = V->getType(); + for (auto II : ArgIndex.second) { + // Use i32 to index structs, and i64 for others (pointers/arrays). + // This satisfies GEP constraints. + Type *IdxTy = + (ElTy->isStructTy() ? Type::getInt32Ty(F->getContext()) + : Type::getInt64Ty(F->getContext())); + Ops.push_back(ConstantInt::get(IdxTy, II)); + // Keep track of the type we're currently indexing. + if (auto *ElPTy = dyn_cast<PointerType>(ElTy)) + ElTy = ElPTy->getElementType(); + else + ElTy = GetElementPtrInst::getTypeAtIndex(ElTy, II); + } + // And create a GEP to extract those indices. + V = IRB.CreateGEP(ArgIndex.first, V, Ops, V->getName() + ".idx"); + Ops.clear(); + } + // Since we're replacing a load make sure we take the alignment + // of the previous load. + LoadInst *newLoad = + IRB.CreateLoad(OrigLoad->getType(), V, V->getName() + ".val"); + newLoad->setAlignment(OrigLoad->getAlign()); + // Transfer the AA info too. + AAMDNodes AAInfo; + OrigLoad->getAAMetadata(AAInfo); + newLoad->setAAMetadata(AAInfo); + + Args.push_back(newLoad); + ArgAttrVec.push_back(AttributeSet()); + } + } + + // Push any varargs arguments on the list. + for (; AI != CB.arg_end(); ++AI, ++ArgNo) { + Args.push_back(*AI); + ArgAttrVec.push_back(CallPAL.getParamAttributes(ArgNo)); + } + + SmallVector<OperandBundleDef, 1> OpBundles; + CB.getOperandBundlesAsDefs(OpBundles); + + CallBase *NewCS = nullptr; + if (InvokeInst *II = dyn_cast<InvokeInst>(&CB)) { + NewCS = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(), + Args, OpBundles, "", &CB); + } else { + auto *NewCall = CallInst::Create(NF, Args, OpBundles, "", &CB); + NewCall->setTailCallKind(cast<CallInst>(&CB)->getTailCallKind()); + NewCS = NewCall; + } + NewCS->setCallingConv(CB.getCallingConv()); + NewCS->setAttributes( + AttributeList::get(F->getContext(), CallPAL.getFnAttributes(), + CallPAL.getRetAttributes(), ArgAttrVec)); + NewCS->copyMetadata(CB, {LLVMContext::MD_prof, LLVMContext::MD_dbg}); + Args.clear(); + ArgAttrVec.clear(); + + // Update the callgraph to know that the callsite has been transformed. + if (ReplaceCallSite) + (*ReplaceCallSite)(CB, *NewCS); + + if (!CB.use_empty()) { + CB.replaceAllUsesWith(NewCS); + NewCS->takeName(&CB); + } + + // Finally, remove the old call from the program, reducing the use-count of + // F. + CB.eraseFromParent(); + } + + const DataLayout &DL = F->getParent()->getDataLayout(); + + // Since we have now created the new function, splice the body of the old + // function right into the new function, leaving the old rotting hulk of the + // function empty. + NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList()); + + // Loop over the argument list, transferring uses of the old arguments over to + // the new arguments, also transferring over the names as well. + for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(), + I2 = NF->arg_begin(); + I != E; ++I) { + if (!ArgsToPromote.count(&*I) && !ByValArgsToTransform.count(&*I)) { + // If this is an unmodified argument, move the name and users over to the + // new version. + I->replaceAllUsesWith(&*I2); + I2->takeName(&*I); + ++I2; + continue; + } + + if (ByValArgsToTransform.count(&*I)) { + // In the callee, we create an alloca, and store each of the new incoming + // arguments into the alloca. + Instruction *InsertPt = &NF->begin()->front(); + + // Just add all the struct element types. + Type *AgTy = cast<PointerType>(I->getType())->getElementType(); + Value *TheAlloca = new AllocaInst( + AgTy, DL.getAllocaAddrSpace(), nullptr, + I->getParamAlign().getValueOr(DL.getPrefTypeAlign(AgTy)), "", + InsertPt); + StructType *STy = cast<StructType>(AgTy); + Value *Idxs[2] = {ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), + nullptr}; + + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { + Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i); + Value *Idx = GetElementPtrInst::Create( + AgTy, TheAlloca, Idxs, TheAlloca->getName() + "." + Twine(i), + InsertPt); + I2->setName(I->getName() + "." + Twine(i)); + new StoreInst(&*I2++, Idx, InsertPt); + } + + // Anything that used the arg should now use the alloca. + I->replaceAllUsesWith(TheAlloca); + TheAlloca->takeName(&*I); + + // If the alloca is used in a call, we must clear the tail flag since + // the callee now uses an alloca from the caller. + for (User *U : TheAlloca->users()) { + CallInst *Call = dyn_cast<CallInst>(U); + if (!Call) + continue; + Call->setTailCall(false); + } + continue; + } + + // There potentially are metadata uses for things like llvm.dbg.value. + // Replace them with undef, after handling the other regular uses. + auto RauwUndefMetadata = make_scope_exit( + [&]() { I->replaceAllUsesWith(UndefValue::get(I->getType())); }); + + if (I->use_empty()) + continue; + + // Otherwise, if we promoted this argument, then all users are load + // instructions (or GEPs with only load users), and all loads should be + // using the new argument that we added. + ScalarizeTable &ArgIndices = ScalarizedElements[&*I]; + + while (!I->use_empty()) { + if (LoadInst *LI = dyn_cast<LoadInst>(I->user_back())) { + assert(ArgIndices.begin()->second.empty() && + "Load element should sort to front!"); + I2->setName(I->getName() + ".val"); + LI->replaceAllUsesWith(&*I2); + LI->eraseFromParent(); + LLVM_DEBUG(dbgs() << "*** Promoted load of argument '" << I->getName() + << "' in function '" << F->getName() << "'\n"); + } else { + GetElementPtrInst *GEP = cast<GetElementPtrInst>(I->user_back()); + assert(!GEP->use_empty() && + "GEPs without uses should be cleaned up already"); + IndicesVector Operands; + Operands.reserve(GEP->getNumIndices()); + for (User::op_iterator II = GEP->idx_begin(), IE = GEP->idx_end(); + II != IE; ++II) + Operands.push_back(cast<ConstantInt>(*II)->getSExtValue()); + + // GEPs with a single 0 index can be merged with direct loads + if (Operands.size() == 1 && Operands.front() == 0) + Operands.clear(); + + Function::arg_iterator TheArg = I2; + for (ScalarizeTable::iterator It = ArgIndices.begin(); + It->second != Operands; ++It, ++TheArg) { + assert(It != ArgIndices.end() && "GEP not handled??"); + } + + TheArg->setName(formatv("{0}.{1:$[.]}.val", I->getName(), + make_range(Operands.begin(), Operands.end()))); + + LLVM_DEBUG(dbgs() << "*** Promoted agg argument '" << TheArg->getName() + << "' of function '" << NF->getName() << "'\n"); + + // All of the uses must be load instructions. Replace them all with + // the argument specified by ArgNo. + while (!GEP->use_empty()) { + LoadInst *L = cast<LoadInst>(GEP->user_back()); + L->replaceAllUsesWith(&*TheArg); + L->eraseFromParent(); + } + GEP->eraseFromParent(); + } + } + // Increment I2 past all of the arguments added for this promoted pointer. + std::advance(I2, ArgIndices.size()); + } + + return NF; +} + +/// Return true if we can prove that all callees pass in a valid pointer for the +/// specified function argument. +static bool allCallersPassValidPointerForArgument(Argument *Arg, Type *Ty) { + Function *Callee = Arg->getParent(); + const DataLayout &DL = Callee->getParent()->getDataLayout(); + + unsigned ArgNo = Arg->getArgNo(); + + // Look at all call sites of the function. At this point we know we only have + // direct callees. + for (User *U : Callee->users()) { + CallBase &CB = cast<CallBase>(*U); + + if (!isDereferenceablePointer(CB.getArgOperand(ArgNo), Ty, DL)) + return false; + } + return true; +} + +/// Returns true if Prefix is a prefix of longer. That means, Longer has a size +/// that is greater than or equal to the size of prefix, and each of the +/// elements in Prefix is the same as the corresponding elements in Longer. +/// +/// This means it also returns true when Prefix and Longer are equal! +static bool isPrefix(const IndicesVector &Prefix, const IndicesVector &Longer) { + if (Prefix.size() > Longer.size()) + return false; + return std::equal(Prefix.begin(), Prefix.end(), Longer.begin()); +} + +/// Checks if Indices, or a prefix of Indices, is in Set. +static bool prefixIn(const IndicesVector &Indices, + std::set<IndicesVector> &Set) { + std::set<IndicesVector>::iterator Low; + Low = Set.upper_bound(Indices); + if (Low != Set.begin()) + Low--; + // Low is now the last element smaller than or equal to Indices. This means + // it points to a prefix of Indices (possibly Indices itself), if such + // prefix exists. + // + // This load is safe if any prefix of its operands is safe to load. + return Low != Set.end() && isPrefix(*Low, Indices); +} + +/// Mark the given indices (ToMark) as safe in the given set of indices +/// (Safe). Marking safe usually means adding ToMark to Safe. However, if there +/// is already a prefix of Indices in Safe, Indices are implicitely marked safe +/// already. Furthermore, any indices that Indices is itself a prefix of, are +/// removed from Safe (since they are implicitely safe because of Indices now). +static void markIndicesSafe(const IndicesVector &ToMark, + std::set<IndicesVector> &Safe) { + std::set<IndicesVector>::iterator Low; + Low = Safe.upper_bound(ToMark); + // Guard against the case where Safe is empty + if (Low != Safe.begin()) + Low--; + // Low is now the last element smaller than or equal to Indices. This + // means it points to a prefix of Indices (possibly Indices itself), if + // such prefix exists. + if (Low != Safe.end()) { + if (isPrefix(*Low, ToMark)) + // If there is already a prefix of these indices (or exactly these + // indices) marked a safe, don't bother adding these indices + return; + + // Increment Low, so we can use it as a "insert before" hint + ++Low; + } + // Insert + Low = Safe.insert(Low, ToMark); + ++Low; + // If there we're a prefix of longer index list(s), remove those + std::set<IndicesVector>::iterator End = Safe.end(); + while (Low != End && isPrefix(ToMark, *Low)) { + std::set<IndicesVector>::iterator Remove = Low; + ++Low; + Safe.erase(Remove); + } +} + +/// isSafeToPromoteArgument - As you might guess from the name of this method, +/// it checks to see if it is both safe and useful to promote the argument. +/// This method limits promotion of aggregates to only promote up to three +/// elements of the aggregate in order to avoid exploding the number of +/// arguments passed in. +static bool isSafeToPromoteArgument(Argument *Arg, Type *ByValTy, AAResults &AAR, + unsigned MaxElements) { + using GEPIndicesSet = std::set<IndicesVector>; + + // Quick exit for unused arguments + if (Arg->use_empty()) + return true; + + // We can only promote this argument if all of the uses are loads, or are GEP + // instructions (with constant indices) that are subsequently loaded. + // + // Promoting the argument causes it to be loaded in the caller + // unconditionally. This is only safe if we can prove that either the load + // would have happened in the callee anyway (ie, there is a load in the entry + // block) or the pointer passed in at every call site is guaranteed to be + // valid. + // In the former case, invalid loads can happen, but would have happened + // anyway, in the latter case, invalid loads won't happen. This prevents us + // from introducing an invalid load that wouldn't have happened in the + // original code. + // + // This set will contain all sets of indices that are loaded in the entry + // block, and thus are safe to unconditionally load in the caller. + GEPIndicesSet SafeToUnconditionallyLoad; + + // This set contains all the sets of indices that we are planning to promote. + // This makes it possible to limit the number of arguments added. + GEPIndicesSet ToPromote; + + // If the pointer is always valid, any load with first index 0 is valid. + + if (ByValTy) + SafeToUnconditionallyLoad.insert(IndicesVector(1, 0)); + + // Whenever a new underlying type for the operand is found, make sure it's + // consistent with the GEPs and loads we've already seen and, if necessary, + // use it to see if all incoming pointers are valid (which implies the 0-index + // is safe). + Type *BaseTy = ByValTy; + auto UpdateBaseTy = [&](Type *NewBaseTy) { + if (BaseTy) + return BaseTy == NewBaseTy; + + BaseTy = NewBaseTy; + if (allCallersPassValidPointerForArgument(Arg, BaseTy)) { + assert(SafeToUnconditionallyLoad.empty()); + SafeToUnconditionallyLoad.insert(IndicesVector(1, 0)); + } + + return true; + }; + + // First, iterate the entry block and mark loads of (geps of) arguments as + // safe. + BasicBlock &EntryBlock = Arg->getParent()->front(); + // Declare this here so we can reuse it + IndicesVector Indices; + for (Instruction &I : EntryBlock) + if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { + Value *V = LI->getPointerOperand(); + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) { + V = GEP->getPointerOperand(); + if (V == Arg) { + // This load actually loads (part of) Arg? Check the indices then. + Indices.reserve(GEP->getNumIndices()); + for (User::op_iterator II = GEP->idx_begin(), IE = GEP->idx_end(); + II != IE; ++II) + if (ConstantInt *CI = dyn_cast<ConstantInt>(*II)) + Indices.push_back(CI->getSExtValue()); + else + // We found a non-constant GEP index for this argument? Bail out + // right away, can't promote this argument at all. + return false; + + if (!UpdateBaseTy(GEP->getSourceElementType())) + return false; + + // Indices checked out, mark them as safe + markIndicesSafe(Indices, SafeToUnconditionallyLoad); + Indices.clear(); + } + } else if (V == Arg) { + // Direct loads are equivalent to a GEP with a single 0 index. + markIndicesSafe(IndicesVector(1, 0), SafeToUnconditionallyLoad); + + if (BaseTy && LI->getType() != BaseTy) + return false; + + BaseTy = LI->getType(); + } + } + + // Now, iterate all uses of the argument to see if there are any uses that are + // not (GEP+)loads, or any (GEP+)loads that are not safe to promote. + SmallVector<LoadInst *, 16> Loads; + IndicesVector Operands; + for (Use &U : Arg->uses()) { + User *UR = U.getUser(); + Operands.clear(); + if (LoadInst *LI = dyn_cast<LoadInst>(UR)) { + // Don't hack volatile/atomic loads + if (!LI->isSimple()) + return false; + Loads.push_back(LI); + // Direct loads are equivalent to a GEP with a zero index and then a load. + Operands.push_back(0); + + if (!UpdateBaseTy(LI->getType())) + return false; + } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(UR)) { + if (GEP->use_empty()) { + // Dead GEP's cause trouble later. Just remove them if we run into + // them. + continue; + } + + if (!UpdateBaseTy(GEP->getSourceElementType())) + return false; + + // Ensure that all of the indices are constants. + for (User::op_iterator i = GEP->idx_begin(), e = GEP->idx_end(); i != e; + ++i) + if (ConstantInt *C = dyn_cast<ConstantInt>(*i)) + Operands.push_back(C->getSExtValue()); + else + return false; // Not a constant operand GEP! + + // Ensure that the only users of the GEP are load instructions. + for (User *GEPU : GEP->users()) + if (LoadInst *LI = dyn_cast<LoadInst>(GEPU)) { + // Don't hack volatile/atomic loads + if (!LI->isSimple()) + return false; + Loads.push_back(LI); + } else { + // Other uses than load? + return false; + } + } else { + return false; // Not a load or a GEP. + } + + // Now, see if it is safe to promote this load / loads of this GEP. Loading + // is safe if Operands, or a prefix of Operands, is marked as safe. + if (!prefixIn(Operands, SafeToUnconditionallyLoad)) + return false; + + // See if we are already promoting a load with these indices. If not, check + // to make sure that we aren't promoting too many elements. If so, nothing + // to do. + if (ToPromote.find(Operands) == ToPromote.end()) { + if (MaxElements > 0 && ToPromote.size() == MaxElements) { + LLVM_DEBUG(dbgs() << "argpromotion not promoting argument '" + << Arg->getName() + << "' because it would require adding more " + << "than " << MaxElements + << " arguments to the function.\n"); + // We limit aggregate promotion to only promoting up to a fixed number + // of elements of the aggregate. + return false; + } + ToPromote.insert(std::move(Operands)); + } + } + + if (Loads.empty()) + return true; // No users, this is a dead argument. + + // Okay, now we know that the argument is only used by load instructions and + // it is safe to unconditionally perform all of them. Use alias analysis to + // check to see if the pointer is guaranteed to not be modified from entry of + // the function to each of the load instructions. + + // Because there could be several/many load instructions, remember which + // blocks we know to be transparent to the load. + df_iterator_default_set<BasicBlock *, 16> TranspBlocks; + + for (LoadInst *Load : Loads) { + // Check to see if the load is invalidated from the start of the block to + // the load itself. + BasicBlock *BB = Load->getParent(); + + MemoryLocation Loc = MemoryLocation::get(Load); + if (AAR.canInstructionRangeModRef(BB->front(), *Load, Loc, ModRefInfo::Mod)) + return false; // Pointer is invalidated! + + // Now check every path from the entry block to the load for transparency. + // To do this, we perform a depth first search on the inverse CFG from the + // loading block. + for (BasicBlock *P : predecessors(BB)) { + for (BasicBlock *TranspBB : inverse_depth_first_ext(P, TranspBlocks)) + if (AAR.canBasicBlockModify(*TranspBB, Loc)) + return false; + } + } + + // If the path from the entry of the function to each load is free of + // instructions that potentially invalidate the load, we can make the + // transformation! + return true; +} + +bool ArgumentPromotionPass::isDenselyPacked(Type *type, const DataLayout &DL) { + // There is no size information, so be conservative. + if (!type->isSized()) + return false; + + // If the alloc size is not equal to the storage size, then there are padding + // bytes. For x86_fp80 on x86-64, size: 80 alloc size: 128. + if (DL.getTypeSizeInBits(type) != DL.getTypeAllocSizeInBits(type)) + return false; + + // FIXME: This isn't the right way to check for padding in vectors with + // non-byte-size elements. + if (VectorType *seqTy = dyn_cast<VectorType>(type)) + return isDenselyPacked(seqTy->getElementType(), DL); + + // For array types, check for padding within members. + if (ArrayType *seqTy = dyn_cast<ArrayType>(type)) + return isDenselyPacked(seqTy->getElementType(), DL); + + if (!isa<StructType>(type)) + return true; + + // Check for padding within and between elements of a struct. + StructType *StructTy = cast<StructType>(type); + const StructLayout *Layout = DL.getStructLayout(StructTy); + uint64_t StartPos = 0; + for (unsigned i = 0, E = StructTy->getNumElements(); i < E; ++i) { + Type *ElTy = StructTy->getElementType(i); + if (!isDenselyPacked(ElTy, DL)) + return false; + if (StartPos != Layout->getElementOffsetInBits(i)) + return false; + StartPos += DL.getTypeAllocSizeInBits(ElTy); + } + + return true; +} + +/// Checks if the padding bytes of an argument could be accessed. +static bool canPaddingBeAccessed(Argument *arg) { + assert(arg->hasByValAttr()); + + // Track all the pointers to the argument to make sure they are not captured. + SmallPtrSet<Value *, 16> PtrValues; + PtrValues.insert(arg); + + // Track all of the stores. + SmallVector<StoreInst *, 16> Stores; + + // Scan through the uses recursively to make sure the pointer is always used + // sanely. + SmallVector<Value *, 16> WorkList(arg->users()); + while (!WorkList.empty()) { + Value *V = WorkList.pop_back_val(); + if (isa<GetElementPtrInst>(V) || isa<PHINode>(V)) { + if (PtrValues.insert(V).second) + llvm::append_range(WorkList, V->users()); + } else if (StoreInst *Store = dyn_cast<StoreInst>(V)) { + Stores.push_back(Store); + } else if (!isa<LoadInst>(V)) { + return true; + } + } + + // Check to make sure the pointers aren't captured + for (StoreInst *Store : Stores) + if (PtrValues.count(Store->getValueOperand())) + return true; + + return false; +} + +bool ArgumentPromotionPass::areFunctionArgsABICompatible( + const Function &F, const TargetTransformInfo &TTI, + SmallPtrSetImpl<Argument *> &ArgsToPromote, + SmallPtrSetImpl<Argument *> &ByValArgsToTransform) { + for (const Use &U : F.uses()) { + CallBase *CB = dyn_cast<CallBase>(U.getUser()); + if (!CB) + return false; + const Function *Caller = CB->getCaller(); + const Function *Callee = CB->getCalledFunction(); + if (!TTI.areFunctionArgsABICompatible(Caller, Callee, ArgsToPromote) || + !TTI.areFunctionArgsABICompatible(Caller, Callee, ByValArgsToTransform)) + return false; + } + return true; +} + +/// PromoteArguments - This method checks the specified function to see if there +/// are any promotable arguments and if it is safe to promote the function (for +/// example, all callers are direct). If safe to promote some arguments, it +/// calls the DoPromotion method. +static Function * +promoteArguments(Function *F, function_ref<AAResults &(Function &F)> AARGetter, + unsigned MaxElements, + Optional<function_ref<void(CallBase &OldCS, CallBase &NewCS)>> + ReplaceCallSite, + const TargetTransformInfo &TTI) { + // Don't perform argument promotion for naked functions; otherwise we can end + // up removing parameters that are seemingly 'not used' as they are referred + // to in the assembly. + if(F->hasFnAttribute(Attribute::Naked)) + return nullptr; + + // Make sure that it is local to this module. + if (!F->hasLocalLinkage()) + return nullptr; + + // Don't promote arguments for variadic functions. Adding, removing, or + // changing non-pack parameters can change the classification of pack + // parameters. Frontends encode that classification at the call site in the + // IR, while in the callee the classification is determined dynamically based + // on the number of registers consumed so far. + if (F->isVarArg()) + return nullptr; + + // Don't transform functions that receive inallocas, as the transformation may + // not be safe depending on calling convention. + if (F->getAttributes().hasAttrSomewhere(Attribute::InAlloca)) + return nullptr; + + // First check: see if there are any pointer arguments! If not, quick exit. + SmallVector<Argument *, 16> PointerArgs; + for (Argument &I : F->args()) + if (I.getType()->isPointerTy()) + PointerArgs.push_back(&I); + if (PointerArgs.empty()) + return nullptr; + + // Second check: make sure that all callers are direct callers. We can't + // transform functions that have indirect callers. Also see if the function + // is self-recursive and check that target features are compatible. + bool isSelfRecursive = false; + for (Use &U : F->uses()) { + CallBase *CB = dyn_cast<CallBase>(U.getUser()); + // Must be a direct call. + if (CB == nullptr || !CB->isCallee(&U)) + return nullptr; + + // Can't change signature of musttail callee + if (CB->isMustTailCall()) + return nullptr; + + if (CB->getParent()->getParent() == F) + isSelfRecursive = true; + } + + // Can't change signature of musttail caller + // FIXME: Support promoting whole chain of musttail functions + for (BasicBlock &BB : *F) + if (BB.getTerminatingMustTailCall()) + return nullptr; + + const DataLayout &DL = F->getParent()->getDataLayout(); + + AAResults &AAR = AARGetter(*F); + + // Check to see which arguments are promotable. If an argument is promotable, + // add it to ArgsToPromote. + SmallPtrSet<Argument *, 8> ArgsToPromote; + SmallPtrSet<Argument *, 8> ByValArgsToTransform; + for (Argument *PtrArg : PointerArgs) { + Type *AgTy = cast<PointerType>(PtrArg->getType())->getElementType(); + + // Replace sret attribute with noalias. This reduces register pressure by + // avoiding a register copy. + if (PtrArg->hasStructRetAttr()) { + unsigned ArgNo = PtrArg->getArgNo(); + F->removeParamAttr(ArgNo, Attribute::StructRet); + F->addParamAttr(ArgNo, Attribute::NoAlias); + for (Use &U : F->uses()) { + CallBase &CB = cast<CallBase>(*U.getUser()); + CB.removeParamAttr(ArgNo, Attribute::StructRet); + CB.addParamAttr(ArgNo, Attribute::NoAlias); + } + } + + // If this is a byval argument, and if the aggregate type is small, just + // pass the elements, which is always safe, if the passed value is densely + // packed or if we can prove the padding bytes are never accessed. + bool isSafeToPromote = PtrArg->hasByValAttr() && + (ArgumentPromotionPass::isDenselyPacked(AgTy, DL) || + !canPaddingBeAccessed(PtrArg)); + if (isSafeToPromote) { + if (StructType *STy = dyn_cast<StructType>(AgTy)) { + if (MaxElements > 0 && STy->getNumElements() > MaxElements) { + LLVM_DEBUG(dbgs() << "argpromotion disable promoting argument '" + << PtrArg->getName() + << "' because it would require adding more" + << " than " << MaxElements + << " arguments to the function.\n"); + continue; + } + + // If all the elements are single-value types, we can promote it. + bool AllSimple = true; + for (const auto *EltTy : STy->elements()) { + if (!EltTy->isSingleValueType()) { + AllSimple = false; + break; + } + } + + // Safe to transform, don't even bother trying to "promote" it. + // Passing the elements as a scalar will allow sroa to hack on + // the new alloca we introduce. + if (AllSimple) { + ByValArgsToTransform.insert(PtrArg); + continue; + } + } + } + + // If the argument is a recursive type and we're in a recursive + // function, we could end up infinitely peeling the function argument. + if (isSelfRecursive) { + if (StructType *STy = dyn_cast<StructType>(AgTy)) { + bool RecursiveType = false; + for (const auto *EltTy : STy->elements()) { + if (EltTy == PtrArg->getType()) { + RecursiveType = true; + break; + } + } + if (RecursiveType) + continue; + } + } + + // Otherwise, see if we can promote the pointer to its value. + Type *ByValTy = + PtrArg->hasByValAttr() ? PtrArg->getParamByValType() : nullptr; + if (isSafeToPromoteArgument(PtrArg, ByValTy, AAR, MaxElements)) + ArgsToPromote.insert(PtrArg); + } + + // No promotable pointer arguments. + if (ArgsToPromote.empty() && ByValArgsToTransform.empty()) + return nullptr; + + if (!ArgumentPromotionPass::areFunctionArgsABICompatible( + *F, TTI, ArgsToPromote, ByValArgsToTransform)) + return nullptr; + + return doPromotion(F, ArgsToPromote, ByValArgsToTransform, ReplaceCallSite); +} + +PreservedAnalyses ArgumentPromotionPass::run(LazyCallGraph::SCC &C, + CGSCCAnalysisManager &AM, + LazyCallGraph &CG, + CGSCCUpdateResult &UR) { + bool Changed = false, LocalChange; + + // Iterate until we stop promoting from this SCC. + do { + LocalChange = false; + + for (LazyCallGraph::Node &N : C) { + Function &OldF = N.getFunction(); + + FunctionAnalysisManager &FAM = + AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager(); + // FIXME: This lambda must only be used with this function. We should + // skip the lambda and just get the AA results directly. + auto AARGetter = [&](Function &F) -> AAResults & { + assert(&F == &OldF && "Called with an unexpected function!"); + return FAM.getResult<AAManager>(F); + }; + + const TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(OldF); + Function *NewF = + promoteArguments(&OldF, AARGetter, MaxElements, None, TTI); + if (!NewF) + continue; + LocalChange = true; + + // Directly substitute the functions in the call graph. Note that this + // requires the old function to be completely dead and completely + // replaced by the new function. It does no call graph updates, it merely + // swaps out the particular function mapped to a particular node in the + // graph. + C.getOuterRefSCC().replaceNodeFunction(N, *NewF); + OldF.eraseFromParent(); + } + + Changed |= LocalChange; + } while (LocalChange); + + if (!Changed) + return PreservedAnalyses::all(); + + return PreservedAnalyses::none(); +} + +namespace { + +/// ArgPromotion - The 'by reference' to 'by value' argument promotion pass. +struct ArgPromotion : public CallGraphSCCPass { + // Pass identification, replacement for typeid + static char ID; + + explicit ArgPromotion(unsigned MaxElements = 3) + : CallGraphSCCPass(ID), MaxElements(MaxElements) { + initializeArgPromotionPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + getAAResultsAnalysisUsage(AU); + CallGraphSCCPass::getAnalysisUsage(AU); + } + + bool runOnSCC(CallGraphSCC &SCC) override; + +private: + using llvm::Pass::doInitialization; + + bool doInitialization(CallGraph &CG) override; + + /// The maximum number of elements to expand, or 0 for unlimited. + unsigned MaxElements; +}; + +} // end anonymous namespace + +char ArgPromotion::ID = 0; + +INITIALIZE_PASS_BEGIN(ArgPromotion, "argpromotion", + "Promote 'by reference' arguments to scalars", false, + false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END(ArgPromotion, "argpromotion", + "Promote 'by reference' arguments to scalars", false, false) + +Pass *llvm::createArgumentPromotionPass(unsigned MaxElements) { + return new ArgPromotion(MaxElements); +} + +bool ArgPromotion::runOnSCC(CallGraphSCC &SCC) { + if (skipSCC(SCC)) + return false; + + // Get the callgraph information that we need to update to reflect our + // changes. + CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); + + LegacyAARGetter AARGetter(*this); + + bool Changed = false, LocalChange; + + // Iterate until we stop promoting from this SCC. + do { + LocalChange = false; + // Attempt to promote arguments from all functions in this SCC. + for (CallGraphNode *OldNode : SCC) { + Function *OldF = OldNode->getFunction(); + if (!OldF) + continue; + + auto ReplaceCallSite = [&](CallBase &OldCS, CallBase &NewCS) { + Function *Caller = OldCS.getParent()->getParent(); + CallGraphNode *NewCalleeNode = + CG.getOrInsertFunction(NewCS.getCalledFunction()); + CallGraphNode *CallerNode = CG[Caller]; + CallerNode->replaceCallEdge(cast<CallBase>(OldCS), + cast<CallBase>(NewCS), NewCalleeNode); + }; + + const TargetTransformInfo &TTI = + getAnalysis<TargetTransformInfoWrapperPass>().getTTI(*OldF); + if (Function *NewF = promoteArguments(OldF, AARGetter, MaxElements, + {ReplaceCallSite}, TTI)) { + LocalChange = true; + + // Update the call graph for the newly promoted function. + CallGraphNode *NewNode = CG.getOrInsertFunction(NewF); + NewNode->stealCalledFunctionsFrom(OldNode); + if (OldNode->getNumReferences() == 0) + delete CG.removeFunctionFromModule(OldNode); + else + OldF->setLinkage(Function::ExternalLinkage); + + // And updat ethe SCC we're iterating as well. + SCC.ReplaceNode(OldNode, NewNode); + } + } + // Remember that we changed something. + Changed |= LocalChange; + } while (LocalChange); + + return Changed; +} + +bool ArgPromotion::doInitialization(CallGraph &CG) { + return CallGraphSCCPass::doInitialization(CG); +} |