diff options
Diffstat (limited to 'contrib/llvm/lib/Analysis/BranchProbabilityInfo.cpp')
| -rw-r--r-- | contrib/llvm/lib/Analysis/BranchProbabilityInfo.cpp | 128 | 
1 files changed, 114 insertions, 14 deletions
diff --git a/contrib/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/contrib/llvm/lib/Analysis/BranchProbabilityInfo.cpp index a329e5ad48c9..58ccad89d508 100644 --- a/contrib/llvm/lib/Analysis/BranchProbabilityInfo.cpp +++ b/contrib/llvm/lib/Analysis/BranchProbabilityInfo.cpp @@ -1,4 +1,4 @@ -//===-- BranchProbabilityInfo.cpp - Branch Probability Analysis -----------===// +//===- BranchProbabilityInfo.cpp - Branch Probability Analysis ------------===//  //  //                     The LLVM Compiler Infrastructure  // @@ -13,21 +13,47 @@  #include "llvm/Analysis/BranchProbabilityInfo.h"  #include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h"  #include "llvm/Analysis/LoopInfo.h"  #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h"  #include "llvm/IR/CFG.h"  #include "llvm/IR/Constants.h"  #include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h"  #include "llvm/IR/Instructions.h"  #include "llvm/IR/LLVMContext.h"  #include "llvm/IR/Metadata.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/BranchProbability.h" +#include "llvm/Support/Casting.h"  #include "llvm/Support/Debug.h"  #include "llvm/Support/raw_ostream.h" +#include <cassert> +#include <cstdint> +#include <iterator> +#include <utility>  using namespace llvm;  #define DEBUG_TYPE "branch-prob" +static cl::opt<bool> PrintBranchProb( +    "print-bpi", cl::init(false), cl::Hidden, +    cl::desc("Print the branch probability info.")); + +cl::opt<std::string> PrintBranchProbFuncName( +    "print-bpi-func-name", cl::Hidden, +    cl::desc("The option to specify the name of the function " +             "whose branch probability info is printed.")); +  INITIALIZE_PASS_BEGIN(BranchProbabilityInfoWrapperPass, "branch-prob",                        "Branch Probability Analysis", false, true)  INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) @@ -221,7 +247,7 @@ bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) {  bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {    const TerminatorInst *TI = BB->getTerminator();    assert(TI->getNumSuccessors() > 1 && "expected more than one successor!"); -  if (!isa<BranchInst>(TI) && !isa<SwitchInst>(TI)) +  if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) || isa<IndirectBrInst>(TI)))      return false;    MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof); @@ -399,25 +425,73 @@ bool BranchProbabilityInfo::calcPointerHeuristics(const BasicBlock *BB) {    return true;  } +static int getSCCNum(const BasicBlock *BB, +                     const BranchProbabilityInfo::SccInfo &SccI) { +  auto SccIt = SccI.SccNums.find(BB); +  if (SccIt == SccI.SccNums.end()) +    return -1; +  return SccIt->second; +} + +// Consider any block that is an entry point to the SCC as a header. +static bool isSCCHeader(const BasicBlock *BB, int SccNum, +                        BranchProbabilityInfo::SccInfo &SccI) { +  assert(getSCCNum(BB, SccI) == SccNum); + +  // Lazily compute the set of headers for a given SCC and cache the results +  // in the SccHeaderMap. +  if (SccI.SccHeaders.size() <= static_cast<unsigned>(SccNum)) +    SccI.SccHeaders.resize(SccNum + 1); +  auto &HeaderMap = SccI.SccHeaders[SccNum]; +  bool Inserted; +  BranchProbabilityInfo::SccHeaderMap::iterator HeaderMapIt; +  std::tie(HeaderMapIt, Inserted) = HeaderMap.insert(std::make_pair(BB, false)); +  if (Inserted) { +    bool IsHeader = llvm::any_of(make_range(pred_begin(BB), pred_end(BB)), +                                 [&](const BasicBlock *Pred) { +                                   return getSCCNum(Pred, SccI) != SccNum; +                                 }); +    HeaderMapIt->second = IsHeader; +    return IsHeader; +  } else +    return HeaderMapIt->second; +} +  // Calculate Edge Weights using "Loop Branch Heuristics". Predict backedges  // as taken, exiting edges as not-taken.  bool BranchProbabilityInfo::calcLoopBranchHeuristics(const BasicBlock *BB, -                                                     const LoopInfo &LI) { +                                                     const LoopInfo &LI, +                                                     SccInfo &SccI) { +  int SccNum;    Loop *L = LI.getLoopFor(BB); -  if (!L) -    return false; +  if (!L) { +    SccNum = getSCCNum(BB, SccI); +    if (SccNum < 0) +      return false; +  }    SmallVector<unsigned, 8> BackEdges;    SmallVector<unsigned, 8> ExitingEdges;    SmallVector<unsigned, 8> InEdges; // Edges from header to the loop.    for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) { -    if (!L->contains(*I)) -      ExitingEdges.push_back(I.getSuccessorIndex()); -    else if (L->getHeader() == *I) -      BackEdges.push_back(I.getSuccessorIndex()); -    else -      InEdges.push_back(I.getSuccessorIndex()); +    // Use LoopInfo if we have it, otherwise fall-back to SCC info to catch +    // irreducible loops. +    if (L) { +      if (!L->contains(*I)) +        ExitingEdges.push_back(I.getSuccessorIndex()); +      else if (L->getHeader() == *I) +        BackEdges.push_back(I.getSuccessorIndex()); +      else +        InEdges.push_back(I.getSuccessorIndex()); +    } else { +      if (getSCCNum(*I, SccI) != SccNum) +        ExitingEdges.push_back(I.getSuccessorIndex()); +      else if (isSCCHeader(*I, SccNum, SccI)) +        BackEdges.push_back(I.getSuccessorIndex()); +      else +        InEdges.push_back(I.getSuccessorIndex()); +    }    }    if (BackEdges.empty() && ExitingEdges.empty()) @@ -480,7 +554,7 @@ bool BranchProbabilityInfo::calcZeroHeuristics(const BasicBlock *BB,    if (Instruction *LHS = dyn_cast<Instruction>(CI->getOperand(0)))      if (LHS->getOpcode() == Instruction::And)        if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(LHS->getOperand(1))) -        if (AndRHS->getUniqueInteger().isPowerOf2()) +        if (AndRHS->getValue().isPowerOf2())            return false;    // Check if the LHS is the return value of a library function @@ -722,7 +796,6 @@ raw_ostream &  BranchProbabilityInfo::printEdgeProbability(raw_ostream &OS,                                              const BasicBlock *Src,                                              const BasicBlock *Dst) const { -    const BranchProbability Prob = getEdgeProbability(Src, Dst);    OS << "edge " << Src->getName() << " -> " << Dst->getName()       << " probability is " << Prob @@ -747,6 +820,27 @@ void BranchProbabilityInfo::calculate(const Function &F, const LoopInfo &LI,    assert(PostDominatedByUnreachable.empty());    assert(PostDominatedByColdCall.empty()); +  // Record SCC numbers of blocks in the CFG to identify irreducible loops. +  // FIXME: We could only calculate this if the CFG is known to be irreducible +  // (perhaps cache this info in LoopInfo if we can easily calculate it there?). +  int SccNum = 0; +  SccInfo SccI; +  for (scc_iterator<const Function *> It = scc_begin(&F); !It.isAtEnd(); +       ++It, ++SccNum) { +    // Ignore single-block SCCs since they either aren't loops or LoopInfo will +    // catch them. +    const std::vector<const BasicBlock *> &Scc = *It; +    if (Scc.size() == 1) +      continue; + +    DEBUG(dbgs() << "BPI: SCC " << SccNum << ":"); +    for (auto *BB : Scc) { +      DEBUG(dbgs() << " " << BB->getName()); +      SccI.SccNums[BB] = SccNum; +    } +    DEBUG(dbgs() << "\n"); +  } +    // Walk the basic blocks in post-order so that we can build up state about    // the successors of a block iteratively.    for (auto BB : post_order(&F.getEntryBlock())) { @@ -762,7 +856,7 @@ void BranchProbabilityInfo::calculate(const Function &F, const LoopInfo &LI,        continue;      if (calcColdCallHeuristics(BB))        continue; -    if (calcLoopBranchHeuristics(BB, LI)) +    if (calcLoopBranchHeuristics(BB, LI, SccI))        continue;      if (calcPointerHeuristics(BB))        continue; @@ -775,6 +869,12 @@ void BranchProbabilityInfo::calculate(const Function &F, const LoopInfo &LI,    PostDominatedByUnreachable.clear();    PostDominatedByColdCall.clear(); + +  if (PrintBranchProb && +      (PrintBranchProbFuncName.empty() || +       F.getName().equals(PrintBranchProbFuncName))) { +    print(dbgs()); +  }  }  void BranchProbabilityInfoWrapperPass::getAnalysisUsage(  | 
