summaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/IPO/PartialInlining.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/IPO/PartialInlining.cpp')
-rw-r--r--llvm/lib/Transforms/IPO/PartialInlining.cpp165
1 files changed, 82 insertions, 83 deletions
diff --git a/llvm/lib/Transforms/IPO/PartialInlining.cpp b/llvm/lib/Transforms/IPO/PartialInlining.cpp
index cd3701e903080..5d863f1330a44 100644
--- a/llvm/lib/Transforms/IPO/PartialInlining.cpp
+++ b/llvm/lib/Transforms/IPO/PartialInlining.cpp
@@ -30,7 +30,6 @@
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CFG.h"
-#include "llvm/IR/CallSite.h"
#include "llvm/IR/DebugLoc.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/Dominators.h"
@@ -199,13 +198,14 @@ struct FunctionOutliningMultiRegionInfo {
struct PartialInlinerImpl {
PartialInlinerImpl(
- std::function<AssumptionCache &(Function &)> *GetAC,
+ function_ref<AssumptionCache &(Function &)> GetAC,
function_ref<AssumptionCache *(Function &)> LookupAC,
- std::function<TargetTransformInfo &(Function &)> *GTTI,
- Optional<function_ref<BlockFrequencyInfo &(Function &)>> GBFI,
- ProfileSummaryInfo *ProfSI)
+ function_ref<TargetTransformInfo &(Function &)> GTTI,
+ function_ref<const TargetLibraryInfo &(Function &)> GTLI,
+ ProfileSummaryInfo &ProfSI,
+ function_ref<BlockFrequencyInfo &(Function &)> GBFI = nullptr)
: GetAssumptionCache(GetAC), LookupAssumptionCache(LookupAC),
- GetTTI(GTTI), GetBFI(GBFI), PSI(ProfSI) {}
+ GetTTI(GTTI), GetBFI(GBFI), GetTLI(GTLI), PSI(ProfSI) {}
bool run(Module &M);
// Main part of the transformation that calls helper functions to find
@@ -270,11 +270,12 @@ struct PartialInlinerImpl {
private:
int NumPartialInlining = 0;
- std::function<AssumptionCache &(Function &)> *GetAssumptionCache;
+ function_ref<AssumptionCache &(Function &)> GetAssumptionCache;
function_ref<AssumptionCache *(Function &)> LookupAssumptionCache;
- std::function<TargetTransformInfo &(Function &)> *GetTTI;
- Optional<function_ref<BlockFrequencyInfo &(Function &)>> GetBFI;
- ProfileSummaryInfo *PSI;
+ function_ref<TargetTransformInfo &(Function &)> GetTTI;
+ function_ref<BlockFrequencyInfo &(Function &)> GetBFI;
+ function_ref<const TargetLibraryInfo &(Function &)> GetTLI;
+ ProfileSummaryInfo &PSI;
// Return the frequency of the OutlininingBB relative to F's entry point.
// The result is no larger than 1 and is represented using BP.
@@ -282,9 +283,9 @@ private:
// edges from the guarding entry blocks).
BranchProbability getOutliningCallBBRelativeFreq(FunctionCloner &Cloner);
- // Return true if the callee of CS should be partially inlined with
+ // Return true if the callee of CB should be partially inlined with
// profit.
- bool shouldPartialInline(CallSite CS, FunctionCloner &Cloner,
+ bool shouldPartialInline(CallBase &CB, FunctionCloner &Cloner,
BlockFrequency WeightedOutliningRcost,
OptimizationRemarkEmitter &ORE);
@@ -303,26 +304,22 @@ private:
NumPartialInlining >= MaxNumPartialInlining);
}
- static CallSite getCallSite(User *U) {
- CallSite CS;
- if (CallInst *CI = dyn_cast<CallInst>(U))
- CS = CallSite(CI);
- else if (InvokeInst *II = dyn_cast<InvokeInst>(U))
- CS = CallSite(II);
- else
- llvm_unreachable("All uses must be calls");
- return CS;
+ static CallBase *getSupportedCallBase(User *U) {
+ if (isa<CallInst>(U) || isa<InvokeInst>(U))
+ return cast<CallBase>(U);
+ llvm_unreachable("All uses must be calls");
+ return nullptr;
}
- static CallSite getOneCallSiteTo(Function *F) {
+ static CallBase *getOneCallSiteTo(Function *F) {
User *User = *F->user_begin();
- return getCallSite(User);
+ return getSupportedCallBase(User);
}
std::tuple<DebugLoc, BasicBlock *> getOneDebugLoc(Function *F) {
- CallSite CS = getOneCallSiteTo(F);
- DebugLoc DLoc = CS.getInstruction()->getDebugLoc();
- BasicBlock *Block = CS.getParent();
+ CallBase *CB = getOneCallSiteTo(F);
+ DebugLoc DLoc = CB->getDebugLoc();
+ BasicBlock *Block = CB->getParent();
return std::make_tuple(DLoc, Block);
}
@@ -355,6 +352,7 @@ struct PartialInlinerLegacyPass : public ModulePass {
AU.addRequired<AssumptionCacheTracker>();
AU.addRequired<ProfileSummaryInfoWrapperPass>();
AU.addRequired<TargetTransformInfoWrapperPass>();
+ AU.addRequired<TargetLibraryInfoWrapperPass>();
}
bool runOnModule(Module &M) override {
@@ -364,11 +362,10 @@ struct PartialInlinerLegacyPass : public ModulePass {
AssumptionCacheTracker *ACT = &getAnalysis<AssumptionCacheTracker>();
TargetTransformInfoWrapperPass *TTIWP =
&getAnalysis<TargetTransformInfoWrapperPass>();
- ProfileSummaryInfo *PSI =
- &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();
+ ProfileSummaryInfo &PSI =
+ getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();
- std::function<AssumptionCache &(Function &)> GetAssumptionCache =
- [&ACT](Function &F) -> AssumptionCache & {
+ auto GetAssumptionCache = [&ACT](Function &F) -> AssumptionCache & {
return ACT->getAssumptionCache(F);
};
@@ -376,13 +373,16 @@ struct PartialInlinerLegacyPass : public ModulePass {
return ACT->lookupAssumptionCache(F);
};
- std::function<TargetTransformInfo &(Function &)> GetTTI =
- [&TTIWP](Function &F) -> TargetTransformInfo & {
+ auto GetTTI = [&TTIWP](Function &F) -> TargetTransformInfo & {
return TTIWP->getTTI(F);
};
- return PartialInlinerImpl(&GetAssumptionCache, LookupAssumptionCache,
- &GetTTI, NoneType::None, PSI)
+ auto GetTLI = [this](Function &F) -> TargetLibraryInfo & {
+ return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
+ };
+
+ return PartialInlinerImpl(GetAssumptionCache, LookupAssumptionCache, GetTTI,
+ GetTLI, PSI)
.run(M);
}
};
@@ -403,10 +403,10 @@ PartialInlinerImpl::computeOutliningColdRegionsInfo(Function *F,
ScopedBFI.reset(new BlockFrequencyInfo(*F, BPI, LI));
BFI = ScopedBFI.get();
} else
- BFI = &(*GetBFI)(*F);
+ BFI = &(GetBFI(*F));
// Return if we don't have profiling information.
- if (!PSI->hasInstrumentationProfile())
+ if (!PSI.hasInstrumentationProfile())
return std::unique_ptr<FunctionOutliningMultiRegionInfo>();
std::unique_ptr<FunctionOutliningMultiRegionInfo> OutliningInfo =
@@ -479,7 +479,7 @@ PartialInlinerImpl::computeOutliningColdRegionsInfo(Function *F,
// Only consider regions with predecessor blocks that are considered
// not-cold (default: part of the top 99.99% of all block counters)
// AND greater than our minimum block execution count (default: 100).
- if (PSI->isColdBlock(thisBB, BFI) ||
+ if (PSI.isColdBlock(thisBB, BFI) ||
BBProfileCount(thisBB) < MinBlockCounterExecution)
continue;
for (auto SI = succ_begin(thisBB); SI != succ_end(thisBB); ++SI) {
@@ -759,31 +759,28 @@ PartialInlinerImpl::getOutliningCallBBRelativeFreq(FunctionCloner &Cloner) {
}
bool PartialInlinerImpl::shouldPartialInline(
- CallSite CS, FunctionCloner &Cloner,
- BlockFrequency WeightedOutliningRcost,
+ CallBase &CB, FunctionCloner &Cloner, BlockFrequency WeightedOutliningRcost,
OptimizationRemarkEmitter &ORE) {
using namespace ore;
- Instruction *Call = CS.getInstruction();
- Function *Callee = CS.getCalledFunction();
+ Function *Callee = CB.getCalledFunction();
assert(Callee == Cloner.ClonedFunc);
if (SkipCostAnalysis)
- return isInlineViable(*Callee);
+ return isInlineViable(*Callee).isSuccess();
- Function *Caller = CS.getCaller();
- auto &CalleeTTI = (*GetTTI)(*Callee);
+ Function *Caller = CB.getCaller();
+ auto &CalleeTTI = GetTTI(*Callee);
bool RemarksEnabled =
Callee->getContext().getDiagHandlerPtr()->isMissedOptRemarkEnabled(
DEBUG_TYPE);
- assert(Call && "invalid callsite for partial inline");
- InlineCost IC = getInlineCost(cast<CallBase>(*Call), getInlineParams(),
- CalleeTTI, *GetAssumptionCache, GetBFI, PSI,
- RemarksEnabled ? &ORE : nullptr);
+ InlineCost IC =
+ getInlineCost(CB, getInlineParams(), CalleeTTI, GetAssumptionCache,
+ GetTLI, GetBFI, &PSI, RemarksEnabled ? &ORE : nullptr);
if (IC.isAlways()) {
ORE.emit([&]() {
- return OptimizationRemarkAnalysis(DEBUG_TYPE, "AlwaysInline", Call)
+ return OptimizationRemarkAnalysis(DEBUG_TYPE, "AlwaysInline", &CB)
<< NV("Callee", Cloner.OrigFunc)
<< " should always be fully inlined, not partially";
});
@@ -792,7 +789,7 @@ bool PartialInlinerImpl::shouldPartialInline(
if (IC.isNever()) {
ORE.emit([&]() {
- return OptimizationRemarkMissed(DEBUG_TYPE, "NeverInline", Call)
+ return OptimizationRemarkMissed(DEBUG_TYPE, "NeverInline", &CB)
<< NV("Callee", Cloner.OrigFunc) << " not partially inlined into "
<< NV("Caller", Caller)
<< " because it should never be inlined (cost=never)";
@@ -802,7 +799,7 @@ bool PartialInlinerImpl::shouldPartialInline(
if (!IC) {
ORE.emit([&]() {
- return OptimizationRemarkAnalysis(DEBUG_TYPE, "TooCostly", Call)
+ return OptimizationRemarkAnalysis(DEBUG_TYPE, "TooCostly", &CB)
<< NV("Callee", Cloner.OrigFunc) << " not partially inlined into "
<< NV("Caller", Caller) << " because too costly to inline (cost="
<< NV("Cost", IC.getCost()) << ", threshold="
@@ -813,14 +810,14 @@ bool PartialInlinerImpl::shouldPartialInline(
const DataLayout &DL = Caller->getParent()->getDataLayout();
// The savings of eliminating the call:
- int NonWeightedSavings = getCallsiteCost(cast<CallBase>(*Call), DL);
+ int NonWeightedSavings = getCallsiteCost(CB, DL);
BlockFrequency NormWeightedSavings(NonWeightedSavings);
// Weighted saving is smaller than weighted cost, return false
if (NormWeightedSavings < WeightedOutliningRcost) {
ORE.emit([&]() {
return OptimizationRemarkAnalysis(DEBUG_TYPE, "OutliningCallcostTooHigh",
- Call)
+ &CB)
<< NV("Callee", Cloner.OrigFunc) << " not partially inlined into "
<< NV("Caller", Caller) << " runtime overhead (overhead="
<< NV("Overhead", (unsigned)WeightedOutliningRcost.getFrequency())
@@ -834,7 +831,7 @@ bool PartialInlinerImpl::shouldPartialInline(
}
ORE.emit([&]() {
- return OptimizationRemarkAnalysis(DEBUG_TYPE, "CanBePartiallyInlined", Call)
+ return OptimizationRemarkAnalysis(DEBUG_TYPE, "CanBePartiallyInlined", &CB)
<< NV("Callee", Cloner.OrigFunc) << " can be partially inlined into "
<< NV("Caller", Caller) << " with cost=" << NV("Cost", IC.getCost())
<< " (threshold="
@@ -941,20 +938,20 @@ void PartialInlinerImpl::computeCallsiteToProfCountMap(
CurrentCallerBFI = TempBFI.get();
} else {
// New pass manager:
- CurrentCallerBFI = &(*GetBFI)(*Caller);
+ CurrentCallerBFI = &(GetBFI(*Caller));
}
};
for (User *User : Users) {
- CallSite CS = getCallSite(User);
- Function *Caller = CS.getCaller();
+ CallBase *CB = getSupportedCallBase(User);
+ Function *Caller = CB->getCaller();
if (CurrentCaller != Caller) {
CurrentCaller = Caller;
ComputeCurrBFI(Caller);
} else {
assert(CurrentCallerBFI && "CallerBFI is not set");
}
- BasicBlock *CallBB = CS.getInstruction()->getParent();
+ BasicBlock *CallBB = CB->getParent();
auto Count = CurrentCallerBFI->getBlockProfileCount(CallBB);
if (Count)
CallSiteToProfCountMap[User] = *Count;
@@ -1155,8 +1152,8 @@ bool PartialInlinerImpl::FunctionCloner::doMultiRegionFunctionOutlining() {
Function *OutlinedFunc = CE.extractCodeRegion(CEAC);
if (OutlinedFunc) {
- CallSite OCS = PartialInlinerImpl::getOneCallSiteTo(OutlinedFunc);
- BasicBlock *OutliningCallBB = OCS.getInstruction()->getParent();
+ CallBase *OCS = PartialInlinerImpl::getOneCallSiteTo(OutlinedFunc);
+ BasicBlock *OutliningCallBB = OCS->getParent();
assert(OutliningCallBB->getParent() == ClonedFunc);
OutlinedFunctions.push_back(std::make_pair(OutlinedFunc,OutliningCallBB));
NumColdRegionsOutlined++;
@@ -1164,7 +1161,7 @@ bool PartialInlinerImpl::FunctionCloner::doMultiRegionFunctionOutlining() {
if (MarkOutlinedColdCC) {
OutlinedFunc->setCallingConv(CallingConv::Cold);
- OCS.setCallingConv(CallingConv::Cold);
+ OCS->setCallingConv(CallingConv::Cold);
}
} else
ORE.emit([&]() {
@@ -1224,7 +1221,6 @@ PartialInlinerImpl::FunctionCloner::doSingleRegionFunctionOutlining() {
if (OutlinedFunc) {
BasicBlock *OutliningCallBB =
PartialInlinerImpl::getOneCallSiteTo(OutlinedFunc)
- .getInstruction()
->getParent();
assert(OutliningCallBB->getParent() == ClonedFunc);
OutlinedFunctions.push_back(std::make_pair(OutlinedFunc, OutliningCallBB));
@@ -1266,7 +1262,7 @@ std::pair<bool, Function *> PartialInlinerImpl::unswitchFunction(Function *F) {
if (F->hasFnAttribute(Attribute::NoInline))
return {false, nullptr};
- if (PSI->isFunctionEntryCold(F))
+ if (PSI.isFunctionEntryCold(F))
return {false, nullptr};
if (F->users().empty())
@@ -1276,7 +1272,7 @@ std::pair<bool, Function *> PartialInlinerImpl::unswitchFunction(Function *F) {
// Only try to outline cold regions if we have a profile summary, which
// implies we have profiling information.
- if (PSI->hasProfileSummary() && F->hasProfileData() &&
+ if (PSI.hasProfileSummary() && F->hasProfileData() &&
!DisableMultiRegionPartialInline) {
std::unique_ptr<FunctionOutliningMultiRegionInfo> OMRI =
computeOutliningColdRegionsInfo(F, ORE);
@@ -1285,8 +1281,8 @@ std::pair<bool, Function *> PartialInlinerImpl::unswitchFunction(Function *F) {
#ifndef NDEBUG
if (TracePartialInlining) {
- dbgs() << "HotCountThreshold = " << PSI->getHotCountThreshold() << "\n";
- dbgs() << "ColdCountThreshold = " << PSI->getColdCountThreshold()
+ dbgs() << "HotCountThreshold = " << PSI.getHotCountThreshold() << "\n";
+ dbgs() << "ColdCountThreshold = " << PSI.getColdCountThreshold()
<< "\n";
}
#endif
@@ -1391,27 +1387,28 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) {
bool AnyInline = false;
for (User *User : Users) {
- CallSite CS = getCallSite(User);
+ CallBase *CB = getSupportedCallBase(User);
if (IsLimitReached())
continue;
- OptimizationRemarkEmitter CallerORE(CS.getCaller());
- if (!shouldPartialInline(CS, Cloner, WeightedRcost, CallerORE))
+ OptimizationRemarkEmitter CallerORE(CB->getCaller());
+ if (!shouldPartialInline(*CB, Cloner, WeightedRcost, CallerORE))
continue;
// Construct remark before doing the inlining, as after successful inlining
// the callsite is removed.
- OptimizationRemark OR(DEBUG_TYPE, "PartiallyInlined", CS.getInstruction());
+ OptimizationRemark OR(DEBUG_TYPE, "PartiallyInlined", CB);
OR << ore::NV("Callee", Cloner.OrigFunc) << " partially inlined into "
- << ore::NV("Caller", CS.getCaller());
+ << ore::NV("Caller", CB->getCaller());
- InlineFunctionInfo IFI(nullptr, GetAssumptionCache, PSI);
+ InlineFunctionInfo IFI(nullptr, GetAssumptionCache, &PSI);
// We can only forward varargs when we outlined a single region, else we
// bail on vararg functions.
- if (!InlineFunction(CS, IFI, nullptr, true,
+ if (!InlineFunction(*CB, IFI, nullptr, true,
(Cloner.ClonedOI ? Cloner.OutlinedFunctions.back().first
- : nullptr)))
+ : nullptr))
+ .isSuccess())
continue;
CallerORE.emit(OR);
@@ -1492,6 +1489,7 @@ INITIALIZE_PASS_BEGIN(PartialInlinerLegacyPass, "partial-inliner",
INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
INITIALIZE_PASS_END(PartialInlinerLegacyPass, "partial-inliner",
"Partial Inliner", false, false)
@@ -1503,8 +1501,7 @@ PreservedAnalyses PartialInlinerPass::run(Module &M,
ModuleAnalysisManager &AM) {
auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
- std::function<AssumptionCache &(Function &)> GetAssumptionCache =
- [&FAM](Function &F) -> AssumptionCache & {
+ auto GetAssumptionCache = [&FAM](Function &F) -> AssumptionCache & {
return FAM.getResult<AssumptionAnalysis>(F);
};
@@ -1512,20 +1509,22 @@ PreservedAnalyses PartialInlinerPass::run(Module &M,
return FAM.getCachedResult<AssumptionAnalysis>(F);
};
- std::function<BlockFrequencyInfo &(Function &)> GetBFI =
- [&FAM](Function &F) -> BlockFrequencyInfo & {
+ auto GetBFI = [&FAM](Function &F) -> BlockFrequencyInfo & {
return FAM.getResult<BlockFrequencyAnalysis>(F);
};
- std::function<TargetTransformInfo &(Function &)> GetTTI =
- [&FAM](Function &F) -> TargetTransformInfo & {
+ auto GetTTI = [&FAM](Function &F) -> TargetTransformInfo & {
return FAM.getResult<TargetIRAnalysis>(F);
};
- ProfileSummaryInfo *PSI = &AM.getResult<ProfileSummaryAnalysis>(M);
+ auto GetTLI = [&FAM](Function &F) -> TargetLibraryInfo & {
+ return FAM.getResult<TargetLibraryAnalysis>(F);
+ };
+
+ ProfileSummaryInfo &PSI = AM.getResult<ProfileSummaryAnalysis>(M);
- if (PartialInlinerImpl(&GetAssumptionCache, LookupAssumptionCache, &GetTTI,
- {GetBFI}, PSI)
+ if (PartialInlinerImpl(GetAssumptionCache, LookupAssumptionCache, GetTTI,
+ GetTLI, PSI, GetBFI)
.run(M))
return PreservedAnalyses::none();
return PreservedAnalyses::all();