diff options
| author | Dimitry Andric <dim@FreeBSD.org> | 2015-05-27 20:44:45 +0000 | 
|---|---|---|
| committer | Dimitry Andric <dim@FreeBSD.org> | 2015-05-27 20:44:45 +0000 | 
| commit | 33956c43007dfb106f401e3c14abc011a4b1d4ca (patch) | |
| tree | 50a603f7e1932cd42f58e26687ce907933014db0 /contrib/llvm/tools/clang/lib/CodeGen/CodeGenPGO.cpp | |
| parent | ff0cc061ecf297f1556e906d229826fd709f37d6 (diff) | |
| parent | 5e20cdd81c44a443562a09007668ffdf76c455af (diff) | |
Notes
Diffstat (limited to 'contrib/llvm/tools/clang/lib/CodeGen/CodeGenPGO.cpp')
| -rw-r--r-- | contrib/llvm/tools/clang/lib/CodeGen/CodeGenPGO.cpp | 910 | 
1 files changed, 450 insertions, 460 deletions
diff --git a/contrib/llvm/tools/clang/lib/CodeGen/CodeGenPGO.cpp b/contrib/llvm/tools/clang/lib/CodeGen/CodeGenPGO.cpp index 24b035d67598..c97244328d37 100644 --- a/contrib/llvm/tools/clang/lib/CodeGen/CodeGenPGO.cpp +++ b/contrib/llvm/tools/clang/lib/CodeGen/CodeGenPGO.cpp @@ -58,12 +58,16 @@ void CodeGenPGO::setFuncName(llvm::Function *Fn) {  }  void CodeGenPGO::createFuncNameVar(llvm::GlobalValue::LinkageTypes Linkage) { -  // Usually, we want to match the function's linkage, but -  // available_externally and extern_weak both have the wrong semantics. +  // We generally want to match the function's linkage, but available_externally +  // and extern_weak both have the wrong semantics, and anything that doesn't +  // need to link across compilation units doesn't need to be visible at all.    if (Linkage == llvm::GlobalValue::ExternalWeakLinkage)      Linkage = llvm::GlobalValue::LinkOnceAnyLinkage;    else if (Linkage == llvm::GlobalValue::AvailableExternallyLinkage)      Linkage = llvm::GlobalValue::LinkOnceODRLinkage; +  else if (Linkage == llvm::GlobalValue::InternalLinkage || +           Linkage == llvm::GlobalValue::ExternalLinkage) +    Linkage = llvm::GlobalValue::PrivateLinkage;    auto *Value =        llvm::ConstantDataArray::getString(CGM.getLLVMContext(), FuncName, false); @@ -138,482 +142,469 @@ const int PGOHash::NumBitsPerType;  const unsigned PGOHash::NumTypesPerWord;  const unsigned PGOHash::TooBig; -  /// A RecursiveASTVisitor that fills a map of statements to PGO counters. -  struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> { -    /// The next counter value to assign. -    unsigned NextCounter; -    /// The function hash. -    PGOHash Hash; -    /// The map of statements to counters. -    llvm::DenseMap<const Stmt *, unsigned> &CounterMap; - -    MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap) -        : NextCounter(0), CounterMap(CounterMap) {} - -    // Blocks and lambdas are handled as separate functions, so we need not -    // traverse them in the parent context. -    bool TraverseBlockExpr(BlockExpr *BE) { return true; } -    bool TraverseLambdaBody(LambdaExpr *LE) { return true; } -    bool TraverseCapturedStmt(CapturedStmt *CS) { return true; } - -    bool VisitDecl(const Decl *D) { -      switch (D->getKind()) { -      default: -        break; -      case Decl::Function: -      case Decl::CXXMethod: -      case Decl::CXXConstructor: -      case Decl::CXXDestructor: -      case Decl::CXXConversion: -      case Decl::ObjCMethod: -      case Decl::Block: -      case Decl::Captured: -        CounterMap[D->getBody()] = NextCounter++; -        break; -      } -      return true; +/// A RecursiveASTVisitor that fills a map of statements to PGO counters. +struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> { +  /// The next counter value to assign. +  unsigned NextCounter; +  /// The function hash. +  PGOHash Hash; +  /// The map of statements to counters. +  llvm::DenseMap<const Stmt *, unsigned> &CounterMap; + +  MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap) +      : NextCounter(0), CounterMap(CounterMap) {} + +  // Blocks and lambdas are handled as separate functions, so we need not +  // traverse them in the parent context. +  bool TraverseBlockExpr(BlockExpr *BE) { return true; } +  bool TraverseLambdaBody(LambdaExpr *LE) { return true; } +  bool TraverseCapturedStmt(CapturedStmt *CS) { return true; } + +  bool VisitDecl(const Decl *D) { +    switch (D->getKind()) { +    default: +      break; +    case Decl::Function: +    case Decl::CXXMethod: +    case Decl::CXXConstructor: +    case Decl::CXXDestructor: +    case Decl::CXXConversion: +    case Decl::ObjCMethod: +    case Decl::Block: +    case Decl::Captured: +      CounterMap[D->getBody()] = NextCounter++; +      break;      } +    return true; +  } -    bool VisitStmt(const Stmt *S) { -      auto Type = getHashType(S); -      if (Type == PGOHash::None) -        return true; - -      CounterMap[S] = NextCounter++; -      Hash.combine(Type); +  bool VisitStmt(const Stmt *S) { +    auto Type = getHashType(S); +    if (Type == PGOHash::None)        return true; + +    CounterMap[S] = NextCounter++; +    Hash.combine(Type); +    return true; +  } +  PGOHash::HashType getHashType(const Stmt *S) { +    switch (S->getStmtClass()) { +    default: +      break; +    case Stmt::LabelStmtClass: +      return PGOHash::LabelStmt; +    case Stmt::WhileStmtClass: +      return PGOHash::WhileStmt; +    case Stmt::DoStmtClass: +      return PGOHash::DoStmt; +    case Stmt::ForStmtClass: +      return PGOHash::ForStmt; +    case Stmt::CXXForRangeStmtClass: +      return PGOHash::CXXForRangeStmt; +    case Stmt::ObjCForCollectionStmtClass: +      return PGOHash::ObjCForCollectionStmt; +    case Stmt::SwitchStmtClass: +      return PGOHash::SwitchStmt; +    case Stmt::CaseStmtClass: +      return PGOHash::CaseStmt; +    case Stmt::DefaultStmtClass: +      return PGOHash::DefaultStmt; +    case Stmt::IfStmtClass: +      return PGOHash::IfStmt; +    case Stmt::CXXTryStmtClass: +      return PGOHash::CXXTryStmt; +    case Stmt::CXXCatchStmtClass: +      return PGOHash::CXXCatchStmt; +    case Stmt::ConditionalOperatorClass: +      return PGOHash::ConditionalOperator; +    case Stmt::BinaryConditionalOperatorClass: +      return PGOHash::BinaryConditionalOperator; +    case Stmt::BinaryOperatorClass: { +      const BinaryOperator *BO = cast<BinaryOperator>(S); +      if (BO->getOpcode() == BO_LAnd) +        return PGOHash::BinaryOperatorLAnd; +      if (BO->getOpcode() == BO_LOr) +        return PGOHash::BinaryOperatorLOr; +      break;      } -    PGOHash::HashType getHashType(const Stmt *S) { -      switch (S->getStmtClass()) { -      default: -        break; -      case Stmt::LabelStmtClass: -        return PGOHash::LabelStmt; -      case Stmt::WhileStmtClass: -        return PGOHash::WhileStmt; -      case Stmt::DoStmtClass: -        return PGOHash::DoStmt; -      case Stmt::ForStmtClass: -        return PGOHash::ForStmt; -      case Stmt::CXXForRangeStmtClass: -        return PGOHash::CXXForRangeStmt; -      case Stmt::ObjCForCollectionStmtClass: -        return PGOHash::ObjCForCollectionStmt; -      case Stmt::SwitchStmtClass: -        return PGOHash::SwitchStmt; -      case Stmt::CaseStmtClass: -        return PGOHash::CaseStmt; -      case Stmt::DefaultStmtClass: -        return PGOHash::DefaultStmt; -      case Stmt::IfStmtClass: -        return PGOHash::IfStmt; -      case Stmt::CXXTryStmtClass: -        return PGOHash::CXXTryStmt; -      case Stmt::CXXCatchStmtClass: -        return PGOHash::CXXCatchStmt; -      case Stmt::ConditionalOperatorClass: -        return PGOHash::ConditionalOperator; -      case Stmt::BinaryConditionalOperatorClass: -        return PGOHash::BinaryConditionalOperator; -      case Stmt::BinaryOperatorClass: { -        const BinaryOperator *BO = cast<BinaryOperator>(S); -        if (BO->getOpcode() == BO_LAnd) -          return PGOHash::BinaryOperatorLAnd; -        if (BO->getOpcode() == BO_LOr) -          return PGOHash::BinaryOperatorLOr; -        break; -      } -      } -      return PGOHash::None;      } +    return PGOHash::None; +  } +}; + +/// A StmtVisitor that propagates the raw counts through the AST and +/// records the count at statements where the value may change. +struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { +  /// PGO state. +  CodeGenPGO &PGO; + +  /// A flag that is set when the current count should be recorded on the +  /// next statement, such as at the exit of a loop. +  bool RecordNextStmtCount; + +  /// The count at the current location in the traversal. +  uint64_t CurrentCount; + +  /// The map of statements to count values. +  llvm::DenseMap<const Stmt *, uint64_t> &CountMap; + +  /// BreakContinueStack - Keep counts of breaks and continues inside loops. +  struct BreakContinue { +    uint64_t BreakCount; +    uint64_t ContinueCount; +    BreakContinue() : BreakCount(0), ContinueCount(0) {}    }; +  SmallVector<BreakContinue, 8> BreakContinueStack; -  /// A StmtVisitor that propagates the raw counts through the AST and -  /// records the count at statements where the value may change. -  struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { -    /// PGO state. -    CodeGenPGO &PGO; - -    /// A flag that is set when the current count should be recorded on the -    /// next statement, such as at the exit of a loop. -    bool RecordNextStmtCount; - -    /// The map of statements to count values. -    llvm::DenseMap<const Stmt *, uint64_t> &CountMap; - -    /// BreakContinueStack - Keep counts of breaks and continues inside loops. -    struct BreakContinue { -      uint64_t BreakCount; -      uint64_t ContinueCount; -      BreakContinue() : BreakCount(0), ContinueCount(0) {} -    }; -    SmallVector<BreakContinue, 8> BreakContinueStack; - -    ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap, -                        CodeGenPGO &PGO) -        : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {} - -    void RecordStmtCount(const Stmt *S) { -      if (RecordNextStmtCount) { -        CountMap[S] = PGO.getCurrentRegionCount(); -        RecordNextStmtCount = false; -      } -    } +  ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap, +                      CodeGenPGO &PGO) +      : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {} -    void VisitStmt(const Stmt *S) { -      RecordStmtCount(S); -      for (Stmt::const_child_range I = S->children(); I; ++I) { -        if (*I) -         this->Visit(*I); -      } +  void RecordStmtCount(const Stmt *S) { +    if (RecordNextStmtCount) { +      CountMap[S] = CurrentCount; +      RecordNextStmtCount = false;      } +  } -    void VisitFunctionDecl(const FunctionDecl *D) { -      // Counter tracks entry to the function body. -      RegionCounter Cnt(PGO, D->getBody()); -      Cnt.beginRegion(); -      CountMap[D->getBody()] = PGO.getCurrentRegionCount(); -      Visit(D->getBody()); -    } +  /// Set and return the current count. +  uint64_t setCount(uint64_t Count) { +    CurrentCount = Count; +    return Count; +  } -    // Skip lambda expressions. We visit these as FunctionDecls when we're -    // generating them and aren't interested in the body when generating a -    // parent context. -    void VisitLambdaExpr(const LambdaExpr *LE) {} - -    void VisitCapturedDecl(const CapturedDecl *D) { -      // Counter tracks entry to the capture body. -      RegionCounter Cnt(PGO, D->getBody()); -      Cnt.beginRegion(); -      CountMap[D->getBody()] = PGO.getCurrentRegionCount(); -      Visit(D->getBody()); +  void VisitStmt(const Stmt *S) { +    RecordStmtCount(S); +    for (Stmt::const_child_range I = S->children(); I; ++I) { +      if (*I) +        this->Visit(*I);      } +  } -    void VisitObjCMethodDecl(const ObjCMethodDecl *D) { -      // Counter tracks entry to the method body. -      RegionCounter Cnt(PGO, D->getBody()); -      Cnt.beginRegion(); -      CountMap[D->getBody()] = PGO.getCurrentRegionCount(); -      Visit(D->getBody()); -    } +  void VisitFunctionDecl(const FunctionDecl *D) { +    // Counter tracks entry to the function body. +    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); +    CountMap[D->getBody()] = BodyCount; +    Visit(D->getBody()); +  } -    void VisitBlockDecl(const BlockDecl *D) { -      // Counter tracks entry to the block body. -      RegionCounter Cnt(PGO, D->getBody()); -      Cnt.beginRegion(); -      CountMap[D->getBody()] = PGO.getCurrentRegionCount(); -      Visit(D->getBody()); -    } +  // Skip lambda expressions. We visit these as FunctionDecls when we're +  // generating them and aren't interested in the body when generating a +  // parent context. +  void VisitLambdaExpr(const LambdaExpr *LE) {} -    void VisitReturnStmt(const ReturnStmt *S) { -      RecordStmtCount(S); -      if (S->getRetValue()) -        Visit(S->getRetValue()); -      PGO.setCurrentRegionUnreachable(); -      RecordNextStmtCount = true; -    } +  void VisitCapturedDecl(const CapturedDecl *D) { +    // Counter tracks entry to the capture body. +    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); +    CountMap[D->getBody()] = BodyCount; +    Visit(D->getBody()); +  } -    void VisitGotoStmt(const GotoStmt *S) { -      RecordStmtCount(S); -      PGO.setCurrentRegionUnreachable(); -      RecordNextStmtCount = true; -    } +  void VisitObjCMethodDecl(const ObjCMethodDecl *D) { +    // Counter tracks entry to the method body. +    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); +    CountMap[D->getBody()] = BodyCount; +    Visit(D->getBody()); +  } -    void VisitLabelStmt(const LabelStmt *S) { -      RecordNextStmtCount = false; -      // Counter tracks the block following the label. -      RegionCounter Cnt(PGO, S); -      Cnt.beginRegion(); -      CountMap[S] = PGO.getCurrentRegionCount(); -      Visit(S->getSubStmt()); -    } +  void VisitBlockDecl(const BlockDecl *D) { +    // Counter tracks entry to the block body. +    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); +    CountMap[D->getBody()] = BodyCount; +    Visit(D->getBody()); +  } -    void VisitBreakStmt(const BreakStmt *S) { -      RecordStmtCount(S); -      assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); -      BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount(); -      PGO.setCurrentRegionUnreachable(); -      RecordNextStmtCount = true; -    } +  void VisitReturnStmt(const ReturnStmt *S) { +    RecordStmtCount(S); +    if (S->getRetValue()) +      Visit(S->getRetValue()); +    CurrentCount = 0; +    RecordNextStmtCount = true; +  } -    void VisitContinueStmt(const ContinueStmt *S) { -      RecordStmtCount(S); -      assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); -      BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount(); -      PGO.setCurrentRegionUnreachable(); -      RecordNextStmtCount = true; -    } +  void VisitCXXThrowExpr(const CXXThrowExpr *E) { +    RecordStmtCount(E); +    if (E->getSubExpr()) +      Visit(E->getSubExpr()); +    CurrentCount = 0; +    RecordNextStmtCount = true; +  } -    void VisitWhileStmt(const WhileStmt *S) { -      RecordStmtCount(S); -      // Counter tracks the body of the loop. -      RegionCounter Cnt(PGO, S); -      BreakContinueStack.push_back(BreakContinue()); -      // Visit the body region first so the break/continue adjustments can be -      // included when visiting the condition. -      Cnt.beginRegion(); -      CountMap[S->getBody()] = PGO.getCurrentRegionCount(); -      Visit(S->getBody()); -      Cnt.adjustForControlFlow(); - -      // ...then go back and propagate counts through the condition. The count -      // at the start of the condition is the sum of the incoming edges, -      // the backedge from the end of the loop body, and the edges from -      // continue statements. -      BreakContinue BC = BreakContinueStack.pop_back_val(); -      Cnt.setCurrentRegionCount(Cnt.getParentCount() + -                                Cnt.getAdjustedCount() + BC.ContinueCount); -      CountMap[S->getCond()] = PGO.getCurrentRegionCount(); -      Visit(S->getCond()); -      Cnt.adjustForControlFlow(); -      Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); -      RecordNextStmtCount = true; -    } +  void VisitGotoStmt(const GotoStmt *S) { +    RecordStmtCount(S); +    CurrentCount = 0; +    RecordNextStmtCount = true; +  } -    void VisitDoStmt(const DoStmt *S) { -      RecordStmtCount(S); -      // Counter tracks the body of the loop. -      RegionCounter Cnt(PGO, S); -      BreakContinueStack.push_back(BreakContinue()); -      Cnt.beginRegion(/*AddIncomingFallThrough=*/true); -      CountMap[S->getBody()] = PGO.getCurrentRegionCount(); -      Visit(S->getBody()); -      Cnt.adjustForControlFlow(); - -      BreakContinue BC = BreakContinueStack.pop_back_val(); -      // The count at the start of the condition is equal to the count at the -      // end of the body. The adjusted count does not include either the -      // fall-through count coming into the loop or the continue count, so add -      // both of those separately. This is coincidentally the same equation as -      // with while loops but for different reasons. -      Cnt.setCurrentRegionCount(Cnt.getParentCount() + -                                Cnt.getAdjustedCount() + BC.ContinueCount); -      CountMap[S->getCond()] = PGO.getCurrentRegionCount(); -      Visit(S->getCond()); -      Cnt.adjustForControlFlow(); -      Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); -      RecordNextStmtCount = true; -    } +  void VisitLabelStmt(const LabelStmt *S) { +    RecordNextStmtCount = false; +    // Counter tracks the block following the label. +    uint64_t BlockCount = setCount(PGO.getRegionCount(S)); +    CountMap[S] = BlockCount; +    Visit(S->getSubStmt()); +  } -    void VisitForStmt(const ForStmt *S) { -      RecordStmtCount(S); -      if (S->getInit()) -        Visit(S->getInit()); -      // Counter tracks the body of the loop. -      RegionCounter Cnt(PGO, S); -      BreakContinueStack.push_back(BreakContinue()); -      // Visit the body region first. (This is basically the same as a while -      // loop; see further comments in VisitWhileStmt.) -      Cnt.beginRegion(); -      CountMap[S->getBody()] = PGO.getCurrentRegionCount(); -      Visit(S->getBody()); -      Cnt.adjustForControlFlow(); - -      // The increment is essentially part of the body but it needs to include -      // the count for all the continue statements. -      if (S->getInc()) { -        Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() + -                                  BreakContinueStack.back().ContinueCount); -        CountMap[S->getInc()] = PGO.getCurrentRegionCount(); -        Visit(S->getInc()); -        Cnt.adjustForControlFlow(); -      } - -      BreakContinue BC = BreakContinueStack.pop_back_val(); - -      // ...then go back and propagate counts through the condition. -      if (S->getCond()) { -        Cnt.setCurrentRegionCount(Cnt.getParentCount() + -                                  Cnt.getAdjustedCount() + -                                  BC.ContinueCount); -        CountMap[S->getCond()] = PGO.getCurrentRegionCount(); -        Visit(S->getCond()); -        Cnt.adjustForControlFlow(); -      } -      Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); -      RecordNextStmtCount = true; -    } +  void VisitBreakStmt(const BreakStmt *S) { +    RecordStmtCount(S); +    assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); +    BreakContinueStack.back().BreakCount += CurrentCount; +    CurrentCount = 0; +    RecordNextStmtCount = true; +  } -    void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { -      RecordStmtCount(S); -      Visit(S->getRangeStmt()); -      Visit(S->getBeginEndStmt()); -      // Counter tracks the body of the loop. -      RegionCounter Cnt(PGO, S); -      BreakContinueStack.push_back(BreakContinue()); -      // Visit the body region first. (This is basically the same as a while -      // loop; see further comments in VisitWhileStmt.) -      Cnt.beginRegion(); -      CountMap[S->getLoopVarStmt()] = PGO.getCurrentRegionCount(); -      Visit(S->getLoopVarStmt()); -      Visit(S->getBody()); -      Cnt.adjustForControlFlow(); - -      // The increment is essentially part of the body but it needs to include -      // the count for all the continue statements. -      Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() + -                                BreakContinueStack.back().ContinueCount); -      CountMap[S->getInc()] = PGO.getCurrentRegionCount(); -      Visit(S->getInc()); -      Cnt.adjustForControlFlow(); +  void VisitContinueStmt(const ContinueStmt *S) { +    RecordStmtCount(S); +    assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); +    BreakContinueStack.back().ContinueCount += CurrentCount; +    CurrentCount = 0; +    RecordNextStmtCount = true; +  } -      BreakContinue BC = BreakContinueStack.pop_back_val(); +  void VisitWhileStmt(const WhileStmt *S) { +    RecordStmtCount(S); +    uint64_t ParentCount = CurrentCount; + +    BreakContinueStack.push_back(BreakContinue()); +    // Visit the body region first so the break/continue adjustments can be +    // included when visiting the condition. +    uint64_t BodyCount = setCount(PGO.getRegionCount(S)); +    CountMap[S->getBody()] = CurrentCount; +    Visit(S->getBody()); +    uint64_t BackedgeCount = CurrentCount; + +    // ...then go back and propagate counts through the condition. The count +    // at the start of the condition is the sum of the incoming edges, +    // the backedge from the end of the loop body, and the edges from +    // continue statements. +    BreakContinue BC = BreakContinueStack.pop_back_val(); +    uint64_t CondCount = +        setCount(ParentCount + BackedgeCount + BC.ContinueCount); +    CountMap[S->getCond()] = CondCount; +    Visit(S->getCond()); +    setCount(BC.BreakCount + CondCount - BodyCount); +    RecordNextStmtCount = true; +  } -      // ...then go back and propagate counts through the condition. -      Cnt.setCurrentRegionCount(Cnt.getParentCount() + -                                Cnt.getAdjustedCount() + -                                BC.ContinueCount); -      CountMap[S->getCond()] = PGO.getCurrentRegionCount(); -      Visit(S->getCond()); -      Cnt.adjustForControlFlow(); -      Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); -      RecordNextStmtCount = true; -    } +  void VisitDoStmt(const DoStmt *S) { +    RecordStmtCount(S); +    uint64_t LoopCount = PGO.getRegionCount(S); + +    BreakContinueStack.push_back(BreakContinue()); +    // The count doesn't include the fallthrough from the parent scope. Add it. +    uint64_t BodyCount = setCount(LoopCount + CurrentCount); +    CountMap[S->getBody()] = BodyCount; +    Visit(S->getBody()); +    uint64_t BackedgeCount = CurrentCount; + +    BreakContinue BC = BreakContinueStack.pop_back_val(); +    // The count at the start of the condition is equal to the count at the +    // end of the body, plus any continues. +    uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount); +    CountMap[S->getCond()] = CondCount; +    Visit(S->getCond()); +    setCount(BC.BreakCount + CondCount - LoopCount); +    RecordNextStmtCount = true; +  } -    void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { -      RecordStmtCount(S); -      Visit(S->getElement()); -      // Counter tracks the body of the loop. -      RegionCounter Cnt(PGO, S); -      BreakContinueStack.push_back(BreakContinue()); -      Cnt.beginRegion(); -      CountMap[S->getBody()] = PGO.getCurrentRegionCount(); -      Visit(S->getBody()); -      BreakContinue BC = BreakContinueStack.pop_back_val(); -      Cnt.adjustForControlFlow(); -      Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); -      RecordNextStmtCount = true; +  void VisitForStmt(const ForStmt *S) { +    RecordStmtCount(S); +    if (S->getInit()) +      Visit(S->getInit()); + +    uint64_t ParentCount = CurrentCount; + +    BreakContinueStack.push_back(BreakContinue()); +    // Visit the body region first. (This is basically the same as a while +    // loop; see further comments in VisitWhileStmt.) +    uint64_t BodyCount = setCount(PGO.getRegionCount(S)); +    CountMap[S->getBody()] = BodyCount; +    Visit(S->getBody()); +    uint64_t BackedgeCount = CurrentCount; +    BreakContinue BC = BreakContinueStack.pop_back_val(); + +    // The increment is essentially part of the body but it needs to include +    // the count for all the continue statements. +    if (S->getInc()) { +      uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); +      CountMap[S->getInc()] = IncCount; +      Visit(S->getInc());      } -    void VisitSwitchStmt(const SwitchStmt *S) { -      RecordStmtCount(S); +    // ...then go back and propagate counts through the condition. +    uint64_t CondCount = +        setCount(ParentCount + BackedgeCount + BC.ContinueCount); +    if (S->getCond()) { +      CountMap[S->getCond()] = CondCount;        Visit(S->getCond()); -      PGO.setCurrentRegionUnreachable(); -      BreakContinueStack.push_back(BreakContinue()); -      Visit(S->getBody()); -      // If the switch is inside a loop, add the continue counts. -      BreakContinue BC = BreakContinueStack.pop_back_val(); -      if (!BreakContinueStack.empty()) -        BreakContinueStack.back().ContinueCount += BC.ContinueCount; -      // Counter tracks the exit block of the switch. -      RegionCounter ExitCnt(PGO, S); -      ExitCnt.beginRegion(); -      RecordNextStmtCount = true;      } +    setCount(BC.BreakCount + CondCount - BodyCount); +    RecordNextStmtCount = true; +  } -    void VisitCaseStmt(const CaseStmt *S) { -      RecordNextStmtCount = false; -      // Counter for this particular case. This counts only jumps from the -      // switch header and does not include fallthrough from the case before -      // this one. -      RegionCounter Cnt(PGO, S); -      Cnt.beginRegion(/*AddIncomingFallThrough=*/true); -      CountMap[S] = Cnt.getCount(); -      RecordNextStmtCount = true; -      Visit(S->getSubStmt()); -    } +  void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { +    RecordStmtCount(S); +    Visit(S->getLoopVarStmt()); +    Visit(S->getRangeStmt()); +    Visit(S->getBeginEndStmt()); + +    uint64_t ParentCount = CurrentCount; +    BreakContinueStack.push_back(BreakContinue()); +    // Visit the body region first. (This is basically the same as a while +    // loop; see further comments in VisitWhileStmt.) +    uint64_t BodyCount = setCount(PGO.getRegionCount(S)); +    CountMap[S->getBody()] = BodyCount; +    Visit(S->getBody()); +    uint64_t BackedgeCount = CurrentCount; +    BreakContinue BC = BreakContinueStack.pop_back_val(); + +    // The increment is essentially part of the body but it needs to include +    // the count for all the continue statements. +    uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); +    CountMap[S->getInc()] = IncCount; +    Visit(S->getInc()); + +    // ...then go back and propagate counts through the condition. +    uint64_t CondCount = +        setCount(ParentCount + BackedgeCount + BC.ContinueCount); +    CountMap[S->getCond()] = CondCount; +    Visit(S->getCond()); +    setCount(BC.BreakCount + CondCount - BodyCount); +    RecordNextStmtCount = true; +  } -    void VisitDefaultStmt(const DefaultStmt *S) { -      RecordNextStmtCount = false; -      // Counter for this default case. This does not include fallthrough from -      // the previous case. -      RegionCounter Cnt(PGO, S); -      Cnt.beginRegion(/*AddIncomingFallThrough=*/true); -      CountMap[S] = Cnt.getCount(); -      RecordNextStmtCount = true; -      Visit(S->getSubStmt()); -    } +  void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { +    RecordStmtCount(S); +    Visit(S->getElement()); +    uint64_t ParentCount = CurrentCount; +    BreakContinueStack.push_back(BreakContinue()); +    // Counter tracks the body of the loop. +    uint64_t BodyCount = setCount(PGO.getRegionCount(S)); +    CountMap[S->getBody()] = BodyCount; +    Visit(S->getBody()); +    uint64_t BackedgeCount = CurrentCount; +    BreakContinue BC = BreakContinueStack.pop_back_val(); + +    setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount - +             BodyCount); +    RecordNextStmtCount = true; +  } -    void VisitIfStmt(const IfStmt *S) { -      RecordStmtCount(S); -      // Counter tracks the "then" part of an if statement. The count for -      // the "else" part, if it exists, will be calculated from this counter. -      RegionCounter Cnt(PGO, S); -      Visit(S->getCond()); +  void VisitSwitchStmt(const SwitchStmt *S) { +    RecordStmtCount(S); +    Visit(S->getCond()); +    CurrentCount = 0; +    BreakContinueStack.push_back(BreakContinue()); +    Visit(S->getBody()); +    // If the switch is inside a loop, add the continue counts. +    BreakContinue BC = BreakContinueStack.pop_back_val(); +    if (!BreakContinueStack.empty()) +      BreakContinueStack.back().ContinueCount += BC.ContinueCount; +    // Counter tracks the exit block of the switch. +    setCount(PGO.getRegionCount(S)); +    RecordNextStmtCount = true; +  } -      Cnt.beginRegion(); -      CountMap[S->getThen()] = PGO.getCurrentRegionCount(); -      Visit(S->getThen()); -      Cnt.adjustForControlFlow(); - -      if (S->getElse()) { -        Cnt.beginElseRegion(); -        CountMap[S->getElse()] = PGO.getCurrentRegionCount(); -        Visit(S->getElse()); -        Cnt.adjustForControlFlow(); -      } -      Cnt.applyAdjustmentsToRegion(0); -      RecordNextStmtCount = true; -    } +  void VisitSwitchCase(const SwitchCase *S) { +    RecordNextStmtCount = false; +    // Counter for this particular case. This counts only jumps from the +    // switch header and does not include fallthrough from the case before +    // this one. +    uint64_t CaseCount = PGO.getRegionCount(S); +    setCount(CurrentCount + CaseCount); +    // We need the count without fallthrough in the mapping, so it's more useful +    // for branch probabilities. +    CountMap[S] = CaseCount; +    RecordNextStmtCount = true; +    Visit(S->getSubStmt()); +  } -    void VisitCXXTryStmt(const CXXTryStmt *S) { -      RecordStmtCount(S); -      Visit(S->getTryBlock()); -      for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) -        Visit(S->getHandler(I)); -      // Counter tracks the continuation block of the try statement. -      RegionCounter Cnt(PGO, S); -      Cnt.beginRegion(); -      RecordNextStmtCount = true; -    } +  void VisitIfStmt(const IfStmt *S) { +    RecordStmtCount(S); +    uint64_t ParentCount = CurrentCount; +    Visit(S->getCond()); + +    // Counter tracks the "then" part of an if statement. The count for +    // the "else" part, if it exists, will be calculated from this counter. +    uint64_t ThenCount = setCount(PGO.getRegionCount(S)); +    CountMap[S->getThen()] = ThenCount; +    Visit(S->getThen()); +    uint64_t OutCount = CurrentCount; + +    uint64_t ElseCount = ParentCount - ThenCount; +    if (S->getElse()) { +      setCount(ElseCount); +      CountMap[S->getElse()] = ElseCount; +      Visit(S->getElse()); +      OutCount += CurrentCount; +    } else +      OutCount += ElseCount; +    setCount(OutCount); +    RecordNextStmtCount = true; +  } -    void VisitCXXCatchStmt(const CXXCatchStmt *S) { -      RecordNextStmtCount = false; -      // Counter tracks the catch statement's handler block. -      RegionCounter Cnt(PGO, S); -      Cnt.beginRegion(); -      CountMap[S] = PGO.getCurrentRegionCount(); -      Visit(S->getHandlerBlock()); -    } +  void VisitCXXTryStmt(const CXXTryStmt *S) { +    RecordStmtCount(S); +    Visit(S->getTryBlock()); +    for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) +      Visit(S->getHandler(I)); +    // Counter tracks the continuation block of the try statement. +    setCount(PGO.getRegionCount(S)); +    RecordNextStmtCount = true; +  } -    void VisitAbstractConditionalOperator( -        const AbstractConditionalOperator *E) { -      RecordStmtCount(E); -      // Counter tracks the "true" part of a conditional operator. The -      // count in the "false" part will be calculated from this counter. -      RegionCounter Cnt(PGO, E); -      Visit(E->getCond()); - -      Cnt.beginRegion(); -      CountMap[E->getTrueExpr()] = PGO.getCurrentRegionCount(); -      Visit(E->getTrueExpr()); -      Cnt.adjustForControlFlow(); - -      Cnt.beginElseRegion(); -      CountMap[E->getFalseExpr()] = PGO.getCurrentRegionCount(); -      Visit(E->getFalseExpr()); -      Cnt.adjustForControlFlow(); - -      Cnt.applyAdjustmentsToRegion(0); -      RecordNextStmtCount = true; -    } +  void VisitCXXCatchStmt(const CXXCatchStmt *S) { +    RecordNextStmtCount = false; +    // Counter tracks the catch statement's handler block. +    uint64_t CatchCount = setCount(PGO.getRegionCount(S)); +    CountMap[S] = CatchCount; +    Visit(S->getHandlerBlock()); +  } -    void VisitBinLAnd(const BinaryOperator *E) { -      RecordStmtCount(E); -      // Counter tracks the right hand side of a logical and operator. -      RegionCounter Cnt(PGO, E); -      Visit(E->getLHS()); -      Cnt.beginRegion(); -      CountMap[E->getRHS()] = PGO.getCurrentRegionCount(); -      Visit(E->getRHS()); -      Cnt.adjustForControlFlow(); -      Cnt.applyAdjustmentsToRegion(0); -      RecordNextStmtCount = true; -    } +  void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) { +    RecordStmtCount(E); +    uint64_t ParentCount = CurrentCount; +    Visit(E->getCond()); + +    // Counter tracks the "true" part of a conditional operator. The +    // count in the "false" part will be calculated from this counter. +    uint64_t TrueCount = setCount(PGO.getRegionCount(E)); +    CountMap[E->getTrueExpr()] = TrueCount; +    Visit(E->getTrueExpr()); +    uint64_t OutCount = CurrentCount; + +    uint64_t FalseCount = setCount(ParentCount - TrueCount); +    CountMap[E->getFalseExpr()] = FalseCount; +    Visit(E->getFalseExpr()); +    OutCount += CurrentCount; + +    setCount(OutCount); +    RecordNextStmtCount = true; +  } -    void VisitBinLOr(const BinaryOperator *E) { -      RecordStmtCount(E); -      // Counter tracks the right hand side of a logical or operator. -      RegionCounter Cnt(PGO, E); -      Visit(E->getLHS()); -      Cnt.beginRegion(); -      CountMap[E->getRHS()] = PGO.getCurrentRegionCount(); -      Visit(E->getRHS()); -      Cnt.adjustForControlFlow(); -      Cnt.applyAdjustmentsToRegion(0); -      RecordNextStmtCount = true; -    } -  }; +  void VisitBinLAnd(const BinaryOperator *E) { +    RecordStmtCount(E); +    uint64_t ParentCount = CurrentCount; +    Visit(E->getLHS()); +    // Counter tracks the right hand side of a logical and operator. +    uint64_t RHSCount = setCount(PGO.getRegionCount(E)); +    CountMap[E->getRHS()] = RHSCount; +    Visit(E->getRHS()); +    setCount(ParentCount + RHSCount - CurrentCount); +    RecordNextStmtCount = true; +  } + +  void VisitBinLOr(const BinaryOperator *E) { +    RecordStmtCount(E); +    uint64_t ParentCount = CurrentCount; +    Visit(E->getLHS()); +    // Counter tracks the right hand side of a logical or operator. +    uint64_t RHSCount = setCount(PGO.getRegionCount(E)); +    CountMap[E->getRHS()] = RHSCount; +    Visit(E->getRHS()); +    setCount(ParentCount + RHSCount - CurrentCount); +    RecordNextStmtCount = true; +  } +};  }  void PGOHash::combine(HashType Type) { @@ -728,12 +719,10 @@ void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {  }  void -CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef FuncName, +CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,                                      llvm::GlobalValue::LinkageTypes Linkage) {    if (SkipCoverageMapping)      return; -  setFuncName(FuncName, Linkage); -    // Don't map the functions inside the system headers    auto Loc = D->getBody()->getLocStart();    if (CGM.getContext().getSourceManager().isInSystemHeader(Loc)) @@ -750,6 +739,7 @@ CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef FuncName,    if (CoverageMapping.empty())      return; +  setFuncName(Name, Linkage);    CGM.getCoverageMapping()->addFunctionMappingRecord(        FuncNameVar, FuncName, FunctionHash, CoverageMapping);  } @@ -785,17 +775,19 @@ CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,      Fn->addFnAttr(llvm::Attribute::Cold);  } -void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) { +void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S) {    if (!CGM.getCodeGenOpts().ProfileInstrGenerate || !RegionCounterMap)      return;    if (!Builder.GetInsertPoint())      return; + +  unsigned Counter = (*RegionCounterMap)[S];    auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); -  Builder.CreateCall4(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment), -                      llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), +  Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment), +                     {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),                        Builder.getInt64(FunctionHash),                        Builder.getInt32(NumRegionCounters), -                      Builder.getInt32(Counter)); +                      Builder.getInt32(Counter)});  }  void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader, @@ -839,8 +831,8 @@ static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {    return Scaled;  } -llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount, -                                              uint64_t FalseCount) { +llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount, +                                                    uint64_t FalseCount) {    // Check for empty weights.    if (!TrueCount && !FalseCount)      return nullptr; @@ -853,7 +845,8 @@ llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,                                        scaleBranchWeight(FalseCount, Scale));  } -llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) { +llvm::MDNode * +CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {    // We need at least two elements to create meaningful weights.    if (Weights.size() < 2)      return nullptr; @@ -875,17 +868,14 @@ llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {    return MDHelper.createBranchWeights(ScaledWeights);  } -llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond, -                                            RegionCounter &Cnt) { -  if (!haveRegionCounts()) +llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond, +                                                           uint64_t LoopCount) { +  if (!PGO.haveRegionCounts())      return nullptr; -  uint64_t LoopCount = Cnt.getCount(); -  uint64_t CondCount = 0; -  bool Found = getStmtCount(Cond, CondCount); -  assert(Found && "missing expected loop condition count"); -  (void)Found; -  if (CondCount == 0) +  Optional<uint64_t> CondCount = PGO.getStmtCount(Cond); +  assert(CondCount.hasValue() && "missing expected loop condition count"); +  if (*CondCount == 0)      return nullptr; -  return createBranchWeights(LoopCount, -                             std::max(CondCount, LoopCount) - LoopCount); +  return createProfileWeights(LoopCount, +                              std::max(*CondCount, LoopCount) - LoopCount);  }  | 
