diff options
Diffstat (limited to 'lib/Transforms')
142 files changed, 17793 insertions, 10977 deletions
diff --git a/lib/Transforms/Coroutines/CoroElide.cpp b/lib/Transforms/Coroutines/CoroElide.cpp index 99974d8da64c..c6ac3f614ff7 100644 --- a/lib/Transforms/Coroutines/CoroElide.cpp +++ b/lib/Transforms/Coroutines/CoroElide.cpp @@ -92,7 +92,7 @@ static void removeTailCallAttribute(AllocaInst *Frame, AAResults &AA) {  // Given a resume function @f.resume(%f.frame* %frame), returns %f.frame type.  static Type *getFrameType(Function *Resume) { -  auto *ArgType = Resume->getArgumentList().front().getType(); +  auto *ArgType = Resume->arg_begin()->getType();    return cast<PointerType>(ArgType)->getElementType();  } @@ -127,7 +127,8 @@ void Lowerer::elideHeapAllocations(Function *F, Type *FrameTy, AAResults &AA) {    // is spilled into the coroutine frame and recreate the alignment information    // here. Possibly we will need to do a mini SROA here and break the coroutine    // frame into individual AllocaInst recreating the original alignment. -  auto *Frame = new AllocaInst(FrameTy, "", InsertPt); +  const DataLayout &DL = F->getParent()->getDataLayout(); +  auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt);    auto *FrameVoidPtr =        new BitCastInst(Frame, Type::getInt8PtrTy(C), "vFrame", InsertPt); diff --git a/lib/Transforms/Coroutines/CoroFrame.cpp b/lib/Transforms/Coroutines/CoroFrame.cpp index bb28558a29e2..19e6789dfa74 100644 --- a/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/lib/Transforms/Coroutines/CoroFrame.cpp @@ -133,6 +133,7 @@ struct SuspendCrossingInfo {  };  } // end anonymous namespace +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)  LLVM_DUMP_METHOD void SuspendCrossingInfo::dump(StringRef Label,                                                  BitVector const &BV) const {    dbgs() << Label << ":"; @@ -151,6 +152,7 @@ LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const {    }    dbgs() << "\n";  } +#endif  SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)      : Mapping(F) { @@ -420,15 +422,31 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) {            report_fatal_error("Coroutines cannot handle non static allocas yet");        } else {          // Otherwise, create a store instruction storing the value into the -        // coroutine frame. For, argument, we will place the store instruction -        // right after the coroutine frame pointer instruction, i.e. bitcase of -        // coro.begin from i8* to %f.frame*. For all other values, the spill is -        // placed immediately after the definition. -        Builder.SetInsertPoint( -            isa<Argument>(CurrentValue) -                ? FramePtr->getNextNode() -                : dyn_cast<Instruction>(E.def())->getNextNode()); +        // coroutine frame. + +        Instruction *InsertPt = nullptr; +        if (isa<Argument>(CurrentValue)) { +          // For arguments, we will place the store instruction right after +          // the coroutine frame pointer instruction, i.e. bitcast of +          // coro.begin from i8* to %f.frame*. +          InsertPt = FramePtr->getNextNode(); +        } else if (auto *II = dyn_cast<InvokeInst>(CurrentValue)) { +          // If we are spilling the result of the invoke instruction, split the +          // normal edge and insert the spill in the new block. +          auto NewBB = SplitEdge(II->getParent(), II->getNormalDest()); +          InsertPt = NewBB->getTerminator(); +        } else if (dyn_cast<PHINode>(CurrentValue)) { +          // Skip the PHINodes and EH pads instructions. +          InsertPt = +              &*cast<Instruction>(E.def())->getParent()->getFirstInsertionPt(); +        } else { +          // For all other values, the spill is placed immediately after +          // the definition. +          assert(!isa<TerminatorInst>(E.def()) && "unexpected terminator"); +          InsertPt = cast<Instruction>(E.def())->getNextNode(); +        } +        Builder.SetInsertPoint(InsertPt);          auto *G = Builder.CreateConstInBoundsGEP2_32(              FrameTy, FramePtr, 0, Index,              CurrentValue->getName() + Twine(".spill.addr")); @@ -484,7 +502,7 @@ static void rewritePHIs(BasicBlock &BB) {    // loop:    //    %n.val = phi i32[%n, %entry], [%inc, %loop]    // -  // It will create: +  // It will create:      //    // loop.from.entry:    //    %n.loop.pre = phi i32 [%n, %entry] @@ -687,13 +705,12 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) {            Spills.emplace_back(&I, U);    // Rewrite materializable instructions to be materialized at the use point. -  std::sort(Spills.begin(), Spills.end());    DEBUG(dump("Materializations", Spills));    rewriteMaterializableInstructions(Builder, Spills);    // Collect the spills for arguments and other not-materializable values.    Spills.clear(); -  for (Argument &A : F.getArgumentList()) +  for (Argument &A : F.args())      for (User *U : A.users())        if (Checker.isDefinitionAcrossSuspend(A, U))          Spills.emplace_back(&A, U); @@ -719,7 +736,6 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) {          Spills.emplace_back(&I, U);        }    } -  std::sort(Spills.begin(), Spills.end());    DEBUG(dump("Spills", Spills));    moveSpillUsesAfterCoroBegin(F, Spills, Shape.CoroBegin);    Shape.FrameTy = buildFrameType(F, Shape, Spills); diff --git a/lib/Transforms/Coroutines/CoroInstr.h b/lib/Transforms/Coroutines/CoroInstr.h index e03cef4bfc46..5c666bdfea1f 100644 --- a/lib/Transforms/Coroutines/CoroInstr.h +++ b/lib/Transforms/Coroutines/CoroInstr.h @@ -23,6 +23,9 @@  // the Coroutine library.  //===----------------------------------------------------------------------===// +#ifndef LLVM_LIB_TRANSFORMS_COROUTINES_COROINSTR_H +#define LLVM_LIB_TRANSFORMS_COROUTINES_COROINSTR_H +  #include "llvm/IR/GlobalVariable.h"  #include "llvm/IR/IntrinsicInst.h" @@ -316,3 +319,5 @@ public:  };  } // End namespace llvm. + +#endif diff --git a/lib/Transforms/Coroutines/CoroSplit.cpp b/lib/Transforms/Coroutines/CoroSplit.cpp index 7a3f4f60bae9..ab648f884c5b 100644 --- a/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/lib/Transforms/Coroutines/CoroSplit.cpp @@ -22,6 +22,7 @@  #include "CoroInternal.h"  #include "llvm/Analysis/CallGraphSCCPass.h"  #include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/InstIterator.h"  #include "llvm/IR/IRBuilder.h"  #include "llvm/IR/LegacyPassManager.h"  #include "llvm/IR/Verifier.h" @@ -144,6 +145,33 @@ static void replaceFallthroughCoroEnd(IntrinsicInst *End,    BB->getTerminator()->eraseFromParent();  } +// In Resumers, we replace unwind coro.end with True to force the immediate +// unwind to caller. +static void replaceUnwindCoroEnds(coro::Shape &Shape, ValueToValueMapTy &VMap) { +  if (Shape.CoroEnds.empty()) +    return; + +  LLVMContext &Context = Shape.CoroEnds.front()->getContext(); +  auto *True = ConstantInt::getTrue(Context); +  for (CoroEndInst *CE : Shape.CoroEnds) { +    if (!CE->isUnwind()) +      continue; + +    auto *NewCE = cast<IntrinsicInst>(VMap[CE]); + +    // If coro.end has an associated bundle, add cleanupret instruction. +    if (auto Bundle = NewCE->getOperandBundle(LLVMContext::OB_funclet)) { +      Value *FromPad = Bundle->Inputs[0]; +      auto *CleanupRet = CleanupReturnInst::Create(FromPad, nullptr, NewCE); +      NewCE->getParent()->splitBasicBlock(NewCE); +      CleanupRet->getParent()->getTerminator()->eraseFromParent(); +    } + +    NewCE->replaceAllUsesWith(True); +    NewCE->eraseFromParent(); +  } +} +  // Rewrite final suspend point handling. We do not use suspend index to  // represent the final suspend point. Instead we zero-out ResumeFnAddr in the  // coroutine frame, since it is undefined behavior to resume a coroutine @@ -157,9 +185,9 @@ static void handleFinalSuspend(IRBuilder<> &Builder, Value *FramePtr,                                 coro::Shape &Shape, SwitchInst *Switch,                                 bool IsDestroy) {    assert(Shape.HasFinalSuspend); -  auto FinalCase = --Switch->case_end(); -  BasicBlock *ResumeBB = FinalCase.getCaseSuccessor(); -  Switch->removeCase(FinalCase); +  auto FinalCaseIt = std::prev(Switch->case_end()); +  BasicBlock *ResumeBB = FinalCaseIt->getCaseSuccessor(); +  Switch->removeCase(FinalCaseIt);    if (IsDestroy) {      BasicBlock *OldSwitchBB = Switch->getParent();      auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch"); @@ -195,7 +223,7 @@ static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape,    // Replace all args with undefs. The buildCoroutineFrame algorithm already    // rewritten access to the args that occurs after suspend points with loads    // and stores to/from the coroutine frame. -  for (Argument &A : F.getArgumentList()) +  for (Argument &A : F.args())      VMap[&A] = UndefValue::get(A.getType());    SmallVector<ReturnInst *, 4> Returns; @@ -216,9 +244,9 @@ static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape,    // Remove old return attributes.    NewF->removeAttributes( -      AttributeSet::ReturnIndex, -      AttributeSet::get( -          NewF->getContext(), AttributeSet::ReturnIndex, +      AttributeList::ReturnIndex, +      AttributeList::get( +          NewF->getContext(), AttributeList::ReturnIndex,            AttributeFuncs::typeIncompatible(NewF->getReturnType())));    // Make AllocaSpillBlock the new entry block. @@ -236,7 +264,7 @@ static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape,    IRBuilder<> Builder(&NewF->getEntryBlock().front());    // Remap frame pointer. -  Argument *NewFramePtr = &NewF->getArgumentList().front(); +  Argument *NewFramePtr = &*NewF->arg_begin();    Value *OldFramePtr = cast<Value>(VMap[Shape.FramePtr]);    NewFramePtr->takeName(OldFramePtr);    OldFramePtr->replaceAllUsesWith(NewFramePtr); @@ -270,9 +298,7 @@ static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape,    // Remove coro.end intrinsics.    replaceFallthroughCoroEnd(Shape.CoroEnds.front(), VMap); -  // FIXME: coming in upcoming patches: -  // replaceUnwindCoroEnds(Shape.CoroEnds, VMap); - +  replaceUnwindCoroEnds(Shape, VMap);    // Eliminate coro.free from the clones, replacing it with 'null' in cleanup,    // to suppress deallocation code.    coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]), @@ -284,8 +310,16 @@ static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape,  }  static void removeCoroEnds(coro::Shape &Shape) { -  for (CoroEndInst *CE : Shape.CoroEnds) +  if (Shape.CoroEnds.empty()) +    return; + +  LLVMContext &Context = Shape.CoroEnds.front()->getContext(); +  auto *False = ConstantInt::getFalse(Context); + +  for (CoroEndInst *CE : Shape.CoroEnds) { +    CE->replaceAllUsesWith(False);      CE->eraseFromParent(); +  }  }  static void replaceFrameSize(coro::Shape &Shape) { diff --git a/lib/Transforms/Coroutines/Coroutines.cpp b/lib/Transforms/Coroutines/Coroutines.cpp index 877ec34b4d3b..ea48043f9381 100644 --- a/lib/Transforms/Coroutines/Coroutines.cpp +++ b/lib/Transforms/Coroutines/Coroutines.cpp @@ -245,9 +245,9 @@ void coro::Shape::buildFrom(Function &F) {            if (CoroBegin)              report_fatal_error(                  "coroutine should have exactly one defining @llvm.coro.begin"); -          CB->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull); -          CB->addAttribute(AttributeSet::ReturnIndex, Attribute::NoAlias); -          CB->removeAttribute(AttributeSet::FunctionIndex, +          CB->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); +          CB->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias); +          CB->removeAttribute(AttributeList::FunctionIndex,                                Attribute::NoDuplicate);            CoroBegin = CB;          } diff --git a/lib/Transforms/IPO/ArgumentPromotion.cpp b/lib/Transforms/IPO/ArgumentPromotion.cpp index 65b7bad3b1ed..a2c8a32dfe86 100644 --- a/lib/Transforms/IPO/ArgumentPromotion.cpp +++ b/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -29,8 +29,9 @@  //  //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/ArgumentPromotion.h"  #include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/Optional.h"  #include "llvm/ADT/Statistic.h"  #include "llvm/ADT/StringExtras.h"  #include "llvm/Analysis/AliasAnalysis.h" @@ -38,6 +39,7 @@  #include "llvm/Analysis/BasicAliasAnalysis.h"  #include "llvm/Analysis/CallGraph.h"  #include "llvm/Analysis/CallGraphSCCPass.h" +#include "llvm/Analysis/LazyCallGraph.h"  #include "llvm/Analysis/Loads.h"  #include "llvm/Analysis/TargetLibraryInfo.h"  #include "llvm/IR/CFG.h" @@ -51,323 +53,400 @@  #include "llvm/IR/Module.h"  #include "llvm/Support/Debug.h"  #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO.h"  #include <set>  using namespace llvm;  #define DEBUG_TYPE "argpromotion" -STATISTIC(NumArgumentsPromoted , "Number of pointer arguments promoted"); +STATISTIC(NumArgumentsPromoted, "Number of pointer arguments promoted");  STATISTIC(NumAggregatesPromoted, "Number of aggregate arguments promoted"); -STATISTIC(NumByValArgsPromoted , "Number of byval arguments promoted"); -STATISTIC(NumArgumentsDead     , "Number of dead pointer args eliminated"); +STATISTIC(NumByValArgsPromoted, "Number of byval arguments promoted"); +STATISTIC(NumArgumentsDead, "Number of dead pointer args eliminated"); -namespace { -  /// ArgPromotion - The 'by reference' to 'by value' argument promotion pass. -  /// -  struct ArgPromotion : public CallGraphSCCPass { -    void getAnalysisUsage(AnalysisUsage &AU) const override { -      AU.addRequired<AssumptionCacheTracker>(); -      AU.addRequired<TargetLibraryInfoWrapperPass>(); -      getAAResultsAnalysisUsage(AU); -      CallGraphSCCPass::getAnalysisUsage(AU); -    } +/// A vector used to hold the indices of a single GEP instruction +typedef std::vector<uint64_t> IndicesVector; -    bool runOnSCC(CallGraphSCC &SCC) override; -    static char ID; // Pass identification, replacement for typeid -    explicit ArgPromotion(unsigned maxElements = 3) -        : CallGraphSCCPass(ID), maxElements(maxElements) { -      initializeArgPromotionPass(*PassRegistry::getPassRegistry()); -    } +/// DoPromotion - This method actually performs the promotion of the specified +/// arguments, and returns the new function.  At this point, we know that it's +/// safe to do so. +static Function * +doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, +            SmallPtrSetImpl<Argument *> &ByValArgsToTransform, +            Optional<function_ref<void(CallSite OldCS, CallSite NewCS)>> +                ReplaceCallSite) { -  private: +  // Start by computing a new prototype for the function, which is the same as +  // the old function, but has modified arguments. +  FunctionType *FTy = F->getFunctionType(); +  std::vector<Type *> Params; -    using llvm::Pass::doInitialization; -    bool doInitialization(CallGraph &CG) override; -    /// The maximum number of elements to expand, or 0 for unlimited. -    unsigned maxElements; -  }; -} +  typedef std::set<std::pair<Type *, IndicesVector>> ScalarizeTable; -/// A vector used to hold the indices of a single GEP instruction -typedef std::vector<uint64_t> IndicesVector; +  // ScalarizedElements - If we are promoting a pointer that has elements +  // accessed out of it, keep track of which elements are accessed so that we +  // can add one argument for each. +  // +  // Arguments that are directly loaded will have a zero element value here, to +  // handle cases where there are both a direct load and GEP accesses. +  // +  std::map<Argument *, ScalarizeTable> ScalarizedElements; -static CallGraphNode * -PromoteArguments(CallGraphNode *CGN, CallGraph &CG, -                 function_ref<AAResults &(Function &F)> AARGetter, -                 unsigned MaxElements); -static bool isDenselyPacked(Type *type, const DataLayout &DL); -static bool canPaddingBeAccessed(Argument *Arg); -static bool isSafeToPromoteArgument(Argument *Arg, bool isByVal, AAResults &AAR, -                                    unsigned MaxElements); -static CallGraphNode * -DoPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, -            SmallPtrSetImpl<Argument *> &ByValArgsToTransform, CallGraph &CG); +  // OriginalLoads - Keep track of a representative load instruction from the +  // original function so that we can tell the alias analysis implementation +  // what the new GEP/Load instructions we are inserting look like. +  // We need to keep the original loads for each argument and the elements +  // of the argument that are accessed. +  std::map<std::pair<Argument *, IndicesVector>, LoadInst *> OriginalLoads; -char ArgPromotion::ID = 0; -INITIALIZE_PASS_BEGIN(ArgPromotion, "argpromotion", -                "Promote 'by reference' arguments to scalars", false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(ArgPromotion, "argpromotion", -                "Promote 'by reference' arguments to scalars", false, false) +  // Attribute - Keep track of the parameter attributes for the arguments +  // that we are *not* promoting. For the ones that we do promote, the parameter +  // attributes are lost +  SmallVector<AttributeSet, 8> ArgAttrVec; +  AttributeList PAL = F->getAttributes(); -Pass *llvm::createArgumentPromotionPass(unsigned maxElements) { -  return new ArgPromotion(maxElements); -} +  // First, determine the new argument list +  unsigned ArgIndex = 0; +  for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E; +       ++I, ++ArgIndex) { +    if (ByValArgsToTransform.count(&*I)) { +      // Simple byval argument? Just add all the struct element types. +      Type *AgTy = cast<PointerType>(I->getType())->getElementType(); +      StructType *STy = cast<StructType>(AgTy); +      Params.insert(Params.end(), STy->element_begin(), STy->element_end()); +      ArgAttrVec.insert(ArgAttrVec.end(), STy->getNumElements(), +                        AttributeSet()); +      ++NumByValArgsPromoted; +    } else if (!ArgsToPromote.count(&*I)) { +      // Unchanged argument +      Params.push_back(I->getType()); +      ArgAttrVec.push_back(PAL.getParamAttributes(ArgIndex)); +    } else if (I->use_empty()) { +      // Dead argument (which are always marked as promotable) +      ++NumArgumentsDead; +    } else { +      // Okay, this is being promoted. This means that the only uses are loads +      // or GEPs which are only used by loads -static bool runImpl(CallGraphSCC &SCC, CallGraph &CG, -                    function_ref<AAResults &(Function &F)> AARGetter, -                    unsigned MaxElements) { -  bool Changed = false, LocalChange; +      // In this table, we will track which indices are loaded from the argument +      // (where direct loads are tracked as no indices). +      ScalarizeTable &ArgIndices = ScalarizedElements[&*I]; +      for (User *U : I->users()) { +        Instruction *UI = cast<Instruction>(U); +        Type *SrcTy; +        if (LoadInst *L = dyn_cast<LoadInst>(UI)) +          SrcTy = L->getType(); +        else +          SrcTy = cast<GetElementPtrInst>(UI)->getSourceElementType(); +        IndicesVector Indices; +        Indices.reserve(UI->getNumOperands() - 1); +        // Since loads will only have a single operand, and GEPs only a single +        // non-index operand, this will record direct loads without any indices, +        // and gep+loads with the GEP indices. +        for (User::op_iterator II = UI->op_begin() + 1, IE = UI->op_end(); +             II != IE; ++II) +          Indices.push_back(cast<ConstantInt>(*II)->getSExtValue()); +        // GEPs with a single 0 index can be merged with direct loads +        if (Indices.size() == 1 && Indices.front() == 0) +          Indices.clear(); +        ArgIndices.insert(std::make_pair(SrcTy, Indices)); +        LoadInst *OrigLoad; +        if (LoadInst *L = dyn_cast<LoadInst>(UI)) +          OrigLoad = L; +        else +          // Take any load, we will use it only to update Alias Analysis +          OrigLoad = cast<LoadInst>(UI->user_back()); +        OriginalLoads[std::make_pair(&*I, Indices)] = OrigLoad; +      } -  do {  // Iterate until we stop promoting from this SCC. -    LocalChange = false; -    // Attempt to promote arguments from all functions in this SCC. -    for (CallGraphNode *OldNode : SCC) { -      if (CallGraphNode *NewNode = -              PromoteArguments(OldNode, CG, AARGetter, MaxElements)) { -        LocalChange = true; -        SCC.ReplaceNode(OldNode, NewNode); +      // Add a parameter to the function for each element passed in. +      for (const auto &ArgIndex : ArgIndices) { +        // not allowed to dereference ->begin() if size() is 0 +        Params.push_back(GetElementPtrInst::getIndexedType( +            cast<PointerType>(I->getType()->getScalarType())->getElementType(), +            ArgIndex.second)); +        ArgAttrVec.push_back(AttributeSet()); +        assert(Params.back());        } + +      if (ArgIndices.size() == 1 && ArgIndices.begin()->second.empty()) +        ++NumArgumentsPromoted; +      else +        ++NumAggregatesPromoted;      } -    Changed |= LocalChange;               // Remember that we changed something. -  } while (LocalChange); -   -  return Changed; -} +  } -bool ArgPromotion::runOnSCC(CallGraphSCC &SCC) { -  if (skipSCC(SCC)) -    return false; +  Type *RetTy = FTy->getReturnType(); -  // Get the callgraph information that we need to update to reflect our -  // changes. -  CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); +  // Construct the new function type using the new arguments. +  FunctionType *NFTy = FunctionType::get(RetTy, Params, FTy->isVarArg()); -  // We compute dedicated AA results for each function in the SCC as needed. We -  // use a lambda referencing external objects so that they live long enough to -  // be queried, but we re-use them each time. -  Optional<BasicAAResult> BAR; -  Optional<AAResults> AAR; -  auto AARGetter = [&](Function &F) -> AAResults & { -    BAR.emplace(createLegacyPMBasicAAResult(*this, F)); -    AAR.emplace(createLegacyPMAAResults(*this, F, *BAR)); -    return *AAR; -  }; - -  return runImpl(SCC, CG, AARGetter, maxElements); -} +  // Create the new function body and insert it into the module. +  Function *NF = Function::Create(NFTy, F->getLinkage(), F->getName()); +  NF->copyAttributesFrom(F); -/// \brief Checks if a type could have padding bytes. -static bool isDenselyPacked(Type *type, const DataLayout &DL) { +  // Patch the pointer to LLVM function in debug info descriptor. +  NF->setSubprogram(F->getSubprogram()); +  F->setSubprogram(nullptr); -  // There is no size information, so be conservative. -  if (!type->isSized()) -    return false; +  DEBUG(dbgs() << "ARG PROMOTION:  Promoting to:" << *NF << "\n" +               << "From: " << *F); -  // If the alloc size is not equal to the storage size, then there are padding -  // bytes. For x86_fp80 on x86-64, size: 80 alloc size: 128. -  if (DL.getTypeSizeInBits(type) != DL.getTypeAllocSizeInBits(type)) -    return false; +  // Recompute the parameter attributes list based on the new arguments for +  // the function. +  NF->setAttributes(AttributeList::get(F->getContext(), PAL.getFnAttributes(), +                                       PAL.getRetAttributes(), ArgAttrVec)); +  ArgAttrVec.clear(); -  if (!isa<CompositeType>(type)) -    return true; +  F->getParent()->getFunctionList().insert(F->getIterator(), NF); +  NF->takeName(F); -  // For homogenous sequential types, check for padding within members. -  if (SequentialType *seqTy = dyn_cast<SequentialType>(type)) -    return isDenselyPacked(seqTy->getElementType(), DL); +  // Loop over all of the callers of the function, transforming the call sites +  // to pass in the loaded pointers. +  // +  SmallVector<Value *, 16> Args; +  while (!F->use_empty()) { +    CallSite CS(F->user_back()); +    assert(CS.getCalledFunction() == F); +    Instruction *Call = CS.getInstruction(); +    const AttributeList &CallPAL = CS.getAttributes(); -  // Check for padding within and between elements of a struct. -  StructType *StructTy = cast<StructType>(type); -  const StructLayout *Layout = DL.getStructLayout(StructTy); -  uint64_t StartPos = 0; -  for (unsigned i = 0, E = StructTy->getNumElements(); i < E; ++i) { -    Type *ElTy = StructTy->getElementType(i); -    if (!isDenselyPacked(ElTy, DL)) -      return false; -    if (StartPos != Layout->getElementOffsetInBits(i)) -      return false; -    StartPos += DL.getTypeAllocSizeInBits(ElTy); -  } +    // Loop over the operands, inserting GEP and loads in the caller as +    // appropriate. +    CallSite::arg_iterator AI = CS.arg_begin(); +    ArgIndex = 1; +    for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E; +         ++I, ++AI, ++ArgIndex) +      if (!ArgsToPromote.count(&*I) && !ByValArgsToTransform.count(&*I)) { +        Args.push_back(*AI); // Unmodified argument +        ArgAttrVec.push_back(CallPAL.getAttributes(ArgIndex)); +      } else if (ByValArgsToTransform.count(&*I)) { +        // Emit a GEP and load for each element of the struct. +        Type *AgTy = cast<PointerType>(I->getType())->getElementType(); +        StructType *STy = cast<StructType>(AgTy); +        Value *Idxs[2] = { +            ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), nullptr}; +        for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { +          Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i); +          Value *Idx = GetElementPtrInst::Create( +              STy, *AI, Idxs, (*AI)->getName() + "." + Twine(i), Call); +          // TODO: Tell AA about the new values? +          Args.push_back(new LoadInst(Idx, Idx->getName() + ".val", Call)); +          ArgAttrVec.push_back(AttributeSet()); +        } +      } else if (!I->use_empty()) { +        // Non-dead argument: insert GEPs and loads as appropriate. +        ScalarizeTable &ArgIndices = ScalarizedElements[&*I]; +        // Store the Value* version of the indices in here, but declare it now +        // for reuse. +        std::vector<Value *> Ops; +        for (const auto &ArgIndex : ArgIndices) { +          Value *V = *AI; +          LoadInst *OrigLoad = +              OriginalLoads[std::make_pair(&*I, ArgIndex.second)]; +          if (!ArgIndex.second.empty()) { +            Ops.reserve(ArgIndex.second.size()); +            Type *ElTy = V->getType(); +            for (unsigned long II : ArgIndex.second) { +              // Use i32 to index structs, and i64 for others (pointers/arrays). +              // This satisfies GEP constraints. +              Type *IdxTy = +                  (ElTy->isStructTy() ? Type::getInt32Ty(F->getContext()) +                                      : Type::getInt64Ty(F->getContext())); +              Ops.push_back(ConstantInt::get(IdxTy, II)); +              // Keep track of the type we're currently indexing. +              if (auto *ElPTy = dyn_cast<PointerType>(ElTy)) +                ElTy = ElPTy->getElementType(); +              else +                ElTy = cast<CompositeType>(ElTy)->getTypeAtIndex(II); +            } +            // And create a GEP to extract those indices. +            V = GetElementPtrInst::Create(ArgIndex.first, V, Ops, +                                          V->getName() + ".idx", Call); +            Ops.clear(); +          } +          // Since we're replacing a load make sure we take the alignment +          // of the previous load. +          LoadInst *newLoad = new LoadInst(V, V->getName() + ".val", Call); +          newLoad->setAlignment(OrigLoad->getAlignment()); +          // Transfer the AA info too. +          AAMDNodes AAInfo; +          OrigLoad->getAAMetadata(AAInfo); +          newLoad->setAAMetadata(AAInfo); -  return true; -} +          Args.push_back(newLoad); +          ArgAttrVec.push_back(AttributeSet()); +        } +      } -/// \brief Checks if the padding bytes of an argument could be accessed. -static bool canPaddingBeAccessed(Argument *arg) { +    // Push any varargs arguments on the list. +    for (; AI != CS.arg_end(); ++AI, ++ArgIndex) { +      Args.push_back(*AI); +      ArgAttrVec.push_back(CallPAL.getAttributes(ArgIndex)); +    } -  assert(arg->hasByValAttr()); +    SmallVector<OperandBundleDef, 1> OpBundles; +    CS.getOperandBundlesAsDefs(OpBundles); -  // Track all the pointers to the argument to make sure they are not captured. -  SmallPtrSet<Value *, 16> PtrValues; -  PtrValues.insert(arg); +    CallSite NewCS; +    if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) { +      NewCS = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(), +                                 Args, OpBundles, "", Call); +    } else { +      auto *NewCall = CallInst::Create(NF, Args, OpBundles, "", Call); +      NewCall->setTailCallKind(cast<CallInst>(Call)->getTailCallKind()); +      NewCS = NewCall; +    } +    NewCS.setCallingConv(CS.getCallingConv()); +    NewCS.setAttributes( +        AttributeList::get(F->getContext(), CallPAL.getFnAttributes(), +                           CallPAL.getRetAttributes(), ArgAttrVec)); +    NewCS->setDebugLoc(Call->getDebugLoc()); +    uint64_t W; +    if (Call->extractProfTotalWeight(W)) +      NewCS->setProfWeight(W); +    Args.clear(); +    ArgAttrVec.clear(); -  // Track all of the stores. -  SmallVector<StoreInst *, 16> Stores; +    // Update the callgraph to know that the callsite has been transformed. +    if (ReplaceCallSite) +      (*ReplaceCallSite)(CS, NewCS); -  // Scan through the uses recursively to make sure the pointer is always used -  // sanely. -  SmallVector<Value *, 16> WorkList; -  WorkList.insert(WorkList.end(), arg->user_begin(), arg->user_end()); -  while (!WorkList.empty()) { -    Value *V = WorkList.back(); -    WorkList.pop_back(); -    if (isa<GetElementPtrInst>(V) || isa<PHINode>(V)) { -      if (PtrValues.insert(V).second) -        WorkList.insert(WorkList.end(), V->user_begin(), V->user_end()); -    } else if (StoreInst *Store = dyn_cast<StoreInst>(V)) { -      Stores.push_back(Store); -    } else if (!isa<LoadInst>(V)) { -      return true; +    if (!Call->use_empty()) { +      Call->replaceAllUsesWith(NewCS.getInstruction()); +      NewCS->takeName(Call);      } -  } -// Check to make sure the pointers aren't captured -  for (StoreInst *Store : Stores) -    if (PtrValues.count(Store->getValueOperand())) -      return true; - -  return false; -} +    // Finally, remove the old call from the program, reducing the use-count of +    // F. +    Call->eraseFromParent(); +  } -/// PromoteArguments - This method checks the specified function to see if there -/// are any promotable arguments and if it is safe to promote the function (for -/// example, all callers are direct).  If safe to promote some arguments, it -/// calls the DoPromotion method. -/// -static CallGraphNode * -PromoteArguments(CallGraphNode *CGN, CallGraph &CG, -                 function_ref<AAResults &(Function &F)> AARGetter, -                 unsigned MaxElements) { -  Function *F = CGN->getFunction(); +  const DataLayout &DL = F->getParent()->getDataLayout(); -  // Make sure that it is local to this module. -  if (!F || !F->hasLocalLinkage()) return nullptr; +  // Since we have now created the new function, splice the body of the old +  // function right into the new function, leaving the old rotting hulk of the +  // function empty. +  NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList()); -  // Don't promote arguments for variadic functions. Adding, removing, or -  // changing non-pack parameters can change the classification of pack -  // parameters. Frontends encode that classification at the call site in the -  // IR, while in the callee the classification is determined dynamically based -  // on the number of registers consumed so far. -  if (F->isVarArg()) return nullptr; +  // Loop over the argument list, transferring uses of the old arguments over to +  // the new arguments, also transferring over the names as well. +  // +  for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(), +                              I2 = NF->arg_begin(); +       I != E; ++I) { +    if (!ArgsToPromote.count(&*I) && !ByValArgsToTransform.count(&*I)) { +      // If this is an unmodified argument, move the name and users over to the +      // new version. +      I->replaceAllUsesWith(&*I2); +      I2->takeName(&*I); +      ++I2; +      continue; +    } -  // First check: see if there are any pointer arguments!  If not, quick exit. -  SmallVector<Argument*, 16> PointerArgs; -  for (Argument &I : F->args()) -    if (I.getType()->isPointerTy()) -      PointerArgs.push_back(&I); -  if (PointerArgs.empty()) return nullptr; +    if (ByValArgsToTransform.count(&*I)) { +      // In the callee, we create an alloca, and store each of the new incoming +      // arguments into the alloca. +      Instruction *InsertPt = &NF->begin()->front(); -  // Second check: make sure that all callers are direct callers.  We can't -  // transform functions that have indirect callers.  Also see if the function -  // is self-recursive. -  bool isSelfRecursive = false; -  for (Use &U : F->uses()) { -    CallSite CS(U.getUser()); -    // Must be a direct call. -    if (CS.getInstruction() == nullptr || !CS.isCallee(&U)) return nullptr; -     -    if (CS.getInstruction()->getParent()->getParent() == F) -      isSelfRecursive = true; -  } -   -  const DataLayout &DL = F->getParent()->getDataLayout(); +      // Just add all the struct element types. +      Type *AgTy = cast<PointerType>(I->getType())->getElementType(); +      Value *TheAlloca = new AllocaInst(AgTy, DL.getAllocaAddrSpace(), nullptr, +                                        "", InsertPt); +      StructType *STy = cast<StructType>(AgTy); +      Value *Idxs[2] = {ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), +                        nullptr}; -  AAResults &AAR = AARGetter(*F); +      for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { +        Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i); +        Value *Idx = GetElementPtrInst::Create( +            AgTy, TheAlloca, Idxs, TheAlloca->getName() + "." + Twine(i), +            InsertPt); +        I2->setName(I->getName() + "." + Twine(i)); +        new StoreInst(&*I2++, Idx, InsertPt); +      } -  // Check to see which arguments are promotable.  If an argument is promotable, -  // add it to ArgsToPromote. -  SmallPtrSet<Argument*, 8> ArgsToPromote; -  SmallPtrSet<Argument*, 8> ByValArgsToTransform; -  for (Argument *PtrArg : PointerArgs) { -    Type *AgTy = cast<PointerType>(PtrArg->getType())->getElementType(); +      // Anything that used the arg should now use the alloca. +      I->replaceAllUsesWith(TheAlloca); +      TheAlloca->takeName(&*I); -    // Replace sret attribute with noalias. This reduces register pressure by -    // avoiding a register copy. -    if (PtrArg->hasStructRetAttr()) { -      unsigned ArgNo = PtrArg->getArgNo(); -      F->setAttributes( -          F->getAttributes() -              .removeAttribute(F->getContext(), ArgNo + 1, Attribute::StructRet) -              .addAttribute(F->getContext(), ArgNo + 1, Attribute::NoAlias)); -      for (Use &U : F->uses()) { -        CallSite CS(U.getUser()); -        CS.setAttributes( -            CS.getAttributes() -                .removeAttribute(F->getContext(), ArgNo + 1, -                                 Attribute::StructRet) -                .addAttribute(F->getContext(), ArgNo + 1, Attribute::NoAlias)); +      // If the alloca is used in a call, we must clear the tail flag since +      // the callee now uses an alloca from the caller. +      for (User *U : TheAlloca->users()) { +        CallInst *Call = dyn_cast<CallInst>(U); +        if (!Call) +          continue; +        Call->setTailCall(false);        } +      continue;      } -    // If this is a byval argument, and if the aggregate type is small, just -    // pass the elements, which is always safe, if the passed value is densely -    // packed or if we can prove the padding bytes are never accessed. This does -    // not apply to inalloca. -    bool isSafeToPromote = -        PtrArg->hasByValAttr() && -        (isDenselyPacked(AgTy, DL) || !canPaddingBeAccessed(PtrArg)); -    if (isSafeToPromote) { -      if (StructType *STy = dyn_cast<StructType>(AgTy)) { -        if (MaxElements > 0 && STy->getNumElements() > MaxElements) { -          DEBUG(dbgs() << "argpromotion disable promoting argument '" -                << PtrArg->getName() << "' because it would require adding more" -                << " than " << MaxElements << " arguments to the function.\n"); -          continue; -        } -         -        // If all the elements are single-value types, we can promote it. -        bool AllSimple = true; -        for (const auto *EltTy : STy->elements()) { -          if (!EltTy->isSingleValueType()) { -            AllSimple = false; -            break; -          } +    if (I->use_empty()) +      continue; + +    // Otherwise, if we promoted this argument, then all users are load +    // instructions (or GEPs with only load users), and all loads should be +    // using the new argument that we added. +    ScalarizeTable &ArgIndices = ScalarizedElements[&*I]; + +    while (!I->use_empty()) { +      if (LoadInst *LI = dyn_cast<LoadInst>(I->user_back())) { +        assert(ArgIndices.begin()->second.empty() && +               "Load element should sort to front!"); +        I2->setName(I->getName() + ".val"); +        LI->replaceAllUsesWith(&*I2); +        LI->eraseFromParent(); +        DEBUG(dbgs() << "*** Promoted load of argument '" << I->getName() +                     << "' in function '" << F->getName() << "'\n"); +      } else { +        GetElementPtrInst *GEP = cast<GetElementPtrInst>(I->user_back()); +        IndicesVector Operands; +        Operands.reserve(GEP->getNumIndices()); +        for (User::op_iterator II = GEP->idx_begin(), IE = GEP->idx_end(); +             II != IE; ++II) +          Operands.push_back(cast<ConstantInt>(*II)->getSExtValue()); + +        // GEPs with a single 0 index can be merged with direct loads +        if (Operands.size() == 1 && Operands.front() == 0) +          Operands.clear(); + +        Function::arg_iterator TheArg = I2; +        for (ScalarizeTable::iterator It = ArgIndices.begin(); +             It->second != Operands; ++It, ++TheArg) { +          assert(It != ArgIndices.end() && "GEP not handled??");          } -        // Safe to transform, don't even bother trying to "promote" it. -        // Passing the elements as a scalar will allow sroa to hack on -        // the new alloca we introduce. -        if (AllSimple) { -          ByValArgsToTransform.insert(PtrArg); -          continue; +        std::string NewName = I->getName(); +        for (unsigned i = 0, e = Operands.size(); i != e; ++i) { +          NewName += "." + utostr(Operands[i]);          } -      } -    } +        NewName += ".val"; +        TheArg->setName(NewName); -    // If the argument is a recursive type and we're in a recursive -    // function, we could end up infinitely peeling the function argument. -    if (isSelfRecursive) { -      if (StructType *STy = dyn_cast<StructType>(AgTy)) { -        bool RecursiveType = false; -        for (const auto *EltTy : STy->elements()) { -          if (EltTy == PtrArg->getType()) { -            RecursiveType = true; -            break; -          } +        DEBUG(dbgs() << "*** Promoted agg argument '" << TheArg->getName() +                     << "' of function '" << NF->getName() << "'\n"); + +        // All of the uses must be load instructions.  Replace them all with +        // the argument specified by ArgNo. +        while (!GEP->use_empty()) { +          LoadInst *L = cast<LoadInst>(GEP->user_back()); +          L->replaceAllUsesWith(&*TheArg); +          L->eraseFromParent();          } -        if (RecursiveType) -          continue; +        GEP->eraseFromParent();        }      } -     -    // Otherwise, see if we can promote the pointer to its value. -    if (isSafeToPromoteArgument(PtrArg, PtrArg->hasByValOrInAllocaAttr(), AAR, -                                MaxElements)) -      ArgsToPromote.insert(PtrArg); -  } -  // No promotable pointer arguments. -  if (ArgsToPromote.empty() && ByValArgsToTransform.empty())  -    return nullptr; +    // Increment I2 past all of the arguments added for this promoted pointer. +    std::advance(I2, ArgIndices.size()); +  } -  return DoPromotion(F, ArgsToPromote, ByValArgsToTransform, CG); +  return NF;  }  /// AllCallersPassInValidPointerForArgument - Return true if we can prove that  /// all callees pass in a valid pointer for the specified function argument. -static bool AllCallersPassInValidPointerForArgument(Argument *Arg) { +static bool allCallersPassInValidPointerForArgument(Argument *Arg) {    Function *Callee = Arg->getParent();    const DataLayout &DL = Callee->getParent()->getDataLayout(); @@ -390,26 +469,25 @@ static bool AllCallersPassInValidPointerForArgument(Argument *Arg) {  /// elements in Prefix is the same as the corresponding elements in Longer.  ///  /// This means it also returns true when Prefix and Longer are equal! -static bool IsPrefix(const IndicesVector &Prefix, const IndicesVector &Longer) { +static bool isPrefix(const IndicesVector &Prefix, const IndicesVector &Longer) {    if (Prefix.size() > Longer.size())      return false;    return std::equal(Prefix.begin(), Prefix.end(), Longer.begin());  } -  /// Checks if Indices, or a prefix of Indices, is in Set. -static bool PrefixIn(const IndicesVector &Indices, +static bool prefixIn(const IndicesVector &Indices,                       std::set<IndicesVector> &Set) { -    std::set<IndicesVector>::iterator Low; -    Low = Set.upper_bound(Indices); -    if (Low != Set.begin()) -      Low--; -    // Low is now the last element smaller than or equal to Indices. This means -    // it points to a prefix of Indices (possibly Indices itself), if such -    // prefix exists. -    // -    // This load is safe if any prefix of its operands is safe to load. -    return Low != Set.end() && IsPrefix(*Low, Indices); +  std::set<IndicesVector>::iterator Low; +  Low = Set.upper_bound(Indices); +  if (Low != Set.begin()) +    Low--; +  // Low is now the last element smaller than or equal to Indices. This means +  // it points to a prefix of Indices (possibly Indices itself), if such +  // prefix exists. +  // +  // This load is safe if any prefix of its operands is safe to load. +  return Low != Set.end() && isPrefix(*Low, Indices);  }  /// Mark the given indices (ToMark) as safe in the given set of indices @@ -417,7 +495,7 @@ static bool PrefixIn(const IndicesVector &Indices,  /// is already a prefix of Indices in Safe, Indices are implicitely marked safe  /// already. Furthermore, any indices that Indices is itself a prefix of, are  /// removed from Safe (since they are implicitely safe because of Indices now). -static void MarkIndicesSafe(const IndicesVector &ToMark, +static void markIndicesSafe(const IndicesVector &ToMark,                              std::set<IndicesVector> &Safe) {    std::set<IndicesVector>::iterator Low;    Low = Safe.upper_bound(ToMark); @@ -428,7 +506,7 @@ static void MarkIndicesSafe(const IndicesVector &ToMark,    // means it points to a prefix of Indices (possibly Indices itself), if    // such prefix exists.    if (Low != Safe.end()) { -    if (IsPrefix(*Low, ToMark)) +    if (isPrefix(*Low, ToMark))        // If there is already a prefix of these indices (or exactly these        // indices) marked a safe, don't bother adding these indices        return; @@ -441,7 +519,7 @@ static void MarkIndicesSafe(const IndicesVector &ToMark,    ++Low;    // If there we're a prefix of longer index list(s), remove those    std::set<IndicesVector>::iterator End = Safe.end(); -  while (Low != End && IsPrefix(ToMark, *Low)) { +  while (Low != End && isPrefix(ToMark, *Low)) {      std::set<IndicesVector>::iterator Remove = Low;      ++Low;      Safe.erase(Remove); @@ -486,7 +564,7 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca,    GEPIndicesSet ToPromote;    // If the pointer is always valid, any load with first index 0 is valid. -  if (isByValOrInAlloca || AllCallersPassInValidPointerForArgument(Arg)) +  if (isByValOrInAlloca || allCallersPassInValidPointerForArgument(Arg))      SafeToUnconditionallyLoad.insert(IndicesVector(1, 0));    // First, iterate the entry block and mark loads of (geps of) arguments as @@ -512,25 +590,26 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca,                return false;            // Indices checked out, mark them as safe -          MarkIndicesSafe(Indices, SafeToUnconditionallyLoad); +          markIndicesSafe(Indices, SafeToUnconditionallyLoad);            Indices.clear();          }        } else if (V == Arg) {          // Direct loads are equivalent to a GEP with a single 0 index. -        MarkIndicesSafe(IndicesVector(1, 0), SafeToUnconditionallyLoad); +        markIndicesSafe(IndicesVector(1, 0), SafeToUnconditionallyLoad);        }      }    // Now, iterate all uses of the argument to see if there are any uses that are    // not (GEP+)loads, or any (GEP+)loads that are not safe to promote. -  SmallVector<LoadInst*, 16> Loads; +  SmallVector<LoadInst *, 16> Loads;    IndicesVector Operands;    for (Use &U : Arg->uses()) {      User *UR = U.getUser();      Operands.clear();      if (LoadInst *LI = dyn_cast<LoadInst>(UR)) {        // Don't hack volatile/atomic loads -      if (!LI->isSimple()) return false; +      if (!LI->isSimple()) +        return false;        Loads.push_back(LI);        // Direct loads are equivalent to a GEP with a zero index and then a load.        Operands.push_back(0); @@ -547,30 +626,31 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca,        }        // Ensure that all of the indices are constants. -      for (User::op_iterator i = GEP->idx_begin(), e = GEP->idx_end(); -        i != e; ++i) +      for (User::op_iterator i = GEP->idx_begin(), e = GEP->idx_end(); i != e; +           ++i)          if (ConstantInt *C = dyn_cast<ConstantInt>(*i))            Operands.push_back(C->getSExtValue());          else -          return false;  // Not a constant operand GEP! +          return false; // Not a constant operand GEP!        // Ensure that the only users of the GEP are load instructions.        for (User *GEPU : GEP->users())          if (LoadInst *LI = dyn_cast<LoadInst>(GEPU)) {            // Don't hack volatile/atomic loads -          if (!LI->isSimple()) return false; +          if (!LI->isSimple()) +            return false;            Loads.push_back(LI);          } else {            // Other uses than load?            return false;          }      } else { -      return false;  // Not a load or a GEP. +      return false; // Not a load or a GEP.      }      // Now, see if it is safe to promote this load / loads of this GEP. Loading      // is safe if Operands, or a prefix of Operands, is marked as safe. -    if (!PrefixIn(Operands, SafeToUnconditionallyLoad)) +    if (!prefixIn(Operands, SafeToUnconditionallyLoad))        return false;      // See if we are already promoting a load with these indices. If not, check @@ -579,8 +659,10 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca,      if (ToPromote.find(Operands) == ToPromote.end()) {        if (MaxElements > 0 && ToPromote.size() == MaxElements) {          DEBUG(dbgs() << "argpromotion not promoting argument '" -              << Arg->getName() << "' because it would require adding more " -              << "than " << MaxElements << " arguments to the function.\n"); +                     << Arg->getName() +                     << "' because it would require adding more " +                     << "than " << MaxElements +                     << " arguments to the function.\n");          // We limit aggregate promotion to only promoting up to a fixed number          // of elements of the aggregate.          return false; @@ -589,7 +671,8 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca,      }    } -  if (Loads.empty()) return true;  // No users, this is a dead argument. +  if (Loads.empty()) +    return true; // No users, this is a dead argument.    // Okay, now we know that the argument is only used by load instructions and    // it is safe to unconditionally perform all of them. Use alias analysis to @@ -598,7 +681,7 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca,    // Because there could be several/many load instructions, remember which    // blocks we know to be transparent to the load. -  df_iterator_default_set<BasicBlock*, 16> TranspBlocks; +  df_iterator_default_set<BasicBlock *, 16> TranspBlocks;    for (LoadInst *Load : Loads) {      // Check to see if the load is invalidated from the start of the block to @@ -607,7 +690,7 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca,      MemoryLocation Loc = MemoryLocation::get(Load);      if (AAR.canInstructionRangeModRef(BB->front(), *Load, Loc, MRI_Mod)) -      return false;  // Pointer is invalidated! +      return false; // Pointer is invalidated!      // Now check every path from the entry block to the load for transparency.      // To do this, we perform a depth first search on the inverse CFG from the @@ -625,416 +708,352 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca,    return true;  } -/// DoPromotion - This method actually performs the promotion of the specified -/// arguments, and returns the new function.  At this point, we know that it's -/// safe to do so. -static CallGraphNode * -DoPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, -            SmallPtrSetImpl<Argument *> &ByValArgsToTransform, CallGraph &CG) { +/// \brief Checks if a type could have padding bytes. +static bool isDenselyPacked(Type *type, const DataLayout &DL) { -  // Start by computing a new prototype for the function, which is the same as -  // the old function, but has modified arguments. -  FunctionType *FTy = F->getFunctionType(); -  std::vector<Type*> Params; +  // There is no size information, so be conservative. +  if (!type->isSized()) +    return false; -  typedef std::set<std::pair<Type *, IndicesVector>> ScalarizeTable; +  // If the alloc size is not equal to the storage size, then there are padding +  // bytes. For x86_fp80 on x86-64, size: 80 alloc size: 128. +  if (DL.getTypeSizeInBits(type) != DL.getTypeAllocSizeInBits(type)) +    return false; -  // ScalarizedElements - If we are promoting a pointer that has elements -  // accessed out of it, keep track of which elements are accessed so that we -  // can add one argument for each. -  // -  // Arguments that are directly loaded will have a zero element value here, to -  // handle cases where there are both a direct load and GEP accesses. -  // -  std::map<Argument*, ScalarizeTable> ScalarizedElements; +  if (!isa<CompositeType>(type)) +    return true; -  // OriginalLoads - Keep track of a representative load instruction from the -  // original function so that we can tell the alias analysis implementation -  // what the new GEP/Load instructions we are inserting look like. -  // We need to keep the original loads for each argument and the elements -  // of the argument that are accessed. -  std::map<std::pair<Argument*, IndicesVector>, LoadInst*> OriginalLoads; +  // For homogenous sequential types, check for padding within members. +  if (SequentialType *seqTy = dyn_cast<SequentialType>(type)) +    return isDenselyPacked(seqTy->getElementType(), DL); -  // Attribute - Keep track of the parameter attributes for the arguments -  // that we are *not* promoting. For the ones that we do promote, the parameter -  // attributes are lost -  SmallVector<AttributeSet, 8> AttributesVec; -  const AttributeSet &PAL = F->getAttributes(); +  // Check for padding within and between elements of a struct. +  StructType *StructTy = cast<StructType>(type); +  const StructLayout *Layout = DL.getStructLayout(StructTy); +  uint64_t StartPos = 0; +  for (unsigned i = 0, E = StructTy->getNumElements(); i < E; ++i) { +    Type *ElTy = StructTy->getElementType(i); +    if (!isDenselyPacked(ElTy, DL)) +      return false; +    if (StartPos != Layout->getElementOffsetInBits(i)) +      return false; +    StartPos += DL.getTypeAllocSizeInBits(ElTy); +  } -  // Add any return attributes. -  if (PAL.hasAttributes(AttributeSet::ReturnIndex)) -    AttributesVec.push_back(AttributeSet::get(F->getContext(), -                                              PAL.getRetAttributes())); +  return true; +} -  // First, determine the new argument list -  unsigned ArgIndex = 1; -  for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E; -       ++I, ++ArgIndex) { -    if (ByValArgsToTransform.count(&*I)) { -      // Simple byval argument? Just add all the struct element types. -      Type *AgTy = cast<PointerType>(I->getType())->getElementType(); -      StructType *STy = cast<StructType>(AgTy); -      Params.insert(Params.end(), STy->element_begin(), STy->element_end()); -      ++NumByValArgsPromoted; -    } else if (!ArgsToPromote.count(&*I)) { -      // Unchanged argument -      Params.push_back(I->getType()); -      AttributeSet attrs = PAL.getParamAttributes(ArgIndex); -      if (attrs.hasAttributes(ArgIndex)) { -        AttrBuilder B(attrs, ArgIndex); -        AttributesVec. -          push_back(AttributeSet::get(F->getContext(), Params.size(), B)); -      } -    } else if (I->use_empty()) { -      // Dead argument (which are always marked as promotable) -      ++NumArgumentsDead; -    } else { -      // Okay, this is being promoted. This means that the only uses are loads -      // or GEPs which are only used by loads +/// \brief Checks if the padding bytes of an argument could be accessed. +static bool canPaddingBeAccessed(Argument *arg) { -      // In this table, we will track which indices are loaded from the argument -      // (where direct loads are tracked as no indices). -      ScalarizeTable &ArgIndices = ScalarizedElements[&*I]; -      for (User *U : I->users()) { -        Instruction *UI = cast<Instruction>(U); -        Type *SrcTy; -        if (LoadInst *L = dyn_cast<LoadInst>(UI)) -          SrcTy = L->getType(); -        else -          SrcTy = cast<GetElementPtrInst>(UI)->getSourceElementType(); -        IndicesVector Indices; -        Indices.reserve(UI->getNumOperands() - 1); -        // Since loads will only have a single operand, and GEPs only a single -        // non-index operand, this will record direct loads without any indices, -        // and gep+loads with the GEP indices. -        for (User::op_iterator II = UI->op_begin() + 1, IE = UI->op_end(); -             II != IE; ++II) -          Indices.push_back(cast<ConstantInt>(*II)->getSExtValue()); -        // GEPs with a single 0 index can be merged with direct loads -        if (Indices.size() == 1 && Indices.front() == 0) -          Indices.clear(); -        ArgIndices.insert(std::make_pair(SrcTy, Indices)); -        LoadInst *OrigLoad; -        if (LoadInst *L = dyn_cast<LoadInst>(UI)) -          OrigLoad = L; -        else -          // Take any load, we will use it only to update Alias Analysis -          OrigLoad = cast<LoadInst>(UI->user_back()); -        OriginalLoads[std::make_pair(&*I, Indices)] = OrigLoad; -      } +  assert(arg->hasByValAttr()); -      // Add a parameter to the function for each element passed in. -      for (const auto &ArgIndex : ArgIndices) { -        // not allowed to dereference ->begin() if size() is 0 -        Params.push_back(GetElementPtrInst::getIndexedType( -            cast<PointerType>(I->getType()->getScalarType())->getElementType(), -            ArgIndex.second)); -        assert(Params.back()); -      } +  // Track all the pointers to the argument to make sure they are not captured. +  SmallPtrSet<Value *, 16> PtrValues; +  PtrValues.insert(arg); -      if (ArgIndices.size() == 1 && ArgIndices.begin()->second.empty()) -        ++NumArgumentsPromoted; -      else -        ++NumAggregatesPromoted; +  // Track all of the stores. +  SmallVector<StoreInst *, 16> Stores; + +  // Scan through the uses recursively to make sure the pointer is always used +  // sanely. +  SmallVector<Value *, 16> WorkList; +  WorkList.insert(WorkList.end(), arg->user_begin(), arg->user_end()); +  while (!WorkList.empty()) { +    Value *V = WorkList.back(); +    WorkList.pop_back(); +    if (isa<GetElementPtrInst>(V) || isa<PHINode>(V)) { +      if (PtrValues.insert(V).second) +        WorkList.insert(WorkList.end(), V->user_begin(), V->user_end()); +    } else if (StoreInst *Store = dyn_cast<StoreInst>(V)) { +      Stores.push_back(Store); +    } else if (!isa<LoadInst>(V)) { +      return true;      }    } -  // Add any function attributes. -  if (PAL.hasAttributes(AttributeSet::FunctionIndex)) -    AttributesVec.push_back(AttributeSet::get(FTy->getContext(), -                                              PAL.getFnAttributes())); +  // Check to make sure the pointers aren't captured +  for (StoreInst *Store : Stores) +    if (PtrValues.count(Store->getValueOperand())) +      return true; -  Type *RetTy = FTy->getReturnType(); +  return false; +} -  // Construct the new function type using the new arguments. -  FunctionType *NFTy = FunctionType::get(RetTy, Params, FTy->isVarArg()); +/// PromoteArguments - This method checks the specified function to see if there +/// are any promotable arguments and if it is safe to promote the function (for +/// example, all callers are direct).  If safe to promote some arguments, it +/// calls the DoPromotion method. +/// +static Function * +promoteArguments(Function *F, function_ref<AAResults &(Function &F)> AARGetter, +                 unsigned MaxElements, +                 Optional<function_ref<void(CallSite OldCS, CallSite NewCS)>> +                     ReplaceCallSite) { +  // Make sure that it is local to this module. +  if (!F->hasLocalLinkage()) +    return nullptr; -  // Create the new function body and insert it into the module. -  Function *NF = Function::Create(NFTy, F->getLinkage(), F->getName()); -  NF->copyAttributesFrom(F); +  // Don't promote arguments for variadic functions. Adding, removing, or +  // changing non-pack parameters can change the classification of pack +  // parameters. Frontends encode that classification at the call site in the +  // IR, while in the callee the classification is determined dynamically based +  // on the number of registers consumed so far. +  if (F->isVarArg()) +    return nullptr; -  // Patch the pointer to LLVM function in debug info descriptor. -  NF->setSubprogram(F->getSubprogram()); -  F->setSubprogram(nullptr); +  // First check: see if there are any pointer arguments!  If not, quick exit. +  SmallVector<Argument *, 16> PointerArgs; +  for (Argument &I : F->args()) +    if (I.getType()->isPointerTy()) +      PointerArgs.push_back(&I); +  if (PointerArgs.empty()) +    return nullptr; -  DEBUG(dbgs() << "ARG PROMOTION:  Promoting to:" << *NF << "\n" -        << "From: " << *F); -   -  // Recompute the parameter attributes list based on the new arguments for -  // the function. -  NF->setAttributes(AttributeSet::get(F->getContext(), AttributesVec)); -  AttributesVec.clear(); +  // Second check: make sure that all callers are direct callers.  We can't +  // transform functions that have indirect callers.  Also see if the function +  // is self-recursive. +  bool isSelfRecursive = false; +  for (Use &U : F->uses()) { +    CallSite CS(U.getUser()); +    // Must be a direct call. +    if (CS.getInstruction() == nullptr || !CS.isCallee(&U)) +      return nullptr; -  F->getParent()->getFunctionList().insert(F->getIterator(), NF); -  NF->takeName(F); +    if (CS.getInstruction()->getParent()->getParent() == F) +      isSelfRecursive = true; +  } -  // Get a new callgraph node for NF. -  CallGraphNode *NF_CGN = CG.getOrInsertFunction(NF); +  const DataLayout &DL = F->getParent()->getDataLayout(); -  // Loop over all of the callers of the function, transforming the call sites -  // to pass in the loaded pointers. -  // -  SmallVector<Value*, 16> Args; -  while (!F->use_empty()) { -    CallSite CS(F->user_back()); -    assert(CS.getCalledFunction() == F); -    Instruction *Call = CS.getInstruction(); -    const AttributeSet &CallPAL = CS.getAttributes(); +  AAResults &AAR = AARGetter(*F); -    // Add any return attributes. -    if (CallPAL.hasAttributes(AttributeSet::ReturnIndex)) -      AttributesVec.push_back(AttributeSet::get(F->getContext(), -                                                CallPAL.getRetAttributes())); +  // Check to see which arguments are promotable.  If an argument is promotable, +  // add it to ArgsToPromote. +  SmallPtrSet<Argument *, 8> ArgsToPromote; +  SmallPtrSet<Argument *, 8> ByValArgsToTransform; +  for (Argument *PtrArg : PointerArgs) { +    Type *AgTy = cast<PointerType>(PtrArg->getType())->getElementType(); -    // Loop over the operands, inserting GEP and loads in the caller as -    // appropriate. -    CallSite::arg_iterator AI = CS.arg_begin(); -    ArgIndex = 1; -    for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); -         I != E; ++I, ++AI, ++ArgIndex) -      if (!ArgsToPromote.count(&*I) && !ByValArgsToTransform.count(&*I)) { -        Args.push_back(*AI);          // Unmodified argument +    // Replace sret attribute with noalias. This reduces register pressure by +    // avoiding a register copy. +    if (PtrArg->hasStructRetAttr()) { +      unsigned ArgNo = PtrArg->getArgNo(); +      F->setAttributes( +          F->getAttributes() +              .removeAttribute(F->getContext(), ArgNo + 1, Attribute::StructRet) +              .addAttribute(F->getContext(), ArgNo + 1, Attribute::NoAlias)); +      for (Use &U : F->uses()) { +        CallSite CS(U.getUser()); +        CS.setAttributes( +            CS.getAttributes() +                .removeAttribute(F->getContext(), ArgNo + 1, +                                 Attribute::StructRet) +                .addAttribute(F->getContext(), ArgNo + 1, Attribute::NoAlias)); +      } +    } -        if (CallPAL.hasAttributes(ArgIndex)) { -          AttrBuilder B(CallPAL, ArgIndex); -          AttributesVec. -            push_back(AttributeSet::get(F->getContext(), Args.size(), B)); -        } -      } else if (ByValArgsToTransform.count(&*I)) { -        // Emit a GEP and load for each element of the struct. -        Type *AgTy = cast<PointerType>(I->getType())->getElementType(); -        StructType *STy = cast<StructType>(AgTy); -        Value *Idxs[2] = { -              ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), nullptr }; -        for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { -          Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i); -          Value *Idx = GetElementPtrInst::Create( -              STy, *AI, Idxs, (*AI)->getName() + "." + Twine(i), Call); -          // TODO: Tell AA about the new values? -          Args.push_back(new LoadInst(Idx, Idx->getName()+".val", Call)); +    // If this is a byval argument, and if the aggregate type is small, just +    // pass the elements, which is always safe, if the passed value is densely +    // packed or if we can prove the padding bytes are never accessed. This does +    // not apply to inalloca. +    bool isSafeToPromote = +        PtrArg->hasByValAttr() && +        (isDenselyPacked(AgTy, DL) || !canPaddingBeAccessed(PtrArg)); +    if (isSafeToPromote) { +      if (StructType *STy = dyn_cast<StructType>(AgTy)) { +        if (MaxElements > 0 && STy->getNumElements() > MaxElements) { +          DEBUG(dbgs() << "argpromotion disable promoting argument '" +                       << PtrArg->getName() +                       << "' because it would require adding more" +                       << " than " << MaxElements +                       << " arguments to the function.\n"); +          continue;          } -      } else if (!I->use_empty()) { -        // Non-dead argument: insert GEPs and loads as appropriate. -        ScalarizeTable &ArgIndices = ScalarizedElements[&*I]; -        // Store the Value* version of the indices in here, but declare it now -        // for reuse. -        std::vector<Value*> Ops; -        for (const auto &ArgIndex : ArgIndices) { -          Value *V = *AI; -          LoadInst *OrigLoad = -              OriginalLoads[std::make_pair(&*I, ArgIndex.second)]; -          if (!ArgIndex.second.empty()) { -            Ops.reserve(ArgIndex.second.size()); -            Type *ElTy = V->getType(); -            for (unsigned long II : ArgIndex.second) { -              // Use i32 to index structs, and i64 for others (pointers/arrays). -              // This satisfies GEP constraints. -              Type *IdxTy = (ElTy->isStructTy() ? -                    Type::getInt32Ty(F->getContext()) :  -                    Type::getInt64Ty(F->getContext())); -              Ops.push_back(ConstantInt::get(IdxTy, II)); -              // Keep track of the type we're currently indexing. -              if (auto *ElPTy = dyn_cast<PointerType>(ElTy)) -                ElTy = ElPTy->getElementType(); -              else -                ElTy = cast<CompositeType>(ElTy)->getTypeAtIndex(II); -            } -            // And create a GEP to extract those indices. -            V = GetElementPtrInst::Create(ArgIndex.first, V, Ops, -                                          V->getName() + ".idx", Call); -            Ops.clear(); + +        // If all the elements are single-value types, we can promote it. +        bool AllSimple = true; +        for (const auto *EltTy : STy->elements()) { +          if (!EltTy->isSingleValueType()) { +            AllSimple = false; +            break;            } -          // Since we're replacing a load make sure we take the alignment -          // of the previous load. -          LoadInst *newLoad = new LoadInst(V, V->getName()+".val", Call); -          newLoad->setAlignment(OrigLoad->getAlignment()); -          // Transfer the AA info too. -          AAMDNodes AAInfo; -          OrigLoad->getAAMetadata(AAInfo); -          newLoad->setAAMetadata(AAInfo); +        } -          Args.push_back(newLoad); +        // Safe to transform, don't even bother trying to "promote" it. +        // Passing the elements as a scalar will allow sroa to hack on +        // the new alloca we introduce. +        if (AllSimple) { +          ByValArgsToTransform.insert(PtrArg); +          continue;          }        } +    } -    // Push any varargs arguments on the list. -    for (; AI != CS.arg_end(); ++AI, ++ArgIndex) { -      Args.push_back(*AI); -      if (CallPAL.hasAttributes(ArgIndex)) { -        AttrBuilder B(CallPAL, ArgIndex); -        AttributesVec. -          push_back(AttributeSet::get(F->getContext(), Args.size(), B)); +    // If the argument is a recursive type and we're in a recursive +    // function, we could end up infinitely peeling the function argument. +    if (isSelfRecursive) { +      if (StructType *STy = dyn_cast<StructType>(AgTy)) { +        bool RecursiveType = false; +        for (const auto *EltTy : STy->elements()) { +          if (EltTy == PtrArg->getType()) { +            RecursiveType = true; +            break; +          } +        } +        if (RecursiveType) +          continue;        }      } -    // Add any function attributes. -    if (CallPAL.hasAttributes(AttributeSet::FunctionIndex)) -      AttributesVec.push_back(AttributeSet::get(Call->getContext(), -                                                CallPAL.getFnAttributes())); +    // Otherwise, see if we can promote the pointer to its value. +    if (isSafeToPromoteArgument(PtrArg, PtrArg->hasByValOrInAllocaAttr(), AAR, +                                MaxElements)) +      ArgsToPromote.insert(PtrArg); +  } + +  // No promotable pointer arguments. +  if (ArgsToPromote.empty() && ByValArgsToTransform.empty()) +    return nullptr; -    SmallVector<OperandBundleDef, 1> OpBundles; -    CS.getOperandBundlesAsDefs(OpBundles); +  return doPromotion(F, ArgsToPromote, ByValArgsToTransform, ReplaceCallSite); +} -    Instruction *New; -    if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) { -      New = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(), -                               Args, OpBundles, "", Call); -      cast<InvokeInst>(New)->setCallingConv(CS.getCallingConv()); -      cast<InvokeInst>(New)->setAttributes(AttributeSet::get(II->getContext(), -                                                            AttributesVec)); -    } else { -      New = CallInst::Create(NF, Args, OpBundles, "", Call); -      cast<CallInst>(New)->setCallingConv(CS.getCallingConv()); -      cast<CallInst>(New)->setAttributes(AttributeSet::get(New->getContext(), -                                                          AttributesVec)); -      cast<CallInst>(New)->setTailCallKind( -          cast<CallInst>(Call)->getTailCallKind()); -    } -    New->setDebugLoc(Call->getDebugLoc()); -    Args.clear(); -    AttributesVec.clear(); +PreservedAnalyses ArgumentPromotionPass::run(LazyCallGraph::SCC &C, +                                             CGSCCAnalysisManager &AM, +                                             LazyCallGraph &CG, +                                             CGSCCUpdateResult &UR) { +  bool Changed = false, LocalChange; -    // Update the callgraph to know that the callsite has been transformed. -    CallGraphNode *CalleeNode = CG[Call->getParent()->getParent()]; -    CalleeNode->replaceCallEdge(CS, CallSite(New), NF_CGN); +  // Iterate until we stop promoting from this SCC. +  do { +    LocalChange = false; -    if (!Call->use_empty()) { -      Call->replaceAllUsesWith(New); -      New->takeName(Call); +    for (LazyCallGraph::Node &N : C) { +      Function &OldF = N.getFunction(); + +      FunctionAnalysisManager &FAM = +          AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager(); +      // FIXME: This lambda must only be used with this function. We should +      // skip the lambda and just get the AA results directly. +      auto AARGetter = [&](Function &F) -> AAResults & { +        assert(&F == &OldF && "Called with an unexpected function!"); +        return FAM.getResult<AAManager>(F); +      }; + +      Function *NewF = promoteArguments(&OldF, AARGetter, 3u, None); +      if (!NewF) +        continue; +      LocalChange = true; + +      // Directly substitute the functions in the call graph. Note that this +      // requires the old function to be completely dead and completely +      // replaced by the new function. It does no call graph updates, it merely +      // swaps out the particular function mapped to a particular node in the +      // graph. +      C.getOuterRefSCC().replaceNodeFunction(N, *NewF); +      OldF.eraseFromParent();      } -    // Finally, remove the old call from the program, reducing the use-count of -    // F. -    Call->eraseFromParent(); -  } - -  // Since we have now created the new function, splice the body of the old -  // function right into the new function, leaving the old rotting hulk of the -  // function empty. -  NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList()); +    Changed |= LocalChange; +  } while (LocalChange); -  // Loop over the argument list, transferring uses of the old arguments over to -  // the new arguments, also transferring over the names as well. -  // -  for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(), -       I2 = NF->arg_begin(); I != E; ++I) { -    if (!ArgsToPromote.count(&*I) && !ByValArgsToTransform.count(&*I)) { -      // If this is an unmodified argument, move the name and users over to the -      // new version. -      I->replaceAllUsesWith(&*I2); -      I2->takeName(&*I); -      ++I2; -      continue; -    } +  if (!Changed) +    return PreservedAnalyses::all(); -    if (ByValArgsToTransform.count(&*I)) { -      // In the callee, we create an alloca, and store each of the new incoming -      // arguments into the alloca. -      Instruction *InsertPt = &NF->begin()->front(); +  return PreservedAnalyses::none(); +} -      // Just add all the struct element types. -      Type *AgTy = cast<PointerType>(I->getType())->getElementType(); -      Value *TheAlloca = new AllocaInst(AgTy, nullptr, "", InsertPt); -      StructType *STy = cast<StructType>(AgTy); -      Value *Idxs[2] = { -            ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), nullptr }; +namespace { +/// ArgPromotion - The 'by reference' to 'by value' argument promotion pass. +/// +struct ArgPromotion : public CallGraphSCCPass { +  void getAnalysisUsage(AnalysisUsage &AU) const override { +    AU.addRequired<AssumptionCacheTracker>(); +    AU.addRequired<TargetLibraryInfoWrapperPass>(); +    getAAResultsAnalysisUsage(AU); +    CallGraphSCCPass::getAnalysisUsage(AU); +  } -      for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { -        Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i); -        Value *Idx = GetElementPtrInst::Create( -            AgTy, TheAlloca, Idxs, TheAlloca->getName() + "." + Twine(i), -            InsertPt); -        I2->setName(I->getName()+"."+Twine(i)); -        new StoreInst(&*I2++, Idx, InsertPt); -      } +  bool runOnSCC(CallGraphSCC &SCC) override; +  static char ID; // Pass identification, replacement for typeid +  explicit ArgPromotion(unsigned MaxElements = 3) +      : CallGraphSCCPass(ID), MaxElements(MaxElements) { +    initializeArgPromotionPass(*PassRegistry::getPassRegistry()); +  } -      // Anything that used the arg should now use the alloca. -      I->replaceAllUsesWith(TheAlloca); -      TheAlloca->takeName(&*I); +private: +  using llvm::Pass::doInitialization; +  bool doInitialization(CallGraph &CG) override; +  /// The maximum number of elements to expand, or 0 for unlimited. +  unsigned MaxElements; +}; +} -      // If the alloca is used in a call, we must clear the tail flag since -      // the callee now uses an alloca from the caller. -      for (User *U : TheAlloca->users()) { -        CallInst *Call = dyn_cast<CallInst>(U); -        if (!Call) -          continue; -        Call->setTailCall(false); -      } -      continue; -    } +char ArgPromotion::ID = 0; +INITIALIZE_PASS_BEGIN(ArgPromotion, "argpromotion", +                      "Promote 'by reference' arguments to scalars", false, +                      false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(ArgPromotion, "argpromotion", +                    "Promote 'by reference' arguments to scalars", false, false) -    if (I->use_empty()) -      continue; +Pass *llvm::createArgumentPromotionPass(unsigned MaxElements) { +  return new ArgPromotion(MaxElements); +} -    // Otherwise, if we promoted this argument, then all users are load -    // instructions (or GEPs with only load users), and all loads should be -    // using the new argument that we added. -    ScalarizeTable &ArgIndices = ScalarizedElements[&*I]; +bool ArgPromotion::runOnSCC(CallGraphSCC &SCC) { +  if (skipSCC(SCC)) +    return false; -    while (!I->use_empty()) { -      if (LoadInst *LI = dyn_cast<LoadInst>(I->user_back())) { -        assert(ArgIndices.begin()->second.empty() && -               "Load element should sort to front!"); -        I2->setName(I->getName()+".val"); -        LI->replaceAllUsesWith(&*I2); -        LI->eraseFromParent(); -        DEBUG(dbgs() << "*** Promoted load of argument '" << I->getName() -              << "' in function '" << F->getName() << "'\n"); -      } else { -        GetElementPtrInst *GEP = cast<GetElementPtrInst>(I->user_back()); -        IndicesVector Operands; -        Operands.reserve(GEP->getNumIndices()); -        for (User::op_iterator II = GEP->idx_begin(), IE = GEP->idx_end(); -             II != IE; ++II) -          Operands.push_back(cast<ConstantInt>(*II)->getSExtValue()); +  // Get the callgraph information that we need to update to reflect our +  // changes. +  CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); -        // GEPs with a single 0 index can be merged with direct loads -        if (Operands.size() == 1 && Operands.front() == 0) -          Operands.clear(); +  LegacyAARGetter AARGetter(*this); -        Function::arg_iterator TheArg = I2; -        for (ScalarizeTable::iterator It = ArgIndices.begin(); -             It->second != Operands; ++It, ++TheArg) { -          assert(It != ArgIndices.end() && "GEP not handled??"); -        } +  bool Changed = false, LocalChange; -        std::string NewName = I->getName(); -        for (unsigned i = 0, e = Operands.size(); i != e; ++i) { -            NewName += "." + utostr(Operands[i]); -        } -        NewName += ".val"; -        TheArg->setName(NewName); +  // Iterate until we stop promoting from this SCC. +  do { +    LocalChange = false; +    // Attempt to promote arguments from all functions in this SCC. +    for (CallGraphNode *OldNode : SCC) { +      Function *OldF = OldNode->getFunction(); +      if (!OldF) +        continue; + +      auto ReplaceCallSite = [&](CallSite OldCS, CallSite NewCS) { +        Function *Caller = OldCS.getInstruction()->getParent()->getParent(); +        CallGraphNode *NewCalleeNode = +            CG.getOrInsertFunction(NewCS.getCalledFunction()); +        CallGraphNode *CallerNode = CG[Caller]; +        CallerNode->replaceCallEdge(OldCS, NewCS, NewCalleeNode); +      }; + +      if (Function *NewF = promoteArguments(OldF, AARGetter, MaxElements, +                                            {ReplaceCallSite})) { +        LocalChange = true; -        DEBUG(dbgs() << "*** Promoted agg argument '" << TheArg->getName() -              << "' of function '" << NF->getName() << "'\n"); +        // Update the call graph for the newly promoted function. +        CallGraphNode *NewNode = CG.getOrInsertFunction(NewF); +        NewNode->stealCalledFunctionsFrom(OldNode); +        if (OldNode->getNumReferences() == 0) +          delete CG.removeFunctionFromModule(OldNode); +        else +          OldF->setLinkage(Function::ExternalLinkage); -        // All of the uses must be load instructions.  Replace them all with -        // the argument specified by ArgNo. -        while (!GEP->use_empty()) { -          LoadInst *L = cast<LoadInst>(GEP->user_back()); -          L->replaceAllUsesWith(&*TheArg); -          L->eraseFromParent(); -        } -        GEP->eraseFromParent(); +        // And updat ethe SCC we're iterating as well. +        SCC.ReplaceNode(OldNode, NewNode);        }      } +    // Remember that we changed something. +    Changed |= LocalChange; +  } while (LocalChange); -    // Increment I2 past all of the arguments added for this promoted pointer. -    std::advance(I2, ArgIndices.size()); -  } - -  NF_CGN->stealCalledFunctionsFrom(CG[F]); -   -  // Now that the old function is dead, delete it.  If there is a dangling -  // reference to the CallgraphNode, just leave the dead function around for -  // someone else to nuke. -  CallGraphNode *CGN = CG[F]; -  if (CGN->getNumReferences() == 0) -    delete CG.removeFunctionFromModule(CGN); -  else -    F->setLinkage(Function::ExternalLinkage); -   -  return NF_CGN; +  return Changed;  }  bool ArgPromotion::doInitialization(CallGraph &CG) { diff --git a/lib/Transforms/IPO/ConstantMerge.cpp b/lib/Transforms/IPO/ConstantMerge.cpp index d75ed206ad23..62b5a9c9ba26 100644 --- a/lib/Transforms/IPO/ConstantMerge.cpp +++ b/lib/Transforms/IPO/ConstantMerge.cpp @@ -60,6 +60,23 @@ static bool IsBetterCanonical(const GlobalVariable &A,    return A.hasGlobalUnnamedAddr();  } +static bool hasMetadataOtherThanDebugLoc(const GlobalVariable *GV) { +  SmallVector<std::pair<unsigned, MDNode *>, 4> MDs; +  GV->getAllMetadata(MDs); +  for (const auto &V : MDs) +    if (V.first != LLVMContext::MD_dbg) +      return true; +  return false; +} + +static void copyDebugLocMetadata(const GlobalVariable *From, +                                 GlobalVariable *To) { +  SmallVector<DIGlobalVariableExpression *, 1> MDs; +  From->getDebugInfo(MDs); +  for (auto MD : MDs) +    To->addDebugInfo(MD); +} +  static unsigned getAlignment(GlobalVariable *GV) {    unsigned Align = GV->getAlignment();    if (Align) @@ -113,6 +130,10 @@ static bool mergeConstants(Module &M) {        if (GV->isWeakForLinker())          continue; +      // Don't touch globals with metadata other then !dbg. +      if (hasMetadataOtherThanDebugLoc(GV)) +        continue; +        Constant *Init = GV->getInitializer();        // Check to see if the initializer is already known. @@ -155,6 +176,9 @@ static bool mergeConstants(Module &M) {        if (!Slot->hasGlobalUnnamedAddr() && !GV->hasGlobalUnnamedAddr())          continue; +      if (hasMetadataOtherThanDebugLoc(GV)) +        continue; +        if (!GV->hasGlobalUnnamedAddr())          Slot->setUnnamedAddr(GlobalValue::UnnamedAddr::None); @@ -178,6 +202,8 @@ static bool mergeConstants(Module &M) {                       getAlignment(Replacements[i].second)));        } +      copyDebugLocMetadata(Replacements[i].first, Replacements[i].second); +        // Eliminate any uses of the dead global.        Replacements[i].first->replaceAllUsesWith(Replacements[i].second); diff --git a/lib/Transforms/IPO/CrossDSOCFI.cpp b/lib/Transforms/IPO/CrossDSOCFI.cpp index ba2e60dee3bc..1b111de06157 100644 --- a/lib/Transforms/IPO/CrossDSOCFI.cpp +++ b/lib/Transforms/IPO/CrossDSOCFI.cpp @@ -98,8 +98,11 @@ void CrossDSOCFI::buildCFICheck(Module &M) {    LLVMContext &Ctx = M.getContext();    Constant *C = M.getOrInsertFunction(        "__cfi_check", Type::getVoidTy(Ctx), Type::getInt64Ty(Ctx), -      Type::getInt8PtrTy(Ctx), Type::getInt8PtrTy(Ctx), nullptr); +      Type::getInt8PtrTy(Ctx), Type::getInt8PtrTy(Ctx));    Function *F = dyn_cast<Function>(C); +  // Take over the existing function. The frontend emits a weak stub so that the +  // linker knows about the symbol; this pass replaces the function body. +  F->deleteBody();    F->setAlignment(4096);    auto args = F->arg_begin();    Value &CallSiteTypeId = *(args++); @@ -117,7 +120,7 @@ void CrossDSOCFI::buildCFICheck(Module &M) {    IRBuilder<> IRBFail(TrapBB);    Constant *CFICheckFailFn = M.getOrInsertFunction(        "__cfi_check_fail", Type::getVoidTy(Ctx), Type::getInt8PtrTy(Ctx), -      Type::getInt8PtrTy(Ctx), nullptr); +      Type::getInt8PtrTy(Ctx));    IRBFail.CreateCall(CFICheckFailFn, {&CFICheckFailData, &Addr});    IRBFail.CreateBr(ExitBB); diff --git a/lib/Transforms/IPO/DeadArgumentElimination.cpp b/lib/Transforms/IPO/DeadArgumentElimination.cpp index 1a5ed4692211..375b74c494d9 100644 --- a/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -166,41 +166,43 @@ bool DeadArgumentEliminationPass::DeleteDeadVarargs(Function &Fn) {      Args.assign(CS.arg_begin(), CS.arg_begin() + NumArgs);      // Drop any attributes that were on the vararg arguments. -    AttributeSet PAL = CS.getAttributes(); +    AttributeList PAL = CS.getAttributes();      if (!PAL.isEmpty() && PAL.getSlotIndex(PAL.getNumSlots() - 1) > NumArgs) { -      SmallVector<AttributeSet, 8> AttributesVec; +      SmallVector<AttributeList, 8> AttributesVec;        for (unsigned i = 0; PAL.getSlotIndex(i) <= NumArgs; ++i)          AttributesVec.push_back(PAL.getSlotAttributes(i)); -      if (PAL.hasAttributes(AttributeSet::FunctionIndex)) -        AttributesVec.push_back(AttributeSet::get(Fn.getContext(), -                                                  PAL.getFnAttributes())); -      PAL = AttributeSet::get(Fn.getContext(), AttributesVec); +      if (PAL.hasAttributes(AttributeList::FunctionIndex)) +        AttributesVec.push_back(AttributeList::get(Fn.getContext(), +                                                   AttributeList::FunctionIndex, +                                                   PAL.getFnAttributes())); +      PAL = AttributeList::get(Fn.getContext(), AttributesVec);      }      SmallVector<OperandBundleDef, 1> OpBundles;      CS.getOperandBundlesAsDefs(OpBundles); -    Instruction *New; +    CallSite NewCS;      if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) { -      New = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(), -                               Args, OpBundles, "", Call); -      cast<InvokeInst>(New)->setCallingConv(CS.getCallingConv()); -      cast<InvokeInst>(New)->setAttributes(PAL); +      NewCS = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(), +                                 Args, OpBundles, "", Call);      } else { -      New = CallInst::Create(NF, Args, OpBundles, "", Call); -      cast<CallInst>(New)->setCallingConv(CS.getCallingConv()); -      cast<CallInst>(New)->setAttributes(PAL); -      cast<CallInst>(New)->setTailCallKind( -          cast<CallInst>(Call)->getTailCallKind()); +      NewCS = CallInst::Create(NF, Args, OpBundles, "", Call); +      cast<CallInst>(NewCS.getInstruction()) +          ->setTailCallKind(cast<CallInst>(Call)->getTailCallKind());      } -    New->setDebugLoc(Call->getDebugLoc()); +    NewCS.setCallingConv(CS.getCallingConv()); +    NewCS.setAttributes(PAL); +    NewCS->setDebugLoc(Call->getDebugLoc()); +    uint64_t W; +    if (Call->extractProfTotalWeight(W)) +      NewCS->setProfWeight(W);      Args.clear();      if (!Call->use_empty()) -      Call->replaceAllUsesWith(New); +      Call->replaceAllUsesWith(NewCS.getInstruction()); -    New->takeName(Call); +    NewCS->takeName(Call);      // Finally, remove the old call from the program, reducing the use-count of      // F. @@ -681,8 +683,8 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) {    bool HasLiveReturnedArg = false;    // Set up to build a new list of parameter attributes. -  SmallVector<AttributeSet, 8> AttributesVec; -  const AttributeSet &PAL = F->getAttributes(); +  SmallVector<AttributeSet, 8> ArgAttrVec; +  const AttributeList &PAL = F->getAttributes();    // Remember which arguments are still alive.    SmallVector<bool, 10> ArgAlive(FTy->getNumParams(), false); @@ -696,16 +698,8 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) {      if (LiveValues.erase(Arg)) {        Params.push_back(I->getType());        ArgAlive[i] = true; - -      // Get the original parameter attributes (skipping the first one, that is -      // for the return value. -      if (PAL.hasAttributes(i + 1)) { -        AttrBuilder B(PAL, i + 1); -        if (B.contains(Attribute::Returned)) -          HasLiveReturnedArg = true; -        AttributesVec. -          push_back(AttributeSet::get(F->getContext(), Params.size(), B)); -      } +      ArgAttrVec.push_back(PAL.getParamAttributes(i)); +      HasLiveReturnedArg |= PAL.hasParamAttribute(i, Attribute::Returned);      } else {        ++NumArgumentsEliminated;        DEBUG(dbgs() << "DeadArgumentEliminationPass - Removing argument " << i @@ -779,30 +773,24 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) {    assert(NRetTy && "No new return type found?");    // The existing function return attributes. -  AttributeSet RAttrs = PAL.getRetAttributes(); +  AttrBuilder RAttrs(PAL.getRetAttributes());    // Remove any incompatible attributes, but only if we removed all return    // values. Otherwise, ensure that we don't have any conflicting attributes    // here. Currently, this should not be possible, but special handling might be    // required when new return value attributes are added.    if (NRetTy->isVoidTy()) -    RAttrs = RAttrs.removeAttributes(NRetTy->getContext(), -                                     AttributeSet::ReturnIndex, -                                     AttributeFuncs::typeIncompatible(NRetTy)); +    RAttrs.remove(AttributeFuncs::typeIncompatible(NRetTy));    else -    assert(!AttrBuilder(RAttrs, AttributeSet::ReturnIndex). -             overlaps(AttributeFuncs::typeIncompatible(NRetTy)) && +    assert(!RAttrs.overlaps(AttributeFuncs::typeIncompatible(NRetTy)) &&             "Return attributes no longer compatible?"); -  if (RAttrs.hasAttributes(AttributeSet::ReturnIndex)) -    AttributesVec.push_back(AttributeSet::get(NRetTy->getContext(), RAttrs)); - -  if (PAL.hasAttributes(AttributeSet::FunctionIndex)) -    AttributesVec.push_back(AttributeSet::get(F->getContext(), -                                              PAL.getFnAttributes())); +  AttributeSet RetAttrs = AttributeSet::get(F->getContext(), RAttrs);    // Reconstruct the AttributesList based on the vector we constructed. -  AttributeSet NewPAL = AttributeSet::get(F->getContext(), AttributesVec); +  assert(ArgAttrVec.size() == Params.size()); +  AttributeList NewPAL = AttributeList::get( +      F->getContext(), PAL.getFnAttributes(), RetAttrs, ArgAttrVec);    // Create the new function type based on the recomputed parameters.    FunctionType *NFTy = FunctionType::get(NRetTy, Params, FTy->isVarArg()); @@ -829,18 +817,14 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) {      CallSite CS(F->user_back());      Instruction *Call = CS.getInstruction(); -    AttributesVec.clear(); -    const AttributeSet &CallPAL = CS.getAttributes(); - -    // The call return attributes. -    AttributeSet RAttrs = CallPAL.getRetAttributes(); +    ArgAttrVec.clear(); +    const AttributeList &CallPAL = CS.getAttributes(); -    // Adjust in case the function was changed to return void. -    RAttrs = RAttrs.removeAttributes(NRetTy->getContext(), -                                     AttributeSet::ReturnIndex, -                        AttributeFuncs::typeIncompatible(NF->getReturnType())); -    if (RAttrs.hasAttributes(AttributeSet::ReturnIndex)) -      AttributesVec.push_back(AttributeSet::get(NF->getContext(), RAttrs)); +    // Adjust the call return attributes in case the function was changed to +    // return void. +    AttrBuilder RAttrs(CallPAL.getRetAttributes()); +    RAttrs.remove(AttributeFuncs::typeIncompatible(NRetTy)); +    AttributeSet RetAttrs = AttributeSet::get(F->getContext(), RAttrs);      // Declare these outside of the loops, so we can reuse them for the second      // loop, which loops the varargs. @@ -852,57 +836,55 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) {        if (ArgAlive[i]) {          Args.push_back(*I);          // Get original parameter attributes, but skip return attributes. -        if (CallPAL.hasAttributes(i + 1)) { -          AttrBuilder B(CallPAL, i + 1); +        AttributeSet Attrs = CallPAL.getParamAttributes(i); +        if (NRetTy != RetTy && Attrs.hasAttribute(Attribute::Returned)) {            // If the return type has changed, then get rid of 'returned' on the            // call site. The alternative is to make all 'returned' attributes on            // call sites keep the return value alive just like 'returned' -          // attributes on function declaration but it's less clearly a win -          // and this is not an expected case anyway -          if (NRetTy != RetTy && B.contains(Attribute::Returned)) -            B.removeAttribute(Attribute::Returned); -          AttributesVec. -            push_back(AttributeSet::get(F->getContext(), Args.size(), B)); +          // attributes on function declaration but it's less clearly a win and +          // this is not an expected case anyway +          ArgAttrVec.push_back(AttributeSet::get( +              F->getContext(), +              AttrBuilder(Attrs).removeAttribute(Attribute::Returned))); +        } else { +          // Otherwise, use the original attributes. +          ArgAttrVec.push_back(Attrs);          }        }      // Push any varargs arguments on the list. Don't forget their attributes.      for (CallSite::arg_iterator E = CS.arg_end(); I != E; ++I, ++i) {        Args.push_back(*I); -      if (CallPAL.hasAttributes(i + 1)) { -        AttrBuilder B(CallPAL, i + 1); -        AttributesVec. -          push_back(AttributeSet::get(F->getContext(), Args.size(), B)); -      } +      ArgAttrVec.push_back(CallPAL.getParamAttributes(i));      } -    if (CallPAL.hasAttributes(AttributeSet::FunctionIndex)) -      AttributesVec.push_back(AttributeSet::get(Call->getContext(), -                                                CallPAL.getFnAttributes())); -      // Reconstruct the AttributesList based on the vector we constructed. -    AttributeSet NewCallPAL = AttributeSet::get(F->getContext(), AttributesVec); +    assert(ArgAttrVec.size() == Args.size()); +    AttributeList NewCallPAL = AttributeList::get( +        F->getContext(), CallPAL.getFnAttributes(), RetAttrs, ArgAttrVec);      SmallVector<OperandBundleDef, 1> OpBundles;      CS.getOperandBundlesAsDefs(OpBundles); -    Instruction *New; +    CallSite NewCS;      if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) { -      New = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(), -                               Args, OpBundles, "", Call->getParent()); -      cast<InvokeInst>(New)->setCallingConv(CS.getCallingConv()); -      cast<InvokeInst>(New)->setAttributes(NewCallPAL); +      NewCS = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(), +                                 Args, OpBundles, "", Call->getParent());      } else { -      New = CallInst::Create(NF, Args, OpBundles, "", Call); -      cast<CallInst>(New)->setCallingConv(CS.getCallingConv()); -      cast<CallInst>(New)->setAttributes(NewCallPAL); -      cast<CallInst>(New)->setTailCallKind( -          cast<CallInst>(Call)->getTailCallKind()); +      NewCS = CallInst::Create(NF, Args, OpBundles, "", Call); +      cast<CallInst>(NewCS.getInstruction()) +          ->setTailCallKind(cast<CallInst>(Call)->getTailCallKind());      } -    New->setDebugLoc(Call->getDebugLoc()); - +    NewCS.setCallingConv(CS.getCallingConv()); +    NewCS.setAttributes(NewCallPAL); +    NewCS->setDebugLoc(Call->getDebugLoc()); +    uint64_t W; +    if (Call->extractProfTotalWeight(W)) +      NewCS->setProfWeight(W);      Args.clear(); +    ArgAttrVec.clear(); +    Instruction *New = NewCS.getInstruction();      if (!Call->use_empty()) {        if (New->getType() == Call->getType()) {          // Return type not changed? Just replace users then. diff --git a/lib/Transforms/IPO/FunctionAttrs.cpp b/lib/Transforms/IPO/FunctionAttrs.cpp index 402a66552c24..4d13b3f40688 100644 --- a/lib/Transforms/IPO/FunctionAttrs.cpp +++ b/lib/Transforms/IPO/FunctionAttrs.cpp @@ -49,31 +49,35 @@ STATISTIC(NumNoAlias, "Number of function returns marked noalias");  STATISTIC(NumNonNullReturn, "Number of function returns marked nonnull");  STATISTIC(NumNoRecurse, "Number of functions marked as norecurse"); -namespace { -typedef SmallSetVector<Function *, 8> SCCNodeSet; -} +// FIXME: This is disabled by default to avoid exposing security vulnerabilities +// in C/C++ code compiled by clang: +// http://lists.llvm.org/pipermail/cfe-dev/2017-January/052066.html +static cl::opt<bool> EnableNonnullArgPropagation( +    "enable-nonnull-arg-prop", cl::Hidden, +    cl::desc("Try to propagate nonnull argument attributes from callsites to " +             "caller functions."));  namespace { -/// The three kinds of memory access relevant to 'readonly' and -/// 'readnone' attributes. -enum MemoryAccessKind { -  MAK_ReadNone = 0, -  MAK_ReadOnly = 1, -  MAK_MayWrite = 2 -}; +typedef SmallSetVector<Function *, 8> SCCNodeSet;  } -static MemoryAccessKind checkFunctionMemoryAccess(Function &F, AAResults &AAR, +/// Returns the memory access attribute for function F using AAR for AA results, +/// where SCCNodes is the current SCC. +/// +/// If ThisBody is true, this function may examine the function body and will +/// return a result pertaining to this copy of the function. If it is false, the +/// result will be based only on AA results for the function declaration; it +/// will be assumed that some other (perhaps less optimized) version of the +/// function may be selected at link time. +static MemoryAccessKind checkFunctionMemoryAccess(Function &F, bool ThisBody, +                                                  AAResults &AAR,                                                    const SCCNodeSet &SCCNodes) {    FunctionModRefBehavior MRB = AAR.getModRefBehavior(&F);    if (MRB == FMRB_DoesNotAccessMemory)      // Already perfect!      return MAK_ReadNone; -  // Non-exact function definitions may not be selected at link time, and an -  // alternative version that writes to memory may be selected.  See the comment -  // on GlobalValue::isDefinitionExact for more details. -  if (!F.hasExactDefinition()) { +  if (!ThisBody) {      if (AliasAnalysis::onlyReadsMemory(MRB))        return MAK_ReadOnly; @@ -172,9 +176,14 @@ static MemoryAccessKind checkFunctionMemoryAccess(Function &F, AAResults &AAR,    return ReadsMemory ? MAK_ReadOnly : MAK_ReadNone;  } +MemoryAccessKind llvm::computeFunctionBodyMemoryAccess(Function &F, +                                                       AAResults &AAR) { +  return checkFunctionMemoryAccess(F, /*ThisBody=*/true, AAR, {}); +} +  /// Deduce readonly/readnone attributes for the SCC.  template <typename AARGetterT> -static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT AARGetter) { +static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT &&AARGetter) {    // Check if any of the functions in the SCC read or write memory.  If they    // write memory then they can't be marked readnone or readonly.    bool ReadsMemory = false; @@ -182,7 +191,11 @@ static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT AARGetter) {      // Call the callable parameter to look up AA results for this function.      AAResults &AAR = AARGetter(*F); -    switch (checkFunctionMemoryAccess(*F, AAR, SCCNodes)) { +    // Non-exact function definitions may not be selected at link time, and an +    // alternative version that writes to memory may be selected.  See the +    // comment on GlobalValue::isDefinitionExact for more details. +    switch (checkFunctionMemoryAccess(*F, F->hasExactDefinition(), +                                      AAR, SCCNodes)) {      case MAK_MayWrite:        return false;      case MAK_ReadOnly: @@ -212,11 +225,11 @@ static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT AARGetter) {      AttrBuilder B;      B.addAttribute(Attribute::ReadOnly).addAttribute(Attribute::ReadNone);      F->removeAttributes( -        AttributeSet::FunctionIndex, -        AttributeSet::get(F->getContext(), AttributeSet::FunctionIndex, B)); +        AttributeList::FunctionIndex, +        AttributeList::get(F->getContext(), AttributeList::FunctionIndex, B));      // Add in the new attribute. -    F->addAttribute(AttributeSet::FunctionIndex, +    F->addAttribute(AttributeList::FunctionIndex,                      ReadsMemory ? Attribute::ReadOnly : Attribute::ReadNone);      if (ReadsMemory) @@ -522,7 +535,7 @@ static bool addArgumentReturnedAttrs(const SCCNodeSet &SCCNodes) {      if (Value *RetArg = FindRetArg()) {        auto *A = cast<Argument>(RetArg); -      A->addAttr(AttributeSet::get(F->getContext(), A->getArgNo() + 1, B)); +      A->addAttr(AttributeList::get(F->getContext(), A->getArgNo() + 1, B));        ++NumReturned;        Changed = true;      } @@ -531,6 +544,49 @@ static bool addArgumentReturnedAttrs(const SCCNodeSet &SCCNodes) {    return Changed;  } +/// If a callsite has arguments that are also arguments to the parent function, +/// try to propagate attributes from the callsite's arguments to the parent's +/// arguments. This may be important because inlining can cause information loss +/// when attribute knowledge disappears with the inlined call. +static bool addArgumentAttrsFromCallsites(Function &F) { +  if (!EnableNonnullArgPropagation) +    return false; + +  bool Changed = false; + +  // For an argument attribute to transfer from a callsite to the parent, the +  // call must be guaranteed to execute every time the parent is called. +  // Conservatively, just check for calls in the entry block that are guaranteed +  // to execute. +  // TODO: This could be enhanced by testing if the callsite post-dominates the +  // entry block or by doing simple forward walks or backward walks to the +  // callsite. +  BasicBlock &Entry = F.getEntryBlock(); +  for (Instruction &I : Entry) { +    if (auto CS = CallSite(&I)) { +      if (auto *CalledFunc = CS.getCalledFunction()) { +        for (auto &CSArg : CalledFunc->args()) { +          if (!CSArg.hasNonNullAttr()) +            continue; + +          // If the non-null callsite argument operand is an argument to 'F' +          // (the caller) and the call is guaranteed to execute, then the value +          // must be non-null throughout 'F'. +          auto *FArg = dyn_cast<Argument>(CS.getArgOperand(CSArg.getArgNo())); +          if (FArg && !FArg->hasNonNullAttr()) { +            FArg->addAttr(Attribute::NonNull); +            Changed = true; +          } +        } +      } +    } +    if (!isGuaranteedToTransferExecutionToSuccessor(&I)) +      break; +  } +   +  return Changed; +} +  /// Deduce nocapture attributes for the SCC.  static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) {    bool Changed = false; @@ -549,6 +605,8 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) {      if (!F->hasExactDefinition())        continue; +    Changed |= addArgumentAttrsFromCallsites(*F); +      // Functions that are readonly (or readnone) and nounwind and don't return      // a value can't capture arguments. Don't analyze them.      if (F->onlyReadsMemory() && F->doesNotThrow() && @@ -556,7 +614,7 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) {        for (Function::arg_iterator A = F->arg_begin(), E = F->arg_end(); A != E;             ++A) {          if (A->getType()->isPointerTy() && !A->hasNoCaptureAttr()) { -          A->addAttr(AttributeSet::get(F->getContext(), A->getArgNo() + 1, B)); +          A->addAttr(AttributeList::get(F->getContext(), A->getArgNo() + 1, B));            ++NumNoCapture;            Changed = true;          } @@ -576,7 +634,7 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) {            if (Tracker.Uses.empty()) {              // If it's trivially not captured, mark it nocapture now.              A->addAttr( -                AttributeSet::get(F->getContext(), A->getArgNo() + 1, B)); +                AttributeList::get(F->getContext(), A->getArgNo() + 1, B));              ++NumNoCapture;              Changed = true;            } else { @@ -604,7 +662,7 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) {          if (R != Attribute::None) {            AttrBuilder B;            B.addAttribute(R); -          A->addAttr(AttributeSet::get(A->getContext(), A->getArgNo() + 1, B)); +          A->addAttr(AttributeList::get(A->getContext(), A->getArgNo() + 1, B));            Changed = true;            R == Attribute::ReadOnly ? ++NumReadOnlyArg : ++NumReadNoneArg;          } @@ -629,7 +687,7 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) {        if (ArgumentSCC[0]->Uses.size() == 1 &&            ArgumentSCC[0]->Uses[0] == ArgumentSCC[0]) {          Argument *A = ArgumentSCC[0]->Definition; -        A->addAttr(AttributeSet::get(A->getContext(), A->getArgNo() + 1, B)); +        A->addAttr(AttributeList::get(A->getContext(), A->getArgNo() + 1, B));          ++NumNoCapture;          Changed = true;        } @@ -671,7 +729,7 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) {      for (unsigned i = 0, e = ArgumentSCC.size(); i != e; ++i) {        Argument *A = ArgumentSCC[i]->Definition; -      A->addAttr(AttributeSet::get(A->getContext(), A->getArgNo() + 1, B)); +      A->addAttr(AttributeList::get(A->getContext(), A->getArgNo() + 1, B));        ++NumNoCapture;        Changed = true;      } @@ -708,8 +766,9 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) {        for (unsigned i = 0, e = ArgumentSCC.size(); i != e; ++i) {          Argument *A = ArgumentSCC[i]->Definition;          // Clear out existing readonly/readnone attributes -        A->removeAttr(AttributeSet::get(A->getContext(), A->getArgNo() + 1, R)); -        A->addAttr(AttributeSet::get(A->getContext(), A->getArgNo() + 1, B)); +        A->removeAttr( +            AttributeList::get(A->getContext(), A->getArgNo() + 1, R)); +        A->addAttr(AttributeList::get(A->getContext(), A->getArgNo() + 1, B));          ReadAttr == Attribute::ReadOnly ? ++NumReadOnlyArg : ++NumReadNoneArg;          Changed = true;        } @@ -769,7 +828,7 @@ static bool isFunctionMallocLike(Function *F, const SCCNodeSet &SCCNodes) {        case Instruction::Call:        case Instruction::Invoke: {          CallSite CS(RVI); -        if (CS.paramHasAttr(0, Attribute::NoAlias)) +        if (CS.hasRetAttr(Attribute::NoAlias))            break;          if (CS.getCalledFunction() && SCCNodes.count(CS.getCalledFunction()))            break; @@ -905,7 +964,7 @@ static bool addNonNullAttrs(const SCCNodeSet &SCCNodes) {    // pointers.    for (Function *F : SCCNodes) {      // Already nonnull. -    if (F->getAttributes().hasAttribute(AttributeSet::ReturnIndex, +    if (F->getAttributes().hasAttribute(AttributeList::ReturnIndex,                                          Attribute::NonNull))        continue; @@ -926,7 +985,7 @@ static bool addNonNullAttrs(const SCCNodeSet &SCCNodes) {          // Mark the function eagerly since we may discover a function          // which prevents us from speculating about the entire SCC          DEBUG(dbgs() << "Eagerly marking " << F->getName() << " as nonnull\n"); -        F->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull); +        F->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);          ++NumNonNullReturn;          MadeChange = true;        } @@ -939,13 +998,13 @@ static bool addNonNullAttrs(const SCCNodeSet &SCCNodes) {    if (SCCReturnsNonNull) {      for (Function *F : SCCNodes) { -      if (F->getAttributes().hasAttribute(AttributeSet::ReturnIndex, +      if (F->getAttributes().hasAttribute(AttributeList::ReturnIndex,                                            Attribute::NonNull) ||            !F->getReturnType()->isPointerTy())          continue;        DEBUG(dbgs() << "SCC marking " << F->getName() << " as nonnull\n"); -      F->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull); +      F->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);        ++NumNonNullReturn;        MadeChange = true;      } @@ -1163,19 +1222,7 @@ static bool runImpl(CallGraphSCC &SCC, AARGetterT AARGetter) {  bool PostOrderFunctionAttrsLegacyPass::runOnSCC(CallGraphSCC &SCC) {    if (skipSCC(SCC))      return false; - -  // We compute dedicated AA results for each function in the SCC as needed. We -  // use a lambda referencing external objects so that they live long enough to -  // be queried, but we re-use them each time. -  Optional<BasicAAResult> BAR; -  Optional<AAResults> AAR; -  auto AARGetter = [&](Function &F) -> AAResults & { -    BAR.emplace(createLegacyPMBasicAAResult(*this, F)); -    AAR.emplace(createLegacyPMAAResults(*this, F, *BAR)); -    return *AAR; -  }; - -  return runImpl(SCC, AARGetter); +  return runImpl(SCC, LegacyAARGetter(*this));  }  namespace { @@ -1275,16 +1322,9 @@ PreservedAnalyses  ReversePostOrderFunctionAttrsPass::run(Module &M, ModuleAnalysisManager &AM) {    auto &CG = AM.getResult<CallGraphAnalysis>(M); -  bool Changed = deduceFunctionAttributeInRPO(M, CG); - -  // CallGraphAnalysis holds AssertingVH and must be invalidated eagerly so -  // that other passes don't delete stuff from under it. -  // FIXME: We need to invalidate this to avoid PR28400. Is there a better -  // solution? -  AM.invalidate<CallGraphAnalysis>(M); - -  if (!Changed) +  if (!deduceFunctionAttributeInRPO(M, CG))      return PreservedAnalyses::all(); +    PreservedAnalyses PA;    PA.preserve<CallGraphAnalysis>();    return PA; diff --git a/lib/Transforms/IPO/FunctionImport.cpp b/lib/Transforms/IPO/FunctionImport.cpp index 6b32f6c31f72..d66411f04cc4 100644 --- a/lib/Transforms/IPO/FunctionImport.cpp +++ b/lib/Transforms/IPO/FunctionImport.cpp @@ -75,12 +75,6 @@ static cl::opt<bool> PrintImports("print-imports", cl::init(false), cl::Hidden,  static cl::opt<bool> ComputeDead("compute-dead", cl::init(true), cl::Hidden,                                   cl::desc("Compute dead symbols")); -// Temporary allows the function import pass to disable always linking -// referenced discardable symbols. -static cl::opt<bool> -    DontForceImportReferencedDiscardableSymbols("disable-force-link-odr", -                                                cl::init(false), cl::Hidden); -  static cl::opt<bool> EnableImportMetadata(      "enable-import-metadata", cl::init(  #if !defined(NDEBUG) @@ -124,7 +118,7 @@ namespace {  static const GlobalValueSummary *  selectCallee(const ModuleSummaryIndex &Index,               const GlobalValueSummaryList &CalleeSummaryList, -             unsigned Threshold) { +             unsigned Threshold, StringRef CallerModulePath) {    auto It = llvm::find_if(        CalleeSummaryList,        [&](const std::unique_ptr<GlobalValueSummary> &SummaryPtr) { @@ -145,6 +139,21 @@ selectCallee(const ModuleSummaryIndex &Index,          auto *Summary = cast<FunctionSummary>(GVSummary); +        // If this is a local function, make sure we import the copy +        // in the caller's module. The only time a local function can +        // share an entry in the index is if there is a local with the same name +        // in another module that had the same source file name (in a different +        // directory), where each was compiled in their own directory so there +        // was not distinguishing path. +        // However, do the import from another module if there is only one +        // entry in the list - in that case this must be a reference due +        // to indirect call profile data, since a function pointer can point to +        // a local in another module. +        if (GlobalValue::isLocalLinkage(Summary->linkage()) && +            CalleeSummaryList.size() > 1 && +            Summary->modulePath() != CallerModulePath) +          return false; +          if (Summary->instCount() > Threshold)            return false; @@ -163,11 +172,13 @@ selectCallee(const ModuleSummaryIndex &Index,  /// null if there's no match.  static const GlobalValueSummary *selectCallee(GlobalValue::GUID GUID,                                                unsigned Threshold, -                                              const ModuleSummaryIndex &Index) { +                                              const ModuleSummaryIndex &Index, +                                              StringRef CallerModulePath) {    auto CalleeSummaryList = Index.findGlobalValueSummaryList(GUID);    if (CalleeSummaryList == Index.end())      return nullptr; // This function does not have a summary -  return selectCallee(Index, CalleeSummaryList->second, Threshold); +  return selectCallee(Index, CalleeSummaryList->second, Threshold, +                      CallerModulePath);  }  using EdgeInfo = std::tuple<const FunctionSummary *, unsigned /* Threshold */, @@ -186,6 +197,15 @@ static void computeImportForFunction(      auto GUID = Edge.first.getGUID();      DEBUG(dbgs() << " edge -> " << GUID << " Threshold:" << Threshold << "\n"); +    if (Index.findGlobalValueSummaryList(GUID) == Index.end()) { +      // For SamplePGO, the indirect call targets for local functions will +      // have its original name annotated in profile. We try to find the +      // corresponding PGOFuncName as the GUID. +      GUID = Index.getGUIDFromOriginalID(GUID); +      if (GUID == 0) +        continue; +    } +      if (DefinedGVSummaries.count(GUID)) {        DEBUG(dbgs() << "ignored! Target already in destination module.\n");        continue; @@ -202,7 +222,8 @@ static void computeImportForFunction(      const auto NewThreshold =          Threshold * GetBonusMultiplier(Edge.second.Hotness); -    auto *CalleeSummary = selectCallee(GUID, NewThreshold, Index); +    auto *CalleeSummary = +        selectCallee(GUID, NewThreshold, Index, Summary.modulePath());      if (!CalleeSummary) {        DEBUG(dbgs() << "ignored! No qualifying callee with summary found.\n");        continue; @@ -522,6 +543,23 @@ llvm::EmitImportsFiles(StringRef ModulePath, StringRef OutputFilename,  /// Fixup WeakForLinker linkages in \p TheModule based on summary analysis.  void llvm::thinLTOResolveWeakForLinkerModule(      Module &TheModule, const GVSummaryMapTy &DefinedGlobals) { +  auto ConvertToDeclaration = [](GlobalValue &GV) { +    DEBUG(dbgs() << "Converting to a declaration: `" << GV.getName() << "\n"); +    if (Function *F = dyn_cast<Function>(&GV)) { +      F->deleteBody(); +      F->clearMetadata(); +    } else if (GlobalVariable *V = dyn_cast<GlobalVariable>(&GV)) { +      V->setInitializer(nullptr); +      V->setLinkage(GlobalValue::ExternalLinkage); +      V->clearMetadata(); +    } else +      // For now we don't resolve or drop aliases. Once we do we'll +      // need to add support here for creating either a function or +      // variable declaration, and return the new GlobalValue* for +      // the caller to use. +      llvm_unreachable("Expected function or variable"); +  }; +    auto updateLinkage = [&](GlobalValue &GV) {      if (!GlobalValue::isWeakForLinker(GV.getLinkage()))        return; @@ -532,18 +570,25 @@ void llvm::thinLTOResolveWeakForLinkerModule(      auto NewLinkage = GS->second->linkage();      if (NewLinkage == GV.getLinkage())        return; -    DEBUG(dbgs() << "ODR fixing up linkage for `" << GV.getName() << "` from " -                 << GV.getLinkage() << " to " << NewLinkage << "\n"); -    GV.setLinkage(NewLinkage); -    // Remove functions converted to available_externally from comdats, +    // Check for a non-prevailing def that has interposable linkage +    // (e.g. non-odr weak or linkonce). In that case we can't simply +    // convert to available_externally, since it would lose the +    // interposable property and possibly get inlined. Simply drop +    // the definition in that case. +    if (GlobalValue::isAvailableExternallyLinkage(NewLinkage) && +        GlobalValue::isInterposableLinkage(GV.getLinkage())) +      ConvertToDeclaration(GV); +    else { +      DEBUG(dbgs() << "ODR fixing up linkage for `" << GV.getName() << "` from " +                   << GV.getLinkage() << " to " << NewLinkage << "\n"); +      GV.setLinkage(NewLinkage); +    } +    // Remove declarations from comdats, including available_externally      // as this is a declaration for the linker, and will be dropped eventually.      // It is illegal for comdats to contain declarations.      auto *GO = dyn_cast_or_null<GlobalObject>(&GV); -    if (GO && GO->isDeclarationForLinker() && GO->hasComdat()) { -      assert(GO->hasAvailableExternallyLinkage() && -             "Expected comdat on definition (possibly available external)"); +    if (GO && GO->isDeclarationForLinker() && GO->hasComdat())        GO->setComdat(nullptr); -    }    };    // Process functions and global now @@ -562,7 +607,7 @@ void llvm::thinLTOInternalizeModule(Module &TheModule,    // the current module.    StringSet<> AsmUndefinedRefs;    ModuleSymbolTable::CollectAsmSymbols( -      Triple(TheModule.getTargetTriple()), TheModule.getModuleInlineAsm(), +      TheModule,        [&AsmUndefinedRefs](StringRef Name, object::BasicSymbolRef::Flags Flags) {          if (Flags & object::BasicSymbolRef::SF_Undefined)            AsmUndefinedRefs.insert(Name); @@ -617,14 +662,12 @@ void llvm::thinLTOInternalizeModule(Module &TheModule,  // index.  //  Expected<bool> FunctionImporter::importFunctions( -    Module &DestModule, const FunctionImporter::ImportMapTy &ImportList, -    bool ForceImportReferencedDiscardableSymbols) { +    Module &DestModule, const FunctionImporter::ImportMapTy &ImportList) {    DEBUG(dbgs() << "Starting import for Module "                 << DestModule.getModuleIdentifier() << "\n");    unsigned ImportedCount = 0; -  // Linker that will be used for importing function -  Linker TheLinker(DestModule); +  IRMover Mover(DestModule);    // Do the actual import of functions now, one Module at a time    std::set<StringRef> ModuleNameOrderedList;    for (auto &FunctionsToImportPerModule : ImportList) { @@ -648,7 +691,7 @@ Expected<bool> FunctionImporter::importFunctions(      auto &ImportGUIDs = FunctionsToImportPerModule->second;      // Find the globals to import -    DenseSet<const GlobalValue *> GlobalsToImport; +    SetVector<GlobalValue *> GlobalsToImport;      for (Function &F : *SrcModule) {        if (!F.hasName())          continue; @@ -687,6 +730,13 @@ Expected<bool> FunctionImporter::importFunctions(        }      }      for (GlobalAlias &GA : SrcModule->aliases()) { +      // FIXME: This should eventually be controlled entirely by the summary. +      if (FunctionImportGlobalProcessing::doImportAsDefinition( +              &GA, &GlobalsToImport)) { +        GlobalsToImport.insert(&GA); +        continue; +      } +        if (!GA.hasName())          continue;        auto GUID = GA.getGUID(); @@ -731,12 +781,9 @@ Expected<bool> FunctionImporter::importFunctions(                 << " from " << SrcModule->getSourceFileName() << "\n";      } -    // Instruct the linker that the client will take care of linkonce resolution -    unsigned Flags = Linker::Flags::None; -    if (!ForceImportReferencedDiscardableSymbols) -      Flags |= Linker::Flags::DontForceLinkLinkonceODR; - -    if (TheLinker.linkInModule(std::move(SrcModule), Flags, &GlobalsToImport)) +    if (Mover.move(std::move(SrcModule), GlobalsToImport.getArrayRef(), +                   [](GlobalValue &, IRMover::ValueAdder) {}, +                   /*IsPerformingImport=*/true))        report_fatal_error("Function Import: link error");      ImportedCount += GlobalsToImport.size(); @@ -796,8 +843,7 @@ static bool doImportingForModule(Module &M) {      return loadFile(Identifier, M.getContext());    };    FunctionImporter Importer(*Index, ModuleLoader); -  Expected<bool> Result = Importer.importFunctions( -      M, ImportList, !DontForceImportReferencedDiscardableSymbols); +  Expected<bool> Result = Importer.importFunctions(M, ImportList);    // FIXME: Probably need to propagate Errors through the pass manager.    if (!Result) { diff --git a/lib/Transforms/IPO/GlobalDCE.cpp b/lib/Transforms/IPO/GlobalDCE.cpp index 7a04de3d12db..c91e8b454927 100644 --- a/lib/Transforms/IPO/GlobalDCE.cpp +++ b/lib/Transforms/IPO/GlobalDCE.cpp @@ -25,7 +25,7 @@  #include "llvm/Transforms/IPO.h"  #include "llvm/Transforms/Utils/CtorUtils.h"  #include "llvm/Transforms/Utils/GlobalStatus.h" -#include <unordered_map> +  using namespace llvm;  #define DEBUG_TYPE "globaldce" @@ -50,7 +50,14 @@ namespace {        if (skipModule(M))          return false; +      // We need a minimally functional dummy module analysis manager. It needs +      // to at least know about the possibility of proxying a function analysis +      // manager. +      FunctionAnalysisManager DummyFAM;        ModuleAnalysisManager DummyMAM; +      DummyMAM.registerPass( +          [&] { return FunctionAnalysisManagerModuleProxy(DummyFAM); }); +        auto PA = Impl.run(M, DummyMAM);        return !PA.areAllPreserved();      } @@ -78,9 +85,67 @@ static bool isEmptyFunction(Function *F) {    return RI.getReturnValue() == nullptr;  } -PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &) { +/// Compute the set of GlobalValue that depends from V. +/// The recursion stops as soon as a GlobalValue is met. +void GlobalDCEPass::ComputeDependencies(Value *V, +                                        SmallPtrSetImpl<GlobalValue *> &Deps) { +  if (auto *I = dyn_cast<Instruction>(V)) { +    Function *Parent = I->getParent()->getParent(); +    Deps.insert(Parent); +  } else if (auto *GV = dyn_cast<GlobalValue>(V)) { +    Deps.insert(GV); +  } else if (auto *CE = dyn_cast<Constant>(V)) { +    // Avoid walking the whole tree of a big ConstantExprs multiple times. +    auto Where = ConstantDependenciesCache.find(CE); +    if (Where != ConstantDependenciesCache.end()) { +      auto const &K = Where->second; +      Deps.insert(K.begin(), K.end()); +    } else { +      SmallPtrSetImpl<GlobalValue *> &LocalDeps = ConstantDependenciesCache[CE]; +      for (User *CEUser : CE->users()) +        ComputeDependencies(CEUser, LocalDeps); +      Deps.insert(LocalDeps.begin(), LocalDeps.end()); +    } +  } +} + +void GlobalDCEPass::UpdateGVDependencies(GlobalValue &GV) { +  SmallPtrSet<GlobalValue *, 8> Deps; +  for (User *User : GV.users()) +    ComputeDependencies(User, Deps); +  Deps.erase(&GV); // Remove self-reference. +  for (GlobalValue *GVU : Deps) { +    GVDependencies.insert(std::make_pair(GVU, &GV)); +  } +} + +/// Mark Global value as Live +void GlobalDCEPass::MarkLive(GlobalValue &GV, +                             SmallVectorImpl<GlobalValue *> *Updates) { +  auto const Ret = AliveGlobals.insert(&GV); +  if (!Ret.second) +    return; + +  if (Updates) +    Updates->push_back(&GV); +  if (Comdat *C = GV.getComdat()) { +    for (auto &&CM : make_range(ComdatMembers.equal_range(C))) +      MarkLive(*CM.second, Updates); // Recursion depth is only two because only +                                     // globals in the same comdat are visited. +  } +} + +PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) {    bool Changed = false; +  // The algorithm first computes the set L of global variables that are +  // trivially live.  Then it walks the initialization of these variables to +  // compute the globals used to initialize them, which effectively builds a +  // directed graph where nodes are global variables, and an edge from A to B +  // means B is used to initialize A.  Finally, it propagates the liveness +  // information through the graph starting from the nodes in L. Nodes note +  // marked as alive are discarded. +    // Remove empty functions from the global ctors list.    Changed |= optimizeGlobalCtorsList(M, isEmptyFunction); @@ -103,21 +168,39 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &) {      // initializer.      if (!GO.isDeclaration() && !GO.hasAvailableExternallyLinkage())        if (!GO.isDiscardableIfUnused()) -        GlobalIsNeeded(&GO); +        MarkLive(GO); + +    UpdateGVDependencies(GO);    } +  // Compute direct dependencies of aliases.    for (GlobalAlias &GA : M.aliases()) {      Changed |= RemoveUnusedGlobalValue(GA);      // Externally visible aliases are needed.      if (!GA.isDiscardableIfUnused()) -      GlobalIsNeeded(&GA); +      MarkLive(GA); + +    UpdateGVDependencies(GA);    } +  // Compute direct dependencies of ifuncs.    for (GlobalIFunc &GIF : M.ifuncs()) {      Changed |= RemoveUnusedGlobalValue(GIF);      // Externally visible ifuncs are needed.      if (!GIF.isDiscardableIfUnused()) -      GlobalIsNeeded(&GIF); +      MarkLive(GIF); + +    UpdateGVDependencies(GIF); +  } + +  // Propagate liveness from collected Global Values through the computed +  // dependencies. +  SmallVector<GlobalValue *, 8> NewLiveGVs{AliveGlobals.begin(), +                                           AliveGlobals.end()}; +  while (!NewLiveGVs.empty()) { +    GlobalValue *LGV = NewLiveGVs.pop_back_val(); +    for (auto &&GVD : make_range(GVDependencies.equal_range(LGV))) +      MarkLive(*GVD.second, &NewLiveGVs);    }    // Now that all globals which are needed are in the AliveGlobals set, we loop @@ -154,7 +237,7 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &) {        GA.setAliasee(nullptr);      } -  // The third pass drops targets of ifuncs which are dead... +  // The fourth pass drops targets of ifuncs which are dead...    std::vector<GlobalIFunc*> DeadIFuncs;    for (GlobalIFunc &GIF : M.ifuncs())      if (!AliveGlobals.count(&GIF)) { @@ -188,7 +271,8 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &) {    // Make sure that all memory is released    AliveGlobals.clear(); -  SeenConstants.clear(); +  ConstantDependenciesCache.clear(); +  GVDependencies.clear();    ComdatMembers.clear();    if (Changed) @@ -196,60 +280,6 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &) {    return PreservedAnalyses::all();  } -/// GlobalIsNeeded - the specific global value as needed, and -/// recursively mark anything that it uses as also needed. -void GlobalDCEPass::GlobalIsNeeded(GlobalValue *G) { -  // If the global is already in the set, no need to reprocess it. -  if (!AliveGlobals.insert(G).second) -    return; - -  if (Comdat *C = G->getComdat()) { -    for (auto &&CM : make_range(ComdatMembers.equal_range(C))) -      GlobalIsNeeded(CM.second); -  } - -  if (GlobalVariable *GV = dyn_cast<GlobalVariable>(G)) { -    // If this is a global variable, we must make sure to add any global values -    // referenced by the initializer to the alive set. -    if (GV->hasInitializer()) -      MarkUsedGlobalsAsNeeded(GV->getInitializer()); -  } else if (GlobalIndirectSymbol *GIS = dyn_cast<GlobalIndirectSymbol>(G)) { -    // The target of a global alias or ifunc is needed. -    MarkUsedGlobalsAsNeeded(GIS->getIndirectSymbol()); -  } else { -    // Otherwise this must be a function object.  We have to scan the body of -    // the function looking for constants and global values which are used as -    // operands.  Any operands of these types must be processed to ensure that -    // any globals used will be marked as needed. -    Function *F = cast<Function>(G); - -    for (Use &U : F->operands()) -      MarkUsedGlobalsAsNeeded(cast<Constant>(U.get())); - -    for (BasicBlock &BB : *F) -      for (Instruction &I : BB) -        for (Use &U : I.operands()) -          if (GlobalValue *GV = dyn_cast<GlobalValue>(U)) -            GlobalIsNeeded(GV); -          else if (Constant *C = dyn_cast<Constant>(U)) -            MarkUsedGlobalsAsNeeded(C); -  } -} - -void GlobalDCEPass::MarkUsedGlobalsAsNeeded(Constant *C) { -  if (GlobalValue *GV = dyn_cast<GlobalValue>(C)) -    return GlobalIsNeeded(GV); - -  // Loop over all of the operands of the constant, adding any globals they -  // use to the list of needed globals. -  for (Use &U : C->operands()) { -    // If we've already processed this constant there's no need to do it again. -    Constant *Op = dyn_cast<Constant>(U); -    if (Op && SeenConstants.insert(Op).second) -      MarkUsedGlobalsAsNeeded(Op); -  } -} -  // RemoveUnusedGlobalValue - Loop over all of the uses of the specified  // GlobalValue, looking for the constant pointer ref that may be pointing to it.  // If found, check to see if the constant pointer ref is safe to destroy, and if diff --git a/lib/Transforms/IPO/GlobalOpt.cpp b/lib/Transforms/IPO/GlobalOpt.cpp index 5b0d5e3bc01e..ade4f21ceb52 100644 --- a/lib/Transforms/IPO/GlobalOpt.cpp +++ b/lib/Transforms/IPO/GlobalOpt.cpp @@ -1819,12 +1819,14 @@ static bool processInternalGlobal(        GS.AccessingFunction->doesNotRecurse() &&        isPointerValueDeadOnEntryToFunction(GS.AccessingFunction, GV,                                            LookupDomTree)) { +    const DataLayout &DL = GV->getParent()->getDataLayout(); +      DEBUG(dbgs() << "LOCALIZING GLOBAL: " << *GV << "\n");      Instruction &FirstI = const_cast<Instruction&>(*GS.AccessingFunction                                                     ->getEntryBlock().begin());      Type *ElemTy = GV->getValueType();      // FIXME: Pass Global's alignment when globals have alignment -    AllocaInst *Alloca = new AllocaInst(ElemTy, nullptr, +    AllocaInst *Alloca = new AllocaInst(ElemTy, DL.getAllocaAddrSpace(), nullptr,                                          GV->getName(), &FirstI);      if (!isa<UndefValue>(GV->getInitializer()))        new StoreInst(GV->getInitializer(), Alloca, &FirstI); @@ -1977,7 +1979,7 @@ static void ChangeCalleesToFastCall(Function *F) {    }  } -static AttributeSet StripNest(LLVMContext &C, const AttributeSet &Attrs) { +static AttributeList StripNest(LLVMContext &C, const AttributeList &Attrs) {    for (unsigned i = 0, e = Attrs.getNumSlots(); i != e; ++i) {      unsigned Index = Attrs.getSlotIndex(i);      if (!Attrs.getSlotAttributes(i).hasAttribute(Index, Attribute::Nest)) @@ -2387,7 +2389,7 @@ OptimizeGlobalAliases(Module &M,  }  static Function *FindCXAAtExit(Module &M, TargetLibraryInfo *TLI) { -  LibFunc::Func F = LibFunc::cxa_atexit; +  LibFunc F = LibFunc_cxa_atexit;    if (!TLI->has(F))      return nullptr; @@ -2396,7 +2398,7 @@ static Function *FindCXAAtExit(Module &M, TargetLibraryInfo *TLI) {      return nullptr;    // Make sure that the function has the correct prototype. -  if (!TLI->getLibFunc(*Fn, F) || F != LibFunc::cxa_atexit) +  if (!TLI->getLibFunc(*Fn, F) || F != LibFunc_cxa_atexit)      return nullptr;    return Fn; diff --git a/lib/Transforms/IPO/GlobalSplit.cpp b/lib/Transforms/IPO/GlobalSplit.cpp index bbbd096e89c0..4705ebe265ae 100644 --- a/lib/Transforms/IPO/GlobalSplit.cpp +++ b/lib/Transforms/IPO/GlobalSplit.cpp @@ -85,7 +85,16 @@ bool splitGlobal(GlobalVariable &GV) {        uint64_t ByteOffset = cast<ConstantInt>(                cast<ConstantAsMetadata>(Type->getOperand(0))->getValue())                ->getZExtValue(); -      if (ByteOffset < SplitBegin || ByteOffset >= SplitEnd) +      // Type metadata may be attached one byte after the end of the vtable, for +      // classes without virtual methods in Itanium ABI. AFAIK, it is never +      // attached to the first byte of a vtable. Subtract one to get the right +      // slice. +      // This is making an assumption that vtable groups are the only kinds of +      // global variables that !type metadata can be attached to, and that they +      // are either Itanium ABI vtable groups or contain a single vtable (i.e. +      // Microsoft ABI vtables). +      uint64_t AttachedTo = (ByteOffset == 0) ? ByteOffset : ByteOffset - 1; +      if (AttachedTo < SplitBegin || AttachedTo >= SplitEnd)          continue;        SplitGV->addMetadata(            LLVMContext::MD_type, diff --git a/lib/Transforms/IPO/IPConstantPropagation.cpp b/lib/Transforms/IPO/IPConstantPropagation.cpp index 916135e33cd5..349807496dc2 100644 --- a/lib/Transforms/IPO/IPConstantPropagation.cpp +++ b/lib/Transforms/IPO/IPConstantPropagation.cpp @@ -136,7 +136,13 @@ static bool PropagateConstantReturn(Function &F) {    // For more details, see GlobalValue::mayBeDerefined.    if (!F.isDefinitionExact())      return false; -     + +  // Don't touch naked functions. The may contain asm returning +  // value we don't see, so we may end up interprocedurally propagating +  // the return value incorrectly. +  if (F.hasFnAttribute(Attribute::Naked)) +    return false; +    // Check to see if this function returns a constant.    SmallVector<Value *,4> RetVals;    StructType *STy = dyn_cast<StructType>(F.getReturnType()); diff --git a/lib/Transforms/IPO/InlineSimple.cpp b/lib/Transforms/IPO/InlineSimple.cpp index 1770445b413f..50e7cc89a3b3 100644 --- a/lib/Transforms/IPO/InlineSimple.cpp +++ b/lib/Transforms/IPO/InlineSimple.cpp @@ -48,7 +48,7 @@ public:    }    explicit SimpleInliner(InlineParams Params) -      : LegacyInlinerBase(ID), Params(Params) { +      : LegacyInlinerBase(ID), Params(std::move(Params)) {      initializeSimpleInlinerPass(*PassRegistry::getPassRegistry());    } @@ -61,7 +61,8 @@ public:          [&](Function &F) -> AssumptionCache & {        return ACT->getAssumptionCache(F);      }; -    return llvm::getInlineCost(CS, Params, TTI, GetAssumptionCache, PSI); +    return llvm::getInlineCost(CS, Params, TTI, GetAssumptionCache, +                               /*GetBFI=*/None, PSI);    }    bool runOnSCC(CallGraphSCC &SCC) override; @@ -92,8 +93,12 @@ Pass *llvm::createFunctionInliningPass(int Threshold) {  }  Pass *llvm::createFunctionInliningPass(unsigned OptLevel, -                                       unsigned SizeOptLevel) { -  return new SimpleInliner(llvm::getInlineParams(OptLevel, SizeOptLevel)); +                                       unsigned SizeOptLevel, +                                       bool DisableInlineHotCallSite) { +  auto Param = llvm::getInlineParams(OptLevel, SizeOptLevel); +  if (DisableInlineHotCallSite) +    Param.HotCallSiteThreshold = 0; +  return new SimpleInliner(Param);  }  Pass *llvm::createFunctionInliningPass(InlineParams &Params) { diff --git a/lib/Transforms/IPO/Inliner.cpp b/lib/Transforms/IPO/Inliner.cpp index 3f4731c937d1..6c83c99ae3be 100644 --- a/lib/Transforms/IPO/Inliner.cpp +++ b/lib/Transforms/IPO/Inliner.cpp @@ -19,6 +19,7 @@  #include "llvm/Analysis/AliasAnalysis.h"  #include "llvm/Analysis/AssumptionCache.h"  #include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/BlockFrequencyInfo.h"  #include "llvm/Analysis/CallGraph.h"  #include "llvm/Analysis/InlineCost.h"  #include "llvm/Analysis/OptimizationDiagnosticInfo.h" @@ -260,8 +261,8 @@ static bool InlineCallIfPossible(  /// Return true if inlining of CS can block the caller from being  /// inlined which is proved to be more beneficial. \p IC is the  /// estimated inline cost associated with callsite \p CS. -/// \p TotalAltCost will be set to the estimated cost of inlining the caller -/// if \p CS is suppressed for inlining. +/// \p TotalSecondaryCost will be set to the estimated cost of inlining the +/// caller if \p CS is suppressed for inlining.  static bool  shouldBeDeferred(Function *Caller, CallSite CS, InlineCost IC,                   int &TotalSecondaryCost, @@ -288,7 +289,7 @@ shouldBeDeferred(Function *Caller, CallSite CS, InlineCost IC,    // treating them as truly abstract units etc.    TotalSecondaryCost = 0;    // The candidate cost to be imposed upon the current function. -  int CandidateCost = IC.getCost() - (InlineConstants::CallPenalty + 1); +  int CandidateCost = IC.getCost() - 1;    // This bool tracks what happens if we do NOT inline C into B.    bool callerWillBeRemoved = Caller->hasLocalLinkage();    // This bool tracks what happens if we DO inline C into B. @@ -325,7 +326,7 @@ shouldBeDeferred(Function *Caller, CallSite CS, InlineCost IC,    // one is set very low by getInlineCost, in anticipation that Caller will    // be removed entirely.  We did not account for this above unless there    // is only one caller of Caller. -  if (callerWillBeRemoved && !Caller->use_empty()) +  if (callerWillBeRemoved && !Caller->hasOneUse())      TotalSecondaryCost -= InlineConstants::LastCallToStaticBonus;    if (inliningPreventsSomeOuterInline && TotalSecondaryCost < IC.getCost()) @@ -342,6 +343,7 @@ static bool shouldInline(CallSite CS,    InlineCost IC = GetInlineCost(CS);    Instruction *Call = CS.getInstruction();    Function *Callee = CS.getCalledFunction(); +  Function *Caller = CS.getCaller();    if (IC.isAlways()) {      DEBUG(dbgs() << "    Inlining: cost=always" @@ -355,19 +357,20 @@ static bool shouldInline(CallSite CS,    if (IC.isNever()) {      DEBUG(dbgs() << "    NOT Inlining: cost=never"                   << ", Call: " << *CS.getInstruction() << "\n"); -    ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, "NeverInline", Call) -             << NV("Callee", Callee) -             << " should never be inlined (cost=never)"); +    ORE.emit(OptimizationRemarkMissed(DEBUG_TYPE, "NeverInline", Call) +             << NV("Callee", Callee) << " not inlined into " +             << NV("Caller", Caller) +             << " because it should never be inlined (cost=never)");      return false;    } -  Function *Caller = CS.getCaller();    if (!IC) {      DEBUG(dbgs() << "    NOT Inlining: cost=" << IC.getCost()                   << ", thres=" << (IC.getCostDelta() + IC.getCost())                   << ", Call: " << *CS.getInstruction() << "\n"); -    ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, "TooCostly", Call) -             << NV("Callee", Callee) << " too costly to inline (cost=" +    ORE.emit(OptimizationRemarkMissed(DEBUG_TYPE, "TooCostly", Call) +             << NV("Callee", Callee) << " not inlined into " +             << NV("Caller", Caller) << " because too costly to inline (cost="               << NV("Cost", IC.getCost()) << ", threshold="               << NV("Threshold", IC.getCostDelta() + IC.getCost()) << ")");      return false; @@ -378,8 +381,8 @@ static bool shouldInline(CallSite CS,      DEBUG(dbgs() << "    NOT Inlining: " << *CS.getInstruction()                   << " Cost = " << IC.getCost()                   << ", outer Cost = " << TotalSecondaryCost << '\n'); -    ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, -                                        "IncreaseCostInOtherContexts", Call) +    ORE.emit(OptimizationRemarkMissed(DEBUG_TYPE, "IncreaseCostInOtherContexts", +                                      Call)               << "Not inlining. Cost of inlining " << NV("Callee", Callee)               << " increases the cost of inlining " << NV("Caller", Caller)               << " in other contexts"); @@ -552,16 +555,11 @@ inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG,          // If the policy determines that we should inline this function,          // try to do so. -        using namespace ore; -        if (!shouldInline(CS, GetInlineCost, ORE)) { -          ORE.emit( -              OptimizationRemarkMissed(DEBUG_TYPE, "NotInlined", DLoc, Block) -              << NV("Callee", Callee) << " will not be inlined into " -              << NV("Caller", Caller)); +        if (!shouldInline(CS, GetInlineCost, ORE))            continue; -        }          // Attempt to inline the function. +        using namespace ore;          if (!InlineCallIfPossible(CS, InlineInfo, InlinedArrayAllocas,                                    InlineHistoryID, InsertLifetime, AARGetter,                                    ImportedFunctionsStats)) { @@ -638,22 +636,12 @@ bool LegacyInlinerBase::inlineCalls(CallGraphSCC &SCC) {    ACT = &getAnalysis<AssumptionCacheTracker>();    PSI = getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();    auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); -  // We compute dedicated AA results for each function in the SCC as needed. We -  // use a lambda referencing external objects so that they live long enough to -  // be queried, but we re-use them each time. -  Optional<BasicAAResult> BAR; -  Optional<AAResults> AAR; -  auto AARGetter = [&](Function &F) -> AAResults & { -    BAR.emplace(createLegacyPMBasicAAResult(*this, F)); -    AAR.emplace(createLegacyPMAAResults(*this, F, *BAR)); -    return *AAR; -  };    auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & {      return ACT->getAssumptionCache(F);    };    return inlineCallsImpl(SCC, CG, GetAssumptionCache, PSI, TLI, InsertLifetime,                           [this](CallSite CS) { return getInlineCost(CS); }, -                         AARGetter, ImportedFunctionsStats); +                         LegacyAARGetter(*this), ImportedFunctionsStats);  }  /// Remove now-dead linkonce functions at the end of @@ -750,9 +738,6 @@ bool LegacyInlinerBase::removeDeadFunctions(CallGraph &CG,  PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC,                                     CGSCCAnalysisManager &AM, LazyCallGraph &CG,                                     CGSCCUpdateResult &UR) { -  FunctionAnalysisManager &FAM = -      AM.getResult<FunctionAnalysisManagerCGSCCProxy>(InitialC, CG) -          .getManager();    const ModuleAnalysisManager &MAM =        AM.getResult<ModuleAnalysisManagerCGSCCProxy>(InitialC, CG).getManager();    bool Changed = false; @@ -761,35 +746,52 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC,    Module &M = *InitialC.begin()->getFunction().getParent();    ProfileSummaryInfo *PSI = MAM.getCachedResult<ProfileSummaryAnalysis>(M); -  std::function<AssumptionCache &(Function &)> GetAssumptionCache = -      [&](Function &F) -> AssumptionCache & { -    return FAM.getResult<AssumptionAnalysis>(F); -  }; - -  // Setup the data structure used to plumb customization into the -  // `InlineFunction` routine. -  InlineFunctionInfo IFI(/*cg=*/nullptr, &GetAssumptionCache); +  // We use a single common worklist for calls across the entire SCC. We +  // process these in-order and append new calls introduced during inlining to +  // the end. +  // +  // Note that this particular order of processing is actually critical to +  // avoid very bad behaviors. Consider *highly connected* call graphs where +  // each function contains a small amonut of code and a couple of calls to +  // other functions. Because the LLVM inliner is fundamentally a bottom-up +  // inliner, it can handle gracefully the fact that these all appear to be +  // reasonable inlining candidates as it will flatten things until they become +  // too big to inline, and then move on and flatten another batch. +  // +  // However, when processing call edges *within* an SCC we cannot rely on this +  // bottom-up behavior. As a consequence, with heavily connected *SCCs* of +  // functions we can end up incrementally inlining N calls into each of +  // N functions because each incremental inlining decision looks good and we +  // don't have a topological ordering to prevent explosions. +  // +  // To compensate for this, we don't process transitive edges made immediate +  // by inlining until we've done one pass of inlining across the entire SCC. +  // Large, highly connected SCCs still lead to some amount of code bloat in +  // this model, but it is uniformly spread across all the functions in the SCC +  // and eventually they all become too large to inline, rather than +  // incrementally maknig a single function grow in a super linear fashion. +  SmallVector<std::pair<CallSite, int>, 16> Calls; -  auto GetInlineCost = [&](CallSite CS) { -    Function &Callee = *CS.getCalledFunction(); -    auto &CalleeTTI = FAM.getResult<TargetIRAnalysis>(Callee); -    return getInlineCost(CS, Params, CalleeTTI, GetAssumptionCache, PSI); -  }; +  // Populate the initial list of calls in this SCC. +  for (auto &N : InitialC) { +    // We want to generally process call sites top-down in order for +    // simplifications stemming from replacing the call with the returned value +    // after inlining to be visible to subsequent inlining decisions. +    // FIXME: Using instructions sequence is a really bad way to do this. +    // Instead we should do an actual RPO walk of the function body. +    for (Instruction &I : instructions(N.getFunction())) +      if (auto CS = CallSite(&I)) +        if (Function *Callee = CS.getCalledFunction()) +          if (!Callee->isDeclaration()) +            Calls.push_back({CS, -1}); +  } +  if (Calls.empty()) +    return PreservedAnalyses::all(); -  // We use a worklist of nodes to process so that we can handle if the SCC -  // structure changes and some nodes are no longer part of the current SCC. We -  // also need to use an updatable pointer for the SCC as a consequence. -  SmallVector<LazyCallGraph::Node *, 16> Nodes; -  for (auto &N : InitialC) -    Nodes.push_back(&N); +  // Capture updatable variables for the current SCC and RefSCC.    auto *C = &InitialC;    auto *RC = &C->getOuterRefSCC(); -  // We also use a secondary worklist of call sites within a particular node to -  // allow quickly continuing to inline through newly inlined call sites where -  // possible. -  SmallVector<std::pair<CallSite, int>, 16> Calls; -    // When inlining a callee produces new call sites, we want to keep track of    // the fact that they were inlined from the callee.  This allows us to avoid    // infinite inlining in some obscure cases.  To represent this, we use an @@ -805,34 +807,58 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC,    // defer deleting these to make it easier to handle the call graph updates.    SmallVector<Function *, 4> DeadFunctions; -  do { -    auto &N = *Nodes.pop_back_val(); +  // Loop forward over all of the calls. Note that we cannot cache the size as +  // inlining can introduce new calls that need to be processed. +  for (int i = 0; i < (int)Calls.size(); ++i) { +    // We expect the calls to typically be batched with sequences of calls that +    // have the same caller, so we first set up some shared infrastructure for +    // this caller. We also do any pruning we can at this layer on the caller +    // alone. +    Function &F = *Calls[i].first.getCaller(); +    LazyCallGraph::Node &N = *CG.lookup(F);      if (CG.lookupSCC(N) != C)        continue; -    Function &F = N.getFunction();      if (F.hasFnAttribute(Attribute::OptimizeNone))        continue; +    DEBUG(dbgs() << "Inlining calls in: " << F.getName() << "\n"); + +    // Get a FunctionAnalysisManager via a proxy for this particular node. We +    // do this each time we visit a node as the SCC may have changed and as +    // we're going to mutate this particular function we want to make sure the +    // proxy is in place to forward any invalidation events. We can use the +    // manager we get here for looking up results for functions other than this +    // node however because those functions aren't going to be mutated by this +    // pass. +    FunctionAnalysisManager &FAM = +        AM.getResult<FunctionAnalysisManagerCGSCCProxy>(*C, CG) +            .getManager(); +    std::function<AssumptionCache &(Function &)> GetAssumptionCache = +        [&](Function &F) -> AssumptionCache & { +      return FAM.getResult<AssumptionAnalysis>(F); +    }; +    auto GetBFI = [&](Function &F) -> BlockFrequencyInfo & { +      return FAM.getResult<BlockFrequencyAnalysis>(F); +    }; + +    auto GetInlineCost = [&](CallSite CS) { +      Function &Callee = *CS.getCalledFunction(); +      auto &CalleeTTI = FAM.getResult<TargetIRAnalysis>(Callee); +      return getInlineCost(CS, Params, CalleeTTI, GetAssumptionCache, {GetBFI}, +                           PSI); +    }; +      // Get the remarks emission analysis for the caller.      auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); -    // We want to generally process call sites top-down in order for -    // simplifications stemming from replacing the call with the returned value -    // after inlining to be visible to subsequent inlining decisions. So we -    // walk the function backwards and then process the back of the vector. -    // FIXME: Using reverse is a really bad way to do this. Instead we should -    // do an actual PO walk of the function body. -    for (Instruction &I : reverse(instructions(F))) -      if (auto CS = CallSite(&I)) -        if (Function *Callee = CS.getCalledFunction()) -          if (!Callee->isDeclaration()) -            Calls.push_back({CS, -1}); - +    // Now process as many calls as we have within this caller in the sequnece. +    // We bail out as soon as the caller has to change so we can update the +    // call graph and prepare the context of that new caller.      bool DidInline = false; -    while (!Calls.empty()) { +    for (; i < (int)Calls.size() && Calls[i].first.getCaller() == &F; ++i) {        int InlineHistoryID;        CallSite CS; -      std::tie(CS, InlineHistoryID) = Calls.pop_back_val(); +      std::tie(CS, InlineHistoryID) = Calls[i];        Function &Callee = *CS.getCalledFunction();        if (InlineHistoryID != -1 && @@ -843,6 +869,13 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC,        if (!shouldInline(CS, GetInlineCost, ORE))          continue; +      // Setup the data structure used to plumb customization into the +      // `InlineFunction` routine. +      InlineFunctionInfo IFI( +          /*cg=*/nullptr, &GetAssumptionCache, +          &FAM.getResult<BlockFrequencyAnalysis>(*(CS.getCaller())), +          &FAM.getResult<BlockFrequencyAnalysis>(Callee)); +        if (!InlineFunction(CS, IFI))          continue;        DidInline = true; @@ -870,6 +903,12 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC,          // made dead by this operation on other functions).          Callee.removeDeadConstantUsers();          if (Callee.use_empty()) { +          Calls.erase( +              std::remove_if(Calls.begin() + i + 1, Calls.end(), +                             [&Callee](const std::pair<CallSite, int> &Call) { +                               return Call.first.getCaller() == &Callee; +                             }), +              Calls.end());            // Clear the body and queue the function itself for deletion when we            // finish inlining and call graph updates.            // Note that after this point, it is an error to do anything other @@ -882,6 +921,10 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC,        }      } +    // Back the call index up by one to put us in a good position to go around +    // the outer loop. +    --i; +      if (!DidInline)        continue;      Changed = true; @@ -896,8 +939,8 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC,      // below.      for (Function *InlinedCallee : InlinedCallees) {        LazyCallGraph::Node &CalleeN = *CG.lookup(*InlinedCallee); -      for (LazyCallGraph::Edge &E : CalleeN) -        RC->insertTrivialRefEdge(N, *E.getNode()); +      for (LazyCallGraph::Edge &E : *CalleeN) +        RC->insertTrivialRefEdge(N, E.getNode());      }      InlinedCallees.clear(); @@ -908,8 +951,9 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC,      // re-use the exact same logic for updating the call graph to reflect the      // change..      C = &updateCGAndAnalysisManagerForFunctionPass(CG, *C, N, AM, UR); +    DEBUG(dbgs() << "Updated inlining SCC: " << *C << "\n");      RC = &C->getOuterRefSCC(); -  } while (!Nodes.empty()); +  }    // Now that we've finished inlining all of the calls across this SCC, delete    // all of the trivially dead functions, updating the call graph and the CGSCC @@ -920,8 +964,13 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC,    // sets.    for (Function *DeadF : DeadFunctions) {      // Get the necessary information out of the call graph and nuke the -    // function there. +    // function there. Also, cclear out any cached analyses.      auto &DeadC = *CG.lookupSCC(*CG.lookup(*DeadF)); +    FunctionAnalysisManager &FAM = +        AM.getResult<FunctionAnalysisManagerCGSCCProxy>(DeadC, CG) +            .getManager(); +    FAM.clear(*DeadF); +    AM.clear(DeadC);      auto &DeadRC = DeadC.getOuterRefSCC();      CG.removeDeadFunction(*DeadF); diff --git a/lib/Transforms/IPO/LowerTypeTests.cpp b/lib/Transforms/IPO/LowerTypeTests.cpp index deb7e819480b..785207efbe5c 100644 --- a/lib/Transforms/IPO/LowerTypeTests.cpp +++ b/lib/Transforms/IPO/LowerTypeTests.cpp @@ -42,8 +42,6 @@  using namespace llvm;  using namespace lowertypetests; -using SummaryAction = LowerTypeTestsSummaryAction; -  #define DEBUG_TYPE "lowertypetests"  STATISTIC(ByteArraySizeBits, "Byte array size in bits"); @@ -57,13 +55,13 @@ static cl::opt<bool> AvoidReuse(      cl::desc("Try to avoid reuse of byte array addresses using aliases"),      cl::Hidden, cl::init(true)); -static cl::opt<SummaryAction> ClSummaryAction( +static cl::opt<PassSummaryAction> ClSummaryAction(      "lowertypetests-summary-action",      cl::desc("What to do with the summary when running this pass"), -    cl::values(clEnumValN(SummaryAction::None, "none", "Do nothing"), -               clEnumValN(SummaryAction::Import, "import", +    cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"), +               clEnumValN(PassSummaryAction::Import, "import",                            "Import typeid resolutions from summary and globals"), -               clEnumValN(SummaryAction::Export, "export", +               clEnumValN(PassSummaryAction::Export, "export",                            "Export typeid resolutions to summary and globals")),      cl::Hidden); @@ -234,8 +232,8 @@ public:  class LowerTypeTestsModule {    Module &M; -  SummaryAction Action; -  ModuleSummaryIndex *Summary; +  ModuleSummaryIndex *ExportSummary; +  const ModuleSummaryIndex *ImportSummary;    bool LinkerSubsectionsViaSymbols;    Triple::ArchType Arch; @@ -253,15 +251,21 @@ class LowerTypeTestsModule {    // Indirect function call index assignment counter for WebAssembly    uint64_t IndirectIndex = 1; -  // Mapping from type identifiers to the call sites that test them. -  DenseMap<Metadata *, std::vector<CallInst *>> TypeTestCallSites; +  // Mapping from type identifiers to the call sites that test them, as well as +  // whether the type identifier needs to be exported to ThinLTO backends as +  // part of the regular LTO phase of the ThinLTO pipeline (see exportTypeId). +  struct TypeIdUserInfo { +    std::vector<CallInst *> CallSites; +    bool IsExported = false; +  }; +  DenseMap<Metadata *, TypeIdUserInfo> TypeIdUsers;    /// This structure describes how to lower type tests for a particular type    /// identifier. It is either built directly from the global analysis (during    /// regular LTO or the regular LTO phase of ThinLTO), or indirectly using type    /// identifier summaries and external symbol references (in ThinLTO backends).    struct TypeIdLowering { -    TypeTestResolution::Kind TheKind; +    TypeTestResolution::Kind TheKind = TypeTestResolution::Unsat;      /// All except Unsat: the start address within the combined global.      Constant *OffsetedGlobal; @@ -274,9 +278,6 @@ class LowerTypeTestsModule {      /// covering members of this type identifier as a multiple of 2^AlignLog2.      Constant *SizeM1; -    /// ByteArray, Inline, AllOnes: range of SizeM1 expressed as a bit width. -    unsigned SizeM1BitWidth; -      /// ByteArray: the byte array to test the address against.      Constant *TheByteArray; @@ -291,6 +292,10 @@ class LowerTypeTestsModule {    Function *WeakInitializerFn = nullptr; +  void exportTypeId(StringRef TypeId, const TypeIdLowering &TIL); +  TypeIdLowering importTypeId(StringRef TypeId); +  void importTypeTest(CallInst *CI); +    BitSetInfo    buildBitSet(Metadata *TypeId,                const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout); @@ -327,8 +332,8 @@ class LowerTypeTestsModule {    void createJumpTable(Function *F, ArrayRef<GlobalTypeMember *> Functions);  public: -  LowerTypeTestsModule(Module &M, SummaryAction Action, -                       ModuleSummaryIndex *Summary); +  LowerTypeTestsModule(Module &M, ModuleSummaryIndex *ExportSummary, +                       const ModuleSummaryIndex *ImportSummary);    bool lower();    // Lower the module using the action and summary passed as command line @@ -341,15 +346,17 @@ struct LowerTypeTests : public ModulePass {    bool UseCommandLine = false; -  SummaryAction Action; -  ModuleSummaryIndex *Summary; +  ModuleSummaryIndex *ExportSummary; +  const ModuleSummaryIndex *ImportSummary;    LowerTypeTests() : ModulePass(ID), UseCommandLine(true) {      initializeLowerTypeTestsPass(*PassRegistry::getPassRegistry());    } -  LowerTypeTests(SummaryAction Action, ModuleSummaryIndex *Summary) -      : ModulePass(ID), Action(Action), Summary(Summary) { +  LowerTypeTests(ModuleSummaryIndex *ExportSummary, +                 const ModuleSummaryIndex *ImportSummary) +      : ModulePass(ID), ExportSummary(ExportSummary), +        ImportSummary(ImportSummary) {      initializeLowerTypeTestsPass(*PassRegistry::getPassRegistry());    } @@ -358,7 +365,7 @@ struct LowerTypeTests : public ModulePass {        return false;      if (UseCommandLine)        return LowerTypeTestsModule::runForTesting(M); -    return LowerTypeTestsModule(M, Action, Summary).lower(); +    return LowerTypeTestsModule(M, ExportSummary, ImportSummary).lower();    }  }; @@ -368,9 +375,10 @@ INITIALIZE_PASS(LowerTypeTests, "lowertypetests", "Lower type metadata", false,                  false)  char LowerTypeTests::ID = 0; -ModulePass *llvm::createLowerTypeTestsPass(SummaryAction Action, -                                           ModuleSummaryIndex *Summary) { -  return new LowerTypeTests(Action, Summary); +ModulePass * +llvm::createLowerTypeTestsPass(ModuleSummaryIndex *ExportSummary, +                               const ModuleSummaryIndex *ImportSummary) { +  return new LowerTypeTests(ExportSummary, ImportSummary);  }  /// Build a bit set for TypeId using the object layouts in @@ -494,10 +502,11 @@ Value *LowerTypeTestsModule::createBitSetTest(IRBuilder<> &B,      return createMaskedBitTest(B, TIL.InlineBits, BitOffset);    } else {      Constant *ByteArray = TIL.TheByteArray; -    if (!LinkerSubsectionsViaSymbols && AvoidReuse) { +    if (!LinkerSubsectionsViaSymbols && AvoidReuse && !ImportSummary) {        // Each use of the byte array uses a different alias. This makes the        // backend less likely to reuse previously computed byte array addresses,        // improving the security of the CFI mechanism based on this pass. +      // This won't work when importing because TheByteArray is external.        ByteArray = GlobalAlias::create(Int8Ty, 0, GlobalValue::PrivateLinkage,                                        "bits_use", ByteArray, &M);      } @@ -593,8 +602,7 @@ Value *LowerTypeTestsModule::lowerTypeTestCall(Metadata *TypeId, CallInst *CI,                       IntPtrTy));    Value *BitOffset = B.CreateOr(OffsetSHR, OffsetSHL); -  Constant *BitSizeConst = ConstantExpr::getZExt(TIL.SizeM1, IntPtrTy); -  Value *OffsetInRange = B.CreateICmpULE(BitOffset, BitSizeConst); +  Value *OffsetInRange = B.CreateICmpULE(BitOffset, TIL.SizeM1);    // If the bit set is all ones, testing against it is unnecessary.    if (TIL.TheKind == TypeTestResolution::AllOnes) @@ -687,6 +695,123 @@ void LowerTypeTestsModule::buildBitSetsFromGlobalVariables(    }  } +/// Export the given type identifier so that ThinLTO backends may import it. +/// Type identifiers are exported by adding coarse-grained information about how +/// to test the type identifier to the summary, and creating symbols in the +/// object file (aliases and absolute symbols) containing fine-grained +/// information about the type identifier. +void LowerTypeTestsModule::exportTypeId(StringRef TypeId, +                                        const TypeIdLowering &TIL) { +  TypeTestResolution &TTRes = +      ExportSummary->getOrInsertTypeIdSummary(TypeId).TTRes; +  TTRes.TheKind = TIL.TheKind; + +  auto ExportGlobal = [&](StringRef Name, Constant *C) { +    GlobalAlias *GA = +        GlobalAlias::create(Int8Ty, 0, GlobalValue::ExternalLinkage, +                            "__typeid_" + TypeId + "_" + Name, C, &M); +    GA->setVisibility(GlobalValue::HiddenVisibility); +  }; + +  if (TIL.TheKind != TypeTestResolution::Unsat) +    ExportGlobal("global_addr", TIL.OffsetedGlobal); + +  if (TIL.TheKind == TypeTestResolution::ByteArray || +      TIL.TheKind == TypeTestResolution::Inline || +      TIL.TheKind == TypeTestResolution::AllOnes) { +    ExportGlobal("align", ConstantExpr::getIntToPtr(TIL.AlignLog2, Int8PtrTy)); +    ExportGlobal("size_m1", ConstantExpr::getIntToPtr(TIL.SizeM1, Int8PtrTy)); + +    uint64_t BitSize = cast<ConstantInt>(TIL.SizeM1)->getZExtValue() + 1; +    if (TIL.TheKind == TypeTestResolution::Inline) +      TTRes.SizeM1BitWidth = (BitSize <= 32) ? 5 : 6; +    else +      TTRes.SizeM1BitWidth = (BitSize <= 128) ? 7 : 32; +  } + +  if (TIL.TheKind == TypeTestResolution::ByteArray) { +    ExportGlobal("byte_array", TIL.TheByteArray); +    ExportGlobal("bit_mask", TIL.BitMask); +  } + +  if (TIL.TheKind == TypeTestResolution::Inline) +    ExportGlobal("inline_bits", +                 ConstantExpr::getIntToPtr(TIL.InlineBits, Int8PtrTy)); +} + +LowerTypeTestsModule::TypeIdLowering +LowerTypeTestsModule::importTypeId(StringRef TypeId) { +  const TypeIdSummary *TidSummary = ImportSummary->getTypeIdSummary(TypeId); +  if (!TidSummary) +    return {}; // Unsat: no globals match this type id. +  const TypeTestResolution &TTRes = TidSummary->TTRes; + +  TypeIdLowering TIL; +  TIL.TheKind = TTRes.TheKind; + +  auto ImportGlobal = [&](StringRef Name, unsigned AbsWidth) { +    Constant *C = +        M.getOrInsertGlobal(("__typeid_" + TypeId + "_" + Name).str(), Int8Ty); +    auto *GV = dyn_cast<GlobalVariable>(C); +    // We only need to set metadata if the global is newly created, in which +    // case it would not have hidden visibility. +    if (!GV || GV->getVisibility() == GlobalValue::HiddenVisibility) +      return C; + +    GV->setVisibility(GlobalValue::HiddenVisibility); +    auto SetAbsRange = [&](uint64_t Min, uint64_t Max) { +      auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min)); +      auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max)); +      GV->setMetadata(LLVMContext::MD_absolute_symbol, +                      MDNode::get(M.getContext(), {MinC, MaxC})); +    }; +    if (AbsWidth == IntPtrTy->getBitWidth()) +      SetAbsRange(~0ull, ~0ull); // Full set. +    else if (AbsWidth) +      SetAbsRange(0, 1ull << AbsWidth); +    return C; +  }; + +  if (TIL.TheKind != TypeTestResolution::Unsat) +    TIL.OffsetedGlobal = ImportGlobal("global_addr", 0); + +  if (TIL.TheKind == TypeTestResolution::ByteArray || +      TIL.TheKind == TypeTestResolution::Inline || +      TIL.TheKind == TypeTestResolution::AllOnes) { +    TIL.AlignLog2 = ConstantExpr::getPtrToInt(ImportGlobal("align", 8), Int8Ty); +    TIL.SizeM1 = ConstantExpr::getPtrToInt( +        ImportGlobal("size_m1", TTRes.SizeM1BitWidth), IntPtrTy); +  } + +  if (TIL.TheKind == TypeTestResolution::ByteArray) { +    TIL.TheByteArray = ImportGlobal("byte_array", 0); +    TIL.BitMask = ImportGlobal("bit_mask", 8); +  } + +  if (TIL.TheKind == TypeTestResolution::Inline) +    TIL.InlineBits = ConstantExpr::getPtrToInt( +        ImportGlobal("inline_bits", 1 << TTRes.SizeM1BitWidth), +        TTRes.SizeM1BitWidth <= 5 ? Int32Ty : Int64Ty); + +  return TIL; +} + +void LowerTypeTestsModule::importTypeTest(CallInst *CI) { +  auto TypeIdMDVal = dyn_cast<MetadataAsValue>(CI->getArgOperand(1)); +  if (!TypeIdMDVal) +    report_fatal_error("Second argument of llvm.type.test must be metadata"); + +  auto TypeIdStr = dyn_cast<MDString>(TypeIdMDVal->getMetadata()); +  if (!TypeIdStr) +    report_fatal_error( +        "Second argument of llvm.type.test must be a metadata string"); + +  TypeIdLowering TIL = importTypeId(TypeIdStr->getString()); +  Value *Lowered = lowerTypeTestCall(TypeIdStr, CI, TIL); +  CI->replaceAllUsesWith(Lowered); +  CI->eraseFromParent(); +} +  void LowerTypeTestsModule::lowerTypeTestCalls(      ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr,      const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout) { @@ -708,16 +833,12 @@ void LowerTypeTestsModule::lowerTypeTestCalls(      TIL.OffsetedGlobal = ConstantExpr::getGetElementPtr(          Int8Ty, CombinedGlobalAddr, ConstantInt::get(IntPtrTy, BSI.ByteOffset)),      TIL.AlignLog2 = ConstantInt::get(Int8Ty, BSI.AlignLog2); +    TIL.SizeM1 = ConstantInt::get(IntPtrTy, BSI.BitSize - 1);      if (BSI.isAllOnes()) {        TIL.TheKind = (BSI.BitSize == 1) ? TypeTestResolution::Single                                         : TypeTestResolution::AllOnes; -      TIL.SizeM1BitWidth = (BSI.BitSize <= 128) ? 7 : 32; -      TIL.SizeM1 = ConstantInt::get((BSI.BitSize <= 128) ? Int8Ty : Int32Ty, -                                    BSI.BitSize - 1);      } else if (BSI.BitSize <= 64) {        TIL.TheKind = TypeTestResolution::Inline; -      TIL.SizeM1BitWidth = (BSI.BitSize <= 32) ? 5 : 6; -      TIL.SizeM1 = ConstantInt::get(Int8Ty, BSI.BitSize - 1);        uint64_t InlineBits = 0;        for (auto Bit : BSI.Bits)          InlineBits |= uint64_t(1) << Bit; @@ -728,17 +849,19 @@ void LowerTypeTestsModule::lowerTypeTestCalls(              (BSI.BitSize <= 32) ? Int32Ty : Int64Ty, InlineBits);      } else {        TIL.TheKind = TypeTestResolution::ByteArray; -      TIL.SizeM1BitWidth = (BSI.BitSize <= 128) ? 7 : 32; -      TIL.SizeM1 = ConstantInt::get((BSI.BitSize <= 128) ? Int8Ty : Int32Ty, -                                    BSI.BitSize - 1);        ++NumByteArraysCreated;        ByteArrayInfo *BAI = createByteArray(BSI);        TIL.TheByteArray = BAI->ByteArray;        TIL.BitMask = BAI->MaskGlobal;      } +    TypeIdUserInfo &TIUI = TypeIdUsers[TypeId]; + +    if (TIUI.IsExported) +      exportTypeId(cast<MDString>(TypeId)->getString(), TIL); +      // Lower each call to llvm.type.test for this type identifier. -    for (CallInst *CI : TypeTestCallSites[TypeId]) { +    for (CallInst *CI : TIUI.CallSites) {        ++NumTypeTestCallsLowered;        Value *Lowered = lowerTypeTestCall(TypeId, CI, TIL);        CI->replaceAllUsesWith(Lowered); @@ -757,9 +880,9 @@ void LowerTypeTestsModule::verifyTypeMDNode(GlobalObject *GO, MDNode *Type) {      report_fatal_error(          "A member of a type identifier may not have an explicit section"); -  if (isa<GlobalVariable>(GO) && GO->isDeclarationForLinker()) -    report_fatal_error( -        "A global var member of a type identifier must be a definition"); +  // FIXME: We previously checked that global var member of a type identifier +  // must be a definition, but the IR linker may leave type metadata on +  // declarations. We should restore this check after fixing PR31759.    auto OffsetConstMD = dyn_cast<ConstantAsMetadata>(Type->getOperand(0));    if (!OffsetConstMD) @@ -1173,13 +1296,11 @@ void LowerTypeTestsModule::buildBitSetsFromDisjointSet(  }  /// Lower all type tests in this module. -LowerTypeTestsModule::LowerTypeTestsModule(Module &M, SummaryAction Action, -                                           ModuleSummaryIndex *Summary) -    : M(M), Action(Action), Summary(Summary) { -  // FIXME: Use these fields. -  (void)this->Action; -  (void)this->Summary; - +LowerTypeTestsModule::LowerTypeTestsModule( +    Module &M, ModuleSummaryIndex *ExportSummary, +    const ModuleSummaryIndex *ImportSummary) +    : M(M), ExportSummary(ExportSummary), ImportSummary(ImportSummary) { +  assert(!(ExportSummary && ImportSummary));    Triple TargetTriple(M.getTargetTriple());    LinkerSubsectionsViaSymbols = TargetTriple.isMacOSX();    Arch = TargetTriple.getArch(); @@ -1203,7 +1324,11 @@ bool LowerTypeTestsModule::runForTesting(Module &M) {      ExitOnErr(errorCodeToError(In.error()));    } -  bool Changed = LowerTypeTestsModule(M, ClSummaryAction, &Summary).lower(); +  bool Changed = +      LowerTypeTestsModule( +          M, ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr, +          ClSummaryAction == PassSummaryAction::Import ? &Summary : nullptr) +          .lower();    if (!ClWriteSummary.empty()) {      ExitOnError ExitOnErr("-lowertypetests-write-summary: " + ClWriteSummary + @@ -1222,9 +1347,18 @@ bool LowerTypeTestsModule::runForTesting(Module &M) {  bool LowerTypeTestsModule::lower() {    Function *TypeTestFunc =        M.getFunction(Intrinsic::getName(Intrinsic::type_test)); -  if (!TypeTestFunc || TypeTestFunc->use_empty()) +  if ((!TypeTestFunc || TypeTestFunc->use_empty()) && !ExportSummary)      return false; +  if (ImportSummary) { +    for (auto UI = TypeTestFunc->use_begin(), UE = TypeTestFunc->use_end(); +         UI != UE;) { +      auto *CI = cast<CallInst>((*UI++).getUser()); +      importTypeTest(CI); +    } +    return true; +  } +    // Equivalence class set containing type identifiers and the globals that    // reference them. This is used to partition the set of type identifiers in    // the module into disjoint sets. @@ -1248,6 +1382,9 @@ bool LowerTypeTestsModule::lower() {    unsigned I = 0;    SmallVector<MDNode *, 2> Types;    for (GlobalObject &GO : M.global_objects()) { +    if (isa<GlobalVariable>(GO) && GO.isDeclarationForLinker()) +      continue; +      Types.clear();      GO.getMetadata(LLVMContext::MD_type, Types);      if (Types.empty()) @@ -1262,33 +1399,57 @@ bool LowerTypeTestsModule::lower() {      }    } -  for (const Use &U : TypeTestFunc->uses()) { -    auto CI = cast<CallInst>(U.getUser()); +  auto AddTypeIdUse = [&](Metadata *TypeId) -> TypeIdUserInfo & { +    // Add the call site to the list of call sites for this type identifier. We +    // also use TypeIdUsers to keep track of whether we have seen this type +    // identifier before. If we have, we don't need to re-add the referenced +    // globals to the equivalence class. +    auto Ins = TypeIdUsers.insert({TypeId, {}}); +    if (Ins.second) { +      // Add the type identifier to the equivalence class. +      GlobalClassesTy::iterator GCI = GlobalClasses.insert(TypeId); +      GlobalClassesTy::member_iterator CurSet = GlobalClasses.findLeader(GCI); + +      // Add the referenced globals to the type identifier's equivalence class. +      for (GlobalTypeMember *GTM : TypeIdInfo[TypeId].RefGlobals) +        CurSet = GlobalClasses.unionSets( +            CurSet, GlobalClasses.findLeader(GlobalClasses.insert(GTM))); +    } + +    return Ins.first->second; +  }; -    auto BitSetMDVal = dyn_cast<MetadataAsValue>(CI->getArgOperand(1)); -    if (!BitSetMDVal) -      report_fatal_error("Second argument of llvm.type.test must be metadata"); -    auto BitSet = BitSetMDVal->getMetadata(); +  if (TypeTestFunc) { +    for (const Use &U : TypeTestFunc->uses()) { +      auto CI = cast<CallInst>(U.getUser()); -    // Add the call site to the list of call sites for this type identifier. We -    // also use TypeTestCallSites to keep track of whether we have seen this -    // type identifier before. If we have, we don't need to re-add the -    // referenced globals to the equivalence class. -    std::pair<DenseMap<Metadata *, std::vector<CallInst *>>::iterator, bool> -        Ins = TypeTestCallSites.insert( -            std::make_pair(BitSet, std::vector<CallInst *>())); -    Ins.first->second.push_back(CI); -    if (!Ins.second) -      continue; +      auto TypeIdMDVal = dyn_cast<MetadataAsValue>(CI->getArgOperand(1)); +      if (!TypeIdMDVal) +        report_fatal_error("Second argument of llvm.type.test must be metadata"); +      auto TypeId = TypeIdMDVal->getMetadata(); +      AddTypeIdUse(TypeId).CallSites.push_back(CI); +    } +  } -    // Add the type identifier to the equivalence class. -    GlobalClassesTy::iterator GCI = GlobalClasses.insert(BitSet); -    GlobalClassesTy::member_iterator CurSet = GlobalClasses.findLeader(GCI); +  if (ExportSummary) { +    DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID; +    for (auto &P : TypeIdInfo) { +      if (auto *TypeId = dyn_cast<MDString>(P.first)) +        MetadataByGUID[GlobalValue::getGUID(TypeId->getString())].push_back( +            TypeId); +    } -    // Add the referenced globals to the type identifier's equivalence class. -    for (GlobalTypeMember *GTM : TypeIdInfo[BitSet].RefGlobals) -      CurSet = GlobalClasses.unionSets( -          CurSet, GlobalClasses.findLeader(GlobalClasses.insert(GTM))); +    for (auto &P : *ExportSummary) { +      for (auto &S : P.second) { +        auto *FS = dyn_cast<FunctionSummary>(S.get()); +        if (!FS) +          continue; +        // FIXME: Only add live functions. +        for (GlobalValue::GUID G : FS->type_tests()) +          for (Metadata *MD : MetadataByGUID[G]) +            AddTypeIdUse(MD).IsExported = true; +      } +    }    }    if (GlobalClasses.empty()) @@ -1349,8 +1510,9 @@ bool LowerTypeTestsModule::lower() {  PreservedAnalyses LowerTypeTestsPass::run(Module &M,                                            ModuleAnalysisManager &AM) { -  bool Changed = -      LowerTypeTestsModule(M, SummaryAction::None, /*Summary=*/nullptr).lower(); +  bool Changed = LowerTypeTestsModule(M, /*ExportSummary=*/nullptr, +                                      /*ImportSummary=*/nullptr) +                     .lower();    if (!Changed)      return PreservedAnalyses::all();    return PreservedAnalyses::none(); diff --git a/lib/Transforms/IPO/MergeFunctions.cpp b/lib/Transforms/IPO/MergeFunctions.cpp index e0bb0eb42b59..771770ddc060 100644 --- a/lib/Transforms/IPO/MergeFunctions.cpp +++ b/lib/Transforms/IPO/MergeFunctions.cpp @@ -96,8 +96,10 @@  #include "llvm/IR/CallSite.h"  #include "llvm/IR/Constants.h"  #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfo.h"  #include "llvm/IR/IRBuilder.h"  #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h"  #include "llvm/IR/LLVMContext.h"  #include "llvm/IR/Module.h"  #include "llvm/IR/ValueHandle.h" @@ -127,6 +129,26 @@ static cl::opt<unsigned> NumFunctionsForSanityCheck(               "'0' disables this check. Works only with '-debug' key."),      cl::init(0), cl::Hidden); +// Under option -mergefunc-preserve-debug-info we: +// - Do not create a new function for a thunk. +// - Retain the debug info for a thunk's parameters (and associated +//   instructions for the debug info) from the entry block. +//   Note: -debug will display the algorithm at work. +// - Create debug-info for the call (to the shared implementation) made by +//   a thunk and its return value. +// - Erase the rest of the function, retaining the (minimally sized) entry +//   block to create a thunk. +// - Preserve a thunk's call site to point to the thunk even when both occur +//   within the same translation unit, to aid debugability. Note that this +//   behaviour differs from the underlying -mergefunc implementation which +//   modifies the thunk's call site to point to the shared implementation +//   when both occur within the same translation unit. +static cl::opt<bool> +    MergeFunctionsPDI("mergefunc-preserve-debug-info", cl::Hidden, +                      cl::init(false), +                      cl::desc("Preserve debug info in thunk when mergefunc " +                               "transformations are made.")); +  namespace {  class FunctionNode { @@ -215,8 +237,21 @@ private:    /// Replace G with a thunk or an alias to F. Deletes G.    void writeThunkOrAlias(Function *F, Function *G); -  /// Replace G with a simple tail call to bitcast(F). Also replace direct uses -  /// of G with bitcast(F). Deletes G. +  /// Fill PDIUnrelatedWL with instructions from the entry block that are +  /// unrelated to parameter related debug info. +  void filterInstsUnrelatedToPDI(BasicBlock *GEntryBlock, +                                 std::vector<Instruction *> &PDIUnrelatedWL); + +  /// Erase the rest of the CFG (i.e. barring the entry block). +  void eraseTail(Function *G); + +  /// Erase the instructions in PDIUnrelatedWL as they are unrelated to the +  /// parameter debug info, from the entry block. +  void eraseInstsUnrelatedToPDI(std::vector<Instruction *> &PDIUnrelatedWL); + +  /// Replace G with a simple tail call to bitcast(F). Also (unless +  /// MergeFunctionsPDI holds) replace direct uses of G with bitcast(F), +  /// delete G.    void writeThunk(Function *F, Function *G);    /// Replace G with an alias to F. Deletes G. @@ -269,8 +304,7 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakVH> &Worklist) {          if (Res1 != -Res2) {            dbgs() << "MERGEFUNC-SANITY: Non-symmetric; triple: " << TripleNumber                   << "\n"; -          F1->dump(); -          F2->dump(); +          dbgs() << *F1 << '\n' << *F2 << '\n';            Valid = false;          } @@ -305,9 +339,7 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakVH> &Worklist) {                     << TripleNumber << "\n";              dbgs() << "Res1, Res3, Res4: " << Res1 << ", " << Res3 << ", "                     << Res4 << "\n"; -            F1->dump(); -            F2->dump(); -            F3->dump(); +            dbgs() << *F1 << '\n' << *F2 << '\n' << *F3 << '\n';              Valid = false;            }          } @@ -400,19 +432,15 @@ void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {        // Transferring other attributes may help other optimizations, but that        // should be done uniformly and not in this ad-hoc way.        auto &Context = New->getContext(); -      auto NewFuncAttrs = New->getAttributes(); -      auto CallSiteAttrs = CS.getAttributes(); - -      CallSiteAttrs = CallSiteAttrs.addAttributes( -          Context, AttributeSet::ReturnIndex, NewFuncAttrs.getRetAttributes()); - -      for (unsigned argIdx = 0; argIdx < CS.arg_size(); argIdx++) { -        AttributeSet Attrs = NewFuncAttrs.getParamAttributes(argIdx); -        if (Attrs.getNumSlots()) -          CallSiteAttrs = CallSiteAttrs.addAttributes(Context, argIdx, Attrs); -      } - -      CS.setAttributes(CallSiteAttrs); +      auto NewPAL = New->getAttributes(); +      SmallVector<AttributeSet, 4> NewArgAttrs; +      for (unsigned argIdx = 0; argIdx < CS.arg_size(); argIdx++) +        NewArgAttrs.push_back(NewPAL.getParamAttributes(argIdx)); +      // Don't transfer attributes from the function to the callee. Function +      // attributes typically aren't relevant to the calling convention or ABI. +      CS.setAttributes(AttributeList::get(Context, /*FnAttrs=*/AttributeSet(), +                                          NewPAL.getRetAttributes(), +                                          NewArgAttrs));        remove(CS.getInstruction()->getParent()->getParent());        U->set(BitcastNew); @@ -461,51 +489,242 @@ static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) {      return Builder.CreateBitCast(V, DestTy);  } -// Replace G with a simple tail call to bitcast(F). Also replace direct uses -// of G with bitcast(F). Deletes G. +// Erase the instructions in PDIUnrelatedWL as they are unrelated to the +// parameter debug info, from the entry block. +void MergeFunctions::eraseInstsUnrelatedToPDI( +    std::vector<Instruction *> &PDIUnrelatedWL) { + +  DEBUG(dbgs() << " Erasing instructions (in reverse order of appearance in " +                  "entry block) unrelated to parameter debug info from entry " +                  "block: {\n"); +  while (!PDIUnrelatedWL.empty()) { +    Instruction *I = PDIUnrelatedWL.back(); +    DEBUG(dbgs() << "  Deleting Instruction: "); +    DEBUG(I->print(dbgs())); +    DEBUG(dbgs() << "\n"); +    I->eraseFromParent(); +    PDIUnrelatedWL.pop_back(); +  } +  DEBUG(dbgs() << " } // Done erasing instructions unrelated to parameter " +                  "debug info from entry block. \n"); +} + +// Reduce G to its entry block. +void MergeFunctions::eraseTail(Function *G) { + +  std::vector<BasicBlock *> WorklistBB; +  for (Function::iterator BBI = std::next(G->begin()), BBE = G->end(); +       BBI != BBE; ++BBI) { +    BBI->dropAllReferences(); +    WorklistBB.push_back(&*BBI); +  } +  while (!WorklistBB.empty()) { +    BasicBlock *BB = WorklistBB.back(); +    BB->eraseFromParent(); +    WorklistBB.pop_back(); +  } +} + +// We are interested in the following instructions from the entry block as being +// related to parameter debug info: +// - @llvm.dbg.declare +// - stores from the incoming parameters to locations on the stack-frame +// - allocas that create these locations on the stack-frame +// - @llvm.dbg.value +// - the entry block's terminator +// The rest are unrelated to debug info for the parameters; fill up +// PDIUnrelatedWL with such instructions. +void MergeFunctions::filterInstsUnrelatedToPDI( +    BasicBlock *GEntryBlock, std::vector<Instruction *> &PDIUnrelatedWL) { + +  std::set<Instruction *> PDIRelated; +  for (BasicBlock::iterator BI = GEntryBlock->begin(), BIE = GEntryBlock->end(); +       BI != BIE; ++BI) { +    if (auto *DVI = dyn_cast<DbgValueInst>(&*BI)) { +      DEBUG(dbgs() << " Deciding: "); +      DEBUG(BI->print(dbgs())); +      DEBUG(dbgs() << "\n"); +      DILocalVariable *DILocVar = DVI->getVariable(); +      if (DILocVar->isParameter()) { +        DEBUG(dbgs() << "  Include (parameter): "); +        DEBUG(BI->print(dbgs())); +        DEBUG(dbgs() << "\n"); +        PDIRelated.insert(&*BI); +      } else { +        DEBUG(dbgs() << "  Delete (!parameter): "); +        DEBUG(BI->print(dbgs())); +        DEBUG(dbgs() << "\n"); +      } +    } else if (auto *DDI = dyn_cast<DbgDeclareInst>(&*BI)) { +      DEBUG(dbgs() << " Deciding: "); +      DEBUG(BI->print(dbgs())); +      DEBUG(dbgs() << "\n"); +      DILocalVariable *DILocVar = DDI->getVariable(); +      if (DILocVar->isParameter()) { +        DEBUG(dbgs() << "  Parameter: "); +        DEBUG(DILocVar->print(dbgs())); +        AllocaInst *AI = dyn_cast_or_null<AllocaInst>(DDI->getAddress()); +        if (AI) { +          DEBUG(dbgs() << "  Processing alloca users: "); +          DEBUG(dbgs() << "\n"); +          for (User *U : AI->users()) { +            if (StoreInst *SI = dyn_cast<StoreInst>(U)) { +              if (Value *Arg = SI->getValueOperand()) { +                if (dyn_cast<Argument>(Arg)) { +                  DEBUG(dbgs() << "  Include: "); +                  DEBUG(AI->print(dbgs())); +                  DEBUG(dbgs() << "\n"); +                  PDIRelated.insert(AI); +                  DEBUG(dbgs() << "   Include (parameter): "); +                  DEBUG(SI->print(dbgs())); +                  DEBUG(dbgs() << "\n"); +                  PDIRelated.insert(SI); +                  DEBUG(dbgs() << "  Include: "); +                  DEBUG(BI->print(dbgs())); +                  DEBUG(dbgs() << "\n"); +                  PDIRelated.insert(&*BI); +                } else { +                  DEBUG(dbgs() << "   Delete (!parameter): "); +                  DEBUG(SI->print(dbgs())); +                  DEBUG(dbgs() << "\n"); +                } +              } +            } else { +              DEBUG(dbgs() << "   Defer: "); +              DEBUG(U->print(dbgs())); +              DEBUG(dbgs() << "\n"); +            } +          } +        } else { +          DEBUG(dbgs() << "  Delete (alloca NULL): "); +          DEBUG(BI->print(dbgs())); +          DEBUG(dbgs() << "\n"); +        } +      } else { +        DEBUG(dbgs() << "  Delete (!parameter): "); +        DEBUG(BI->print(dbgs())); +        DEBUG(dbgs() << "\n"); +      } +    } else if (dyn_cast<TerminatorInst>(BI) == GEntryBlock->getTerminator()) { +      DEBUG(dbgs() << " Will Include Terminator: "); +      DEBUG(BI->print(dbgs())); +      DEBUG(dbgs() << "\n"); +      PDIRelated.insert(&*BI); +    } else { +      DEBUG(dbgs() << " Defer: "); +      DEBUG(BI->print(dbgs())); +      DEBUG(dbgs() << "\n"); +    } +  } +  DEBUG(dbgs() +        << " Report parameter debug info related/related instructions: {\n"); +  for (BasicBlock::iterator BI = GEntryBlock->begin(), BE = GEntryBlock->end(); +       BI != BE; ++BI) { + +    Instruction *I = &*BI; +    if (PDIRelated.find(I) == PDIRelated.end()) { +      DEBUG(dbgs() << "  !PDIRelated: "); +      DEBUG(I->print(dbgs())); +      DEBUG(dbgs() << "\n"); +      PDIUnrelatedWL.push_back(I); +    } else { +      DEBUG(dbgs() << "   PDIRelated: "); +      DEBUG(I->print(dbgs())); +      DEBUG(dbgs() << "\n"); +    } +  } +  DEBUG(dbgs() << " }\n"); +} + +// Replace G with a simple tail call to bitcast(F). Also (unless +// MergeFunctionsPDI holds) replace direct uses of G with bitcast(F), +// delete G. Under MergeFunctionsPDI, we use G itself for creating +// the thunk as we preserve the debug info (and associated instructions) +// from G's entry block pertaining to G's incoming arguments which are +// passed on as corresponding arguments in the call that G makes to F. +// For better debugability, under MergeFunctionsPDI, we do not modify G's +// call sites to point to F even when within the same translation unit.  void MergeFunctions::writeThunk(Function *F, Function *G) { -  if (!G->isInterposable()) { -    // Redirect direct callers of G to F. +  if (!G->isInterposable() && !MergeFunctionsPDI) { +    // Redirect direct callers of G to F. (See note on MergeFunctionsPDI +    // above).      replaceDirectCallers(G, F);    }    // If G was internal then we may have replaced all uses of G with F. If so, -  // stop here and delete G. There's no need for a thunk. -  if (G->hasLocalLinkage() && G->use_empty()) { +  // stop here and delete G. There's no need for a thunk. (See note on +  // MergeFunctionsPDI above). +  if (G->hasLocalLinkage() && G->use_empty() && !MergeFunctionsPDI) {      G->eraseFromParent();      return;    } -  Function *NewG = Function::Create(G->getFunctionType(), G->getLinkage(), "", -                                    G->getParent()); -  BasicBlock *BB = BasicBlock::Create(F->getContext(), "", NewG); -  IRBuilder<> Builder(BB); +  BasicBlock *GEntryBlock = nullptr; +  std::vector<Instruction *> PDIUnrelatedWL; +  BasicBlock *BB = nullptr; +  Function *NewG = nullptr; +  if (MergeFunctionsPDI) { +    DEBUG(dbgs() << "writeThunk: (MergeFunctionsPDI) Do not create a new " +                    "function as thunk; retain original: " +                 << G->getName() << "()\n"); +    GEntryBlock = &G->getEntryBlock(); +    DEBUG(dbgs() << "writeThunk: (MergeFunctionsPDI) filter parameter related " +                    "debug info for " +                 << G->getName() << "() {\n"); +    filterInstsUnrelatedToPDI(GEntryBlock, PDIUnrelatedWL); +    GEntryBlock->getTerminator()->eraseFromParent(); +    BB = GEntryBlock; +  } else { +    NewG = Function::Create(G->getFunctionType(), G->getLinkage(), "", +                            G->getParent()); +    BB = BasicBlock::Create(F->getContext(), "", NewG); +  } +  IRBuilder<> Builder(BB); +  Function *H = MergeFunctionsPDI ? G : NewG;    SmallVector<Value *, 16> Args;    unsigned i = 0;    FunctionType *FFTy = F->getFunctionType(); -  for (Argument & AI : NewG->args()) { +  for (Argument & AI : H->args()) {      Args.push_back(createCast(Builder, &AI, FFTy->getParamType(i)));      ++i;    }    CallInst *CI = Builder.CreateCall(F, Args); +  ReturnInst *RI = nullptr;    CI->setTailCall();    CI->setCallingConv(F->getCallingConv());    CI->setAttributes(F->getAttributes()); -  if (NewG->getReturnType()->isVoidTy()) { -    Builder.CreateRetVoid(); +  if (H->getReturnType()->isVoidTy()) { +    RI = Builder.CreateRetVoid();    } else { -    Builder.CreateRet(createCast(Builder, CI, NewG->getReturnType())); +    RI = Builder.CreateRet(createCast(Builder, CI, H->getReturnType()));    } -  NewG->copyAttributesFrom(G); -  NewG->takeName(G); -  removeUsers(G); -  G->replaceAllUsesWith(NewG); -  G->eraseFromParent(); +  if (MergeFunctionsPDI) { +    DISubprogram *DIS = G->getSubprogram(); +    if (DIS) { +      DebugLoc CIDbgLoc = DebugLoc::get(DIS->getScopeLine(), 0, DIS); +      DebugLoc RIDbgLoc = DebugLoc::get(DIS->getScopeLine(), 0, DIS); +      CI->setDebugLoc(CIDbgLoc); +      RI->setDebugLoc(RIDbgLoc); +    } else { +      DEBUG(dbgs() << "writeThunk: (MergeFunctionsPDI) No DISubprogram for " +                   << G->getName() << "()\n"); +    } +    eraseTail(G); +    eraseInstsUnrelatedToPDI(PDIUnrelatedWL); +    DEBUG(dbgs() << "} // End of parameter related debug info filtering for: " +                 << G->getName() << "()\n"); +  } else { +    NewG->copyAttributesFrom(G); +    NewG->takeName(G); +    removeUsers(G); +    G->replaceAllUsesWith(NewG); +    G->eraseFromParent(); +  } -  DEBUG(dbgs() << "writeThunk: " << NewG->getName() << '\n'); +  DEBUG(dbgs() << "writeThunk: " << H->getName() << '\n');    ++NumThunksWritten;  } diff --git a/lib/Transforms/IPO/PartialInlining.cpp b/lib/Transforms/IPO/PartialInlining.cpp index 7ef3fc1fc2a7..a2f6e5639d9d 100644 --- a/lib/Transforms/IPO/PartialInlining.cpp +++ b/lib/Transforms/IPO/PartialInlining.cpp @@ -33,7 +33,7 @@ STATISTIC(NumPartialInlined, "Number of functions partially inlined");  namespace {  struct PartialInlinerImpl { -  PartialInlinerImpl(InlineFunctionInfo IFI) : IFI(IFI) {} +  PartialInlinerImpl(InlineFunctionInfo IFI) : IFI(std::move(IFI)) {}    bool run(Module &M);    Function *unswitchFunction(Function *F); diff --git a/lib/Transforms/IPO/PassManagerBuilder.cpp b/lib/Transforms/IPO/PassManagerBuilder.cpp index 941efb210d1c..f11b58d1adc4 100644 --- a/lib/Transforms/IPO/PassManagerBuilder.cpp +++ b/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -93,10 +93,6 @@ static cl::opt<CFLAAType>                          clEnumValN(CFLAAType::Both, "both",                                     "Enable both variants of CFL-AA"))); -static cl::opt<bool> -EnableMLSM("mlsm", cl::init(true), cl::Hidden, -           cl::desc("Enable motion of merged load and store")); -  static cl::opt<bool> EnableLoopInterchange(      "enable-loopinterchange", cl::init(false), cl::Hidden,      cl::desc("Enable the new, experimental LoopInterchange Pass")); @@ -141,8 +137,8 @@ static cl::opt<int> PreInlineThreshold(               "(default = 75)"));  static cl::opt<bool> EnableGVNHoist( -    "enable-gvn-hoist", cl::init(false), cl::Hidden, -    cl::desc("Enable the GVN hoisting pass")); +    "enable-gvn-hoist", cl::init(true), cl::Hidden, +    cl::desc("Enable the GVN hoisting pass (default = on)"));  static cl::opt<bool>      DisableLibCallsShrinkWrap("disable-libcalls-shrinkwrap", cl::init(false), @@ -172,6 +168,7 @@ PassManagerBuilder::PassManagerBuilder() {      PGOInstrUse = RunPGOInstrUse;      PrepareForThinLTO = EnablePrepareForThinLTO;      PerformThinLTO = false; +    DivergentTarget = false;  }  PassManagerBuilder::~PassManagerBuilder() { @@ -248,8 +245,6 @@ void PassManagerBuilder::populateFunctionPassManager(    FPM.add(createCFGSimplificationPass());    FPM.add(createSROAPass());    FPM.add(createEarlyCSEPass()); -  if(EnableGVNHoist) -    FPM.add(createGVNHoistPass());    FPM.add(createLowerExpectIntrinsicPass());  } @@ -294,6 +289,8 @@ void PassManagerBuilder::addFunctionSimplificationPasses(    // Break up aggregate allocas, using SSAUpdater.    MPM.add(createSROAPass());    MPM.add(createEarlyCSEPass());              // Catch trivial redundancies +  if (EnableGVNHoist) +    MPM.add(createGVNHoistPass());    // Speculative execution if the target has divergent branches; otherwise nop.    MPM.add(createSpeculativeExecutionIfHasBranchDivergencePass());    MPM.add(createJumpThreadingPass());         // Thread jumps. @@ -305,29 +302,34 @@ void PassManagerBuilder::addFunctionSimplificationPasses(      MPM.add(createLibCallsShrinkWrapPass());    addExtensionsToPM(EP_Peephole, MPM); +  // Optimize memory intrinsic calls based on the profiled size information. +  if (SizeLevel == 0) +    MPM.add(createPGOMemOPSizeOptLegacyPass()); +    MPM.add(createTailCallEliminationPass()); // Eliminate tail calls    MPM.add(createCFGSimplificationPass());     // Merge & remove BBs    MPM.add(createReassociatePass());           // Reassociate expressions    // Rotate Loop - disable header duplication at -Oz    MPM.add(createLoopRotatePass(SizeLevel == 2 ? 0 : -1));    MPM.add(createLICMPass());                  // Hoist loop invariants -  MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3)); +  MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3, DivergentTarget));    MPM.add(createCFGSimplificationPass());    addInstructionCombiningPass(MPM);    MPM.add(createIndVarSimplifyPass());        // Canonicalize indvars    MPM.add(createLoopIdiomPass());             // Recognize idioms like memset. +  addExtensionsToPM(EP_LateLoopOptimizations, MPM);    MPM.add(createLoopDeletionPass());          // Delete dead loops +    if (EnableLoopInterchange) {      MPM.add(createLoopInterchangePass()); // Interchange loops      MPM.add(createCFGSimplificationPass());    }    if (!DisableUnrollLoops) -    MPM.add(createSimpleLoopUnrollPass());    // Unroll small loops +    MPM.add(createSimpleLoopUnrollPass(OptLevel));    // Unroll small loops    addExtensionsToPM(EP_LoopOptimizerEnd, MPM);    if (OptLevel > 1) { -    if (EnableMLSM) -      MPM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds +    MPM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds      MPM.add(NewGVN ? createNewGVNPass()                     : createGVNPass(DisableGVNLoadPRE)); // Remove redundancies    } @@ -369,7 +371,7 @@ void PassManagerBuilder::addFunctionSimplificationPasses(        // BBVectorize may have significantly shortened a loop body; unroll again.        if (!DisableUnrollLoops) -        MPM.add(createLoopUnrollPass()); +        MPM.add(createLoopUnrollPass(OptLevel));      }    } @@ -434,7 +436,16 @@ void PassManagerBuilder::populateModulePassManager(    // earlier in the pass pipeline, here before globalopt. Otherwise imported    // available_externally functions look unreferenced and are removed.    if (PerformThinLTO) -    MPM.add(createPGOIndirectCallPromotionLegacyPass(/*InLTO = */ true)); +    MPM.add(createPGOIndirectCallPromotionLegacyPass(/*InLTO = */ true, +                                                     !PGOSampleUse.empty())); + +  // For SamplePGO in ThinLTO compile phase, we do not want to unroll loops +  // as it will change the CFG too much to make the 2nd profile annotation +  // in backend more difficult. +  bool PrepareForThinLTOUsingPGOSampleProfile = +      PrepareForThinLTO && !PGOSampleUse.empty(); +  if (PrepareForThinLTOUsingPGOSampleProfile) +    DisableUnrollLoops = true;    if (!DisableUnitAtATime) {      // Infer attributes about declarations if possible. @@ -454,14 +465,18 @@ void PassManagerBuilder::populateModulePassManager(      MPM.add(createCFGSimplificationPass()); // Clean up after IPCP & DAE    } -  if (!PerformThinLTO) { +  // For SamplePGO in ThinLTO compile phase, we do not want to do indirect +  // call promotion as it will change the CFG too much to make the 2nd +  // profile annotation in backend more difficult. +  if (!PerformThinLTO && !PrepareForThinLTOUsingPGOSampleProfile) {      /// PGO instrumentation is added during the compile phase for ThinLTO, do      /// not run it a second time      addPGOInstrPasses(MPM);      // Indirect call promotion that promotes intra-module targets only.      // For ThinLTO this is done earlier due to interactions with globalopt      // for imported functions. -    MPM.add(createPGOIndirectCallPromotionLegacyPass()); +    MPM.add( +        createPGOIndirectCallPromotionLegacyPass(false, !PGOSampleUse.empty()));    }    if (EnableNonLTOGlobalsModRef) @@ -589,7 +604,7 @@ void PassManagerBuilder::populateModulePassManager(      MPM.add(createCorrelatedValuePropagationPass());      addInstructionCombiningPass(MPM);      MPM.add(createLICMPass()); -    MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3)); +    MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3, DivergentTarget));      MPM.add(createCFGSimplificationPass());      addInstructionCombiningPass(MPM);    } @@ -615,16 +630,16 @@ void PassManagerBuilder::populateModulePassManager(        // BBVectorize may have significantly shortened a loop body; unroll again.        if (!DisableUnrollLoops) -        MPM.add(createLoopUnrollPass()); +        MPM.add(createLoopUnrollPass(OptLevel));      }    }    addExtensionsToPM(EP_Peephole, MPM); -  MPM.add(createCFGSimplificationPass()); +  MPM.add(createLateCFGSimplificationPass()); // Switches to lookup tables    addInstructionCombiningPass(MPM);    if (!DisableUnrollLoops) { -    MPM.add(createLoopUnrollPass());    // Unroll small loops +    MPM.add(createLoopUnrollPass(OptLevel));    // Unroll small loops      // LoopUnroll may generate some redundency to cleanup.      addInstructionCombiningPass(MPM); @@ -684,7 +699,8 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) {      // left by the earlier promotion pass that promotes intra-module targets.      // This two-step promotion is to save the compile time. For LTO, it should      // produce the same result as if we only do promotion here. -    PM.add(createPGOIndirectCallPromotionLegacyPass(true)); +    PM.add( +        createPGOIndirectCallPromotionLegacyPass(true, !PGOSampleUse.empty()));      // Propagate constants at call sites into the functions they call.  This      // opens opportunities for globalopt (and inlining) by substituting function @@ -703,7 +719,7 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) {    PM.add(createGlobalSplitPass());    // Apply whole-program devirtualization and virtual constant propagation. -  PM.add(createWholeProgramDevirtPass()); +  PM.add(createWholeProgramDevirtPass(ExportSummary, nullptr));    // That's all we need at opt level 1.    if (OptLevel == 1) @@ -759,8 +775,7 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) {    PM.add(createGlobalsAAWrapperPass()); // IP alias analysis.    PM.add(createLICMPass());                 // Hoist loop invariants. -  if (EnableMLSM) -    PM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds. +  PM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds.    PM.add(NewGVN ? createNewGVNPass()                  : createGVNPass(DisableGVNLoadPRE)); // Remove redundancies.    PM.add(createMemCpyOptPass());            // Remove dead memcpys. @@ -775,11 +790,11 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) {      PM.add(createLoopInterchangePass());    if (!DisableUnrollLoops) -    PM.add(createSimpleLoopUnrollPass());   // Unroll small loops +    PM.add(createSimpleLoopUnrollPass(OptLevel));   // Unroll small loops    PM.add(createLoopVectorizePass(true, LoopVectorize));    // The vectorizer may have significantly shortened a loop body; unroll again.    if (!DisableUnrollLoops) -    PM.add(createLoopUnrollPass()); +    PM.add(createLoopUnrollPass(OptLevel));    // Now that we've optimized loops (in particular loop induction variables),    // we may have exposed more scalar opportunities. Run parts of the scalar @@ -833,6 +848,23 @@ void PassManagerBuilder::populateThinLTOPassManager(    if (VerifyInput)      PM.add(createVerifierPass()); +  if (ImportSummary) { +    // These passes import type identifier resolutions for whole-program +    // devirtualization and CFI. They must run early because other passes may +    // disturb the specific instruction patterns that these passes look for, +    // creating dependencies on resolutions that may not appear in the summary. +    // +    // For example, GVN may transform the pattern assume(type.test) appearing in +    // two basic blocks into assume(phi(type.test, type.test)), which would +    // transform a dependency on a WPD resolution into a dependency on a type +    // identifier resolution for CFI. +    // +    // Also, WPD has access to more precise information than ICP and can +    // devirtualize more effectively, so it should operate on the IR first. +    PM.add(createWholeProgramDevirtPass(nullptr, ImportSummary)); +    PM.add(createLowerTypeTestsPass(nullptr, ImportSummary)); +  } +    populateModulePassManager(PM);    if (VerifyOutput) @@ -857,8 +889,7 @@ void PassManagerBuilder::populateLTOPassManager(legacy::PassManagerBase &PM) {    // Lower type metadata and the type.test intrinsic. This pass supports Clang's    // control flow integrity mechanisms (-fsanitize=cfi*) and needs to run at    // link time if CFI is enabled. The pass does nothing if CFI is disabled. -  PM.add(createLowerTypeTestsPass(LowerTypeTestsSummaryAction::None, -                                  /*Summary=*/nullptr)); +  PM.add(createLowerTypeTestsPass(ExportSummary, nullptr));    if (OptLevel != 0)      addLateLTOOptimizationPasses(PM); diff --git a/lib/Transforms/IPO/SampleProfile.cpp b/lib/Transforms/IPO/SampleProfile.cpp index 6a43f8dbac48..3371de6e3d14 100644 --- a/lib/Transforms/IPO/SampleProfile.cpp +++ b/lib/Transforms/IPO/SampleProfile.cpp @@ -35,6 +35,7 @@  #include "llvm/IR/DiagnosticInfo.h"  #include "llvm/IR/Dominators.h"  #include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h"  #include "llvm/IR/InstIterator.h"  #include "llvm/IR/Instructions.h"  #include "llvm/IR/IntrinsicInst.h" @@ -43,6 +44,7 @@  #include "llvm/IR/Metadata.h"  #include "llvm/IR/Module.h"  #include "llvm/Pass.h" +#include "llvm/ProfileData/InstrProf.h"  #include "llvm/ProfileData/SampleProfReader.h"  #include "llvm/Support/CommandLine.h"  #include "llvm/Support/Debug.h" @@ -50,6 +52,7 @@  #include "llvm/Support/Format.h"  #include "llvm/Support/raw_ostream.h"  #include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/Instrumentation.h"  #include "llvm/Transforms/Utils/Cloning.h"  #include <cctype> @@ -159,8 +162,11 @@ protected:    ErrorOr<uint64_t> getInstWeight(const Instruction &I);    ErrorOr<uint64_t> getBlockWeight(const BasicBlock *BB);    const FunctionSamples *findCalleeFunctionSamples(const Instruction &I) const; +  std::vector<const FunctionSamples *> +  findIndirectCallFunctionSamples(const Instruction &I) const;    const FunctionSamples *findFunctionSamples(const Instruction &I) const; -  bool inlineHotFunctions(Function &F); +  bool inlineHotFunctions(Function &F, +                          DenseSet<GlobalValue::GUID> &ImportGUIDs);    void printEdgeWeight(raw_ostream &OS, Edge E);    void printBlockWeight(raw_ostream &OS, const BasicBlock *BB) const;    void printBlockEquivalence(raw_ostream &OS, const BasicBlock *BB); @@ -173,7 +179,7 @@ protected:    void buildEdges(Function &F);    bool propagateThroughEdges(Function &F, bool UpdateBlockCount);    void computeDominanceAndLoopInfo(Function &F); -  unsigned getOffset(unsigned L, unsigned H) const; +  unsigned getOffset(const DILocation *DIL) const;    void clearFunctionData();    /// \brief Map basic blocks to their computed weights. @@ -326,11 +332,12 @@ SampleCoverageTracker::countUsedRecords(const FunctionSamples *FS) const {    // If there are inlined callsites in this function, count the samples found    // in the respective bodies. However, do not bother counting callees with 0    // total samples, these are callees that were never invoked at runtime. -  for (const auto &I : FS->getCallsiteSamples()) { -    const FunctionSamples *CalleeSamples = &I.second; -    if (callsiteIsHot(FS, CalleeSamples)) -      Count += countUsedRecords(CalleeSamples); -  } +  for (const auto &I : FS->getCallsiteSamples()) +    for (const auto &J : I.second) { +      const FunctionSamples *CalleeSamples = &J.second; +      if (callsiteIsHot(FS, CalleeSamples)) +        Count += countUsedRecords(CalleeSamples); +    }    return Count;  } @@ -343,11 +350,12 @@ SampleCoverageTracker::countBodyRecords(const FunctionSamples *FS) const {    unsigned Count = FS->getBodySamples().size();    // Only count records in hot callsites. -  for (const auto &I : FS->getCallsiteSamples()) { -    const FunctionSamples *CalleeSamples = &I.second; -    if (callsiteIsHot(FS, CalleeSamples)) -      Count += countBodyRecords(CalleeSamples); -  } +  for (const auto &I : FS->getCallsiteSamples()) +    for (const auto &J : I.second) { +      const FunctionSamples *CalleeSamples = &J.second; +      if (callsiteIsHot(FS, CalleeSamples)) +        Count += countBodyRecords(CalleeSamples); +    }    return Count;  } @@ -362,11 +370,12 @@ SampleCoverageTracker::countBodySamples(const FunctionSamples *FS) const {      Total += I.second.getSamples();    // Only count samples in hot callsites. -  for (const auto &I : FS->getCallsiteSamples()) { -    const FunctionSamples *CalleeSamples = &I.second; -    if (callsiteIsHot(FS, CalleeSamples)) -      Total += countBodySamples(CalleeSamples); -  } +  for (const auto &I : FS->getCallsiteSamples()) +    for (const auto &J : I.second) { +      const FunctionSamples *CalleeSamples = &J.second; +      if (callsiteIsHot(FS, CalleeSamples)) +        Total += countBodySamples(CalleeSamples); +    }    return Total;  } @@ -398,15 +407,11 @@ void SampleProfileLoader::clearFunctionData() {    CoverageTracker.clear();  } -/// \brief Returns the offset of lineno \p L to head_lineno \p H -/// -/// \param L  Lineno -/// \param H  Header lineno of the function -/// -/// \returns offset to the header lineno. 16 bits are used to represent offset. +/// Returns the line offset to the start line of the subprogram.  /// We assume that a single function will not exceed 65535 LOC. -unsigned SampleProfileLoader::getOffset(unsigned L, unsigned H) const { -  return (L - H) & 0xffff; +unsigned SampleProfileLoader::getOffset(const DILocation *DIL) const { +  return (DIL->getLine() - DIL->getScope()->getSubprogram()->getLine()) & +         0xffff;  }  /// \brief Print the weight of edge \p E on stream \p OS. @@ -451,8 +456,7 @@ void SampleProfileLoader::printBlockWeight(raw_ostream &OS,  /// \param Inst Instruction to query.  ///  /// \returns the weight of \p Inst. -ErrorOr<uint64_t> -SampleProfileLoader::getInstWeight(const Instruction &Inst) { +ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) {    const DebugLoc &DLoc = Inst.getDebugLoc();    if (!DLoc)      return std::error_code(); @@ -470,19 +474,14 @@ SampleProfileLoader::getInstWeight(const Instruction &Inst) {    // If a call/invoke instruction is inlined in profile, but not inlined here,    // it means that the inlined callsite has no sample, thus the call    // instruction should have 0 count. -  bool IsCall = isa<CallInst>(Inst) || isa<InvokeInst>(Inst); -  if (IsCall && findCalleeFunctionSamples(Inst)) +  if ((isa<CallInst>(Inst) || isa<InvokeInst>(Inst)) && +      findCalleeFunctionSamples(Inst))      return 0;    const DILocation *DIL = DLoc; -  unsigned Lineno = DLoc.getLine(); -  unsigned HeaderLineno = DIL->getScope()->getSubprogram()->getLine(); - -  uint32_t LineOffset = getOffset(Lineno, HeaderLineno); -  uint32_t Discriminator = DIL->getDiscriminator(); -  ErrorOr<uint64_t> R = IsCall -                            ? FS->findCallSamplesAt(LineOffset, Discriminator) -                            : FS->findSamplesAt(LineOffset, Discriminator); +  uint32_t LineOffset = getOffset(DIL); +  uint32_t Discriminator = DIL->getBaseDiscriminator(); +  ErrorOr<uint64_t> R = FS->findSamplesAt(LineOffset, Discriminator);    if (R) {      bool FirstMark =          CoverageTracker.markSamplesUsed(FS, LineOffset, Discriminator, R.get()); @@ -491,13 +490,14 @@ SampleProfileLoader::getInstWeight(const Instruction &Inst) {        LLVMContext &Ctx = F->getContext();        emitOptimizationRemark(            Ctx, DEBUG_TYPE, *F, DLoc, -          Twine("Applied ") + Twine(*R) + " samples from profile (offset: " + -              Twine(LineOffset) + +          Twine("Applied ") + Twine(*R) + +              " samples from profile (offset: " + Twine(LineOffset) +                ((Discriminator) ? Twine(".") + Twine(Discriminator) : "") + ")");      } -    DEBUG(dbgs() << "    " << Lineno << "." << DIL->getDiscriminator() << ":" -                 << Inst << " (line offset: " << Lineno - HeaderLineno << "." -                 << DIL->getDiscriminator() << " - weight: " << R.get() +    DEBUG(dbgs() << "    " << DLoc.getLine() << "." +                 << DIL->getBaseDiscriminator() << ":" << Inst +                 << " (line offset: " << LineOffset << "." +                 << DIL->getBaseDiscriminator() << " - weight: " << R.get()                   << ")\n");    }    return R; @@ -511,8 +511,7 @@ SampleProfileLoader::getInstWeight(const Instruction &Inst) {  /// \param BB The basic block to query.  ///  /// \returns the weight for \p BB. -ErrorOr<uint64_t> -SampleProfileLoader::getBlockWeight(const BasicBlock *BB) { +ErrorOr<uint64_t> SampleProfileLoader::getBlockWeight(const BasicBlock *BB) {    uint64_t Max = 0;    bool HasWeight = false;    for (auto &I : BB->getInstList()) { @@ -565,16 +564,49 @@ SampleProfileLoader::findCalleeFunctionSamples(const Instruction &Inst) const {    if (!DIL) {      return nullptr;    } -  DISubprogram *SP = DIL->getScope()->getSubprogram(); -  if (!SP) -    return nullptr; + +  StringRef CalleeName; +  if (const CallInst *CI = dyn_cast<CallInst>(&Inst)) +    if (Function *Callee = CI->getCalledFunction()) +      CalleeName = Callee->getName();    const FunctionSamples *FS = findFunctionSamples(Inst);    if (FS == nullptr)      return nullptr; -  return FS->findFunctionSamplesAt(LineLocation( -      getOffset(DIL->getLine(), SP->getLine()), DIL->getDiscriminator())); +  return FS->findFunctionSamplesAt( +      LineLocation(getOffset(DIL), DIL->getBaseDiscriminator()), CalleeName); +} + +/// Returns a vector of FunctionSamples that are the indirect call targets +/// of \p Inst. The vector is sorted by the total number of samples. +std::vector<const FunctionSamples *> +SampleProfileLoader::findIndirectCallFunctionSamples( +    const Instruction &Inst) const { +  const DILocation *DIL = Inst.getDebugLoc(); +  std::vector<const FunctionSamples *> R; + +  if (!DIL) { +    return R; +  } + +  const FunctionSamples *FS = findFunctionSamples(Inst); +  if (FS == nullptr) +    return R; + +  if (const FunctionSamplesMap *M = FS->findFunctionSamplesMapAt( +          LineLocation(getOffset(DIL), DIL->getBaseDiscriminator()))) { +    if (M->size() == 0) +      return R; +    for (const auto &NameFS : *M) { +      R.push_back(&NameFS.second); +    } +    std::sort(R.begin(), R.end(), +              [](const FunctionSamples *L, const FunctionSamples *R) { +                return L->getTotalSamples() > R->getTotalSamples(); +              }); +  } +  return R;  }  /// \brief Get the FunctionSamples for an instruction. @@ -588,23 +620,23 @@ SampleProfileLoader::findCalleeFunctionSamples(const Instruction &Inst) const {  /// \returns the FunctionSamples pointer to the inlined instance.  const FunctionSamples *  SampleProfileLoader::findFunctionSamples(const Instruction &Inst) const { -  SmallVector<LineLocation, 10> S; +  SmallVector<std::pair<LineLocation, StringRef>, 10> S;    const DILocation *DIL = Inst.getDebugLoc(); -  if (!DIL) { +  if (!DIL)      return Samples; -  } + +  const DILocation *PrevDIL = DIL;    for (DIL = DIL->getInlinedAt(); DIL; DIL = DIL->getInlinedAt()) { -    DISubprogram *SP = DIL->getScope()->getSubprogram(); -    if (!SP) -      return nullptr; -    S.push_back(LineLocation(getOffset(DIL->getLine(), SP->getLine()), -                             DIL->getDiscriminator())); +    S.push_back(std::make_pair( +        LineLocation(getOffset(DIL), DIL->getBaseDiscriminator()), +        PrevDIL->getScope()->getSubprogram()->getLinkageName())); +    PrevDIL = DIL;    }    if (S.size() == 0)      return Samples;    const FunctionSamples *FS = Samples;    for (int i = S.size() - 1; i >= 0 && FS != nullptr; i--) { -    FS = FS->findFunctionSamplesAt(S[i]); +    FS = FS->findFunctionSamplesAt(S[i].first, S[i].second);    }    return FS;  } @@ -614,14 +646,17 @@ SampleProfileLoader::findFunctionSamples(const Instruction &Inst) const {  /// Iteratively traverse all callsites of the function \p F, and find if  /// the corresponding inlined instance exists and is hot in profile. If  /// it is hot enough, inline the callsites and adds new callsites of the -/// callee into the caller. -/// -/// TODO: investigate the possibility of not invoking InlineFunction directly. +/// callee into the caller. If the call is an indirect call, first promote +/// it to direct call. Each indirect call is limited with a single target.  ///  /// \param F function to perform iterative inlining. +/// \param ImportGUIDs a set to be updated to include all GUIDs that come +///     from a different module but inlined in the profiled binary.  ///  /// \returns True if there is any inline happened. -bool SampleProfileLoader::inlineHotFunctions(Function &F) { +bool SampleProfileLoader::inlineHotFunctions( +    Function &F, DenseSet<GlobalValue::GUID> &ImportGUIDs) { +  DenseSet<Instruction *> PromotedInsns;    bool Changed = false;    LLVMContext &Ctx = F.getContext();    std::function<AssumptionCache &(Function &)> GetAssumptionCache = [&]( @@ -647,18 +682,42 @@ bool SampleProfileLoader::inlineHotFunctions(Function &F) {      }      for (auto I : CIS) {        InlineFunctionInfo IFI(nullptr, ACT ? &GetAssumptionCache : nullptr); -      CallSite CS(I); -      Function *CalledFunction = CS.getCalledFunction(); -      if (!CalledFunction || !CalledFunction->getSubprogram()) +      Function *CalledFunction = CallSite(I).getCalledFunction(); +      Instruction *DI = I; +      if (!CalledFunction && !PromotedInsns.count(I) && +          CallSite(I).isIndirectCall()) +        for (const auto *FS : findIndirectCallFunctionSamples(*I)) { +          auto CalleeFunctionName = FS->getName(); +          const char *Reason = "Callee function not available"; +          CalledFunction = F.getParent()->getFunction(CalleeFunctionName); +          if (CalledFunction && isLegalToPromote(I, CalledFunction, &Reason)) { +            // The indirect target was promoted and inlined in the profile, as a +            // result, we do not have profile info for the branch probability. +            // We set the probability to 80% taken to indicate that the static +            // call is likely taken. +            DI = dyn_cast<Instruction>( +                promoteIndirectCall(I, CalledFunction, 80, 100, false) +                    ->stripPointerCasts()); +            PromotedInsns.insert(I); +          } else { +            DEBUG(dbgs() << "\nFailed to promote indirect call to " +                         << CalleeFunctionName << " because " << Reason +                         << "\n"); +            continue; +          } +        } +      if (!CalledFunction || !CalledFunction->getSubprogram()) { +        findCalleeFunctionSamples(*I)->findImportedFunctions( +            ImportGUIDs, F.getParent(), +            Samples->getTotalSamples() * SampleProfileHotThreshold / 100);          continue; +      }        DebugLoc DLoc = I->getDebugLoc(); -      uint64_t NumSamples = findCalleeFunctionSamples(*I)->getTotalSamples(); -      if (InlineFunction(CS, IFI)) { +      if (InlineFunction(CallSite(DI), IFI)) {          LocalChanged = true;          emitOptimizationRemark(Ctx, DEBUG_TYPE, F, DLoc,                                 Twine("inlined hot callee '") + -                                   CalledFunction->getName() + "' with " + -                                   Twine(NumSamples) + " samples into '" + +                                   CalledFunction->getName() + "' into '" +                                     F.getName() + "'");        }      } @@ -994,6 +1053,26 @@ void SampleProfileLoader::buildEdges(Function &F) {    }  } +/// Sorts the CallTargetMap \p M by count in descending order and stores the +/// sorted result in \p Sorted. Returns the total counts. +static uint64_t SortCallTargets(SmallVector<InstrProfValueData, 2> &Sorted, +                                const SampleRecord::CallTargetMap &M) { +  Sorted.clear(); +  uint64_t Sum = 0; +  for (auto I = M.begin(); I != M.end(); ++I) { +    Sum += I->getValue(); +    Sorted.push_back({Function::getGUID(I->getKey()), I->getValue()}); +  } +  std::sort(Sorted.begin(), Sorted.end(), +            [](const InstrProfValueData &L, const InstrProfValueData &R) { +              if (L.Count == R.Count) +                return L.Value > R.Value; +              else +                return L.Count > R.Count; +            }); +  return Sum; +} +  /// \brief Propagate weights into edges  ///  /// The following rules are applied to every block BB in the CFG: @@ -1015,10 +1094,6 @@ void SampleProfileLoader::propagateWeights(Function &F) {    bool Changed = true;    unsigned I = 0; -  // Add an entry count to the function using the samples gathered -  // at the function entry. -  F.setEntryCount(Samples->getHeadSamples() + 1); -    // If BB weight is larger than its corresponding loop's header BB weight,    // use the BB weight to replace the loop header BB weight.    for (auto &BI : F) { @@ -1071,13 +1146,32 @@ void SampleProfileLoader::propagateWeights(Function &F) {      if (BlockWeights[BB]) {        for (auto &I : BB->getInstList()) { -        if (CallInst *CI = dyn_cast<CallInst>(&I)) { -          if (!dyn_cast<IntrinsicInst>(&I)) { -            SmallVector<uint32_t, 1> Weights; -            Weights.push_back(BlockWeights[BB]); -            CI->setMetadata(LLVMContext::MD_prof, -                            MDB.createBranchWeights(Weights)); -          } +        if (!isa<CallInst>(I) && !isa<InvokeInst>(I)) +          continue; +        CallSite CS(&I); +        if (!CS.getCalledFunction()) { +          const DebugLoc &DLoc = I.getDebugLoc(); +          if (!DLoc) +            continue; +          const DILocation *DIL = DLoc; +          uint32_t LineOffset = getOffset(DIL); +          uint32_t Discriminator = DIL->getBaseDiscriminator(); + +          const FunctionSamples *FS = findFunctionSamples(I); +          if (!FS) +            continue; +          auto T = FS->findCallTargetMapAt(LineOffset, Discriminator); +          if (!T || T.get().size() == 0) +            continue; +          SmallVector<InstrProfValueData, 2> SortedCallTargets; +          uint64_t Sum = SortCallTargets(SortedCallTargets, T.get()); +          annotateValueSite(*I.getParent()->getParent()->getParent(), I, +                            SortedCallTargets, Sum, IPVK_IndirectCallTarget, +                            SortedCallTargets.size()); +        } else if (!dyn_cast<IntrinsicInst>(&I)) { +          SmallVector<uint32_t, 1> Weights; +          Weights.push_back(BlockWeights[BB]); +          I.setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));          }        }      } @@ -1115,9 +1209,13 @@ void SampleProfileLoader::propagateWeights(Function &F) {        }      } +    uint64_t TempWeight;      // Only set weights if there is at least one non-zero weight.      // In any other case, let the analyzer set weights. -    if (MaxWeight > 0) { +    // Do not set weights if the weights are present. In ThinLTO, the profile +    // annotation is done twice. If the first annotation already set the +    // weights, the second pass does not need to set it. +    if (MaxWeight > 0 && !TI->extractProfTotalWeight(TempWeight)) {        DEBUG(dbgs() << "SUCCESS. Found non-zero weights.\n");        TI->setMetadata(llvm::LLVMContext::MD_prof,                        MDB.createBranchWeights(Weights)); @@ -1228,12 +1326,19 @@ bool SampleProfileLoader::emitAnnotations(Function &F) {    DEBUG(dbgs() << "Line number for the first instruction in " << F.getName()                 << ": " << getFunctionLoc(F) << "\n"); -  Changed |= inlineHotFunctions(F); +  DenseSet<GlobalValue::GUID> ImportGUIDs; +  Changed |= inlineHotFunctions(F, ImportGUIDs);    // Compute basic block weights.    Changed |= computeBlockWeights(F);    if (Changed) { +    // Add an entry count to the function using the samples gathered at the +    // function entry. Also sets the GUIDs that comes from a different +    // module but inlined in the profiled binary. This is aiming at making +    // the IR match the profiled binary before annotation. +    F.setEntryCount(Samples->getHeadSamples() + 1, &ImportGUIDs); +      // Compute dominance and loop info needed for propagation.      computeDominanceAndLoopInfo(F); @@ -1329,7 +1434,7 @@ bool SampleProfileLoaderLegacyPass::runOnModule(Module &M) {  bool SampleProfileLoader::runOnFunction(Function &F) {    F.setEntryCount(0);    Samples = Reader->getSamplesFor(F); -  if (!Samples->empty()) +  if (Samples && !Samples->empty())      return emitAnnotations(F);    return false;  } diff --git a/lib/Transforms/IPO/StripSymbols.cpp b/lib/Transforms/IPO/StripSymbols.cpp index 8f6f161428e8..fb64367eef91 100644 --- a/lib/Transforms/IPO/StripSymbols.cpp +++ b/lib/Transforms/IPO/StripSymbols.cpp @@ -323,6 +323,14 @@ bool StripDeadDebugInfo::runOnModule(Module &M) {        LiveGVs.insert(GVE);    } +  std::set<DICompileUnit *> LiveCUs; +  // Any CU referenced from a subprogram is live. +  for (DISubprogram *SP : F.subprograms()) { +    if (SP->getUnit()) +      LiveCUs.insert(SP->getUnit()); +  } + +  bool HasDeadCUs = false;    for (DICompileUnit *DIC : F.compile_units()) {      // Create our live global variable list.      bool GlobalVariableChange = false; @@ -341,6 +349,11 @@ bool StripDeadDebugInfo::runOnModule(Module &M) {          GlobalVariableChange = true;      } +    if (!LiveGlobalVariables.empty()) +      LiveCUs.insert(DIC); +    else if (!LiveCUs.count(DIC)) +      HasDeadCUs = true; +      // If we found dead global variables, replace the current global      // variable list with our new live global variable list.      if (GlobalVariableChange) { @@ -352,5 +365,16 @@ bool StripDeadDebugInfo::runOnModule(Module &M) {      LiveGlobalVariables.clear();    } +  if (HasDeadCUs) { +    // Delete the old node and replace it with a new one +    NamedMDNode *NMD = M.getOrInsertNamedMetadata("llvm.dbg.cu"); +    NMD->clearOperands(); +    if (!LiveCUs.empty()) { +      for (DICompileUnit *CU : LiveCUs) +        NMD->addOperand(CU); +    } +    Changed = true; +  } +    return Changed;  } diff --git a/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp b/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp index 3680cfc813a1..65deb82cd2a5 100644 --- a/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp +++ b/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp @@ -14,16 +14,21 @@  //  //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO.h" +#include "llvm/Analysis/BasicAliasAnalysis.h"  #include "llvm/Analysis/ModuleSummaryAnalysis.h"  #include "llvm/Analysis/TypeMetadataUtils.h"  #include "llvm/Bitcode/BitcodeWriter.h"  #include "llvm/IR/Constants.h" +#include "llvm/IR/DebugInfo.h"  #include "llvm/IR/Intrinsics.h"  #include "llvm/IR/Module.h"  #include "llvm/IR/PassManager.h"  #include "llvm/Pass.h" +#include "llvm/Support/FileSystem.h"  #include "llvm/Support/ScopedPrinter.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/FunctionAttrs.h"  #include "llvm/Transforms/Utils/Cloning.h"  using namespace llvm; @@ -41,23 +46,14 @@ namespace {  std::string getModuleId(Module *M) {    MD5 Md5;    bool ExportsSymbols = false; -  auto AddGlobal = [&](GlobalValue &GV) { +  for (auto &GV : M->global_values()) {      if (GV.isDeclaration() || GV.getName().startswith("llvm.") ||          !GV.hasExternalLinkage()) -      return; +      continue;      ExportsSymbols = true;      Md5.update(GV.getName());      Md5.update(ArrayRef<uint8_t>{0}); -  }; - -  for (auto &F : *M) -    AddGlobal(F); -  for (auto &GV : M->globals()) -    AddGlobal(GV); -  for (auto &GA : M->aliases()) -    AddGlobal(GA); -  for (auto &IF : M->ifuncs()) -    AddGlobal(IF); +  }    if (!ExportsSymbols)      return ""; @@ -73,15 +69,21 @@ std::string getModuleId(Module *M) {  // Promote each local-linkage entity defined by ExportM and used by ImportM by  // changing visibility and appending the given ModuleId.  void promoteInternals(Module &ExportM, Module &ImportM, StringRef ModuleId) { -  auto PromoteInternal = [&](GlobalValue &ExportGV) { +  DenseMap<const Comdat *, Comdat *> RenamedComdats; +  for (auto &ExportGV : ExportM.global_values()) {      if (!ExportGV.hasLocalLinkage()) -      return; +      continue; -    GlobalValue *ImportGV = ImportM.getNamedValue(ExportGV.getName()); +    auto Name = ExportGV.getName(); +    GlobalValue *ImportGV = ImportM.getNamedValue(Name);      if (!ImportGV || ImportGV->use_empty()) -      return; +      continue; + +    std::string NewName = (Name + ModuleId).str(); -    std::string NewName = (ExportGV.getName() + ModuleId).str(); +    if (const auto *C = ExportGV.getComdat()) +      if (C->getName() == Name) +        RenamedComdats.try_emplace(C, ExportM.getOrInsertComdat(NewName));      ExportGV.setName(NewName);      ExportGV.setLinkage(GlobalValue::ExternalLinkage); @@ -89,16 +91,15 @@ void promoteInternals(Module &ExportM, Module &ImportM, StringRef ModuleId) {      ImportGV->setName(NewName);      ImportGV->setVisibility(GlobalValue::HiddenVisibility); -  }; +  } -  for (auto &F : ExportM) -    PromoteInternal(F); -  for (auto &GV : ExportM.globals()) -    PromoteInternal(GV); -  for (auto &GA : ExportM.aliases()) -    PromoteInternal(GA); -  for (auto &IF : ExportM.ifuncs()) -    PromoteInternal(IF); +  if (!RenamedComdats.empty()) +    for (auto &GO : ExportM.global_objects()) +      if (auto *C = GO.getComdat()) { +        auto Replacement = RenamedComdats.find(C); +        if (Replacement != RenamedComdats.end()) +          GO.setComdat(Replacement->second); +      }  }  // Promote all internal (i.e. distinct) type ids used by the module by replacing @@ -194,24 +195,7 @@ void simplifyExternals(Module &M) {  }  void filterModule( -    Module *M, std::function<bool(const GlobalValue *)> ShouldKeepDefinition) { -  for (Function &F : *M) { -    if (ShouldKeepDefinition(&F)) -      continue; - -    F.deleteBody(); -    F.clearMetadata(); -  } - -  for (GlobalVariable &GV : M->globals()) { -    if (ShouldKeepDefinition(&GV)) -      continue; - -    GV.setInitializer(nullptr); -    GV.setLinkage(GlobalValue::ExternalLinkage); -    GV.clearMetadata(); -  } - +    Module *M, function_ref<bool(const GlobalValue *)> ShouldKeepDefinition) {    for (Module::alias_iterator I = M->alias_begin(), E = M->alias_end();         I != E;) {      GlobalAlias *GA = &*I++; @@ -219,7 +203,7 @@ void filterModule(        continue;      GlobalObject *GO; -    if (I->getValueType()->isFunctionTy()) +    if (GA->getValueType()->isFunctionTy())        GO = Function::Create(cast<FunctionType>(GA->getValueType()),                              GlobalValue::ExternalLinkage, "", M);      else @@ -231,53 +215,168 @@ void filterModule(      GA->replaceAllUsesWith(GO);      GA->eraseFromParent();    } + +  for (Function &F : *M) { +    if (ShouldKeepDefinition(&F)) +      continue; + +    F.deleteBody(); +    F.setComdat(nullptr); +    F.clearMetadata(); +  } + +  for (GlobalVariable &GV : M->globals()) { +    if (ShouldKeepDefinition(&GV)) +      continue; + +    GV.setInitializer(nullptr); +    GV.setLinkage(GlobalValue::ExternalLinkage); +    GV.setComdat(nullptr); +    GV.clearMetadata(); +  } +} + +void forEachVirtualFunction(Constant *C, function_ref<void(Function *)> Fn) { +  if (auto *F = dyn_cast<Function>(C)) +    return Fn(F); +  if (isa<GlobalValue>(C)) +    return; +  for (Value *Op : C->operands()) +    forEachVirtualFunction(cast<Constant>(Op), Fn);  }  // If it's possible to split M into regular and thin LTO parts, do so and write  // a multi-module bitcode file with the two parts to OS. Otherwise, write only a  // regular LTO bitcode file to OS. -void splitAndWriteThinLTOBitcode(raw_ostream &OS, Module &M) { +void splitAndWriteThinLTOBitcode( +    raw_ostream &OS, raw_ostream *ThinLinkOS, +    function_ref<AAResults &(Function &)> AARGetter, Module &M) {    std::string ModuleId = getModuleId(&M);    if (ModuleId.empty()) {      // We couldn't generate a module ID for this module, just write it out as a      // regular LTO module.      WriteBitcodeToFile(&M, OS); +    if (ThinLinkOS) +      // We don't have a ThinLTO part, but still write the module to the +      // ThinLinkOS if requested so that the expected output file is produced. +      WriteBitcodeToFile(&M, *ThinLinkOS);      return;    }    promoteTypeIds(M, ModuleId); -  auto IsInMergedM = [&](const GlobalValue *GV) { -    auto *GVar = dyn_cast<GlobalVariable>(GV->getBaseObject()); -    if (!GVar) -      return false; - +  // Returns whether a global has attached type metadata. Such globals may +  // participate in CFI or whole-program devirtualization, so they need to +  // appear in the merged module instead of the thin LTO module. +  auto HasTypeMetadata = [&](const GlobalObject *GO) {      SmallVector<MDNode *, 1> MDs; -    GVar->getMetadata(LLVMContext::MD_type, MDs); +    GO->getMetadata(LLVMContext::MD_type, MDs);      return !MDs.empty();    }; +  // Collect the set of virtual functions that are eligible for virtual constant +  // propagation. Each eligible function must not access memory, must return +  // an integer of width <=64 bits, must take at least one argument, must not +  // use its first argument (assumed to be "this") and all arguments other than +  // the first one must be of <=64 bit integer type. +  // +  // Note that we test whether this copy of the function is readnone, rather +  // than testing function attributes, which must hold for any copy of the +  // function, even a less optimized version substituted at link time. This is +  // sound because the virtual constant propagation optimizations effectively +  // inline all implementations of the virtual function into each call site, +  // rather than using function attributes to perform local optimization. +  std::set<const Function *> EligibleVirtualFns; +  // If any member of a comdat lives in MergedM, put all members of that +  // comdat in MergedM to keep the comdat together. +  DenseSet<const Comdat *> MergedMComdats; +  for (GlobalVariable &GV : M.globals()) +    if (HasTypeMetadata(&GV)) { +      if (const auto *C = GV.getComdat()) +        MergedMComdats.insert(C); +      forEachVirtualFunction(GV.getInitializer(), [&](Function *F) { +        auto *RT = dyn_cast<IntegerType>(F->getReturnType()); +        if (!RT || RT->getBitWidth() > 64 || F->arg_empty() || +            !F->arg_begin()->use_empty()) +          return; +        for (auto &Arg : make_range(std::next(F->arg_begin()), F->arg_end())) { +          auto *ArgT = dyn_cast<IntegerType>(Arg.getType()); +          if (!ArgT || ArgT->getBitWidth() > 64) +            return; +        } +        if (computeFunctionBodyMemoryAccess(*F, AARGetter(*F)) == MAK_ReadNone) +          EligibleVirtualFns.insert(F); +      }); +    } +    ValueToValueMapTy VMap; -  std::unique_ptr<Module> MergedM(CloneModule(&M, VMap, IsInMergedM)); +  std::unique_ptr<Module> MergedM( +      CloneModule(&M, VMap, [&](const GlobalValue *GV) -> bool { +        if (const auto *C = GV->getComdat()) +          if (MergedMComdats.count(C)) +            return true; +        if (auto *F = dyn_cast<Function>(GV)) +          return EligibleVirtualFns.count(F); +        if (auto *GVar = dyn_cast_or_null<GlobalVariable>(GV->getBaseObject())) +          return HasTypeMetadata(GVar); +        return false; +      })); +  StripDebugInfo(*MergedM); + +  for (Function &F : *MergedM) +    if (!F.isDeclaration()) { +      // Reset the linkage of all functions eligible for virtual constant +      // propagation. The canonical definitions live in the thin LTO module so +      // that they can be imported. +      F.setLinkage(GlobalValue::AvailableExternallyLinkage); +      F.setComdat(nullptr); +    } -  filterModule(&M, [&](const GlobalValue *GV) { return !IsInMergedM(GV); }); +  // Remove all globals with type metadata, globals with comdats that live in +  // MergedM, and aliases pointing to such globals from the thin LTO module. +  filterModule(&M, [&](const GlobalValue *GV) { +    if (auto *GVar = dyn_cast_or_null<GlobalVariable>(GV->getBaseObject())) +      if (HasTypeMetadata(GVar)) +        return false; +    if (const auto *C = GV->getComdat()) +      if (MergedMComdats.count(C)) +        return false; +    return true; +  });    promoteInternals(*MergedM, M, ModuleId);    promoteInternals(M, *MergedM, ModuleId);    simplifyExternals(*MergedM); -  SmallVector<char, 0> Buffer; -  BitcodeWriter W(Buffer);    // FIXME: Try to re-use BSI and PFI from the original module here.    ModuleSummaryIndex Index = buildModuleSummaryIndex(M, nullptr, nullptr); -  W.writeModule(&M, /*ShouldPreserveUseListOrder=*/false, &Index, -                /*GenerateHash=*/true); -  W.writeModule(MergedM.get()); +  SmallVector<char, 0> Buffer; +  BitcodeWriter W(Buffer); +  // Save the module hash produced for the full bitcode, which will +  // be used in the backends, and use that in the minimized bitcode +  // produced for the full link. +  ModuleHash ModHash = {{0}}; +  W.writeModule(&M, /*ShouldPreserveUseListOrder=*/false, &Index, +                /*GenerateHash=*/true, &ModHash); +  W.writeModule(MergedM.get());    OS << Buffer; + +  // If a minimized bitcode module was requested for the thin link, +  // strip the debug info (the merged module was already stripped above) +  // and write it to the given OS. +  if (ThinLinkOS) { +    Buffer.clear(); +    BitcodeWriter W2(Buffer); +    StripDebugInfo(M); +    W2.writeModule(&M, /*ShouldPreserveUseListOrder=*/false, &Index, +                   /*GenerateHash=*/false, &ModHash); +    W2.writeModule(MergedM.get()); +    *ThinLinkOS << Buffer; +  }  }  // Returns whether this module needs to be split because it uses type metadata. @@ -292,28 +391,45 @@ bool requiresSplit(Module &M) {    return false;  } -void writeThinLTOBitcode(raw_ostream &OS, Module &M, -                         const ModuleSummaryIndex *Index) { +void writeThinLTOBitcode(raw_ostream &OS, raw_ostream *ThinLinkOS, +                         function_ref<AAResults &(Function &)> AARGetter, +                         Module &M, const ModuleSummaryIndex *Index) {    // See if this module has any type metadata. If so, we need to split it.    if (requiresSplit(M)) -    return splitAndWriteThinLTOBitcode(OS, M); +    return splitAndWriteThinLTOBitcode(OS, ThinLinkOS, AARGetter, M);    // Otherwise we can just write it out as a regular module. + +  // Save the module hash produced for the full bitcode, which will +  // be used in the backends, and use that in the minimized bitcode +  // produced for the full link. +  ModuleHash ModHash = {{0}};    WriteBitcodeToFile(&M, OS, /*ShouldPreserveUseListOrder=*/false, Index, -                     /*GenerateHash=*/true); +                     /*GenerateHash=*/true, &ModHash); +  // If a minimized bitcode module was requested for the thin link, +  // strip the debug info and write it to the given OS. +  if (ThinLinkOS) { +    StripDebugInfo(M); +    WriteBitcodeToFile(&M, *ThinLinkOS, /*ShouldPreserveUseListOrder=*/false, +                       Index, +                       /*GenerateHash=*/false, &ModHash); +  }  }  class WriteThinLTOBitcode : public ModulePass {    raw_ostream &OS; // raw_ostream to print on +  // The output stream on which to emit a minimized module for use +  // just in the thin link, if requested. +  raw_ostream *ThinLinkOS;  public:    static char ID; // Pass identification, replacement for typeid -  WriteThinLTOBitcode() : ModulePass(ID), OS(dbgs()) { +  WriteThinLTOBitcode() : ModulePass(ID), OS(dbgs()), ThinLinkOS(nullptr) {      initializeWriteThinLTOBitcodePass(*PassRegistry::getPassRegistry());    } -  explicit WriteThinLTOBitcode(raw_ostream &o) -      : ModulePass(ID), OS(o) { +  explicit WriteThinLTOBitcode(raw_ostream &o, raw_ostream *ThinLinkOS) +      : ModulePass(ID), OS(o), ThinLinkOS(ThinLinkOS) {      initializeWriteThinLTOBitcodePass(*PassRegistry::getPassRegistry());    } @@ -322,12 +438,14 @@ public:    bool runOnModule(Module &M) override {      const ModuleSummaryIndex *Index =          &(getAnalysis<ModuleSummaryIndexWrapperPass>().getIndex()); -    writeThinLTOBitcode(OS, M, Index); +    writeThinLTOBitcode(OS, ThinLinkOS, LegacyAARGetter(*this), M, Index);      return true;    }    void getAnalysisUsage(AnalysisUsage &AU) const override {      AU.setPreservesAll(); +    AU.addRequired<AssumptionCacheTracker>();      AU.addRequired<ModuleSummaryIndexWrapperPass>(); +    AU.addRequired<TargetLibraryInfoWrapperPass>();    }  };  } // anonymous namespace @@ -335,10 +453,13 @@ public:  char WriteThinLTOBitcode::ID = 0;  INITIALIZE_PASS_BEGIN(WriteThinLTOBitcode, "write-thinlto-bitcode",                        "Write ThinLTO Bitcode", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)  INITIALIZE_PASS_DEPENDENCY(ModuleSummaryIndexWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)  INITIALIZE_PASS_END(WriteThinLTOBitcode, "write-thinlto-bitcode",                      "Write ThinLTO Bitcode", false, true) -ModulePass *llvm::createWriteThinLTOBitcodePass(raw_ostream &Str) { -  return new WriteThinLTOBitcode(Str); +ModulePass *llvm::createWriteThinLTOBitcodePass(raw_ostream &Str, +                                                raw_ostream *ThinLinkOS) { +  return new WriteThinLTOBitcode(Str, ThinLinkOS);  } diff --git a/lib/Transforms/IPO/WholeProgramDevirt.cpp b/lib/Transforms/IPO/WholeProgramDevirt.cpp index 844cc0f70eed..cb7d487b68b0 100644 --- a/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -25,6 +25,20 @@  //   returns 0, or a single vtable's function returns 1, replace each virtual  //   call with a comparison of the vptr against that vtable's address.  // +// This pass is intended to be used during the regular and thin LTO pipelines. +// During regular LTO, the pass determines the best optimization for each +// virtual call and applies the resolutions directly to virtual calls that are +// eligible for virtual call optimization (i.e. calls that use either of the +// llvm.assume(llvm.type.test) or llvm.type.checked.load intrinsics). During +// ThinLTO, the pass operates in two phases: +// - Export phase: this is run during the thin link over a single merged module +//   that contains all vtables with !type metadata that participate in the link. +//   The pass computes a resolution for each virtual call and stores it in the +//   type identifier summary. +// - Import phase: this is run during the thin backends over the individual +//   modules. The pass applies the resolutions previously computed during the +//   import phase to each eligible virtual call. +//  //===----------------------------------------------------------------------===//  #include "llvm/Transforms/IPO/WholeProgramDevirt.h" @@ -35,6 +49,8 @@  #include "llvm/ADT/iterator_range.h"  #include "llvm/ADT/MapVector.h"  #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/BasicAliasAnalysis.h"  #include "llvm/Analysis/TypeMetadataUtils.h"  #include "llvm/IR/CallSite.h"  #include "llvm/IR/Constants.h" @@ -54,12 +70,16 @@  #include "llvm/IR/LLVMContext.h"  #include "llvm/IR/Metadata.h"  #include "llvm/IR/Module.h" +#include "llvm/IR/ModuleSummaryIndexYAML.h"  #include "llvm/Pass.h"  #include "llvm/PassRegistry.h"  #include "llvm/PassSupport.h"  #include "llvm/Support/Casting.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem.h"  #include "llvm/Support/MathExtras.h"  #include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/FunctionAttrs.h"  #include "llvm/Transforms/Utils/Evaluator.h"  #include <algorithm>  #include <cstddef> @@ -72,6 +92,26 @@ using namespace wholeprogramdevirt;  #define DEBUG_TYPE "wholeprogramdevirt" +static cl::opt<PassSummaryAction> ClSummaryAction( +    "wholeprogramdevirt-summary-action", +    cl::desc("What to do with the summary when running this pass"), +    cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"), +               clEnumValN(PassSummaryAction::Import, "import", +                          "Import typeid resolutions from summary and globals"), +               clEnumValN(PassSummaryAction::Export, "export", +                          "Export typeid resolutions to summary and globals")), +    cl::Hidden); + +static cl::opt<std::string> ClReadSummary( +    "wholeprogramdevirt-read-summary", +    cl::desc("Read summary from given YAML file before running pass"), +    cl::Hidden); + +static cl::opt<std::string> ClWriteSummary( +    "wholeprogramdevirt-write-summary", +    cl::desc("Write summary to given YAML file after running pass"), +    cl::Hidden); +  // Find the minimum offset that we may store a value of size Size bits at. If  // IsAfter is set, look for an offset before the object, otherwise look for an  // offset after the object. @@ -259,15 +299,92 @@ struct VirtualCallSite {    }  }; +// Call site information collected for a specific VTableSlot and possibly a list +// of constant integer arguments. The grouping by arguments is handled by the +// VTableSlotInfo class. +struct CallSiteInfo { +  /// The set of call sites for this slot. Used during regular LTO and the +  /// import phase of ThinLTO (as well as the export phase of ThinLTO for any +  /// call sites that appear in the merged module itself); in each of these +  /// cases we are directly operating on the call sites at the IR level. +  std::vector<VirtualCallSite> CallSites; + +  // These fields are used during the export phase of ThinLTO and reflect +  // information collected from function summaries. + +  /// Whether any function summary contains an llvm.assume(llvm.type.test) for +  /// this slot. +  bool SummaryHasTypeTestAssumeUsers; + +  /// CFI-specific: a vector containing the list of function summaries that use +  /// the llvm.type.checked.load intrinsic and therefore will require +  /// resolutions for llvm.type.test in order to implement CFI checks if +  /// devirtualization was unsuccessful. If devirtualization was successful, the +  /// pass will clear this vector by calling markDevirt(). If at the end of the +  /// pass the vector is non-empty, we will need to add a use of llvm.type.test +  /// to each of the function summaries in the vector. +  std::vector<FunctionSummary *> SummaryTypeCheckedLoadUsers; + +  bool isExported() const { +    return SummaryHasTypeTestAssumeUsers || +           !SummaryTypeCheckedLoadUsers.empty(); +  } + +  /// As explained in the comment for SummaryTypeCheckedLoadUsers. +  void markDevirt() { SummaryTypeCheckedLoadUsers.clear(); } +}; + +// Call site information collected for a specific VTableSlot. +struct VTableSlotInfo { +  // The set of call sites which do not have all constant integer arguments +  // (excluding "this"). +  CallSiteInfo CSInfo; + +  // The set of call sites with all constant integer arguments (excluding +  // "this"), grouped by argument list. +  std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo; + +  void addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses); + +private: +  CallSiteInfo &findCallSiteInfo(CallSite CS); +}; + +CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) { +  std::vector<uint64_t> Args; +  auto *CI = dyn_cast<IntegerType>(CS.getType()); +  if (!CI || CI->getBitWidth() > 64 || CS.arg_empty()) +    return CSInfo; +  for (auto &&Arg : make_range(CS.arg_begin() + 1, CS.arg_end())) { +    auto *CI = dyn_cast<ConstantInt>(Arg); +    if (!CI || CI->getBitWidth() > 64) +      return CSInfo; +    Args.push_back(CI->getZExtValue()); +  } +  return ConstCSInfo[Args]; +} + +void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS, +                                 unsigned *NumUnsafeUses) { +  findCallSiteInfo(CS).CallSites.push_back({VTable, CS, NumUnsafeUses}); +} +  struct DevirtModule {    Module &M; +  function_ref<AAResults &(Function &)> AARGetter; + +  ModuleSummaryIndex *ExportSummary; +  const ModuleSummaryIndex *ImportSummary; +    IntegerType *Int8Ty;    PointerType *Int8PtrTy;    IntegerType *Int32Ty; +  IntegerType *Int64Ty; +  IntegerType *IntPtrTy;    bool RemarksEnabled; -  MapVector<VTableSlot, std::vector<VirtualCallSite>> CallSlots; +  MapVector<VTableSlot, VTableSlotInfo> CallSlots;    // This map keeps track of the number of "unsafe" uses of a loaded function    // pointer. The key is the associated llvm.type.test intrinsic call generated @@ -279,11 +396,18 @@ struct DevirtModule {    // true.    std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest; -  DevirtModule(Module &M) -      : M(M), Int8Ty(Type::getInt8Ty(M.getContext())), +  DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter, +               ModuleSummaryIndex *ExportSummary, +               const ModuleSummaryIndex *ImportSummary) +      : M(M), AARGetter(AARGetter), ExportSummary(ExportSummary), +        ImportSummary(ImportSummary), Int8Ty(Type::getInt8Ty(M.getContext())),          Int8PtrTy(Type::getInt8PtrTy(M.getContext())),          Int32Ty(Type::getInt32Ty(M.getContext())), -        RemarksEnabled(areRemarksEnabled()) {} +        Int64Ty(Type::getInt64Ty(M.getContext())), +        IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)), +        RemarksEnabled(areRemarksEnabled()) { +    assert(!(ExportSummary && ImportSummary)); +  }    bool areRemarksEnabled(); @@ -298,57 +422,169 @@ struct DevirtModule {    tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot,                              const std::set<TypeMemberInfo> &TypeMemberInfos,                              uint64_t ByteOffset); + +  void applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn, +                             bool &IsExported);    bool trySingleImplDevirt(MutableArrayRef<VirtualCallTarget> TargetsForSlot, -                           MutableArrayRef<VirtualCallSite> CallSites); +                           VTableSlotInfo &SlotInfo, +                           WholeProgramDevirtResolution *Res); +    bool tryEvaluateFunctionsWithArgs(        MutableArrayRef<VirtualCallTarget> TargetsForSlot, -      ArrayRef<ConstantInt *> Args); -  bool tryUniformRetValOpt(IntegerType *RetType, -                           MutableArrayRef<VirtualCallTarget> TargetsForSlot, -                           MutableArrayRef<VirtualCallSite> CallSites); +      ArrayRef<uint64_t> Args); + +  void applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, +                             uint64_t TheRetVal); +  bool tryUniformRetValOpt(MutableArrayRef<VirtualCallTarget> TargetsForSlot, +                           CallSiteInfo &CSInfo, +                           WholeProgramDevirtResolution::ByArg *Res); + +  // Returns the global symbol name that is used to export information about the +  // given vtable slot and list of arguments. +  std::string getGlobalName(VTableSlot Slot, ArrayRef<uint64_t> Args, +                            StringRef Name); + +  // This function is called during the export phase to create a symbol +  // definition containing information about the given vtable slot and list of +  // arguments. +  void exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name, +                    Constant *C); + +  // This function is called during the import phase to create a reference to +  // the symbol definition created during the export phase. +  Constant *importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, +                         StringRef Name, unsigned AbsWidth = 0); + +  void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne, +                            Constant *UniqueMemberAddr);    bool tryUniqueRetValOpt(unsigned BitWidth,                            MutableArrayRef<VirtualCallTarget> TargetsForSlot, -                          MutableArrayRef<VirtualCallSite> CallSites); +                          CallSiteInfo &CSInfo, +                          WholeProgramDevirtResolution::ByArg *Res, +                          VTableSlot Slot, ArrayRef<uint64_t> Args); + +  void applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, +                             Constant *Byte, Constant *Bit);    bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot, -                           ArrayRef<VirtualCallSite> CallSites); +                           VTableSlotInfo &SlotInfo, +                           WholeProgramDevirtResolution *Res, VTableSlot Slot);    void rebuildGlobal(VTableBits &B); +  // Apply the summary resolution for Slot to all virtual calls in SlotInfo. +  void importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo); + +  // If we were able to eliminate all unsafe uses for a type checked load, +  // eliminate the associated type tests by replacing them with true. +  void removeRedundantTypeTests(); +    bool run(); + +  // Lower the module using the action and summary passed as command line +  // arguments. For testing purposes only. +  static bool runForTesting(Module &M, +                            function_ref<AAResults &(Function &)> AARGetter);  };  struct WholeProgramDevirt : public ModulePass {    static char ID; -  WholeProgramDevirt() : ModulePass(ID) { +  bool UseCommandLine = false; + +  ModuleSummaryIndex *ExportSummary; +  const ModuleSummaryIndex *ImportSummary; + +  WholeProgramDevirt() : ModulePass(ID), UseCommandLine(true) { +    initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry()); +  } + +  WholeProgramDevirt(ModuleSummaryIndex *ExportSummary, +                     const ModuleSummaryIndex *ImportSummary) +      : ModulePass(ID), ExportSummary(ExportSummary), +        ImportSummary(ImportSummary) {      initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry());    }    bool runOnModule(Module &M) override {      if (skipModule(M))        return false; +    if (UseCommandLine) +      return DevirtModule::runForTesting(M, LegacyAARGetter(*this)); +    return DevirtModule(M, LegacyAARGetter(*this), ExportSummary, ImportSummary) +        .run(); +  } -    return DevirtModule(M).run(); +  void getAnalysisUsage(AnalysisUsage &AU) const override { +    AU.addRequired<AssumptionCacheTracker>(); +    AU.addRequired<TargetLibraryInfoWrapperPass>();    }  };  } // end anonymous namespace -INITIALIZE_PASS(WholeProgramDevirt, "wholeprogramdevirt", -                "Whole program devirtualization", false, false) +INITIALIZE_PASS_BEGIN(WholeProgramDevirt, "wholeprogramdevirt", +                      "Whole program devirtualization", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(WholeProgramDevirt, "wholeprogramdevirt", +                    "Whole program devirtualization", false, false)  char WholeProgramDevirt::ID = 0; -ModulePass *llvm::createWholeProgramDevirtPass() { -  return new WholeProgramDevirt; +ModulePass * +llvm::createWholeProgramDevirtPass(ModuleSummaryIndex *ExportSummary, +                                   const ModuleSummaryIndex *ImportSummary) { +  return new WholeProgramDevirt(ExportSummary, ImportSummary);  }  PreservedAnalyses WholeProgramDevirtPass::run(Module &M, -                                              ModuleAnalysisManager &) { -  if (!DevirtModule(M).run()) +                                              ModuleAnalysisManager &AM) { +  auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); +  auto AARGetter = [&](Function &F) -> AAResults & { +    return FAM.getResult<AAManager>(F); +  }; +  if (!DevirtModule(M, AARGetter, nullptr, nullptr).run())      return PreservedAnalyses::all();    return PreservedAnalyses::none();  } +bool DevirtModule::runForTesting( +    Module &M, function_ref<AAResults &(Function &)> AARGetter) { +  ModuleSummaryIndex Summary; + +  // Handle the command-line summary arguments. This code is for testing +  // purposes only, so we handle errors directly. +  if (!ClReadSummary.empty()) { +    ExitOnError ExitOnErr("-wholeprogramdevirt-read-summary: " + ClReadSummary + +                          ": "); +    auto ReadSummaryFile = +        ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary))); + +    yaml::Input In(ReadSummaryFile->getBuffer()); +    In >> Summary; +    ExitOnErr(errorCodeToError(In.error())); +  } + +  bool Changed = +      DevirtModule( +          M, AARGetter, +          ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr, +          ClSummaryAction == PassSummaryAction::Import ? &Summary : nullptr) +          .run(); + +  if (!ClWriteSummary.empty()) { +    ExitOnError ExitOnErr( +        "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": "); +    std::error_code EC; +    raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::F_Text); +    ExitOnErr(errorCodeToError(EC)); + +    yaml::Output Out(OS); +    Out << Summary; +  } + +  return Changed; +} +  void DevirtModule::buildTypeIdentifierMap(      std::vector<VTableBits> &Bits,      DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) { @@ -443,9 +679,31 @@ bool DevirtModule::tryFindVirtualCallTargets(    return !TargetsForSlot.empty();  } +void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, +                                         Constant *TheFn, bool &IsExported) { +  auto Apply = [&](CallSiteInfo &CSInfo) { +    for (auto &&VCallSite : CSInfo.CallSites) { +      if (RemarksEnabled) +        VCallSite.emitRemark("single-impl", TheFn->getName()); +      VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast( +          TheFn, VCallSite.CS.getCalledValue()->getType())); +      // This use is no longer unsafe. +      if (VCallSite.NumUnsafeUses) +        --*VCallSite.NumUnsafeUses; +    } +    if (CSInfo.isExported()) { +      IsExported = true; +      CSInfo.markDevirt(); +    } +  }; +  Apply(SlotInfo.CSInfo); +  for (auto &P : SlotInfo.ConstCSInfo) +    Apply(P.second); +} +  bool DevirtModule::trySingleImplDevirt(      MutableArrayRef<VirtualCallTarget> TargetsForSlot, -    MutableArrayRef<VirtualCallSite> CallSites) { +    VTableSlotInfo &SlotInfo, WholeProgramDevirtResolution *Res) {    // See if the program contains a single implementation of this virtual    // function.    Function *TheFn = TargetsForSlot[0].Fn; @@ -453,39 +711,51 @@ bool DevirtModule::trySingleImplDevirt(      if (TheFn != Target.Fn)        return false; +  // If so, update each call site to call that implementation directly.    if (RemarksEnabled)      TargetsForSlot[0].WasDevirt = true; -  // If so, update each call site to call that implementation directly. -  for (auto &&VCallSite : CallSites) { -    if (RemarksEnabled) -      VCallSite.emitRemark("single-impl", TheFn->getName()); -    VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast( -        TheFn, VCallSite.CS.getCalledValue()->getType())); -    // This use is no longer unsafe. -    if (VCallSite.NumUnsafeUses) -      --*VCallSite.NumUnsafeUses; + +  bool IsExported = false; +  applySingleImplDevirt(SlotInfo, TheFn, IsExported); +  if (!IsExported) +    return false; + +  // If the only implementation has local linkage, we must promote to external +  // to make it visible to thin LTO objects. We can only get here during the +  // ThinLTO export phase. +  if (TheFn->hasLocalLinkage()) { +    TheFn->setLinkage(GlobalValue::ExternalLinkage); +    TheFn->setVisibility(GlobalValue::HiddenVisibility); +    TheFn->setName(TheFn->getName() + "$merged");    } + +  Res->TheKind = WholeProgramDevirtResolution::SingleImpl; +  Res->SingleImplName = TheFn->getName(); +    return true;  }  bool DevirtModule::tryEvaluateFunctionsWithArgs(      MutableArrayRef<VirtualCallTarget> TargetsForSlot, -    ArrayRef<ConstantInt *> Args) { +    ArrayRef<uint64_t> Args) {    // Evaluate each function and store the result in each target's RetVal    // field.    for (VirtualCallTarget &Target : TargetsForSlot) {      if (Target.Fn->arg_size() != Args.size() + 1)        return false; -    for (unsigned I = 0; I != Args.size(); ++I) -      if (Target.Fn->getFunctionType()->getParamType(I + 1) != -          Args[I]->getType()) -        return false;      Evaluator Eval(M.getDataLayout(), nullptr);      SmallVector<Constant *, 2> EvalArgs;      EvalArgs.push_back(          Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0))); -    EvalArgs.insert(EvalArgs.end(), Args.begin(), Args.end()); +    for (unsigned I = 0; I != Args.size(); ++I) { +      auto *ArgTy = dyn_cast<IntegerType>( +          Target.Fn->getFunctionType()->getParamType(I + 1)); +      if (!ArgTy) +        return false; +      EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I])); +    } +      Constant *RetVal;      if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) ||          !isa<ConstantInt>(RetVal)) @@ -495,9 +765,18 @@ bool DevirtModule::tryEvaluateFunctionsWithArgs(    return true;  } +void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, +                                         uint64_t TheRetVal) { +  for (auto Call : CSInfo.CallSites) +    Call.replaceAndErase( +        "uniform-ret-val", FnName, RemarksEnabled, +        ConstantInt::get(cast<IntegerType>(Call.CS.getType()), TheRetVal)); +  CSInfo.markDevirt(); +} +  bool DevirtModule::tryUniformRetValOpt( -    IntegerType *RetType, MutableArrayRef<VirtualCallTarget> TargetsForSlot, -    MutableArrayRef<VirtualCallSite> CallSites) { +    MutableArrayRef<VirtualCallTarget> TargetsForSlot, CallSiteInfo &CSInfo, +    WholeProgramDevirtResolution::ByArg *Res) {    // Uniform return value optimization. If all functions return the same    // constant, replace all calls with that constant.    uint64_t TheRetVal = TargetsForSlot[0].RetVal; @@ -505,19 +784,77 @@ bool DevirtModule::tryUniformRetValOpt(      if (Target.RetVal != TheRetVal)        return false; -  auto TheRetValConst = ConstantInt::get(RetType, TheRetVal); -  for (auto Call : CallSites) -    Call.replaceAndErase("uniform-ret-val", TargetsForSlot[0].Fn->getName(), -                         RemarksEnabled, TheRetValConst); +  if (CSInfo.isExported()) { +    Res->TheKind = WholeProgramDevirtResolution::ByArg::UniformRetVal; +    Res->Info = TheRetVal; +  } + +  applyUniformRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), TheRetVal);    if (RemarksEnabled)      for (auto &&Target : TargetsForSlot)        Target.WasDevirt = true;    return true;  } +std::string DevirtModule::getGlobalName(VTableSlot Slot, +                                        ArrayRef<uint64_t> Args, +                                        StringRef Name) { +  std::string FullName = "__typeid_"; +  raw_string_ostream OS(FullName); +  OS << cast<MDString>(Slot.TypeID)->getString() << '_' << Slot.ByteOffset; +  for (uint64_t Arg : Args) +    OS << '_' << Arg; +  OS << '_' << Name; +  return OS.str(); +} + +void DevirtModule::exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, +                                StringRef Name, Constant *C) { +  GlobalAlias *GA = GlobalAlias::create(Int8Ty, 0, GlobalValue::ExternalLinkage, +                                        getGlobalName(Slot, Args, Name), C, &M); +  GA->setVisibility(GlobalValue::HiddenVisibility); +} + +Constant *DevirtModule::importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, +                                     StringRef Name, unsigned AbsWidth) { +  Constant *C = M.getOrInsertGlobal(getGlobalName(Slot, Args, Name), Int8Ty); +  auto *GV = dyn_cast<GlobalVariable>(C); +  // We only need to set metadata if the global is newly created, in which +  // case it would not have hidden visibility. +  if (!GV || GV->getVisibility() == GlobalValue::HiddenVisibility) +    return C; + +  GV->setVisibility(GlobalValue::HiddenVisibility); +  auto SetAbsRange = [&](uint64_t Min, uint64_t Max) { +    auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min)); +    auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max)); +    GV->setMetadata(LLVMContext::MD_absolute_symbol, +                    MDNode::get(M.getContext(), {MinC, MaxC})); +  }; +  if (AbsWidth == IntPtrTy->getBitWidth()) +    SetAbsRange(~0ull, ~0ull); // Full set. +  else if (AbsWidth) +    SetAbsRange(0, 1ull << AbsWidth); +  return GV; +} + +void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, +                                        bool IsOne, +                                        Constant *UniqueMemberAddr) { +  for (auto &&Call : CSInfo.CallSites) { +    IRBuilder<> B(Call.CS.getInstruction()); +    Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, +                              Call.VTable, UniqueMemberAddr); +    Cmp = B.CreateZExt(Cmp, Call.CS->getType()); +    Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, Cmp); +  } +  CSInfo.markDevirt(); +} +  bool DevirtModule::tryUniqueRetValOpt(      unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot, -    MutableArrayRef<VirtualCallSite> CallSites) { +    CallSiteInfo &CSInfo, WholeProgramDevirtResolution::ByArg *Res, +    VTableSlot Slot, ArrayRef<uint64_t> Args) {    // IsOne controls whether we look for a 0 or a 1.    auto tryUniqueRetValOptFor = [&](bool IsOne) {      const TypeMemberInfo *UniqueMember = nullptr; @@ -533,16 +870,23 @@ bool DevirtModule::tryUniqueRetValOpt(      // checked for a uniform return value in tryUniformRetValOpt.      assert(UniqueMember); -    // Replace each call with the comparison. -    for (auto &&Call : CallSites) { -      IRBuilder<> B(Call.CS.getInstruction()); -      Value *OneAddr = B.CreateBitCast(UniqueMember->Bits->GV, Int8PtrTy); -      OneAddr = B.CreateConstGEP1_64(OneAddr, UniqueMember->Offset); -      Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, -                                Call.VTable, OneAddr); -      Call.replaceAndErase("unique-ret-val", TargetsForSlot[0].Fn->getName(), -                           RemarksEnabled, Cmp); +    Constant *UniqueMemberAddr = +        ConstantExpr::getBitCast(UniqueMember->Bits->GV, Int8PtrTy); +    UniqueMemberAddr = ConstantExpr::getGetElementPtr( +        Int8Ty, UniqueMemberAddr, +        ConstantInt::get(Int64Ty, UniqueMember->Offset)); + +    if (CSInfo.isExported()) { +      Res->TheKind = WholeProgramDevirtResolution::ByArg::UniqueRetVal; +      Res->Info = IsOne; + +      exportGlobal(Slot, Args, "unique_member", UniqueMemberAddr);      } + +    // Replace each call with the comparison. +    applyUniqueRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), IsOne, +                         UniqueMemberAddr); +      // Update devirtualization statistics for targets.      if (RemarksEnabled)        for (auto &&Target : TargetsForSlot) @@ -560,9 +904,30 @@ bool DevirtModule::tryUniqueRetValOpt(    return false;  } +void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, +                                         Constant *Byte, Constant *Bit) { +  for (auto Call : CSInfo.CallSites) { +    auto *RetType = cast<IntegerType>(Call.CS.getType()); +    IRBuilder<> B(Call.CS.getInstruction()); +    Value *Addr = B.CreateGEP(Int8Ty, Call.VTable, Byte); +    if (RetType->getBitWidth() == 1) { +      Value *Bits = B.CreateLoad(Addr); +      Value *BitsAndBit = B.CreateAnd(Bits, Bit); +      auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0)); +      Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled, +                           IsBitSet); +    } else { +      Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo()); +      Value *Val = B.CreateLoad(RetType, ValAddr); +      Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled, Val); +    } +  } +  CSInfo.markDevirt(); +} +  bool DevirtModule::tryVirtualConstProp( -    MutableArrayRef<VirtualCallTarget> TargetsForSlot, -    ArrayRef<VirtualCallSite> CallSites) { +    MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, +    WholeProgramDevirtResolution *Res, VTableSlot Slot) {    // This only works if the function returns an integer.    auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType());    if (!RetType) @@ -571,55 +936,38 @@ bool DevirtModule::tryVirtualConstProp(    if (BitWidth > 64)      return false; -  // Make sure that each function does not access memory, takes at least one -  // argument, does not use its first argument (which we assume is 'this'), -  // and has the same return type. +  // Make sure that each function is defined, does not access memory, takes at +  // least one argument, does not use its first argument (which we assume is +  // 'this'), and has the same return type. +  // +  // Note that we test whether this copy of the function is readnone, rather +  // than testing function attributes, which must hold for any copy of the +  // function, even a less optimized version substituted at link time. This is +  // sound because the virtual constant propagation optimizations effectively +  // inline all implementations of the virtual function into each call site, +  // rather than using function attributes to perform local optimization.    for (VirtualCallTarget &Target : TargetsForSlot) { -    if (!Target.Fn->doesNotAccessMemory() || Target.Fn->arg_empty() || -        !Target.Fn->arg_begin()->use_empty() || +    if (Target.Fn->isDeclaration() || +        computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn)) != +            MAK_ReadNone || +        Target.Fn->arg_empty() || !Target.Fn->arg_begin()->use_empty() ||          Target.Fn->getReturnType() != RetType)        return false;    } -  // Group call sites by the list of constant arguments they pass. -  // The comparator ensures deterministic ordering. -  struct ByAPIntValue { -    bool operator()(const std::vector<ConstantInt *> &A, -                    const std::vector<ConstantInt *> &B) const { -      return std::lexicographical_compare( -          A.begin(), A.end(), B.begin(), B.end(), -          [](ConstantInt *AI, ConstantInt *BI) { -            return AI->getValue().ult(BI->getValue()); -          }); -    } -  }; -  std::map<std::vector<ConstantInt *>, std::vector<VirtualCallSite>, -           ByAPIntValue> -      VCallSitesByConstantArg; -  for (auto &&VCallSite : CallSites) { -    std::vector<ConstantInt *> Args; -    if (VCallSite.CS.getType() != RetType) -      continue; -    for (auto &&Arg : -         make_range(VCallSite.CS.arg_begin() + 1, VCallSite.CS.arg_end())) { -      if (!isa<ConstantInt>(Arg)) -        break; -      Args.push_back(cast<ConstantInt>(&Arg)); -    } -    if (Args.size() + 1 != VCallSite.CS.arg_size()) -      continue; - -    VCallSitesByConstantArg[Args].push_back(VCallSite); -  } - -  for (auto &&CSByConstantArg : VCallSitesByConstantArg) { +  for (auto &&CSByConstantArg : SlotInfo.ConstCSInfo) {      if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first))        continue; -    if (tryUniformRetValOpt(RetType, TargetsForSlot, CSByConstantArg.second)) +    WholeProgramDevirtResolution::ByArg *ResByArg = nullptr; +    if (Res) +      ResByArg = &Res->ResByArg[CSByConstantArg.first]; + +    if (tryUniformRetValOpt(TargetsForSlot, CSByConstantArg.second, ResByArg))        continue; -    if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second)) +    if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second, +                           ResByArg, Slot, CSByConstantArg.first))        continue;      // Find an allocation offset in bits in all vtables associated with the @@ -659,26 +1007,20 @@ bool DevirtModule::tryVirtualConstProp(        for (auto &&Target : TargetsForSlot)          Target.WasDevirt = true; -    // Rewrite each call to a load from OffsetByte/OffsetBit. -    for (auto Call : CSByConstantArg.second) { -      IRBuilder<> B(Call.CS.getInstruction()); -      Value *Addr = B.CreateConstGEP1_64(Call.VTable, OffsetByte); -      if (BitWidth == 1) { -        Value *Bits = B.CreateLoad(Addr); -        Value *Bit = ConstantInt::get(Int8Ty, 1ULL << OffsetBit); -        Value *BitsAndBit = B.CreateAnd(Bits, Bit); -        auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0)); -        Call.replaceAndErase("virtual-const-prop-1-bit", -                             TargetsForSlot[0].Fn->getName(), -                             RemarksEnabled, IsBitSet); -      } else { -        Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo()); -        Value *Val = B.CreateLoad(RetType, ValAddr); -        Call.replaceAndErase("virtual-const-prop", -                             TargetsForSlot[0].Fn->getName(), -                             RemarksEnabled, Val); -      } +    Constant *ByteConst = ConstantInt::get(Int32Ty, OffsetByte); +    Constant *BitConst = ConstantInt::get(Int8Ty, 1ULL << OffsetBit); + +    if (CSByConstantArg.second.isExported()) { +      ResByArg->TheKind = WholeProgramDevirtResolution::ByArg::VirtualConstProp; +      exportGlobal(Slot, CSByConstantArg.first, "byte", +                   ConstantExpr::getIntToPtr(ByteConst, Int8PtrTy)); +      exportGlobal(Slot, CSByConstantArg.first, "bit", +                   ConstantExpr::getIntToPtr(BitConst, Int8PtrTy));      } + +    // Rewrite each call to a load from OffsetByte/OffsetBit. +    applyVirtualConstProp(CSByConstantArg.second, +                          TargetsForSlot[0].Fn->getName(), ByteConst, BitConst);    }    return true;  } @@ -733,7 +1075,11 @@ bool DevirtModule::areRemarksEnabled() {    if (FL.empty())      return false;    const Function &Fn = FL.front(); -  auto DI = OptimizationRemark(DEBUG_TYPE, Fn, DebugLoc(), ""); + +  const auto &BBL = Fn.getBasicBlockList(); +  if (BBL.empty()) +    return false; +  auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &BBL.front());    return DI.isEnabled();  } @@ -766,8 +1112,8 @@ void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc,        Value *Ptr = CI->getArgOperand(0)->stripPointerCasts();        if (SeenPtrs.insert(Ptr).second) {          for (DevirtCallSite Call : DevirtCalls) { -          CallSlots[{TypeId, Call.Offset}].push_back( -              {CI->getArgOperand(0), Call.CS, nullptr}); +          CallSlots[{TypeId, Call.Offset}].addCallSite(CI->getArgOperand(0), +                                                       Call.CS, nullptr);          }        }      } @@ -853,14 +1199,79 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {      if (HasNonCallUses)        ++NumUnsafeUses;      for (DevirtCallSite Call : DevirtCalls) { -      CallSlots[{TypeId, Call.Offset}].push_back( -          {Ptr, Call.CS, &NumUnsafeUses}); +      CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS, +                                                   &NumUnsafeUses);      }      CI->eraseFromParent();    }  } +void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) { +  const TypeIdSummary *TidSummary = +      ImportSummary->getTypeIdSummary(cast<MDString>(Slot.TypeID)->getString()); +  if (!TidSummary) +    return; +  auto ResI = TidSummary->WPDRes.find(Slot.ByteOffset); +  if (ResI == TidSummary->WPDRes.end()) +    return; +  const WholeProgramDevirtResolution &Res = ResI->second; + +  if (Res.TheKind == WholeProgramDevirtResolution::SingleImpl) { +    // The type of the function in the declaration is irrelevant because every +    // call site will cast it to the correct type. +    auto *SingleImpl = M.getOrInsertFunction( +        Res.SingleImplName, Type::getVoidTy(M.getContext())); + +    // This is the import phase so we should not be exporting anything. +    bool IsExported = false; +    applySingleImplDevirt(SlotInfo, SingleImpl, IsExported); +    assert(!IsExported); +  } + +  for (auto &CSByConstantArg : SlotInfo.ConstCSInfo) { +    auto I = Res.ResByArg.find(CSByConstantArg.first); +    if (I == Res.ResByArg.end()) +      continue; +    auto &ResByArg = I->second; +    // FIXME: We should figure out what to do about the "function name" argument +    // to the apply* functions, as the function names are unavailable during the +    // importing phase. For now we just pass the empty string. This does not +    // impact correctness because the function names are just used for remarks. +    switch (ResByArg.TheKind) { +    case WholeProgramDevirtResolution::ByArg::UniformRetVal: +      applyUniformRetValOpt(CSByConstantArg.second, "", ResByArg.Info); +      break; +    case WholeProgramDevirtResolution::ByArg::UniqueRetVal: { +      Constant *UniqueMemberAddr = +          importGlobal(Slot, CSByConstantArg.first, "unique_member"); +      applyUniqueRetValOpt(CSByConstantArg.second, "", ResByArg.Info, +                           UniqueMemberAddr); +      break; +    } +    case WholeProgramDevirtResolution::ByArg::VirtualConstProp: { +      Constant *Byte = importGlobal(Slot, CSByConstantArg.first, "byte", 32); +      Byte = ConstantExpr::getPtrToInt(Byte, Int32Ty); +      Constant *Bit = importGlobal(Slot, CSByConstantArg.first, "bit", 8); +      Bit = ConstantExpr::getPtrToInt(Bit, Int8Ty); +      applyVirtualConstProp(CSByConstantArg.second, "", Byte, Bit); +    } +    default: +      break; +    } +  } +} + +void DevirtModule::removeRedundantTypeTests() { +  auto True = ConstantInt::getTrue(M.getContext()); +  for (auto &&U : NumUnsafeUsesForTypeTest) { +    if (U.second == 0) { +      U.first->replaceAllUsesWith(True); +      U.first->eraseFromParent(); +    } +  } +} +  bool DevirtModule::run() {    Function *TypeTestFunc =        M.getFunction(Intrinsic::getName(Intrinsic::type_test)); @@ -868,7 +1279,11 @@ bool DevirtModule::run() {        M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load));    Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); -  if ((!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc || +  // Normally if there are no users of the devirtualization intrinsics in the +  // module, this pass has nothing to do. But if we are exporting, we also need +  // to handle any users that appear only in the function summaries. +  if (!ExportSummary && +      (!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc ||         AssumeFunc->use_empty()) &&        (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()))      return false; @@ -879,6 +1294,17 @@ bool DevirtModule::run() {    if (TypeCheckedLoadFunc)      scanTypeCheckedLoadUsers(TypeCheckedLoadFunc); +  if (ImportSummary) { +    for (auto &S : CallSlots) +      importResolution(S.first, S.second); + +    removeRedundantTypeTests(); + +    // The rest of the code is only necessary when exporting or during regular +    // LTO, so we are done. +    return true; +  } +    // Rebuild type metadata into a map for easy lookup.    std::vector<VTableBits> Bits;    DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap; @@ -886,6 +1312,53 @@ bool DevirtModule::run() {    if (TypeIdMap.empty())      return true; +  // Collect information from summary about which calls to try to devirtualize. +  if (ExportSummary) { +    DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID; +    for (auto &P : TypeIdMap) { +      if (auto *TypeId = dyn_cast<MDString>(P.first)) +        MetadataByGUID[GlobalValue::getGUID(TypeId->getString())].push_back( +            TypeId); +    } + +    for (auto &P : *ExportSummary) { +      for (auto &S : P.second) { +        auto *FS = dyn_cast<FunctionSummary>(S.get()); +        if (!FS) +          continue; +        // FIXME: Only add live functions. +        for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) { +          for (Metadata *MD : MetadataByGUID[VF.GUID]) { +            CallSlots[{MD, VF.Offset}].CSInfo.SummaryHasTypeTestAssumeUsers = +                true; +          } +        } +        for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) { +          for (Metadata *MD : MetadataByGUID[VF.GUID]) { +            CallSlots[{MD, VF.Offset}] +                .CSInfo.SummaryTypeCheckedLoadUsers.push_back(FS); +          } +        } +        for (const FunctionSummary::ConstVCall &VC : +             FS->type_test_assume_const_vcalls()) { +          for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) { +            CallSlots[{MD, VC.VFunc.Offset}] +                .ConstCSInfo[VC.Args] +                .SummaryHasTypeTestAssumeUsers = true; +          } +        } +        for (const FunctionSummary::ConstVCall &VC : +             FS->type_checked_load_const_vcalls()) { +          for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) { +            CallSlots[{MD, VC.VFunc.Offset}] +                .ConstCSInfo[VC.Args] +                .SummaryTypeCheckedLoadUsers.push_back(FS); +          } +        } +      } +    } +  } +    // For each (type, offset) pair:    bool DidVirtualConstProp = false;    std::map<std::string, Function*> DevirtTargets; @@ -894,19 +1367,39 @@ bool DevirtModule::run() {      // function implementation at offset S.first.ByteOffset, and add to      // TargetsForSlot.      std::vector<VirtualCallTarget> TargetsForSlot; -    if (!tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID], -                                   S.first.ByteOffset)) -      continue; - -    if (!trySingleImplDevirt(TargetsForSlot, S.second) && -        tryVirtualConstProp(TargetsForSlot, S.second)) +    if (tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID], +                                  S.first.ByteOffset)) { +      WholeProgramDevirtResolution *Res = nullptr; +      if (ExportSummary && isa<MDString>(S.first.TypeID)) +        Res = &ExportSummary +                   ->getOrInsertTypeIdSummary( +                       cast<MDString>(S.first.TypeID)->getString()) +                   .WPDRes[S.first.ByteOffset]; + +      if (!trySingleImplDevirt(TargetsForSlot, S.second, Res) && +          tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first))          DidVirtualConstProp = true; -    // Collect functions devirtualized at least for one call site for stats. -    if (RemarksEnabled) -      for (const auto &T : TargetsForSlot) -        if (T.WasDevirt) -          DevirtTargets[T.Fn->getName()] = T.Fn; +      // Collect functions devirtualized at least for one call site for stats. +      if (RemarksEnabled) +        for (const auto &T : TargetsForSlot) +          if (T.WasDevirt) +            DevirtTargets[T.Fn->getName()] = T.Fn; +    } + +    // CFI-specific: if we are exporting and any llvm.type.checked.load +    // intrinsics were *not* devirtualized, we need to add the resulting +    // llvm.type.test intrinsics to the function summaries so that the +    // LowerTypeTests pass will export them. +    if (ExportSummary && isa<MDString>(S.first.TypeID)) { +      auto GUID = +          GlobalValue::getGUID(cast<MDString>(S.first.TypeID)->getString()); +      for (auto FS : S.second.CSInfo.SummaryTypeCheckedLoadUsers) +        FS->addTypeTest(GUID); +      for (auto &CCS : S.second.ConstCSInfo) +        for (auto FS : CCS.second.SummaryTypeCheckedLoadUsers) +          FS->addTypeTest(GUID); +    }    }    if (RemarksEnabled) { @@ -914,23 +1407,12 @@ bool DevirtModule::run() {      for (const auto &DT : DevirtTargets) {        Function *F = DT.second;        DISubprogram *SP = F->getSubprogram(); -      DebugLoc DL = SP ? DebugLoc::get(SP->getScopeLine(), 0, SP) : DebugLoc(); -      emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, DL, +      emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, SP,                               Twine("devirtualized ") + F->getName());      }    } -  // If we were able to eliminate all unsafe uses for a type checked load, -  // eliminate the type test by replacing it with true. -  if (TypeCheckedLoadFunc) { -    auto True = ConstantInt::getTrue(M.getContext()); -    for (auto &&U : NumUnsafeUsesForTypeTest) { -      if (U.second == 0) { -        U.first->replaceAllUsesWith(True); -        U.first->eraseFromParent(); -      } -    } -  } +  removeRedundantTypeTests();    // Rebuild each global we touched as part of virtual constant propagation to    // include the before and after bytes. diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 2d34c1cc74bd..174ec8036274 100644 --- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -902,7 +902,7 @@ bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS,    APInt RHSKnownOne(BitWidth, 0);    computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, &CxtI); -  // Addition of two 2's compliment numbers having opposite signs will never +  // Addition of two 2's complement numbers having opposite signs will never    // overflow.    if ((LHSKnownOne[BitWidth - 1] && RHSKnownZero[BitWidth - 1]) ||        (LHSKnownZero[BitWidth - 1] && RHSKnownOne[BitWidth - 1])) @@ -939,7 +939,7 @@ bool InstCombiner::WillNotOverflowSignedSub(Value *LHS, Value *RHS,    APInt RHSKnownOne(BitWidth, 0);    computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, &CxtI); -  // Subtraction of two 2's compliment numbers having identical signs will +  // Subtraction of two 2's complement numbers having identical signs will    // never overflow.    if ((LHSKnownOne[BitWidth - 1] && RHSKnownOne[BitWidth - 1]) ||        (LHSKnownZero[BitWidth - 1] && RHSKnownZero[BitWidth - 1])) @@ -1042,43 +1042,42 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {    if (Value *V = SimplifyUsingDistributiveLaws(I))      return replaceInstUsesWith(I, V); -  const APInt *Val; -  if (match(RHS, m_APInt(Val))) { -    // X + (signbit) --> X ^ signbit -    if (Val->isSignBit()) +  const APInt *RHSC; +  if (match(RHS, m_APInt(RHSC))) { +    if (RHSC->isSignBit()) { +      // If wrapping is not allowed, then the addition must set the sign bit: +      // X + (signbit) --> X | signbit +      if (I.hasNoSignedWrap() || I.hasNoUnsignedWrap()) +        return BinaryOperator::CreateOr(LHS, RHS); + +      // If wrapping is allowed, then the addition flips the sign bit of LHS: +      // X + (signbit) --> X ^ signbit        return BinaryOperator::CreateXor(LHS, RHS); +    }      // Is this add the last step in a convoluted sext?      Value *X;      const APInt *C;      if (match(LHS, m_ZExt(m_Xor(m_Value(X), m_APInt(C)))) &&          C->isMinSignedValue() && -        C->sext(LHS->getType()->getScalarSizeInBits()) == *Val) { +        C->sext(LHS->getType()->getScalarSizeInBits()) == *RHSC) {        // add(zext(xor i16 X, -32768), -32768) --> sext X        return CastInst::Create(Instruction::SExt, X, LHS->getType());      } -    if (Val->isNegative() && +    if (RHSC->isNegative() &&          match(LHS, m_ZExt(m_NUWAdd(m_Value(X), m_APInt(C)))) && -        Val->sge(-C->sext(Val->getBitWidth()))) { +        RHSC->sge(-C->sext(RHSC->getBitWidth()))) {        // (add (zext (add nuw X, C)), Val) -> (zext (add nuw X, C+Val)) -      return CastInst::Create( -          Instruction::ZExt, -          Builder->CreateNUWAdd( -              X, Constant::getIntegerValue(X->getType(), -                                           *C + Val->trunc(C->getBitWidth()))), -          I.getType()); +      Constant *NewC = +          ConstantInt::get(X->getType(), *C + RHSC->trunc(C->getBitWidth())); +      return new ZExtInst(Builder->CreateNUWAdd(X, NewC), I.getType());      }    }    // FIXME: Use the match above instead of dyn_cast to allow these transforms    // for splat vectors.    if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { -    // See if SimplifyDemandedBits can simplify this.  This handles stuff like -    // (X & 254)+1 -> (X&254)|1 -    if (SimplifyDemandedInstructionBits(I)) -      return &I; -      // zext(bool) + C -> bool ? C + 1 : C      if (ZExtInst *ZI = dyn_cast<ZExtInst>(LHS))        if (ZI->getSrcTy()->isIntegerTy(1)) @@ -1129,8 +1128,8 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {      }    } -  if (isa<Constant>(RHS) && isa<PHINode>(LHS)) -    if (Instruction *NV = FoldOpIntoPhi(I)) +  if (isa<Constant>(RHS)) +    if (Instruction *NV = foldOpWithConstantIntoOperand(I))        return NV;    if (I.getType()->getScalarType()->isIntegerTy(1)) @@ -1201,11 +1200,6 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {          return BinaryOperator::CreateAnd(NewAdd, C2);        }      } - -    // Try to fold constant add into select arguments. -    if (SelectInst *SI = dyn_cast<SelectInst>(LHS)) -      if (Instruction *R = FoldOpIntoSelect(I, SI)) -        return R;    }    // add (select X 0 (sub n A)) A  -->  select X A n @@ -1253,7 +1247,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {      // (add (sext x), (sext y)) --> (sext (add int x, y))      if (SExtInst *RHSConv = dyn_cast<SExtInst>(RHS)) { -      // Only do this if x/y have the same type, if at last one of them has a +      // Only do this if x/y have the same type, if at least one of them has a        // single use (so we don't increase the number of sexts), and if the        // integer add will not overflow.        if (LHSConv->getOperand(0)->getType() == @@ -1290,7 +1284,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {      // (add (zext x), (zext y)) --> (zext (add int x, y))      if (auto *RHSConv = dyn_cast<ZExtInst>(RHS)) { -      // Only do this if x/y have the same type, if at last one of them has a +      // Only do this if x/y have the same type, if at least one of them has a        // single use (so we don't increase the number of zexts), and if the        // integer add will not overflow.        if (LHSConv->getOperand(0)->getType() == @@ -1311,13 +1305,11 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {    {      Value *A = nullptr, *B = nullptr;      if (match(RHS, m_Xor(m_Value(A), m_Value(B))) && -        (match(LHS, m_And(m_Specific(A), m_Specific(B))) || -         match(LHS, m_And(m_Specific(B), m_Specific(A))))) +        match(LHS, m_c_And(m_Specific(A), m_Specific(B))))        return BinaryOperator::CreateOr(A, B);      if (match(LHS, m_Xor(m_Value(A), m_Value(B))) && -        (match(RHS, m_And(m_Specific(A), m_Specific(B))) || -         match(RHS, m_And(m_Specific(B), m_Specific(A))))) +        match(RHS, m_c_And(m_Specific(A), m_Specific(B))))        return BinaryOperator::CreateOr(A, B);    } @@ -1325,8 +1317,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {    {      Value *A = nullptr, *B = nullptr;      if (match(RHS, m_Or(m_Value(A), m_Value(B))) && -        (match(LHS, m_And(m_Specific(A), m_Specific(B))) || -         match(LHS, m_And(m_Specific(B), m_Specific(A))))) { +        match(LHS, m_c_And(m_Specific(A), m_Specific(B)))) {        auto *New = BinaryOperator::CreateAdd(A, B);        New->setHasNoSignedWrap(I.hasNoSignedWrap());        New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); @@ -1334,8 +1325,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {      }      if (match(LHS, m_Or(m_Value(A), m_Value(B))) && -        (match(RHS, m_And(m_Specific(A), m_Specific(B))) || -         match(RHS, m_And(m_Specific(B), m_Specific(A))))) { +        match(RHS, m_c_And(m_Specific(A), m_Specific(B)))) {        auto *New = BinaryOperator::CreateAdd(A, B);        New->setHasNoSignedWrap(I.hasNoSignedWrap());        New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); @@ -1394,6 +1384,8 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {    // Check for (fadd double (sitofp x), y), see if we can merge this into an    // integer add followed by a promotion.    if (SIToFPInst *LHSConv = dyn_cast<SIToFPInst>(LHS)) { +    Value *LHSIntVal = LHSConv->getOperand(0); +      // (fadd double (sitofp x), fpcst) --> (sitofp (add int x, intcst))      // ... if the constant fits in the integer value.  This is useful for things      // like (double)(x & 1234) + 4.0 -> (double)((X & 1234)+4) which no longer @@ -1401,12 +1393,12 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {      // instcombined.      if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHS)) {        Constant *CI = -      ConstantExpr::getFPToSI(CFP, LHSConv->getOperand(0)->getType()); +      ConstantExpr::getFPToSI(CFP, LHSIntVal->getType());        if (LHSConv->hasOneUse() &&            ConstantExpr::getSIToFP(CI, I.getType()) == CFP && -          WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, I)) { +          WillNotOverflowSignedAdd(LHSIntVal, CI, I)) {          // Insert the new integer add. -        Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), +        Value *NewAdd = Builder->CreateNSWAdd(LHSIntVal,                                                CI, "addconv");          return new SIToFPInst(NewAdd, I.getType());        } @@ -1414,17 +1406,17 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {      // (fadd double (sitofp x), (sitofp y)) --> (sitofp (add int x, y))      if (SIToFPInst *RHSConv = dyn_cast<SIToFPInst>(RHS)) { -      // Only do this if x/y have the same type, if at last one of them has a +      Value *RHSIntVal = RHSConv->getOperand(0); + +      // Only do this if x/y have the same type, if at least one of them has a        // single use (so we don't increase the number of int->fp conversions),        // and if the integer add will not overflow. -      if (LHSConv->getOperand(0)->getType() == -              RHSConv->getOperand(0)->getType() && +      if (LHSIntVal->getType() == RHSIntVal->getType() &&            (LHSConv->hasOneUse() || RHSConv->hasOneUse()) && -          WillNotOverflowSignedAdd(LHSConv->getOperand(0), -                                   RHSConv->getOperand(0), I)) { +          WillNotOverflowSignedAdd(LHSIntVal, RHSIntVal, I)) {          // Insert the new integer add. -        Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), -                                              RHSConv->getOperand(0),"addconv"); +        Value *NewAdd = Builder->CreateNSWAdd(LHSIntVal, +                                              RHSIntVal, "addconv");          return new SIToFPInst(NewAdd, I.getType());        }      } @@ -1562,7 +1554,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {      return Res;    } -  if (I.getType()->isIntegerTy(1)) +  if (I.getType()->getScalarType()->isIntegerTy(1))      return BinaryOperator::CreateXor(Op0, Op1);    // Replace (-1 - A) with (~A). @@ -1580,14 +1572,16 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {        if (Instruction *R = FoldOpIntoSelect(I, SI))          return R; +    // Try to fold constant sub into PHI values. +    if (PHINode *PN = dyn_cast<PHINode>(Op1)) +      if (Instruction *R = foldOpIntoPhi(I, PN)) +        return R; +      // C-(X+C2) --> (C-C2)-X      Constant *C2;      if (match(Op1, m_Add(m_Value(X), m_Constant(C2))))        return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X); -    if (SimplifyDemandedInstructionBits(I)) -      return &I; -      // Fold (sub 0, (zext bool to B)) --> (sext bool to B)      if (C->isNullValue() && match(Op1, m_ZExt(m_Value(X))))        if (X->getType()->getScalarType()->isIntegerTy(1)) @@ -1622,11 +1616,11 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {      // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known      // zero. -    if ((*Op0C + 1).isPowerOf2()) { -      APInt KnownZero(BitWidth, 0); -      APInt KnownOne(BitWidth, 0); -      computeKnownBits(&I, KnownZero, KnownOne, 0, &I); -      if ((*Op0C | KnownZero).isAllOnesValue()) +    if (Op0C->isMask()) { +      APInt RHSKnownZero(BitWidth, 0); +      APInt RHSKnownOne(BitWidth, 0); +      computeKnownBits(Op1, RHSKnownZero, RHSKnownOne, 0, &I); +      if ((*Op0C | RHSKnownZero).isAllOnesValue())          return BinaryOperator::CreateXor(Op1, Op0);      }    } @@ -1634,8 +1628,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {    {      Value *Y;      // X-(X+Y) == -Y    X-(Y+X) == -Y -    if (match(Op1, m_Add(m_Specific(Op0), m_Value(Y))) || -        match(Op1, m_Add(m_Value(Y), m_Specific(Op0)))) +    if (match(Op1, m_c_Add(m_Specific(Op0), m_Value(Y))))        return BinaryOperator::CreateNeg(Y);      // (X-Y)-X == -Y @@ -1645,18 +1638,16 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {    // (sub (or A, B) (xor A, B)) --> (and A, B)    { -    Value *A = nullptr, *B = nullptr; +    Value *A, *B;      if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && -        (match(Op0, m_Or(m_Specific(A), m_Specific(B))) || -         match(Op0, m_Or(m_Specific(B), m_Specific(A))))) +        match(Op0, m_c_Or(m_Specific(A), m_Specific(B))))        return BinaryOperator::CreateAnd(A, B);    } -  if (Op0->hasOneUse()) { -    Value *Y = nullptr; +  { +    Value *Y;      // ((X | Y) - X) --> (~X & Y) -    if (match(Op0, m_Or(m_Value(Y), m_Specific(Op1))) || -        match(Op0, m_Or(m_Specific(Op1), m_Value(Y)))) +    if (match(Op0, m_OneUse(m_c_Or(m_Value(Y), m_Specific(Op1)))))        return BinaryOperator::CreateAnd(            Y, Builder->CreateNot(Op1, Op1->getName() + ".not"));    } @@ -1664,7 +1655,6 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {    if (Op1->hasOneUse()) {      Value *X = nullptr, *Y = nullptr, *Z = nullptr;      Constant *C = nullptr; -    Constant *CI = nullptr;      // (X - (Y - Z))  -->  (X + (Z - Y)).      if (match(Op1, m_Sub(m_Value(Y), m_Value(Z)))) @@ -1673,8 +1663,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {      // (X - (X & Y))   -->   (X & ~Y)      // -    if (match(Op1, m_And(m_Value(Y), m_Specific(Op0))) || -        match(Op1, m_And(m_Specific(Op0), m_Value(Y)))) +    if (match(Op1, m_c_And(m_Value(Y), m_Specific(Op0))))        return BinaryOperator::CreateAnd(Op0,                                    Builder->CreateNot(Y, Y->getName() + ".not")); @@ -1702,14 +1691,14 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {      // X - A*-B -> X + A*B      // X - -A*B -> X + A*B      Value *A, *B; -    if (match(Op1, m_Mul(m_Value(A), m_Neg(m_Value(B)))) || -        match(Op1, m_Mul(m_Neg(m_Value(A)), m_Value(B)))) +    Constant *CI; +    if (match(Op1, m_c_Mul(m_Value(A), m_Neg(m_Value(B)))))        return BinaryOperator::CreateAdd(Op0, Builder->CreateMul(A, B));      // X - A*CI -> X + A*-CI -    // X - CI*A -> X + A*-CI -    if (match(Op1, m_Mul(m_Value(A), m_Constant(CI))) || -        match(Op1, m_Mul(m_Constant(CI), m_Value(A)))) { +    // No need to handle commuted multiply because multiply handling will +    // ensure constant will be move to the right hand side. +    if (match(Op1, m_Mul(m_Value(A), m_Constant(CI)))) {        Value *NewMul = Builder->CreateMul(A, ConstantExpr::getNeg(CI));        return BinaryOperator::CreateAdd(Op0, NewMul);      } diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index da5384a86aac..b2a41c699202 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -137,9 +137,8 @@ Value *InstCombiner::SimplifyBSwap(BinaryOperator &I) {  }  /// This handles expressions of the form ((val OP C1) & C2).  Where -/// the Op parameter is 'OP', OpRHS is 'C1', and AndRHS is 'C2'.  Op is -/// guaranteed to be a binary operator. -Instruction *InstCombiner::OptAndOp(Instruction *Op, +/// the Op parameter is 'OP', OpRHS is 'C1', and AndRHS is 'C2'. +Instruction *InstCombiner::OptAndOp(BinaryOperator *Op,                                      ConstantInt *OpRHS,                                      ConstantInt *AndRHS,                                      BinaryOperator &TheAnd) { @@ -149,6 +148,7 @@ Instruction *InstCombiner::OptAndOp(Instruction *Op,      Together = ConstantExpr::getAnd(AndRHS, OpRHS);    switch (Op->getOpcode()) { +  default: break;    case Instruction::Xor:      if (Op->hasOneUse()) {        // (X ^ C1) & C2 --> (X & C2) ^ (C1&C2) @@ -159,13 +159,6 @@ Instruction *InstCombiner::OptAndOp(Instruction *Op,      break;    case Instruction::Or:      if (Op->hasOneUse()){ -      if (Together != OpRHS) { -        // (X | C1) & C2 --> (X | (C1&C2)) & C2 -        Value *Or = Builder->CreateOr(X, Together); -        Or->takeName(Op); -        return BinaryOperator::CreateAnd(Or, AndRHS); -      } -        ConstantInt *TogetherCI = dyn_cast<ConstantInt>(Together);        if (TogetherCI && !TogetherCI->isZero()){          // (X | C1) & C2 --> (X & (C2^(C1&C2))) | C1 @@ -302,178 +295,91 @@ Value *InstCombiner::insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi,    return Builder->CreateICmp(Pred, VMinusLo, HiMinusLo);  } -/// Returns true iff Val consists of one contiguous run of 1s with any number -/// of 0s on either side.  The 1s are allowed to wrap from LSB to MSB, -/// so 0x000FFF0, 0x0000FFFF, and 0xFF0000FF are all runs.  0x0F0F0000 is -/// not, since all 1s are not contiguous. -static bool isRunOfOnes(ConstantInt *Val, uint32_t &MB, uint32_t &ME) { -  const APInt& V = Val->getValue(); -  uint32_t BitWidth = Val->getType()->getBitWidth(); -  if (!APIntOps::isShiftedMask(BitWidth, V)) return false; - -  // look for the first zero bit after the run of ones -  MB = BitWidth - ((V - 1) ^ V).countLeadingZeros(); -  // look for the first non-zero bit -  ME = V.getActiveBits(); -  return true; -} - -/// This is part of an expression (LHS +/- RHS) & Mask, where isSub determines -/// whether the operator is a sub. If we can fold one of the following xforms: +/// Classify (icmp eq (A & B), C) and (icmp ne (A & B), C) as matching patterns +/// that can be simplified. +/// One of A and B is considered the mask. The other is the value. This is +/// described as the "AMask" or "BMask" part of the enum. If the enum contains +/// only "Mask", then both A and B can be considered masks. If A is the mask, +/// then it was proven that (A & C) == C. This is trivial if C == A or C == 0. +/// If both A and C are constants, this proof is also easy. +/// For the following explanations, we assume that A is the mask.  /// -/// ((A & N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == Mask -/// ((A | N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == 0 -/// ((A ^ N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == 0 +/// "AllOnes" declares that the comparison is true only if (A & B) == A or all +/// bits of A are set in B. +///   Example: (icmp eq (A & 3), 3) -> AMask_AllOnes  /// -/// return (A +/- B). +/// "AllZeros" declares that the comparison is true only if (A & B) == 0 or all +/// bits of A are cleared in B. +///   Example: (icmp eq (A & 3), 0) -> Mask_AllZeroes +/// +/// "Mixed" declares that (A & B) == C and C might or might not contain any +/// number of one bits and zero bits. +///   Example: (icmp eq (A & 3), 1) -> AMask_Mixed +/// +/// "Not" means that in above descriptions "==" should be replaced by "!=". +///   Example: (icmp ne (A & 3), 3) -> AMask_NotAllOnes  /// -Value *InstCombiner::FoldLogicalPlusAnd(Value *LHS, Value *RHS, -                                        ConstantInt *Mask, bool isSub, -                                        Instruction &I) { -  Instruction *LHSI = dyn_cast<Instruction>(LHS); -  if (!LHSI || LHSI->getNumOperands() != 2 || -      !isa<ConstantInt>(LHSI->getOperand(1))) return nullptr; - -  ConstantInt *N = cast<ConstantInt>(LHSI->getOperand(1)); - -  switch (LHSI->getOpcode()) { -  default: return nullptr; -  case Instruction::And: -    if (ConstantExpr::getAnd(N, Mask) == Mask) { -      // If the AndRHS is a power of two minus one (0+1+), this is simple. -      if ((Mask->getValue().countLeadingZeros() + -           Mask->getValue().countPopulation()) == -          Mask->getValue().getBitWidth()) -        break; - -      // Otherwise, if Mask is 0+1+0+, and if B is known to have the low 0+ -      // part, we don't need any explicit masks to take them out of A.  If that -      // is all N is, ignore it. -      uint32_t MB = 0, ME = 0; -      if (isRunOfOnes(Mask, MB, ME)) {  // begin/end bit of run, inclusive -        uint32_t BitWidth = cast<IntegerType>(RHS->getType())->getBitWidth(); -        APInt Mask(APInt::getLowBitsSet(BitWidth, MB-1)); -        if (MaskedValueIsZero(RHS, Mask, 0, &I)) -          break; -      } -    } -    return nullptr; -  case Instruction::Or: -  case Instruction::Xor: -    // If the AndRHS is a power of two minus one (0+1+), and N&Mask == 0 -    if ((Mask->getValue().countLeadingZeros() + -         Mask->getValue().countPopulation()) == Mask->getValue().getBitWidth() -        && ConstantExpr::getAnd(N, Mask)->isNullValue()) -      break; -    return nullptr; -  } - -  if (isSub) -    return Builder->CreateSub(LHSI->getOperand(0), RHS, "fold"); -  return Builder->CreateAdd(LHSI->getOperand(0), RHS, "fold"); -} - -/// enum for classifying (icmp eq (A & B), C) and (icmp ne (A & B), C) -/// One of A and B is considered the mask, the other the value. This is -/// described as the "AMask" or "BMask" part of the enum. If the enum -/// contains only "Mask", then both A and B can be considered masks. -/// If A is the mask, then it was proven, that (A & C) == C. This -/// is trivial if C == A, or C == 0. If both A and C are constants, this -/// proof is also easy. -/// For the following explanations we assume that A is the mask. -/// The part "AllOnes" declares, that the comparison is true only -/// if (A & B) == A, or all bits of A are set in B. -///   Example: (icmp eq (A & 3), 3) -> FoldMskICmp_AMask_AllOnes -/// The part "AllZeroes" declares, that the comparison is true only -/// if (A & B) == 0, or all bits of A are cleared in B. -///   Example: (icmp eq (A & 3), 0) -> FoldMskICmp_Mask_AllZeroes -/// The part "Mixed" declares, that (A & B) == C and C might or might not -/// contain any number of one bits and zero bits. -///   Example: (icmp eq (A & 3), 1) -> FoldMskICmp_AMask_Mixed -/// The Part "Not" means, that in above descriptions "==" should be replaced -/// by "!=". -///   Example: (icmp ne (A & 3), 3) -> FoldMskICmp_AMask_NotAllOnes  /// If the mask A contains a single bit, then the following is equivalent:  ///    (icmp eq (A & B), A) equals (icmp ne (A & B), 0)  ///    (icmp ne (A & B), A) equals (icmp eq (A & B), 0)  enum MaskedICmpType { -  FoldMskICmp_AMask_AllOnes           =     1, -  FoldMskICmp_AMask_NotAllOnes        =     2, -  FoldMskICmp_BMask_AllOnes           =     4, -  FoldMskICmp_BMask_NotAllOnes        =     8, -  FoldMskICmp_Mask_AllZeroes          =    16, -  FoldMskICmp_Mask_NotAllZeroes       =    32, -  FoldMskICmp_AMask_Mixed             =    64, -  FoldMskICmp_AMask_NotMixed          =   128, -  FoldMskICmp_BMask_Mixed             =   256, -  FoldMskICmp_BMask_NotMixed          =   512 +  AMask_AllOnes           =     1, +  AMask_NotAllOnes        =     2, +  BMask_AllOnes           =     4, +  BMask_NotAllOnes        =     8, +  Mask_AllZeros           =    16, +  Mask_NotAllZeros        =    32, +  AMask_Mixed             =    64, +  AMask_NotMixed          =   128, +  BMask_Mixed             =   256, +  BMask_NotMixed          =   512  }; -/// Return the set of pattern classes (from MaskedICmpType) -/// that (icmp SCC (A & B), C) satisfies. -static unsigned getTypeOfMaskedICmp(Value* A, Value* B, Value* C, -                                    ICmpInst::Predicate SCC) -{ +/// Return the set of patterns (from MaskedICmpType) that (icmp SCC (A & B), C) +/// satisfies. +static unsigned getMaskedICmpType(Value *A, Value *B, Value *C, +                                  ICmpInst::Predicate Pred) {    ConstantInt *ACst = dyn_cast<ConstantInt>(A);    ConstantInt *BCst = dyn_cast<ConstantInt>(B);    ConstantInt *CCst = dyn_cast<ConstantInt>(C); -  bool icmp_eq = (SCC == ICmpInst::ICMP_EQ); -  bool icmp_abit = (ACst && !ACst->isZero() && -                    ACst->getValue().isPowerOf2()); -  bool icmp_bbit = (BCst && !BCst->isZero() && -                    BCst->getValue().isPowerOf2()); -  unsigned result = 0; +  bool IsEq = (Pred == ICmpInst::ICMP_EQ); +  bool IsAPow2 = (ACst && !ACst->isZero() && ACst->getValue().isPowerOf2()); +  bool IsBPow2 = (BCst && !BCst->isZero() && BCst->getValue().isPowerOf2()); +  unsigned MaskVal = 0;    if (CCst && CCst->isZero()) {      // if C is zero, then both A and B qualify as mask -    result |= (icmp_eq ? (FoldMskICmp_Mask_AllZeroes | -                          FoldMskICmp_AMask_Mixed | -                          FoldMskICmp_BMask_Mixed) -                       : (FoldMskICmp_Mask_NotAllZeroes | -                          FoldMskICmp_AMask_NotMixed | -                          FoldMskICmp_BMask_NotMixed)); -    if (icmp_abit) -      result |= (icmp_eq ? (FoldMskICmp_AMask_NotAllOnes | -                            FoldMskICmp_AMask_NotMixed) -                         : (FoldMskICmp_AMask_AllOnes | -                            FoldMskICmp_AMask_Mixed)); -    if (icmp_bbit) -      result |= (icmp_eq ? (FoldMskICmp_BMask_NotAllOnes | -                            FoldMskICmp_BMask_NotMixed) -                         : (FoldMskICmp_BMask_AllOnes | -                            FoldMskICmp_BMask_Mixed)); -    return result; +    MaskVal |= (IsEq ? (Mask_AllZeros | AMask_Mixed | BMask_Mixed) +                     : (Mask_NotAllZeros | AMask_NotMixed | BMask_NotMixed)); +    if (IsAPow2) +      MaskVal |= (IsEq ? (AMask_NotAllOnes | AMask_NotMixed) +                       : (AMask_AllOnes | AMask_Mixed)); +    if (IsBPow2) +      MaskVal |= (IsEq ? (BMask_NotAllOnes | BMask_NotMixed) +                       : (BMask_AllOnes | BMask_Mixed)); +    return MaskVal;    } +    if (A == C) { -    result |= (icmp_eq ? (FoldMskICmp_AMask_AllOnes | -                          FoldMskICmp_AMask_Mixed) -                       : (FoldMskICmp_AMask_NotAllOnes | -                          FoldMskICmp_AMask_NotMixed)); -    if (icmp_abit) -      result |= (icmp_eq ? (FoldMskICmp_Mask_NotAllZeroes | -                            FoldMskICmp_AMask_NotMixed) -                         : (FoldMskICmp_Mask_AllZeroes | -                            FoldMskICmp_AMask_Mixed)); -  } else if (ACst && CCst && -             ConstantExpr::getAnd(ACst, CCst) == CCst) { -    result |= (icmp_eq ? FoldMskICmp_AMask_Mixed -                       : FoldMskICmp_AMask_NotMixed); +    MaskVal |= (IsEq ? (AMask_AllOnes | AMask_Mixed) +                     : (AMask_NotAllOnes | AMask_NotMixed)); +    if (IsAPow2) +      MaskVal |= (IsEq ? (Mask_NotAllZeros | AMask_NotMixed) +                       : (Mask_AllZeros | AMask_Mixed)); +  } else if (ACst && CCst && ConstantExpr::getAnd(ACst, CCst) == CCst) { +    MaskVal |= (IsEq ? AMask_Mixed : AMask_NotMixed);    } +    if (B == C) { -    result |= (icmp_eq ? (FoldMskICmp_BMask_AllOnes | -                          FoldMskICmp_BMask_Mixed) -                       : (FoldMskICmp_BMask_NotAllOnes | -                          FoldMskICmp_BMask_NotMixed)); -    if (icmp_bbit) -      result |= (icmp_eq ? (FoldMskICmp_Mask_NotAllZeroes | -                            FoldMskICmp_BMask_NotMixed) -                         : (FoldMskICmp_Mask_AllZeroes | -                            FoldMskICmp_BMask_Mixed)); -  } else if (BCst && CCst && -             ConstantExpr::getAnd(BCst, CCst) == CCst) { -    result |= (icmp_eq ? FoldMskICmp_BMask_Mixed -                       : FoldMskICmp_BMask_NotMixed); -  } -  return result; +    MaskVal |= (IsEq ? (BMask_AllOnes | BMask_Mixed) +                     : (BMask_NotAllOnes | BMask_NotMixed)); +    if (IsBPow2) +      MaskVal |= (IsEq ? (Mask_NotAllZeros | BMask_NotMixed) +                       : (Mask_AllZeros | BMask_Mixed)); +  } else if (BCst && CCst && ConstantExpr::getAnd(BCst, CCst) == CCst) { +    MaskVal |= (IsEq ? BMask_Mixed : BMask_NotMixed); +  } + +  return MaskVal;  }  /// Convert an analysis of a masked ICmp into its equivalent if all boolean @@ -482,32 +388,30 @@ static unsigned getTypeOfMaskedICmp(Value* A, Value* B, Value* C,  /// involves swapping those bits over.  static unsigned conjugateICmpMask(unsigned Mask) {    unsigned NewMask; -  NewMask = (Mask & (FoldMskICmp_AMask_AllOnes | FoldMskICmp_BMask_AllOnes | -                     FoldMskICmp_Mask_AllZeroes | FoldMskICmp_AMask_Mixed | -                     FoldMskICmp_BMask_Mixed)) +  NewMask = (Mask & (AMask_AllOnes | BMask_AllOnes | Mask_AllZeros | +                     AMask_Mixed | BMask_Mixed))              << 1; -  NewMask |= -      (Mask & (FoldMskICmp_AMask_NotAllOnes | FoldMskICmp_BMask_NotAllOnes | -               FoldMskICmp_Mask_NotAllZeroes | FoldMskICmp_AMask_NotMixed | -               FoldMskICmp_BMask_NotMixed)) -      >> 1; +  NewMask |= (Mask & (AMask_NotAllOnes | BMask_NotAllOnes | Mask_NotAllZeros | +                      AMask_NotMixed | BMask_NotMixed)) +             >> 1;    return NewMask;  } -/// Handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) -/// Return the set of pattern classes (from MaskedICmpType) -/// that both LHS and RHS satisfy. -static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, -                                             Value*& B, Value*& C, -                                             Value*& D, Value*& E, -                                             ICmpInst *LHS, ICmpInst *RHS, -                                             ICmpInst::Predicate &LHSCC, -                                             ICmpInst::Predicate &RHSCC) { -  if (LHS->getOperand(0)->getType() != RHS->getOperand(0)->getType()) return 0; +/// Handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E). +/// Return the set of pattern classes (from MaskedICmpType) that both LHS and +/// RHS satisfy. +static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, +                                         Value *&D, Value *&E, ICmpInst *LHS, +                                         ICmpInst *RHS, +                                         ICmpInst::Predicate &PredL, +                                         ICmpInst::Predicate &PredR) { +  if (LHS->getOperand(0)->getType() != RHS->getOperand(0)->getType()) +    return 0;    // vectors are not (yet?) supported -  if (LHS->getOperand(0)->getType()->isVectorTy()) return 0; +  if (LHS->getOperand(0)->getType()->isVectorTy()) +    return 0;    // Here comes the tricky part:    // LHS might be of the form L11 & L12 == X, X == L21 & L22, @@ -517,9 +421,9 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A,    // above.    Value *L1 = LHS->getOperand(0);    Value *L2 = LHS->getOperand(1); -  Value *L11,*L12,*L21,*L22; +  Value *L11, *L12, *L21, *L22;    // Check whether the icmp can be decomposed into a bit test. -  if (decomposeBitTestICmp(LHS, LHSCC, L11, L12, L2)) { +  if (decomposeBitTestICmp(LHS, PredL, L11, L12, L2)) {      L21 = L22 = L1 = nullptr;    } else {      // Look for ANDs in the LHS icmp. @@ -543,22 +447,26 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A,    }    // Bail if LHS was a icmp that can't be decomposed into an equality. -  if (!ICmpInst::isEquality(LHSCC)) +  if (!ICmpInst::isEquality(PredL))      return 0;    Value *R1 = RHS->getOperand(0);    Value *R2 = RHS->getOperand(1); -  Value *R11,*R12; -  bool ok = false; -  if (decomposeBitTestICmp(RHS, RHSCC, R11, R12, R2)) { +  Value *R11, *R12; +  bool Ok = false; +  if (decomposeBitTestICmp(RHS, PredR, R11, R12, R2)) {      if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { -      A = R11; D = R12; +      A = R11; +      D = R12;      } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) { -      A = R12; D = R11; +      A = R12; +      D = R11;      } else {        return 0;      } -    E = R2; R1 = nullptr; ok = true; +    E = R2; +    R1 = nullptr; +    Ok = true;    } else if (R1->getType()->isIntegerTy()) {      if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) {        // As before, model no mask as a trivial mask if it'll let us do an @@ -568,46 +476,62 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A,      }      if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { -      A = R11; D = R12; E = R2; ok = true; +      A = R11; +      D = R12; +      E = R2; +      Ok = true;      } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) { -      A = R12; D = R11; E = R2; ok = true; +      A = R12; +      D = R11; +      E = R2; +      Ok = true;      }    }    // Bail if RHS was a icmp that can't be decomposed into an equality. -  if (!ICmpInst::isEquality(RHSCC)) +  if (!ICmpInst::isEquality(PredR))      return 0;    // Look for ANDs on the right side of the RHS icmp. -  if (!ok && R2->getType()->isIntegerTy()) { +  if (!Ok && R2->getType()->isIntegerTy()) {      if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) {        R11 = R2;        R12 = Constant::getAllOnesValue(R2->getType());      }      if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { -      A = R11; D = R12; E = R1; ok = true; +      A = R11; +      D = R12; +      E = R1; +      Ok = true;      } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) { -      A = R12; D = R11; E = R1; ok = true; +      A = R12; +      D = R11; +      E = R1; +      Ok = true;      } else {        return 0;      }    } -  if (!ok) +  if (!Ok)      return 0;    if (L11 == A) { -    B = L12; C = L2; +    B = L12; +    C = L2;    } else if (L12 == A) { -    B = L11; C = L2; +    B = L11; +    C = L2;    } else if (L21 == A) { -    B = L22; C = L1; +    B = L22; +    C = L1;    } else if (L22 == A) { -    B = L21; C = L1; +    B = L21; +    C = L1;    } -  unsigned LeftType = getTypeOfMaskedICmp(A, B, C, LHSCC); -  unsigned RightType = getTypeOfMaskedICmp(A, D, E, RHSCC); +  unsigned LeftType = getMaskedICmpType(A, B, C, PredL); +  unsigned RightType = getMaskedICmpType(A, D, E, PredR);    return LeftType & RightType;  } @@ -616,12 +540,14 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A,  static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,                                       llvm::InstCombiner::BuilderTy *Builder) {    Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr; -  ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); -  unsigned Mask = foldLogOpOfMaskedICmpsHelper(A, B, C, D, E, LHS, RHS, -                                               LHSCC, RHSCC); -  if (Mask == 0) return nullptr; -  assert(ICmpInst::isEquality(LHSCC) && ICmpInst::isEquality(RHSCC) && -         "foldLogOpOfMaskedICmpsHelper must return an equality predicate."); +  ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); +  unsigned Mask = +      getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR); +  if (Mask == 0) +    return nullptr; + +  assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) && +         "Expected equality predicates for masked type of icmps.");    // In full generality:    //     (icmp (A & B) Op C) | (icmp (A & D) Op E) @@ -642,7 +568,7 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,      Mask = conjugateICmpMask(Mask);    } -  if (Mask & FoldMskICmp_Mask_AllZeroes) { +  if (Mask & Mask_AllZeros) {      // (icmp eq (A & B), 0) & (icmp eq (A & D), 0)      // -> (icmp eq (A & (B|D)), 0)      Value *NewOr = Builder->CreateOr(B, D); @@ -653,14 +579,14 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,      Value *Zero = Constant::getNullValue(A->getType());      return Builder->CreateICmp(NewCC, NewAnd, Zero);    } -  if (Mask & FoldMskICmp_BMask_AllOnes) { +  if (Mask & BMask_AllOnes) {      // (icmp eq (A & B), B) & (icmp eq (A & D), D)      // -> (icmp eq (A & (B|D)), (B|D))      Value *NewOr = Builder->CreateOr(B, D);      Value *NewAnd = Builder->CreateAnd(A, NewOr);      return Builder->CreateICmp(NewCC, NewAnd, NewOr);    } -  if (Mask & FoldMskICmp_AMask_AllOnes) { +  if (Mask & AMask_AllOnes) {      // (icmp eq (A & B), A) & (icmp eq (A & D), A)      // -> (icmp eq (A & (B&D)), A)      Value *NewAnd1 = Builder->CreateAnd(B, D); @@ -672,11 +598,13 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,    // their actual values. This isn't strictly necessary, just a "handle the    // easy cases for now" decision.    ConstantInt *BCst = dyn_cast<ConstantInt>(B); -  if (!BCst) return nullptr; +  if (!BCst) +    return nullptr;    ConstantInt *DCst = dyn_cast<ConstantInt>(D); -  if (!DCst) return nullptr; +  if (!DCst) +    return nullptr; -  if (Mask & (FoldMskICmp_Mask_NotAllZeroes | FoldMskICmp_BMask_NotAllOnes)) { +  if (Mask & (Mask_NotAllZeros | BMask_NotAllOnes)) {      // (icmp ne (A & B), 0) & (icmp ne (A & D), 0) and      // (icmp ne (A & B), B) & (icmp ne (A & D), D)      //     -> (icmp ne (A & B), 0) or (icmp ne (A & D), 0) @@ -689,7 +617,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,      else if (NewMask == DCst->getValue())        return RHS;    } -  if (Mask & FoldMskICmp_AMask_NotAllOnes) { + +  if (Mask & AMask_NotAllOnes) {      // (icmp ne (A & B), B) & (icmp ne (A & D), D)      //     -> (icmp ne (A & B), A) or (icmp ne (A & D), A)      // Only valid if one of the masks is a superset of the other (check "B|D" is @@ -701,7 +630,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,      else if (NewMask == DCst->getValue())        return RHS;    } -  if (Mask & FoldMskICmp_BMask_Mixed) { + +  if (Mask & BMask_Mixed) {      // (icmp eq (A & B), C) & (icmp eq (A & D), E)      // We already know that B & C == C && D & E == E.      // If we can prove that (B & D) & (C ^ E) == 0, that is, the bits of @@ -713,23 +643,28 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,      //   (icmp ne (A & B), B) & (icmp eq (A & D), D)      // with B and D, having a single bit set.      ConstantInt *CCst = dyn_cast<ConstantInt>(C); -    if (!CCst) return nullptr; +    if (!CCst) +      return nullptr;      ConstantInt *ECst = dyn_cast<ConstantInt>(E); -    if (!ECst) return nullptr; -    if (LHSCC != NewCC) +    if (!ECst) +      return nullptr; +    if (PredL != NewCC)        CCst = cast<ConstantInt>(ConstantExpr::getXor(BCst, CCst)); -    if (RHSCC != NewCC) +    if (PredR != NewCC)        ECst = cast<ConstantInt>(ConstantExpr::getXor(DCst, ECst)); +      // If there is a conflict, we should actually return a false for the      // whole construct.      if (((BCst->getValue() & DCst->getValue()) &           (CCst->getValue() ^ ECst->getValue())) != 0)        return ConstantInt::get(LHS->getType(), !IsAnd); +      Value *NewOr1 = Builder->CreateOr(B, D);      Value *NewOr2 = ConstantExpr::getOr(CCst, ECst);      Value *NewAnd = Builder->CreateAnd(A, NewOr1);      return Builder->CreateICmp(NewCC, NewAnd, NewOr2);    } +    return nullptr;  } @@ -789,12 +724,67 @@ Value *InstCombiner::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1,    return Builder->CreateICmp(NewPred, Input, RangeEnd);  } +static Value * +foldAndOrOfEqualityCmpsWithConstants(ICmpInst *LHS, ICmpInst *RHS, +                                     bool JoinedByAnd, +                                     InstCombiner::BuilderTy *Builder) { +  Value *X = LHS->getOperand(0); +  if (X != RHS->getOperand(0)) +    return nullptr; + +  const APInt *C1, *C2; +  if (!match(LHS->getOperand(1), m_APInt(C1)) || +      !match(RHS->getOperand(1), m_APInt(C2))) +    return nullptr; + +  // We only handle (X != C1 && X != C2) and (X == C1 || X == C2). +  ICmpInst::Predicate Pred = LHS->getPredicate(); +  if (Pred !=  RHS->getPredicate()) +    return nullptr; +  if (JoinedByAnd && Pred != ICmpInst::ICMP_NE) +    return nullptr; +  if (!JoinedByAnd && Pred != ICmpInst::ICMP_EQ) +    return nullptr; + +  // The larger unsigned constant goes on the right. +  if (C1->ugt(*C2)) +    std::swap(C1, C2); + +  APInt Xor = *C1 ^ *C2; +  if (Xor.isPowerOf2()) { +    // If LHSC and RHSC differ by only one bit, then set that bit in X and +    // compare against the larger constant: +    // (X == C1 || X == C2) --> (X | (C1 ^ C2)) == C2 +    // (X != C1 && X != C2) --> (X | (C1 ^ C2)) != C2 +    // We choose an 'or' with a Pow2 constant rather than the inverse mask with +    // 'and' because that may lead to smaller codegen from a smaller constant. +    Value *Or = Builder->CreateOr(X, ConstantInt::get(X->getType(), Xor)); +    return Builder->CreateICmp(Pred, Or, ConstantInt::get(X->getType(), *C2)); +  } + +  // Special case: get the ordering right when the values wrap around zero. +  // Ie, we assumed the constants were unsigned when swapping earlier. +  if (*C1 == 0 && C2->isAllOnesValue()) +    std::swap(C1, C2); + +  if (*C1 == *C2 - 1) { +    // (X == 13 || X == 14) --> X - 13 <=u 1 +    // (X != 13 && X != 14) --> X - 13  >u 1 +    // An 'add' is the canonical IR form, so favor that over a 'sub'. +    Value *Add = Builder->CreateAdd(X, ConstantInt::get(X->getType(), -(*C1))); +    auto NewPred = JoinedByAnd ? ICmpInst::ICMP_UGT : ICmpInst::ICMP_ULE; +    return Builder->CreateICmp(NewPred, Add, ConstantInt::get(X->getType(), 1)); +  } + +  return nullptr; +} +  /// Fold (icmp)&(icmp) if possible.  Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { -  ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); +  ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();    // (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B) -  if (PredicatesFoldable(LHSCC, RHSCC)) { +  if (PredicatesFoldable(PredL, PredR)) {      if (LHS->getOperand(0) == RHS->getOperand(1) &&          LHS->getOperand(1) == RHS->getOperand(0))        LHS->swapOperands(); @@ -819,86 +809,90 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) {    if (Value *V = simplifyRangeCheck(RHS, LHS, /*Inverted=*/false))      return V; +  if (Value *V = foldAndOrOfEqualityCmpsWithConstants(LHS, RHS, true, Builder)) +    return V; +    // This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2). -  Value *Val = LHS->getOperand(0), *Val2 = RHS->getOperand(0); -  ConstantInt *LHSCst = dyn_cast<ConstantInt>(LHS->getOperand(1)); -  ConstantInt *RHSCst = dyn_cast<ConstantInt>(RHS->getOperand(1)); -  if (!LHSCst || !RHSCst) return nullptr; +  Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); +  ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1)); +  ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1)); +  if (!LHSC || !RHSC) +    return nullptr; -  if (LHSCst == RHSCst && LHSCC == RHSCC) { +  if (LHSC == RHSC && PredL == PredR) {      // (icmp ult A, C) & (icmp ult B, C) --> (icmp ult (A|B), C)      // where C is a power of 2 or      // (icmp eq A, 0) & (icmp eq B, 0) --> (icmp eq (A|B), 0) -    if ((LHSCC == ICmpInst::ICMP_ULT && LHSCst->getValue().isPowerOf2()) || -        (LHSCC == ICmpInst::ICMP_EQ && LHSCst->isZero())) { -      Value *NewOr = Builder->CreateOr(Val, Val2); -      return Builder->CreateICmp(LHSCC, NewOr, LHSCst); +    if ((PredL == ICmpInst::ICMP_ULT && LHSC->getValue().isPowerOf2()) || +        (PredL == ICmpInst::ICMP_EQ && LHSC->isZero())) { +      Value *NewOr = Builder->CreateOr(LHS0, RHS0); +      return Builder->CreateICmp(PredL, NewOr, LHSC);      }    }    // (trunc x) == C1 & (and x, CA) == C2 -> (and x, CA|CMAX) == C1|C2    // where CMAX is the all ones value for the truncated type,    // iff the lower bits of C2 and CA are zero. -  if (LHSCC == ICmpInst::ICMP_EQ && LHSCC == RHSCC && -      LHS->hasOneUse() && RHS->hasOneUse()) { +  if (PredL == ICmpInst::ICMP_EQ && PredL == PredR && LHS->hasOneUse() && +      RHS->hasOneUse()) {      Value *V; -    ConstantInt *AndCst, *SmallCst = nullptr, *BigCst = nullptr; +    ConstantInt *AndC, *SmallC = nullptr, *BigC = nullptr;      // (trunc x) == C1 & (and x, CA) == C2      // (and x, CA) == C2 & (trunc x) == C1 -    if (match(Val2, m_Trunc(m_Value(V))) && -        match(Val, m_And(m_Specific(V), m_ConstantInt(AndCst)))) { -      SmallCst = RHSCst; -      BigCst = LHSCst; -    } else if (match(Val, m_Trunc(m_Value(V))) && -               match(Val2, m_And(m_Specific(V), m_ConstantInt(AndCst)))) { -      SmallCst = LHSCst; -      BigCst = RHSCst; +    if (match(RHS0, m_Trunc(m_Value(V))) && +        match(LHS0, m_And(m_Specific(V), m_ConstantInt(AndC)))) { +      SmallC = RHSC; +      BigC = LHSC; +    } else if (match(LHS0, m_Trunc(m_Value(V))) && +               match(RHS0, m_And(m_Specific(V), m_ConstantInt(AndC)))) { +      SmallC = LHSC; +      BigC = RHSC;      } -    if (SmallCst && BigCst) { -      unsigned BigBitSize = BigCst->getType()->getBitWidth(); -      unsigned SmallBitSize = SmallCst->getType()->getBitWidth(); +    if (SmallC && BigC) { +      unsigned BigBitSize = BigC->getType()->getBitWidth(); +      unsigned SmallBitSize = SmallC->getType()->getBitWidth();        // Check that the low bits are zero.        APInt Low = APInt::getLowBitsSet(BigBitSize, SmallBitSize); -      if ((Low & AndCst->getValue()) == 0 && (Low & BigCst->getValue()) == 0) { -        Value *NewAnd = Builder->CreateAnd(V, Low | AndCst->getValue()); -        APInt N = SmallCst->getValue().zext(BigBitSize) | BigCst->getValue(); -        Value *NewVal = ConstantInt::get(AndCst->getType()->getContext(), N); -        return Builder->CreateICmp(LHSCC, NewAnd, NewVal); +      if ((Low & AndC->getValue()) == 0 && (Low & BigC->getValue()) == 0) { +        Value *NewAnd = Builder->CreateAnd(V, Low | AndC->getValue()); +        APInt N = SmallC->getValue().zext(BigBitSize) | BigC->getValue(); +        Value *NewVal = ConstantInt::get(AndC->getType()->getContext(), N); +        return Builder->CreateICmp(PredL, NewAnd, NewVal);        }      }    }    // From here on, we only handle:    //    (icmp1 A, C1) & (icmp2 A, C2) --> something simpler. -  if (Val != Val2) return nullptr; +  if (LHS0 != RHS0) +    return nullptr; -  // ICMP_[US][GL]E X, CST is folded to ICMP_[US][GL]T elsewhere. -  if (LHSCC == ICmpInst::ICMP_UGE || LHSCC == ICmpInst::ICMP_ULE || -      RHSCC == ICmpInst::ICMP_UGE || RHSCC == ICmpInst::ICMP_ULE || -      LHSCC == ICmpInst::ICMP_SGE || LHSCC == ICmpInst::ICMP_SLE || -      RHSCC == ICmpInst::ICMP_SGE || RHSCC == ICmpInst::ICMP_SLE) +  // ICMP_[US][GL]E X, C is folded to ICMP_[US][GL]T elsewhere. +  if (PredL == ICmpInst::ICMP_UGE || PredL == ICmpInst::ICMP_ULE || +      PredR == ICmpInst::ICMP_UGE || PredR == ICmpInst::ICMP_ULE || +      PredL == ICmpInst::ICMP_SGE || PredL == ICmpInst::ICMP_SLE || +      PredR == ICmpInst::ICMP_SGE || PredR == ICmpInst::ICMP_SLE)      return nullptr;    // We can't fold (ugt x, C) & (sgt x, C2). -  if (!PredicatesFoldable(LHSCC, RHSCC)) +  if (!PredicatesFoldable(PredL, PredR))      return nullptr;    // Ensure that the larger constant is on the RHS.    bool ShouldSwap; -  if (CmpInst::isSigned(LHSCC) || -      (ICmpInst::isEquality(LHSCC) && -       CmpInst::isSigned(RHSCC))) -    ShouldSwap = LHSCst->getValue().sgt(RHSCst->getValue()); +  if (CmpInst::isSigned(PredL) || +      (ICmpInst::isEquality(PredL) && CmpInst::isSigned(PredR))) +    ShouldSwap = LHSC->getValue().sgt(RHSC->getValue());    else -    ShouldSwap = LHSCst->getValue().ugt(RHSCst->getValue()); +    ShouldSwap = LHSC->getValue().ugt(RHSC->getValue());    if (ShouldSwap) {      std::swap(LHS, RHS); -    std::swap(LHSCst, RHSCst); -    std::swap(LHSCC, RHSCC); +    std::swap(LHSC, RHSC); +    std::swap(PredL, PredR);    }    // At this point, we know we have two icmp instructions @@ -907,113 +901,95 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) {    // icmp eq, icmp ne, icmp [su]lt, and icmp [SU]gt here. We also know    // (from the icmp folding check above), that the two constants    // are not equal and that the larger constant is on the RHS -  assert(LHSCst != RHSCst && "Compares not folded above?"); +  assert(LHSC != RHSC && "Compares not folded above?"); -  switch (LHSCC) { -  default: llvm_unreachable("Unknown integer condition code!"); +  switch (PredL) { +  default: +    llvm_unreachable("Unknown integer condition code!");    case ICmpInst::ICMP_EQ: -    switch (RHSCC) { -    default: llvm_unreachable("Unknown integer condition code!"); -    case ICmpInst::ICMP_NE:         // (X == 13 & X != 15) -> X == 13 -    case ICmpInst::ICMP_ULT:        // (X == 13 & X <  15) -> X == 13 -    case ICmpInst::ICMP_SLT:        // (X == 13 & X <  15) -> X == 13 +    switch (PredR) { +    default: +      llvm_unreachable("Unknown integer condition code!"); +    case ICmpInst::ICMP_NE:  // (X == 13 & X != 15) -> X == 13 +    case ICmpInst::ICMP_ULT: // (X == 13 & X <  15) -> X == 13 +    case ICmpInst::ICMP_SLT: // (X == 13 & X <  15) -> X == 13        return LHS;      }    case ICmpInst::ICMP_NE: -    switch (RHSCC) { -    default: llvm_unreachable("Unknown integer condition code!"); +    switch (PredR) { +    default: +      llvm_unreachable("Unknown integer condition code!");      case ICmpInst::ICMP_ULT: -      if (LHSCst == SubOne(RHSCst)) // (X != 13 & X u< 14) -> X < 13 -        return Builder->CreateICmpULT(Val, LHSCst); -      if (LHSCst->isNullValue())    // (X !=  0 & X u< 14) -> X-1 u< 13 -        return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(), +      if (LHSC == SubOne(RHSC)) // (X != 13 & X u< 14) -> X < 13 +        return Builder->CreateICmpULT(LHS0, LHSC); +      if (LHSC->isNullValue()) // (X !=  0 & X u< 14) -> X-1 u< 13 +        return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(),                                 false, true); -      break;                        // (X != 13 & X u< 15) -> no change +      break; // (X != 13 & X u< 15) -> no change      case ICmpInst::ICMP_SLT: -      if (LHSCst == SubOne(RHSCst)) // (X != 13 & X s< 14) -> X < 13 -        return Builder->CreateICmpSLT(Val, LHSCst); -      break;                        // (X != 13 & X s< 15) -> no change -    case ICmpInst::ICMP_EQ:         // (X != 13 & X == 15) -> X == 15 -    case ICmpInst::ICMP_UGT:        // (X != 13 & X u> 15) -> X u> 15 -    case ICmpInst::ICMP_SGT:        // (X != 13 & X s> 15) -> X s> 15 +      if (LHSC == SubOne(RHSC)) // (X != 13 & X s< 14) -> X < 13 +        return Builder->CreateICmpSLT(LHS0, LHSC); +      break;                 // (X != 13 & X s< 15) -> no change +    case ICmpInst::ICMP_EQ:  // (X != 13 & X == 15) -> X == 15 +    case ICmpInst::ICMP_UGT: // (X != 13 & X u> 15) -> X u> 15 +    case ICmpInst::ICMP_SGT: // (X != 13 & X s> 15) -> X s> 15        return RHS;      case ICmpInst::ICMP_NE: -      // Special case to get the ordering right when the values wrap around -      // zero. -      if (LHSCst->getValue() == 0 && RHSCst->getValue().isAllOnesValue()) -        std::swap(LHSCst, RHSCst); -      if (LHSCst == SubOne(RHSCst)){// (X != 13 & X != 14) -> X-13 >u 1 -        Constant *AddCST = ConstantExpr::getNeg(LHSCst); -        Value *Add = Builder->CreateAdd(Val, AddCST, Val->getName()+".off"); -        return Builder->CreateICmpUGT(Add, ConstantInt::get(Add->getType(), 1), -                                      Val->getName()+".cmp"); -      } -      break;                        // (X != 13 & X != 15) -> no change +      // Potential folds for this case should already be handled. +      break;      }      break;    case ICmpInst::ICMP_ULT: -    switch (RHSCC) { -    default: llvm_unreachable("Unknown integer condition code!"); -    case ICmpInst::ICMP_EQ:         // (X u< 13 & X == 15) -> false -    case ICmpInst::ICMP_UGT:        // (X u< 13 & X u> 15) -> false +    switch (PredR) { +    default: +      llvm_unreachable("Unknown integer condition code!"); +    case ICmpInst::ICMP_EQ:  // (X u< 13 & X == 15) -> false +    case ICmpInst::ICMP_UGT: // (X u< 13 & X u> 15) -> false        return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0); -    case ICmpInst::ICMP_SGT:        // (X u< 13 & X s> 15) -> no change -      break; -    case ICmpInst::ICMP_NE:         // (X u< 13 & X != 15) -> X u< 13 -    case ICmpInst::ICMP_ULT:        // (X u< 13 & X u< 15) -> X u< 13 +    case ICmpInst::ICMP_NE:  // (X u< 13 & X != 15) -> X u< 13 +    case ICmpInst::ICMP_ULT: // (X u< 13 & X u< 15) -> X u< 13        return LHS; -    case ICmpInst::ICMP_SLT:        // (X u< 13 & X s< 15) -> no change -      break;      }      break;    case ICmpInst::ICMP_SLT: -    switch (RHSCC) { -    default: llvm_unreachable("Unknown integer condition code!"); -    case ICmpInst::ICMP_UGT:        // (X s< 13 & X u> 15) -> no change -      break; -    case ICmpInst::ICMP_NE:         // (X s< 13 & X != 15) -> X < 13 -    case ICmpInst::ICMP_SLT:        // (X s< 13 & X s< 15) -> X < 13 +    switch (PredR) { +    default: +      llvm_unreachable("Unknown integer condition code!"); +    case ICmpInst::ICMP_NE:  // (X s< 13 & X != 15) -> X < 13 +    case ICmpInst::ICMP_SLT: // (X s< 13 & X s< 15) -> X < 13        return LHS; -    case ICmpInst::ICMP_ULT:        // (X s< 13 & X u< 15) -> no change -      break;      }      break;    case ICmpInst::ICMP_UGT: -    switch (RHSCC) { -    default: llvm_unreachable("Unknown integer condition code!"); -    case ICmpInst::ICMP_EQ:         // (X u> 13 & X == 15) -> X == 15 -    case ICmpInst::ICMP_UGT:        // (X u> 13 & X u> 15) -> X u> 15 +    switch (PredR) { +    default: +      llvm_unreachable("Unknown integer condition code!"); +    case ICmpInst::ICMP_EQ:  // (X u> 13 & X == 15) -> X == 15 +    case ICmpInst::ICMP_UGT: // (X u> 13 & X u> 15) -> X u> 15        return RHS; -    case ICmpInst::ICMP_SGT:        // (X u> 13 & X s> 15) -> no change -      break;      case ICmpInst::ICMP_NE: -      if (RHSCst == AddOne(LHSCst)) // (X u> 13 & X != 14) -> X u> 14 -        return Builder->CreateICmp(LHSCC, Val, RHSCst); -      break;                        // (X u> 13 & X != 15) -> no change -    case ICmpInst::ICMP_ULT:        // (X u> 13 & X u< 15) -> (X-14) <u 1 -      return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(), +      if (RHSC == AddOne(LHSC)) // (X u> 13 & X != 14) -> X u> 14 +        return Builder->CreateICmp(PredL, LHS0, RHSC); +      break;                 // (X u> 13 & X != 15) -> no change +    case ICmpInst::ICMP_ULT: // (X u> 13 & X u< 15) -> (X-14) <u 1 +      return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(),                               false, true); -    case ICmpInst::ICMP_SLT:        // (X u> 13 & X s< 15) -> no change -      break;      }      break;    case ICmpInst::ICMP_SGT: -    switch (RHSCC) { -    default: llvm_unreachable("Unknown integer condition code!"); -    case ICmpInst::ICMP_EQ:         // (X s> 13 & X == 15) -> X == 15 -    case ICmpInst::ICMP_SGT:        // (X s> 13 & X s> 15) -> X s> 15 +    switch (PredR) { +    default: +      llvm_unreachable("Unknown integer condition code!"); +    case ICmpInst::ICMP_EQ:  // (X s> 13 & X == 15) -> X == 15 +    case ICmpInst::ICMP_SGT: // (X s> 13 & X s> 15) -> X s> 15        return RHS; -    case ICmpInst::ICMP_UGT:        // (X s> 13 & X u> 15) -> no change -      break;      case ICmpInst::ICMP_NE: -      if (RHSCst == AddOne(LHSCst)) // (X s> 13 & X != 14) -> X s> 14 -        return Builder->CreateICmp(LHSCC, Val, RHSCst); -      break;                        // (X s> 13 & X != 15) -> no change -    case ICmpInst::ICMP_SLT:        // (X s> 13 & X s< 15) -> (X-14) s< 1 -      return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(), -                             true, true); -    case ICmpInst::ICMP_ULT:        // (X s> 13 & X u< 15) -> no change -      break; +      if (RHSC == AddOne(LHSC)) // (X s> 13 & X != 14) -> X s> 14 +        return Builder->CreateICmp(PredL, LHS0, RHSC); +      break;                 // (X s> 13 & X != 15) -> no change +    case ICmpInst::ICMP_SLT: // (X s> 13 & X s< 15) -> (X-14) s< 1 +      return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), true, +                             true);      }      break;    } @@ -1314,39 +1290,11 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {          break;        } -      case Instruction::Add: -        // ((A & N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == AndRHS. -        // ((A | N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == 0 -        // ((A ^ N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == 0 -        if (Value *V = FoldLogicalPlusAnd(Op0LHS, Op0RHS, AndRHS, false, I)) -          return BinaryOperator::CreateAnd(V, AndRHS); -        if (Value *V = FoldLogicalPlusAnd(Op0RHS, Op0LHS, AndRHS, false, I)) -          return BinaryOperator::CreateAnd(V, AndRHS);  // Add commutes -        break; -        case Instruction::Sub: -        // ((A & N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == AndRHS. -        // ((A | N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == 0 -        // ((A ^ N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == 0 -        if (Value *V = FoldLogicalPlusAnd(Op0LHS, Op0RHS, AndRHS, true, I)) -          return BinaryOperator::CreateAnd(V, AndRHS); -          // -x & 1 -> x & 1          if (AndRHSMask == 1 && match(Op0LHS, m_Zero()))            return BinaryOperator::CreateAnd(Op0RHS, AndRHS); -        // (A - N) & AndRHS -> -N & AndRHS iff A&AndRHS==0 and AndRHS -        // has 1's for all bits that the subtraction with A might affect. -        if (Op0I->hasOneUse() && !match(Op0LHS, m_Zero())) { -          uint32_t BitWidth = AndRHSMask.getBitWidth(); -          uint32_t Zeros = AndRHSMask.countLeadingZeros(); -          APInt Mask = APInt::getLowBitsSet(BitWidth, BitWidth - Zeros); - -          if (MaskedValueIsZero(Op0LHS, Mask, 0, &I)) { -            Value *NewNeg = Builder->CreateNeg(Op0RHS); -            return BinaryOperator::CreateAnd(NewNeg, AndRHS); -          } -        }          break;        case Instruction::Shl: @@ -1361,6 +1309,33 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {          break;        } +      // ((C1 OP zext(X)) & C2) -> zext((C1-X) & C2) if C2 fits in the bitwidth +      // of X and OP behaves well when given trunc(C1) and X. +      switch (Op0I->getOpcode()) { +      default: +        break; +      case Instruction::Xor: +      case Instruction::Or: +      case Instruction::Mul: +      case Instruction::Add: +      case Instruction::Sub: +        Value *X; +        ConstantInt *C1; +        if (match(Op0I, m_c_BinOp(m_ZExt(m_Value(X)), m_ConstantInt(C1)))) { +          if (AndRHSMask.isIntN(X->getType()->getScalarSizeInBits())) { +            auto *TruncC1 = ConstantExpr::getTrunc(C1, X->getType()); +            Value *BinOp; +            if (isa<ZExtInst>(Op0LHS)) +              BinOp = Builder->CreateBinOp(Op0I->getOpcode(), X, TruncC1); +            else +              BinOp = Builder->CreateBinOp(Op0I->getOpcode(), TruncC1, X); +            auto *TruncC2 = ConstantExpr::getTrunc(AndRHS, X->getType()); +            auto *And = Builder->CreateAnd(BinOp, TruncC2); +            return new ZExtInst(And, I.getType()); +          } +        } +      } +        if (ConstantInt *Op0CI = dyn_cast<ConstantInt>(Op0I->getOperand(1)))          if (Instruction *Res = OptAndOp(Op0I, Op0CI, AndRHS, I))            return Res; @@ -1381,10 +1356,11 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {          return BinaryOperator::CreateAnd(NewCast, C3);        }      } +  } +  if (isa<Constant>(Op1))      if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I))        return FoldedLogic; -  }    if (Instruction *DeMorgan = matchDeMorgansLaws(I, Builder))      return DeMorgan; @@ -1630,15 +1606,15 @@ static Value *matchSelectFromAndOr(Value *A, Value *C, Value *B, Value *D,  /// Fold (icmp)|(icmp) if possible.  Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,                                     Instruction *CxtI) { -  ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); +  ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();    // Fold (iszero(A & K1) | iszero(A & K2)) ->  (A & (K1 | K2)) != (K1 | K2)    // if K1 and K2 are a one-bit mask. -  ConstantInt *LHSCst = dyn_cast<ConstantInt>(LHS->getOperand(1)); -  ConstantInt *RHSCst = dyn_cast<ConstantInt>(RHS->getOperand(1)); +  ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1)); +  ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1)); -  if (LHS->getPredicate() == ICmpInst::ICMP_EQ && LHSCst && LHSCst->isZero() && -      RHS->getPredicate() == ICmpInst::ICMP_EQ && RHSCst && RHSCst->isZero()) { +  if (LHS->getPredicate() == ICmpInst::ICMP_EQ && LHSC && LHSC->isZero() && +      RHS->getPredicate() == ICmpInst::ICMP_EQ && RHSC && RHSC->isZero()) {      BinaryOperator *LAnd = dyn_cast<BinaryOperator>(LHS->getOperand(0));      BinaryOperator *RAnd = dyn_cast<BinaryOperator>(RHS->getOperand(0)); @@ -1680,52 +1656,52 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,    // 4) LowRange1 ^ LowRange2 and HighRange1 ^ HighRange2 are one-bit mask.    // This implies all values in the two ranges differ by exactly one bit. -  if ((LHSCC == ICmpInst::ICMP_ULT || LHSCC == ICmpInst::ICMP_ULE) && -      LHSCC == RHSCC && LHSCst && RHSCst && LHS->hasOneUse() && -      RHS->hasOneUse() && LHSCst->getType() == RHSCst->getType() && -      LHSCst->getValue() == (RHSCst->getValue())) { +  if ((PredL == ICmpInst::ICMP_ULT || PredL == ICmpInst::ICMP_ULE) && +      PredL == PredR && LHSC && RHSC && LHS->hasOneUse() && RHS->hasOneUse() && +      LHSC->getType() == RHSC->getType() && +      LHSC->getValue() == (RHSC->getValue())) {      Value *LAdd = LHS->getOperand(0);      Value *RAdd = RHS->getOperand(0);      Value *LAddOpnd, *RAddOpnd; -    ConstantInt *LAddCst, *RAddCst; -    if (match(LAdd, m_Add(m_Value(LAddOpnd), m_ConstantInt(LAddCst))) && -        match(RAdd, m_Add(m_Value(RAddOpnd), m_ConstantInt(RAddCst))) && -        LAddCst->getValue().ugt(LHSCst->getValue()) && -        RAddCst->getValue().ugt(LHSCst->getValue())) { - -      APInt DiffCst = LAddCst->getValue() ^ RAddCst->getValue(); -      if (LAddOpnd == RAddOpnd && DiffCst.isPowerOf2()) { -        ConstantInt *MaxAddCst = nullptr; -        if (LAddCst->getValue().ult(RAddCst->getValue())) -          MaxAddCst = RAddCst; +    ConstantInt *LAddC, *RAddC; +    if (match(LAdd, m_Add(m_Value(LAddOpnd), m_ConstantInt(LAddC))) && +        match(RAdd, m_Add(m_Value(RAddOpnd), m_ConstantInt(RAddC))) && +        LAddC->getValue().ugt(LHSC->getValue()) && +        RAddC->getValue().ugt(LHSC->getValue())) { + +      APInt DiffC = LAddC->getValue() ^ RAddC->getValue(); +      if (LAddOpnd == RAddOpnd && DiffC.isPowerOf2()) { +        ConstantInt *MaxAddC = nullptr; +        if (LAddC->getValue().ult(RAddC->getValue())) +          MaxAddC = RAddC;          else -          MaxAddCst = LAddCst; +          MaxAddC = LAddC; -        APInt RRangeLow = -RAddCst->getValue(); -        APInt RRangeHigh = RRangeLow + LHSCst->getValue(); -        APInt LRangeLow = -LAddCst->getValue(); -        APInt LRangeHigh = LRangeLow + LHSCst->getValue(); +        APInt RRangeLow = -RAddC->getValue(); +        APInt RRangeHigh = RRangeLow + LHSC->getValue(); +        APInt LRangeLow = -LAddC->getValue(); +        APInt LRangeHigh = LRangeLow + LHSC->getValue();          APInt LowRangeDiff = RRangeLow ^ LRangeLow;          APInt HighRangeDiff = RRangeHigh ^ LRangeHigh;          APInt RangeDiff = LRangeLow.sgt(RRangeLow) ? LRangeLow - RRangeLow                                                     : RRangeLow - LRangeLow;          if (LowRangeDiff.isPowerOf2() && LowRangeDiff == HighRangeDiff && -            RangeDiff.ugt(LHSCst->getValue())) { -          Value *MaskCst = ConstantInt::get(LAddCst->getType(), ~DiffCst); +            RangeDiff.ugt(LHSC->getValue())) { +          Value *MaskC = ConstantInt::get(LAddC->getType(), ~DiffC); -          Value *NewAnd = Builder->CreateAnd(LAddOpnd, MaskCst); -          Value *NewAdd = Builder->CreateAdd(NewAnd, MaxAddCst); -          return (Builder->CreateICmp(LHS->getPredicate(), NewAdd, LHSCst)); +          Value *NewAnd = Builder->CreateAnd(LAddOpnd, MaskC); +          Value *NewAdd = Builder->CreateAdd(NewAnd, MaxAddC); +          return (Builder->CreateICmp(LHS->getPredicate(), NewAdd, LHSC));          }        }      }    }    // (icmp1 A, B) | (icmp2 A, B) --> (icmp3 A, B) -  if (PredicatesFoldable(LHSCC, RHSCC)) { +  if (PredicatesFoldable(PredL, PredR)) {      if (LHS->getOperand(0) == RHS->getOperand(1) &&          LHS->getOperand(1) == RHS->getOperand(0))        LHS->swapOperands(); @@ -1743,25 +1719,25 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,    if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, false, Builder))      return V; -  Value *Val = LHS->getOperand(0), *Val2 = RHS->getOperand(0); +  Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0);    if (LHS->hasOneUse() || RHS->hasOneUse()) {      // (icmp eq B, 0) | (icmp ult A, B) -> (icmp ule A, B-1)      // (icmp eq B, 0) | (icmp ugt B, A) -> (icmp ule A, B-1)      Value *A = nullptr, *B = nullptr; -    if (LHSCC == ICmpInst::ICMP_EQ && LHSCst && LHSCst->isZero()) { -      B = Val; -      if (RHSCC == ICmpInst::ICMP_ULT && Val == RHS->getOperand(1)) -        A = Val2; -      else if (RHSCC == ICmpInst::ICMP_UGT && Val == Val2) +    if (PredL == ICmpInst::ICMP_EQ && LHSC && LHSC->isZero()) { +      B = LHS0; +      if (PredR == ICmpInst::ICMP_ULT && LHS0 == RHS->getOperand(1)) +        A = RHS0; +      else if (PredR == ICmpInst::ICMP_UGT && LHS0 == RHS0)          A = RHS->getOperand(1);      }      // (icmp ult A, B) | (icmp eq B, 0) -> (icmp ule A, B-1)      // (icmp ugt B, A) | (icmp eq B, 0) -> (icmp ule A, B-1) -    else if (RHSCC == ICmpInst::ICMP_EQ && RHSCst && RHSCst->isZero()) { -      B = Val2; -      if (LHSCC == ICmpInst::ICMP_ULT && Val2 == LHS->getOperand(1)) -        A = Val; -      else if (LHSCC == ICmpInst::ICMP_UGT && Val2 == Val) +    else if (PredR == ICmpInst::ICMP_EQ && RHSC && RHSC->isZero()) { +      B = RHS0; +      if (PredL == ICmpInst::ICMP_ULT && RHS0 == LHS->getOperand(1)) +        A = LHS0; +      else if (PredL == ICmpInst::ICMP_UGT && LHS0 == RHS0)          A = LHS->getOperand(1);      }      if (A && B) @@ -1778,54 +1754,58 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,    if (Value *V = simplifyRangeCheck(RHS, LHS, /*Inverted=*/true))      return V; +  if (Value *V = foldAndOrOfEqualityCmpsWithConstants(LHS, RHS, false, Builder)) +    return V; +    // This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2). -  if (!LHSCst || !RHSCst) return nullptr; +  if (!LHSC || !RHSC) +    return nullptr; -  if (LHSCst == RHSCst && LHSCC == RHSCC) { +  if (LHSC == RHSC && PredL == PredR) {      // (icmp ne A, 0) | (icmp ne B, 0) --> (icmp ne (A|B), 0) -    if (LHSCC == ICmpInst::ICMP_NE && LHSCst->isZero()) { -      Value *NewOr = Builder->CreateOr(Val, Val2); -      return Builder->CreateICmp(LHSCC, NewOr, LHSCst); +    if (PredL == ICmpInst::ICMP_NE && LHSC->isZero()) { +      Value *NewOr = Builder->CreateOr(LHS0, RHS0); +      return Builder->CreateICmp(PredL, NewOr, LHSC);      }    }    // (icmp ult (X + CA), C1) | (icmp eq X, C2) -> (icmp ule (X + CA), C1)    //   iff C2 + CA == C1. -  if (LHSCC == ICmpInst::ICMP_ULT && RHSCC == ICmpInst::ICMP_EQ) { -    ConstantInt *AddCst; -    if (match(Val, m_Add(m_Specific(Val2), m_ConstantInt(AddCst)))) -      if (RHSCst->getValue() + AddCst->getValue() == LHSCst->getValue()) -        return Builder->CreateICmpULE(Val, LHSCst); +  if (PredL == ICmpInst::ICMP_ULT && PredR == ICmpInst::ICMP_EQ) { +    ConstantInt *AddC; +    if (match(LHS0, m_Add(m_Specific(RHS0), m_ConstantInt(AddC)))) +      if (RHSC->getValue() + AddC->getValue() == LHSC->getValue()) +        return Builder->CreateICmpULE(LHS0, LHSC);    }    // From here on, we only handle:    //    (icmp1 A, C1) | (icmp2 A, C2) --> something simpler. -  if (Val != Val2) return nullptr; +  if (LHS0 != RHS0) +    return nullptr; -  // ICMP_[US][GL]E X, CST is folded to ICMP_[US][GL]T elsewhere. -  if (LHSCC == ICmpInst::ICMP_UGE || LHSCC == ICmpInst::ICMP_ULE || -      RHSCC == ICmpInst::ICMP_UGE || RHSCC == ICmpInst::ICMP_ULE || -      LHSCC == ICmpInst::ICMP_SGE || LHSCC == ICmpInst::ICMP_SLE || -      RHSCC == ICmpInst::ICMP_SGE || RHSCC == ICmpInst::ICMP_SLE) +  // ICMP_[US][GL]E X, C is folded to ICMP_[US][GL]T elsewhere. +  if (PredL == ICmpInst::ICMP_UGE || PredL == ICmpInst::ICMP_ULE || +      PredR == ICmpInst::ICMP_UGE || PredR == ICmpInst::ICMP_ULE || +      PredL == ICmpInst::ICMP_SGE || PredL == ICmpInst::ICMP_SLE || +      PredR == ICmpInst::ICMP_SGE || PredR == ICmpInst::ICMP_SLE)      return nullptr;    // We can't fold (ugt x, C) | (sgt x, C2). -  if (!PredicatesFoldable(LHSCC, RHSCC)) +  if (!PredicatesFoldable(PredL, PredR))      return nullptr;    // Ensure that the larger constant is on the RHS.    bool ShouldSwap; -  if (CmpInst::isSigned(LHSCC) || -      (ICmpInst::isEquality(LHSCC) && -       CmpInst::isSigned(RHSCC))) -    ShouldSwap = LHSCst->getValue().sgt(RHSCst->getValue()); +  if (CmpInst::isSigned(PredL) || +      (ICmpInst::isEquality(PredL) && CmpInst::isSigned(PredR))) +    ShouldSwap = LHSC->getValue().sgt(RHSC->getValue());    else -    ShouldSwap = LHSCst->getValue().ugt(RHSCst->getValue()); +    ShouldSwap = LHSC->getValue().ugt(RHSC->getValue());    if (ShouldSwap) {      std::swap(LHS, RHS); -    std::swap(LHSCst, RHSCst); -    std::swap(LHSCC, RHSCC); +    std::swap(LHSC, RHSC); +    std::swap(PredL, PredR);    }    // At this point, we know we have two icmp instructions @@ -1834,127 +1814,98 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,    // ICMP_EQ, ICMP_NE, ICMP_LT, and ICMP_GT here. We also know (from the    // icmp folding check above), that the two constants are not    // equal. -  assert(LHSCst != RHSCst && "Compares not folded above?"); +  assert(LHSC != RHSC && "Compares not folded above?"); -  switch (LHSCC) { -  default: llvm_unreachable("Unknown integer condition code!"); +  switch (PredL) { +  default: +    llvm_unreachable("Unknown integer condition code!");    case ICmpInst::ICMP_EQ: -    switch (RHSCC) { -    default: llvm_unreachable("Unknown integer condition code!"); +    switch (PredR) { +    default: +      llvm_unreachable("Unknown integer condition code!");      case ICmpInst::ICMP_EQ: -      if (LHS->getOperand(0) == RHS->getOperand(0)) { -        // if LHSCst and RHSCst differ only by one bit: -        // (A == C1 || A == C2) -> (A | (C1 ^ C2)) == C2 -        assert(LHSCst->getValue().ule(LHSCst->getValue())); - -        APInt Xor = LHSCst->getValue() ^ RHSCst->getValue(); -        if (Xor.isPowerOf2()) { -          Value *Cst = Builder->getInt(Xor); -          Value *Or = Builder->CreateOr(LHS->getOperand(0), Cst); -          return Builder->CreateICmp(ICmpInst::ICMP_EQ, Or, RHSCst); -        } -      } - -      if (LHSCst == SubOne(RHSCst)) { -        // (X == 13 | X == 14) -> X-13 <u 2 -        Constant *AddCST = ConstantExpr::getNeg(LHSCst); -        Value *Add = Builder->CreateAdd(Val, AddCST, Val->getName()+".off"); -        AddCST = ConstantExpr::getSub(AddOne(RHSCst), LHSCst); -        return Builder->CreateICmpULT(Add, AddCST); -      } - -      break;                         // (X == 13 | X == 15) -> no change -    case ICmpInst::ICMP_UGT:         // (X == 13 | X u> 14) -> no change -    case ICmpInst::ICMP_SGT:         // (X == 13 | X s> 14) -> no change +      // Potential folds for this case should already be handled. +      break; +    case ICmpInst::ICMP_UGT: // (X == 13 | X u> 14) -> no change +    case ICmpInst::ICMP_SGT: // (X == 13 | X s> 14) -> no change        break; -    case ICmpInst::ICMP_NE:          // (X == 13 | X != 15) -> X != 15 -    case ICmpInst::ICMP_ULT:         // (X == 13 | X u< 15) -> X u< 15 -    case ICmpInst::ICMP_SLT:         // (X == 13 | X s< 15) -> X s< 15 +    case ICmpInst::ICMP_NE:  // (X == 13 | X != 15) -> X != 15 +    case ICmpInst::ICMP_ULT: // (X == 13 | X u< 15) -> X u< 15 +    case ICmpInst::ICMP_SLT: // (X == 13 | X s< 15) -> X s< 15        return RHS;      }      break;    case ICmpInst::ICMP_NE: -    switch (RHSCC) { -    default: llvm_unreachable("Unknown integer condition code!"); -    case ICmpInst::ICMP_EQ:          // (X != 13 | X == 15) -> X != 13 -    case ICmpInst::ICMP_UGT:         // (X != 13 | X u> 15) -> X != 13 -    case ICmpInst::ICMP_SGT:         // (X != 13 | X s> 15) -> X != 13 +    switch (PredR) { +    default: +      llvm_unreachable("Unknown integer condition code!"); +    case ICmpInst::ICMP_EQ:  // (X != 13 | X == 15) -> X != 13 +    case ICmpInst::ICMP_UGT: // (X != 13 | X u> 15) -> X != 13 +    case ICmpInst::ICMP_SGT: // (X != 13 | X s> 15) -> X != 13        return LHS; -    case ICmpInst::ICMP_NE:          // (X != 13 | X != 15) -> true -    case ICmpInst::ICMP_ULT:         // (X != 13 | X u< 15) -> true -    case ICmpInst::ICMP_SLT:         // (X != 13 | X s< 15) -> true +    case ICmpInst::ICMP_NE:  // (X != 13 | X != 15) -> true +    case ICmpInst::ICMP_ULT: // (X != 13 | X u< 15) -> true +    case ICmpInst::ICMP_SLT: // (X != 13 | X s< 15) -> true        return Builder->getTrue();      }    case ICmpInst::ICMP_ULT: -    switch (RHSCC) { -    default: llvm_unreachable("Unknown integer condition code!"); -    case ICmpInst::ICMP_EQ:         // (X u< 13 | X == 14) -> no change +    switch (PredR) { +    default: +      llvm_unreachable("Unknown integer condition code!"); +    case ICmpInst::ICMP_EQ: // (X u< 13 | X == 14) -> no change        break; -    case ICmpInst::ICMP_UGT:        // (X u< 13 | X u> 15) -> (X-13) u> 2 -      // If RHSCst is [us]MAXINT, it is always false.  Not handling +    case ICmpInst::ICMP_UGT: // (X u< 13 | X u> 15) -> (X-13) u> 2 +      // If RHSC is [us]MAXINT, it is always false.  Not handling        // this can cause overflow. -      if (RHSCst->isMaxValue(false)) +      if (RHSC->isMaxValue(false))          return LHS; -      return insertRangeTest(Val, LHSCst->getValue(), RHSCst->getValue() + 1, +      return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue() + 1,                               false, false); -    case ICmpInst::ICMP_SGT:        // (X u< 13 | X s> 15) -> no change -      break; -    case ICmpInst::ICMP_NE:         // (X u< 13 | X != 15) -> X != 15 -    case ICmpInst::ICMP_ULT:        // (X u< 13 | X u< 15) -> X u< 15 +    case ICmpInst::ICMP_NE:  // (X u< 13 | X != 15) -> X != 15 +    case ICmpInst::ICMP_ULT: // (X u< 13 | X u< 15) -> X u< 15        return RHS; -    case ICmpInst::ICMP_SLT:        // (X u< 13 | X s< 15) -> no change -      break;      }      break;    case ICmpInst::ICMP_SLT: -    switch (RHSCC) { -    default: llvm_unreachable("Unknown integer condition code!"); -    case ICmpInst::ICMP_EQ:         // (X s< 13 | X == 14) -> no change +    switch (PredR) { +    default: +      llvm_unreachable("Unknown integer condition code!"); +    case ICmpInst::ICMP_EQ: // (X s< 13 | X == 14) -> no change        break; -    case ICmpInst::ICMP_SGT:        // (X s< 13 | X s> 15) -> (X-13) s> 2 -      // If RHSCst is [us]MAXINT, it is always false.  Not handling +    case ICmpInst::ICMP_SGT: // (X s< 13 | X s> 15) -> (X-13) s> 2 +      // If RHSC is [us]MAXINT, it is always false.  Not handling        // this can cause overflow. -      if (RHSCst->isMaxValue(true)) +      if (RHSC->isMaxValue(true))          return LHS; -      return insertRangeTest(Val, LHSCst->getValue(), RHSCst->getValue() + 1, -                             true, false); -    case ICmpInst::ICMP_UGT:        // (X s< 13 | X u> 15) -> no change -      break; -    case ICmpInst::ICMP_NE:         // (X s< 13 | X != 15) -> X != 15 -    case ICmpInst::ICMP_SLT:        // (X s< 13 | X s< 15) -> X s< 15 +      return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue() + 1, true, +                             false); +    case ICmpInst::ICMP_NE:  // (X s< 13 | X != 15) -> X != 15 +    case ICmpInst::ICMP_SLT: // (X s< 13 | X s< 15) -> X s< 15        return RHS; -    case ICmpInst::ICMP_ULT:        // (X s< 13 | X u< 15) -> no change -      break;      }      break;    case ICmpInst::ICMP_UGT: -    switch (RHSCC) { -    default: llvm_unreachable("Unknown integer condition code!"); -    case ICmpInst::ICMP_EQ:         // (X u> 13 | X == 15) -> X u> 13 -    case ICmpInst::ICMP_UGT:        // (X u> 13 | X u> 15) -> X u> 13 +    switch (PredR) { +    default: +      llvm_unreachable("Unknown integer condition code!"); +    case ICmpInst::ICMP_EQ:  // (X u> 13 | X == 15) -> X u> 13 +    case ICmpInst::ICMP_UGT: // (X u> 13 | X u> 15) -> X u> 13        return LHS; -    case ICmpInst::ICMP_SGT:        // (X u> 13 | X s> 15) -> no change -      break; -    case ICmpInst::ICMP_NE:         // (X u> 13 | X != 15) -> true -    case ICmpInst::ICMP_ULT:        // (X u> 13 | X u< 15) -> true +    case ICmpInst::ICMP_NE:  // (X u> 13 | X != 15) -> true +    case ICmpInst::ICMP_ULT: // (X u> 13 | X u< 15) -> true        return Builder->getTrue(); -    case ICmpInst::ICMP_SLT:        // (X u> 13 | X s< 15) -> no change -      break;      }      break;    case ICmpInst::ICMP_SGT: -    switch (RHSCC) { -    default: llvm_unreachable("Unknown integer condition code!"); -    case ICmpInst::ICMP_EQ:         // (X s> 13 | X == 15) -> X > 13 -    case ICmpInst::ICMP_SGT:        // (X s> 13 | X s> 15) -> X > 13 +    switch (PredR) { +    default: +      llvm_unreachable("Unknown integer condition code!"); +    case ICmpInst::ICMP_EQ:  // (X s> 13 | X == 15) -> X > 13 +    case ICmpInst::ICMP_SGT: // (X s> 13 | X s> 15) -> X > 13        return LHS; -    case ICmpInst::ICMP_UGT:        // (X s> 13 | X u> 15) -> no change -      break; -    case ICmpInst::ICMP_NE:         // (X s> 13 | X != 15) -> true -    case ICmpInst::ICMP_SLT:        // (X s> 13 | X s< 15) -> true +    case ICmpInst::ICMP_NE:  // (X s> 13 | X != 15) -> true +    case ICmpInst::ICMP_SLT: // (X s> 13 | X s< 15) -> true        return Builder->getTrue(); -    case ICmpInst::ICMP_ULT:        // (X s> 13 | X u< 15) -> no change -      break;      }      break;    } @@ -2100,17 +2051,6 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {    if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) {      ConstantInt *C1 = nullptr; Value *X = nullptr; -    // (X & C1) | C2 --> (X | C2) & (C1|C2) -    // iff (C1 & C2) == 0. -    if (match(Op0, m_And(m_Value(X), m_ConstantInt(C1))) && -        (RHS->getValue() & C1->getValue()) != 0 && -        Op0->hasOneUse()) { -      Value *Or = Builder->CreateOr(X, RHS); -      Or->takeName(Op0); -      return BinaryOperator::CreateAnd(Or, -                             Builder->getInt(RHS->getValue() | C1->getValue())); -    } -      // (X ^ C1) | C2 --> (X | C2) ^ (C1&~C2)      if (match(Op0, m_Xor(m_Value(X), m_ConstantInt(C1))) &&          Op0->hasOneUse()) { @@ -2119,45 +2059,51 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {        return BinaryOperator::CreateXor(Or,                              Builder->getInt(C1->getValue() & ~RHS->getValue()));      } +  } +  if (isa<Constant>(Op1))      if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I))        return FoldedLogic; -  }    // Given an OR instruction, check to see if this is a bswap.    if (Instruction *BSwap = MatchBSwap(I))      return BSwap; -  Value *A = nullptr, *B = nullptr; -  ConstantInt *C1 = nullptr, *C2 = nullptr; +  { +    Value *A; +    const APInt *C; +    // (X^C)|Y -> (X|Y)^C iff Y&C == 0 +    if (match(Op0, m_OneUse(m_Xor(m_Value(A), m_APInt(C)))) && +        MaskedValueIsZero(Op1, *C, 0, &I)) { +      Value *NOr = Builder->CreateOr(A, Op1); +      NOr->takeName(Op0); +      return BinaryOperator::CreateXor(NOr, +                                       cast<Instruction>(Op0)->getOperand(1)); +    } -  // (X^C)|Y -> (X|Y)^C iff Y&C == 0 -  if (Op0->hasOneUse() && -      match(Op0, m_Xor(m_Value(A), m_ConstantInt(C1))) && -      MaskedValueIsZero(Op1, C1->getValue(), 0, &I)) { -    Value *NOr = Builder->CreateOr(A, Op1); -    NOr->takeName(Op0); -    return BinaryOperator::CreateXor(NOr, C1); +    // Y|(X^C) -> (X|Y)^C iff Y&C == 0 +    if (match(Op1, m_OneUse(m_Xor(m_Value(A), m_APInt(C)))) && +        MaskedValueIsZero(Op0, *C, 0, &I)) { +      Value *NOr = Builder->CreateOr(A, Op0); +      NOr->takeName(Op0); +      return BinaryOperator::CreateXor(NOr, +                                       cast<Instruction>(Op1)->getOperand(1)); +    }    } -  // Y|(X^C) -> (X|Y)^C iff Y&C == 0 -  if (Op1->hasOneUse() && -      match(Op1, m_Xor(m_Value(A), m_ConstantInt(C1))) && -      MaskedValueIsZero(Op0, C1->getValue(), 0, &I)) { -    Value *NOr = Builder->CreateOr(A, Op0); -    NOr->takeName(Op0); -    return BinaryOperator::CreateXor(NOr, C1); -  } +  Value *A, *B;    // ((~A & B) | A) -> (A | B) -  if (match(Op0, m_And(m_Not(m_Value(A)), m_Value(B))) && -      match(Op1, m_Specific(A))) -    return BinaryOperator::CreateOr(A, B); +  if (match(Op0, m_c_And(m_Not(m_Specific(Op1)), m_Value(A)))) +    return BinaryOperator::CreateOr(A, Op1); +  if (match(Op1, m_c_And(m_Not(m_Specific(Op0)), m_Value(A)))) +    return BinaryOperator::CreateOr(Op0, A);    // ((A & B) | ~A) -> (~A | B) -  if (match(Op0, m_And(m_Value(A), m_Value(B))) && -      match(Op1, m_Not(m_Specific(A)))) -    return BinaryOperator::CreateOr(Builder->CreateNot(A), B); +  // The NOT is guaranteed to be in the RHS by complexity ordering. +  if (match(Op1, m_Not(m_Value(A))) && +      match(Op0, m_c_And(m_Specific(A), m_Value(B)))) +    return BinaryOperator::CreateOr(Op1, B);    // (A & ~B) | (A ^ B) -> (A ^ B)    // (~B & A) | (A ^ B) -> (A ^ B) @@ -2177,8 +2123,8 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {    if (match(Op0, m_And(m_Value(A), m_Value(C))) &&        match(Op1, m_And(m_Value(B), m_Value(D)))) {      Value *V1 = nullptr, *V2 = nullptr; -    C1 = dyn_cast<ConstantInt>(C); -    C2 = dyn_cast<ConstantInt>(D); +    ConstantInt *C1 = dyn_cast<ConstantInt>(C); +    ConstantInt *C2 = dyn_cast<ConstantInt>(D);      if (C1 && C2) {  // (A & C1)|(B & C2)        if ((C1->getValue() & C2->getValue()) == 0) {          // ((V | N) & C1) | (V & C2) --> (V|N) & (C1|C2) @@ -2403,6 +2349,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {    // be simplified by a later pass either, so we try swapping the inner/outer    // ORs in the hopes that we'll be able to simplify it this way.    // (X|C) | V --> (X|V) | C +  ConstantInt *C1;    if (Op0->hasOneUse() && !isa<ConstantInt>(Op1) &&        match(Op0, m_Or(m_Value(A), m_ConstantInt(C1)))) {      Value *Inner = Builder->CreateOr(A, Op1); @@ -2493,23 +2440,22 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {      }    } -  if (Constant *RHS = dyn_cast<Constant>(Op1)) { -    if (RHS->isAllOnesValue() && Op0->hasOneUse()) -      // xor (cmp A, B), true = not (cmp A, B) = !cmp A, B -      if (CmpInst *CI = dyn_cast<CmpInst>(Op0)) -        return CmpInst::Create(CI->getOpcode(), -                               CI->getInversePredicate(), -                               CI->getOperand(0), CI->getOperand(1)); +  // xor (cmp A, B), true = not (cmp A, B) = !cmp A, B +  ICmpInst::Predicate Pred; +  if (match(Op0, m_OneUse(m_Cmp(Pred, m_Value(), m_Value()))) && +      match(Op1, m_AllOnes())) { +    cast<CmpInst>(Op0)->setPredicate(CmpInst::getInversePredicate(Pred)); +    return replaceInstUsesWith(I, Op0);    } -  if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { +  if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) {      // fold (xor(zext(cmp)), 1) and (xor(sext(cmp)), -1) to ext(!cmp).      if (CastInst *Op0C = dyn_cast<CastInst>(Op0)) {        if (CmpInst *CI = dyn_cast<CmpInst>(Op0C->getOperand(0))) {          if (CI->hasOneUse() && Op0C->hasOneUse()) {            Instruction::CastOps Opcode = Op0C->getOpcode();            if ((Opcode == Instruction::ZExt || Opcode == Instruction::SExt) && -              (RHS == ConstantExpr::getCast(Opcode, Builder->getTrue(), +              (RHSC == ConstantExpr::getCast(Opcode, Builder->getTrue(),                                              Op0C->getDestTy()))) {              CI->setPredicate(CI->getInversePredicate());              return CastInst::Create(Opcode, CI, Op0C->getType()); @@ -2520,26 +2466,23 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {      if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) {        // ~(c-X) == X-c-1 == X+(-c-1) -      if (Op0I->getOpcode() == Instruction::Sub && RHS->isAllOnesValue()) +      if (Op0I->getOpcode() == Instruction::Sub && RHSC->isAllOnesValue())          if (Constant *Op0I0C = dyn_cast<Constant>(Op0I->getOperand(0))) {            Constant *NegOp0I0C = ConstantExpr::getNeg(Op0I0C); -          Constant *ConstantRHS = ConstantExpr::getSub(NegOp0I0C, -                                      ConstantInt::get(I.getType(), 1)); -          return BinaryOperator::CreateAdd(Op0I->getOperand(1), ConstantRHS); +          return BinaryOperator::CreateAdd(Op0I->getOperand(1), +                                           SubOne(NegOp0I0C));          }        if (ConstantInt *Op0CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) {          if (Op0I->getOpcode() == Instruction::Add) {            // ~(X-c) --> (-c-1)-X -          if (RHS->isAllOnesValue()) { +          if (RHSC->isAllOnesValue()) {              Constant *NegOp0CI = ConstantExpr::getNeg(Op0CI); -            return BinaryOperator::CreateSub( -                           ConstantExpr::getSub(NegOp0CI, -                                      ConstantInt::get(I.getType(), 1)), -                                      Op0I->getOperand(0)); -          } else if (RHS->getValue().isSignBit()) { +            return BinaryOperator::CreateSub(SubOne(NegOp0CI), +                                             Op0I->getOperand(0)); +          } else if (RHSC->getValue().isSignBit()) {              // (X + C) ^ signbit -> (X + C + signbit) -            Constant *C = Builder->getInt(RHS->getValue() + Op0CI->getValue()); +            Constant *C = Builder->getInt(RHSC->getValue() + Op0CI->getValue());              return BinaryOperator::CreateAdd(Op0I->getOperand(0), C);            } @@ -2547,10 +2490,10 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {            // (X|C1)^C2 -> X^(C1|C2) iff X&~C1 == 0            if (MaskedValueIsZero(Op0I->getOperand(0), Op0CI->getValue(),                                  0, &I)) { -            Constant *NewRHS = ConstantExpr::getOr(Op0CI, RHS); +            Constant *NewRHS = ConstantExpr::getOr(Op0CI, RHSC);              // Anything in both C1 and C2 is known to be zero, remove it from              // NewRHS. -            Constant *CommonBits = ConstantExpr::getAnd(Op0CI, RHS); +            Constant *CommonBits = ConstantExpr::getAnd(Op0CI, RHSC);              NewRHS = ConstantExpr::getAnd(NewRHS,                                         ConstantExpr::getNot(CommonBits));              Worklist.Add(Op0I); @@ -2568,7 +2511,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {                E1->getOpcode() == Instruction::Xor &&                (C1 = dyn_cast<ConstantInt>(E1->getOperand(1)))) {              // fold (C1 >> C2) ^ C3 -            ConstantInt *C2 = Op0CI, *C3 = RHS; +            ConstantInt *C2 = Op0CI, *C3 = RHSC;              APInt FoldConst = C1->getValue().lshr(C2->getValue());              FoldConst ^= C3->getValue();              // Prepare the two operands. @@ -2582,27 +2525,26 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {          }        }      } +  } +  if (isa<Constant>(Op1))      if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I))        return FoldedLogic; -  } -  BinaryOperator *Op1I = dyn_cast<BinaryOperator>(Op1); -  if (Op1I) { +  {      Value *A, *B; -    if (match(Op1I, m_Or(m_Value(A), m_Value(B)))) { -      if (A == Op0) {              // B^(B|A) == (A|B)^B -        Op1I->swapOperands(); -        I.swapOperands(); -        std::swap(Op0, Op1); -      } else if (B == Op0) {       // B^(A|B) == (A|B)^B +    if (match(Op1, m_OneUse(m_Or(m_Value(A), m_Value(B))))) { +      if (A == Op0) {                                      // A^(A|B) == A^(B|A) +        cast<BinaryOperator>(Op1)->swapOperands(); +        std::swap(A, B); +      } +      if (B == Op0) {                                      // A^(B|A) == (B|A)^A          I.swapOperands();     // Simplified below.          std::swap(Op0, Op1);        } -    } else if (match(Op1I, m_And(m_Value(A), m_Value(B))) && -               Op1I->hasOneUse()){ +    } else if (match(Op1, m_OneUse(m_And(m_Value(A), m_Value(B))))) {        if (A == Op0) {                                      // A^(A&B) -> A^(B&A) -        Op1I->swapOperands(); +        cast<BinaryOperator>(Op1)->swapOperands();          std::swap(A, B);        }        if (B == Op0) {                                      // A^(B&A) -> (B&A)^A @@ -2612,65 +2554,63 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {      }    } -  BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0); -  if (Op0I) { +  {      Value *A, *B; -    if (match(Op0I, m_Or(m_Value(A), m_Value(B))) && -        Op0I->hasOneUse()) { +    if (match(Op0, m_OneUse(m_Or(m_Value(A), m_Value(B))))) {        if (A == Op1)                                  // (B|A)^B == (A|B)^B          std::swap(A, B);        if (B == Op1)                                  // (A|B)^B == A & ~B          return BinaryOperator::CreateAnd(A, Builder->CreateNot(Op1)); -    } else if (match(Op0I, m_And(m_Value(A), m_Value(B))) && -               Op0I->hasOneUse()){ +    } else if (match(Op0, m_OneUse(m_And(m_Value(A), m_Value(B))))) {        if (A == Op1)                                        // (A&B)^A -> (B&A)^A          std::swap(A, B); +      const APInt *C;        if (B == Op1 &&                                      // (B&A)^A == ~B & A -          !isa<ConstantInt>(Op1)) {  // Canonical form is (B&C)^C +          !match(Op1, m_APInt(C))) {  // Canonical form is (B&C)^C          return BinaryOperator::CreateAnd(Builder->CreateNot(A), Op1);        }      }    } -  if (Op0I && Op1I) { +  {      Value *A, *B, *C, *D;      // (A & B)^(A | B) -> A ^ B -    if (match(Op0I, m_And(m_Value(A), m_Value(B))) && -        match(Op1I, m_Or(m_Value(C), m_Value(D)))) { +    if (match(Op0, m_And(m_Value(A), m_Value(B))) && +        match(Op1, m_Or(m_Value(C), m_Value(D)))) {        if ((A == C && B == D) || (A == D && B == C))          return BinaryOperator::CreateXor(A, B);      }      // (A | B)^(A & B) -> A ^ B -    if (match(Op0I, m_Or(m_Value(A), m_Value(B))) && -        match(Op1I, m_And(m_Value(C), m_Value(D)))) { +    if (match(Op0, m_Or(m_Value(A), m_Value(B))) && +        match(Op1, m_And(m_Value(C), m_Value(D)))) {        if ((A == C && B == D) || (A == D && B == C))          return BinaryOperator::CreateXor(A, B);      }      // (A | ~B) ^ (~A | B) -> A ^ B      // (~B | A) ^ (~A | B) -> A ^ B -    if (match(Op0I, m_c_Or(m_Value(A), m_Not(m_Value(B)))) && -        match(Op1I, m_Or(m_Not(m_Specific(A)), m_Specific(B)))) +    if (match(Op0, m_c_Or(m_Value(A), m_Not(m_Value(B)))) && +        match(Op1, m_Or(m_Not(m_Specific(A)), m_Specific(B))))        return BinaryOperator::CreateXor(A, B);      // (~A | B) ^ (A | ~B) -> A ^ B -    if (match(Op0I, m_Or(m_Not(m_Value(A)), m_Value(B))) && -        match(Op1I, m_Or(m_Specific(A), m_Not(m_Specific(B))))) { +    if (match(Op0, m_Or(m_Not(m_Value(A)), m_Value(B))) && +        match(Op1, m_Or(m_Specific(A), m_Not(m_Specific(B))))) {        return BinaryOperator::CreateXor(A, B);      }      // (A & ~B) ^ (~A & B) -> A ^ B      // (~B & A) ^ (~A & B) -> A ^ B -    if (match(Op0I, m_c_And(m_Value(A), m_Not(m_Value(B)))) && -        match(Op1I, m_And(m_Not(m_Specific(A)), m_Specific(B)))) +    if (match(Op0, m_c_And(m_Value(A), m_Not(m_Value(B)))) && +        match(Op1, m_And(m_Not(m_Specific(A)), m_Specific(B))))        return BinaryOperator::CreateXor(A, B);      // (~A & B) ^ (A & ~B) -> A ^ B -    if (match(Op0I, m_And(m_Not(m_Value(A)), m_Value(B))) && -        match(Op1I, m_And(m_Specific(A), m_Not(m_Specific(B))))) { +    if (match(Op0, m_And(m_Not(m_Value(A)), m_Value(B))) && +        match(Op1, m_And(m_Specific(A), m_Not(m_Specific(B))))) {        return BinaryOperator::CreateXor(A, B);      }      // (A ^ C)^(A | B) -> ((~A) & B) ^ C -    if (match(Op0I, m_Xor(m_Value(D), m_Value(C))) && -        match(Op1I, m_Or(m_Value(A), m_Value(B)))) { +    if (match(Op0, m_Xor(m_Value(D), m_Value(C))) && +        match(Op1, m_Or(m_Value(A), m_Value(B)))) {        if (D == A)          return BinaryOperator::CreateXor(              Builder->CreateAnd(Builder->CreateNot(A), B), C); @@ -2679,8 +2619,8 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {              Builder->CreateAnd(Builder->CreateNot(B), A), C);      }      // (A | B)^(A ^ C) -> ((~A) & B) ^ C -    if (match(Op0I, m_Or(m_Value(A), m_Value(B))) && -        match(Op1I, m_Xor(m_Value(D), m_Value(C)))) { +    if (match(Op0, m_Or(m_Value(A), m_Value(B))) && +        match(Op1, m_Xor(m_Value(D), m_Value(C)))) {        if (D == A)          return BinaryOperator::CreateXor(              Builder->CreateAnd(Builder->CreateNot(A), B), C); @@ -2689,12 +2629,12 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {              Builder->CreateAnd(Builder->CreateNot(B), A), C);      }      // (A & B) ^ (A ^ B) -> (A | B) -    if (match(Op0I, m_And(m_Value(A), m_Value(B))) && -        match(Op1I, m_Xor(m_Specific(A), m_Specific(B)))) +    if (match(Op0, m_And(m_Value(A), m_Value(B))) && +        match(Op1, m_c_Xor(m_Specific(A), m_Specific(B))))        return BinaryOperator::CreateOr(A, B);      // (A ^ B) ^ (A & B) -> (A | B) -    if (match(Op0I, m_Xor(m_Value(A), m_Value(B))) && -        match(Op1I, m_And(m_Specific(A), m_Specific(B)))) +    if (match(Op0, m_Xor(m_Value(A), m_Value(B))) && +        match(Op1, m_c_And(m_Specific(A), m_Specific(B))))        return BinaryOperator::CreateOr(A, B);    } diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp index 2ef82ba3ed8c..69484f47223f 100644 --- a/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -60,6 +60,12 @@ using namespace PatternMatch;  STATISTIC(NumSimplified, "Number of library calls simplified"); +static cl::opt<unsigned> UnfoldElementAtomicMemcpyMaxElements( +    "unfold-element-atomic-memcpy-max-elements", +    cl::init(16), +    cl::desc("Maximum number of elements in atomic memcpy the optimizer is " +             "allowed to unfold")); +  /// Return the specified type promoted as it would be to pass though a va_arg  /// area.  static Type *getPromotedType(Type *Ty) { @@ -70,27 +76,6 @@ static Type *getPromotedType(Type *Ty) {    return Ty;  } -/// Given an aggregate type which ultimately holds a single scalar element, -/// like {{{type}}} or [1 x type], return type. -static Type *reduceToSingleValueType(Type *T) { -  while (!T->isSingleValueType()) { -    if (StructType *STy = dyn_cast<StructType>(T)) { -      if (STy->getNumElements() == 1) -        T = STy->getElementType(0); -      else -        break; -    } else if (ArrayType *ATy = dyn_cast<ArrayType>(T)) { -      if (ATy->getNumElements() == 1) -        T = ATy->getElementType(); -      else -        break; -    } else -      break; -  } - -  return T; -} -  /// Return a constant boolean vector that has true elements in all positions  /// where the input constant data vector has an element with the sign bit set.  static Constant *getNegativeIsTrueBoolVec(ConstantDataVector *V) { @@ -108,6 +93,78 @@ static Constant *getNegativeIsTrueBoolVec(ConstantDataVector *V) {    return ConstantVector::get(BoolVec);  } +Instruction * +InstCombiner::SimplifyElementAtomicMemCpy(ElementAtomicMemCpyInst *AMI) { +  // Try to unfold this intrinsic into sequence of explicit atomic loads and +  // stores. +  // First check that number of elements is compile time constant. +  auto *NumElementsCI = dyn_cast<ConstantInt>(AMI->getNumElements()); +  if (!NumElementsCI) +    return nullptr; + +  // Check that there are not too many elements. +  uint64_t NumElements = NumElementsCI->getZExtValue(); +  if (NumElements >= UnfoldElementAtomicMemcpyMaxElements) +    return nullptr; + +  // Don't unfold into illegal integers +  uint64_t ElementSizeInBytes = AMI->getElementSizeInBytes() * 8; +  if (!getDataLayout().isLegalInteger(ElementSizeInBytes)) +    return nullptr; + +  // Cast source and destination to the correct type. Intrinsic input arguments +  // are usually represented as i8*. +  // Often operands will be explicitly casted to i8* and we can just strip +  // those casts instead of inserting new ones. However it's easier to rely on +  // other InstCombine rules which will cover trivial cases anyway. +  Value *Src = AMI->getRawSource(); +  Value *Dst = AMI->getRawDest(); +  Type *ElementPointerType = Type::getIntNPtrTy( +      AMI->getContext(), ElementSizeInBytes, Src->getType()->getPointerAddressSpace()); + +  Value *SrcCasted = Builder->CreatePointerCast(Src, ElementPointerType, +                                                "memcpy_unfold.src_casted"); +  Value *DstCasted = Builder->CreatePointerCast(Dst, ElementPointerType, +                                                "memcpy_unfold.dst_casted"); + +  for (uint64_t i = 0; i < NumElements; ++i) { +    // Get current element addresses +    ConstantInt *ElementIdxCI = +        ConstantInt::get(AMI->getContext(), APInt(64, i)); +    Value *SrcElementAddr = +        Builder->CreateGEP(SrcCasted, ElementIdxCI, "memcpy_unfold.src_addr"); +    Value *DstElementAddr = +        Builder->CreateGEP(DstCasted, ElementIdxCI, "memcpy_unfold.dst_addr"); + +    // Load from the source. Transfer alignment information and mark load as +    // unordered atomic. +    LoadInst *Load = Builder->CreateLoad(SrcElementAddr, "memcpy_unfold.val"); +    Load->setOrdering(AtomicOrdering::Unordered); +    // We know alignment of the first element. It is also guaranteed by the +    // verifier that element size is less or equal than first element alignment +    // and both of this values are powers of two. +    // This means that all subsequent accesses are at least element size +    // aligned. +    // TODO: We can infer better alignment but there is no evidence that this +    // will matter. +    Load->setAlignment(i == 0 ? AMI->getSrcAlignment() +                              : AMI->getElementSizeInBytes()); +    Load->setDebugLoc(AMI->getDebugLoc()); + +    // Store loaded value via unordered atomic store. +    StoreInst *Store = Builder->CreateStore(Load, DstElementAddr); +    Store->setOrdering(AtomicOrdering::Unordered); +    Store->setAlignment(i == 0 ? AMI->getDstAlignment() +                               : AMI->getElementSizeInBytes()); +    Store->setDebugLoc(AMI->getDebugLoc()); +  } + +  // Set the number of elements of the copy to 0, it will be deleted on the +  // next iteration. +  AMI->setNumElements(Constant::getNullValue(NumElementsCI->getType())); +  return AMI; +} +  Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) {    unsigned DstAlign = getKnownAlignment(MI->getArgOperand(0), DL, MI, &AC, &DT);    unsigned SrcAlign = getKnownAlignment(MI->getArgOperand(1), DL, MI, &AC, &DT); @@ -144,41 +201,19 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) {    Type *NewSrcPtrTy = PointerType::get(IntType, SrcAddrSp);    Type *NewDstPtrTy = PointerType::get(IntType, DstAddrSp); -  // Memcpy forces the use of i8* for the source and destination.  That means -  // that if you're using memcpy to move one double around, you'll get a cast -  // from double* to i8*.  We'd much rather use a double load+store rather than -  // an i64 load+store, here because this improves the odds that the source or -  // dest address will be promotable.  See if we can find a better type than the -  // integer datatype. -  Value *StrippedDest = MI->getArgOperand(0)->stripPointerCasts(); +  // If the memcpy has metadata describing the members, see if we can get the +  // TBAA tag describing our copy.    MDNode *CopyMD = nullptr; -  if (StrippedDest != MI->getArgOperand(0)) { -    Type *SrcETy = cast<PointerType>(StrippedDest->getType()) -                                    ->getElementType(); -    if (SrcETy->isSized() && DL.getTypeStoreSize(SrcETy) == Size) { -      // The SrcETy might be something like {{{double}}} or [1 x double].  Rip -      // down through these levels if so. -      SrcETy = reduceToSingleValueType(SrcETy); - -      if (SrcETy->isSingleValueType()) { -        NewSrcPtrTy = PointerType::get(SrcETy, SrcAddrSp); -        NewDstPtrTy = PointerType::get(SrcETy, DstAddrSp); - -        // If the memcpy has metadata describing the members, see if we can -        // get the TBAA tag describing our copy. -        if (MDNode *M = MI->getMetadata(LLVMContext::MD_tbaa_struct)) { -          if (M->getNumOperands() == 3 && M->getOperand(0) && -              mdconst::hasa<ConstantInt>(M->getOperand(0)) && -              mdconst::extract<ConstantInt>(M->getOperand(0))->isNullValue() && -              M->getOperand(1) && -              mdconst::hasa<ConstantInt>(M->getOperand(1)) && -              mdconst::extract<ConstantInt>(M->getOperand(1))->getValue() == -                  Size && -              M->getOperand(2) && isa<MDNode>(M->getOperand(2))) -            CopyMD = cast<MDNode>(M->getOperand(2)); -        } -      } -    } +  if (MDNode *M = MI->getMetadata(LLVMContext::MD_tbaa_struct)) { +    if (M->getNumOperands() == 3 && M->getOperand(0) && +        mdconst::hasa<ConstantInt>(M->getOperand(0)) && +        mdconst::extract<ConstantInt>(M->getOperand(0))->isNullValue() && +        M->getOperand(1) && +        mdconst::hasa<ConstantInt>(M->getOperand(1)) && +        mdconst::extract<ConstantInt>(M->getOperand(1))->getValue() == +        Size && +        M->getOperand(2) && isa<MDNode>(M->getOperand(2))) +      CopyMD = cast<MDNode>(M->getOperand(2));    }    // If the memcpy/memmove provides better alignment info than we can @@ -510,6 +545,131 @@ static Value *simplifyX86varShift(const IntrinsicInst &II,    return Builder.CreateAShr(Vec, ShiftVec);  } +static Value *simplifyX86muldq(const IntrinsicInst &II, +                               InstCombiner::BuilderTy &Builder) { +  Value *Arg0 = II.getArgOperand(0); +  Value *Arg1 = II.getArgOperand(1); +  Type *ResTy = II.getType(); +  assert(Arg0->getType()->getScalarSizeInBits() == 32 && +         Arg1->getType()->getScalarSizeInBits() == 32 && +         ResTy->getScalarSizeInBits() == 64 && "Unexpected muldq/muludq types"); + +  // muldq/muludq(undef, undef) -> zero (matches generic mul behavior) +  if (isa<UndefValue>(Arg0) || isa<UndefValue>(Arg1)) +    return ConstantAggregateZero::get(ResTy); + +  // Constant folding. +  // PMULDQ  = (mul(vXi64 sext(shuffle<0,2,..>(Arg0)), +  //                vXi64 sext(shuffle<0,2,..>(Arg1)))) +  // PMULUDQ = (mul(vXi64 zext(shuffle<0,2,..>(Arg0)), +  //                vXi64 zext(shuffle<0,2,..>(Arg1)))) +  if (!isa<Constant>(Arg0) || !isa<Constant>(Arg1)) +    return nullptr; + +  unsigned NumElts = ResTy->getVectorNumElements(); +  assert(Arg0->getType()->getVectorNumElements() == (2 * NumElts) && +         Arg1->getType()->getVectorNumElements() == (2 * NumElts) && +         "Unexpected muldq/muludq types"); + +  unsigned IntrinsicID = II.getIntrinsicID(); +  bool IsSigned = (Intrinsic::x86_sse41_pmuldq == IntrinsicID || +                   Intrinsic::x86_avx2_pmul_dq == IntrinsicID || +                   Intrinsic::x86_avx512_pmul_dq_512 == IntrinsicID); + +  SmallVector<unsigned, 16> ShuffleMask; +  for (unsigned i = 0; i != NumElts; ++i) +    ShuffleMask.push_back(i * 2); + +  auto *LHS = Builder.CreateShuffleVector(Arg0, Arg0, ShuffleMask); +  auto *RHS = Builder.CreateShuffleVector(Arg1, Arg1, ShuffleMask); + +  if (IsSigned) { +    LHS = Builder.CreateSExt(LHS, ResTy); +    RHS = Builder.CreateSExt(RHS, ResTy); +  } else { +    LHS = Builder.CreateZExt(LHS, ResTy); +    RHS = Builder.CreateZExt(RHS, ResTy); +  } + +  return Builder.CreateMul(LHS, RHS); +} + +static Value *simplifyX86pack(IntrinsicInst &II, InstCombiner &IC, +                              InstCombiner::BuilderTy &Builder, bool IsSigned) { +  Value *Arg0 = II.getArgOperand(0); +  Value *Arg1 = II.getArgOperand(1); +  Type *ResTy = II.getType(); + +  // Fast all undef handling. +  if (isa<UndefValue>(Arg0) && isa<UndefValue>(Arg1)) +    return UndefValue::get(ResTy); + +  Type *ArgTy = Arg0->getType(); +  unsigned NumLanes = ResTy->getPrimitiveSizeInBits() / 128; +  unsigned NumDstElts = ResTy->getVectorNumElements(); +  unsigned NumSrcElts = ArgTy->getVectorNumElements(); +  assert(NumDstElts == (2 * NumSrcElts) && "Unexpected packing types"); + +  unsigned NumDstEltsPerLane = NumDstElts / NumLanes; +  unsigned NumSrcEltsPerLane = NumSrcElts / NumLanes; +  unsigned DstScalarSizeInBits = ResTy->getScalarSizeInBits(); +  assert(ArgTy->getScalarSizeInBits() == (2 * DstScalarSizeInBits) && +         "Unexpected packing types"); + +  // Constant folding. +  auto *Cst0 = dyn_cast<Constant>(Arg0); +  auto *Cst1 = dyn_cast<Constant>(Arg1); +  if (!Cst0 || !Cst1) +    return nullptr; + +  SmallVector<Constant *, 32> Vals; +  for (unsigned Lane = 0; Lane != NumLanes; ++Lane) { +    for (unsigned Elt = 0; Elt != NumDstEltsPerLane; ++Elt) { +      unsigned SrcIdx = Lane * NumSrcEltsPerLane + Elt % NumSrcEltsPerLane; +      auto *Cst = (Elt >= NumSrcEltsPerLane) ? Cst1 : Cst0; +      auto *COp = Cst->getAggregateElement(SrcIdx); +      if (COp && isa<UndefValue>(COp)) { +        Vals.push_back(UndefValue::get(ResTy->getScalarType())); +        continue; +      } + +      auto *CInt = dyn_cast_or_null<ConstantInt>(COp); +      if (!CInt) +        return nullptr; + +      APInt Val = CInt->getValue(); +      assert(Val.getBitWidth() == ArgTy->getScalarSizeInBits() && +             "Unexpected constant bitwidth"); + +      if (IsSigned) { +        // PACKSS: Truncate signed value with signed saturation. +        // Source values less than dst minint are saturated to minint. +        // Source values greater than dst maxint are saturated to maxint. +        if (Val.isSignedIntN(DstScalarSizeInBits)) +          Val = Val.trunc(DstScalarSizeInBits); +        else if (Val.isNegative()) +          Val = APInt::getSignedMinValue(DstScalarSizeInBits); +        else +          Val = APInt::getSignedMaxValue(DstScalarSizeInBits); +      } else { +        // PACKUS: Truncate signed value with unsigned saturation. +        // Source values less than zero are saturated to zero. +        // Source values greater than dst maxuint are saturated to maxuint. +        if (Val.isIntN(DstScalarSizeInBits)) +          Val = Val.trunc(DstScalarSizeInBits); +        else if (Val.isNegative()) +          Val = APInt::getNullValue(DstScalarSizeInBits); +        else +          Val = APInt::getAllOnesValue(DstScalarSizeInBits); +      } + +      Vals.push_back(ConstantInt::get(ResTy->getScalarType(), Val)); +    } +  } + +  return ConstantVector::get(Vals); +} +  static Value *simplifyX86movmsk(const IntrinsicInst &II,                                  InstCombiner::BuilderTy &Builder) {    Value *Arg = II.getArgOperand(0); @@ -1330,6 +1490,27 @@ static bool simplifyX86MaskedStore(IntrinsicInst &II, InstCombiner &IC) {    return true;  } +// Constant fold llvm.amdgcn.fmed3 intrinsics for standard inputs. +// +// A single NaN input is folded to minnum, so we rely on that folding for +// handling NaNs. +static APFloat fmed3AMDGCN(const APFloat &Src0, const APFloat &Src1, +                           const APFloat &Src2) { +  APFloat Max3 = maxnum(maxnum(Src0, Src1), Src2); + +  APFloat::cmpResult Cmp0 = Max3.compare(Src0); +  assert(Cmp0 != APFloat::cmpUnordered && "nans handled separately"); +  if (Cmp0 == APFloat::cmpEqual) +    return maxnum(Src1, Src2); + +  APFloat::cmpResult Cmp1 = Max3.compare(Src1); +  assert(Cmp1 != APFloat::cmpUnordered && "nans handled separately"); +  if (Cmp1 == APFloat::cmpEqual) +    return maxnum(Src0, Src2); + +  return maxnum(Src0, Src1); +} +  // Returns true iff the 2 intrinsics have the same operands, limiting the  // comparison to the first NumOperands.  static bool haveSameOperands(const IntrinsicInst &I, const IntrinsicInst &E, @@ -1373,6 +1554,254 @@ static bool removeTriviallyEmptyRange(IntrinsicInst &I, unsigned StartID,    return false;  } +// Convert NVVM intrinsics to target-generic LLVM code where possible. +static Instruction *SimplifyNVVMIntrinsic(IntrinsicInst *II, InstCombiner &IC) { +  // Each NVVM intrinsic we can simplify can be replaced with one of: +  // +  //  * an LLVM intrinsic, +  //  * an LLVM cast operation, +  //  * an LLVM binary operation, or +  //  * ad-hoc LLVM IR for the particular operation. + +  // Some transformations are only valid when the module's +  // flush-denormals-to-zero (ftz) setting is true/false, whereas other +  // transformations are valid regardless of the module's ftz setting. +  enum FtzRequirementTy { +    FTZ_Any,       // Any ftz setting is ok. +    FTZ_MustBeOn,  // Transformation is valid only if ftz is on. +    FTZ_MustBeOff, // Transformation is valid only if ftz is off. +  }; +  // Classes of NVVM intrinsics that can't be replaced one-to-one with a +  // target-generic intrinsic, cast op, or binary op but that we can nonetheless +  // simplify. +  enum SpecialCase { +    SPC_Reciprocal, +  }; + +  // SimplifyAction is a poor-man's variant (plus an additional flag) that +  // represents how to replace an NVVM intrinsic with target-generic LLVM IR. +  struct SimplifyAction { +    // Invariant: At most one of these Optionals has a value. +    Optional<Intrinsic::ID> IID; +    Optional<Instruction::CastOps> CastOp; +    Optional<Instruction::BinaryOps> BinaryOp; +    Optional<SpecialCase> Special; + +    FtzRequirementTy FtzRequirement = FTZ_Any; + +    SimplifyAction() = default; + +    SimplifyAction(Intrinsic::ID IID, FtzRequirementTy FtzReq) +        : IID(IID), FtzRequirement(FtzReq) {} + +    // Cast operations don't have anything to do with FTZ, so we skip that +    // argument. +    SimplifyAction(Instruction::CastOps CastOp) : CastOp(CastOp) {} + +    SimplifyAction(Instruction::BinaryOps BinaryOp, FtzRequirementTy FtzReq) +        : BinaryOp(BinaryOp), FtzRequirement(FtzReq) {} + +    SimplifyAction(SpecialCase Special, FtzRequirementTy FtzReq) +        : Special(Special), FtzRequirement(FtzReq) {} +  }; + +  // Try to generate a SimplifyAction describing how to replace our +  // IntrinsicInstr with target-generic LLVM IR. +  const SimplifyAction Action = [II]() -> SimplifyAction { +    switch (II->getIntrinsicID()) { + +    // NVVM intrinsics that map directly to LLVM intrinsics. +    case Intrinsic::nvvm_ceil_d: +      return {Intrinsic::ceil, FTZ_Any}; +    case Intrinsic::nvvm_ceil_f: +      return {Intrinsic::ceil, FTZ_MustBeOff}; +    case Intrinsic::nvvm_ceil_ftz_f: +      return {Intrinsic::ceil, FTZ_MustBeOn}; +    case Intrinsic::nvvm_fabs_d: +      return {Intrinsic::fabs, FTZ_Any}; +    case Intrinsic::nvvm_fabs_f: +      return {Intrinsic::fabs, FTZ_MustBeOff}; +    case Intrinsic::nvvm_fabs_ftz_f: +      return {Intrinsic::fabs, FTZ_MustBeOn}; +    case Intrinsic::nvvm_floor_d: +      return {Intrinsic::floor, FTZ_Any}; +    case Intrinsic::nvvm_floor_f: +      return {Intrinsic::floor, FTZ_MustBeOff}; +    case Intrinsic::nvvm_floor_ftz_f: +      return {Intrinsic::floor, FTZ_MustBeOn}; +    case Intrinsic::nvvm_fma_rn_d: +      return {Intrinsic::fma, FTZ_Any}; +    case Intrinsic::nvvm_fma_rn_f: +      return {Intrinsic::fma, FTZ_MustBeOff}; +    case Intrinsic::nvvm_fma_rn_ftz_f: +      return {Intrinsic::fma, FTZ_MustBeOn}; +    case Intrinsic::nvvm_fmax_d: +      return {Intrinsic::maxnum, FTZ_Any}; +    case Intrinsic::nvvm_fmax_f: +      return {Intrinsic::maxnum, FTZ_MustBeOff}; +    case Intrinsic::nvvm_fmax_ftz_f: +      return {Intrinsic::maxnum, FTZ_MustBeOn}; +    case Intrinsic::nvvm_fmin_d: +      return {Intrinsic::minnum, FTZ_Any}; +    case Intrinsic::nvvm_fmin_f: +      return {Intrinsic::minnum, FTZ_MustBeOff}; +    case Intrinsic::nvvm_fmin_ftz_f: +      return {Intrinsic::minnum, FTZ_MustBeOn}; +    case Intrinsic::nvvm_round_d: +      return {Intrinsic::round, FTZ_Any}; +    case Intrinsic::nvvm_round_f: +      return {Intrinsic::round, FTZ_MustBeOff}; +    case Intrinsic::nvvm_round_ftz_f: +      return {Intrinsic::round, FTZ_MustBeOn}; +    case Intrinsic::nvvm_sqrt_rn_d: +      return {Intrinsic::sqrt, FTZ_Any}; +    case Intrinsic::nvvm_sqrt_f: +      // nvvm_sqrt_f is a special case.  For  most intrinsics, foo_ftz_f is the +      // ftz version, and foo_f is the non-ftz version.  But nvvm_sqrt_f adopts +      // the ftz-ness of the surrounding code.  sqrt_rn_f and sqrt_rn_ftz_f are +      // the versions with explicit ftz-ness. +      return {Intrinsic::sqrt, FTZ_Any}; +    case Intrinsic::nvvm_sqrt_rn_f: +      return {Intrinsic::sqrt, FTZ_MustBeOff}; +    case Intrinsic::nvvm_sqrt_rn_ftz_f: +      return {Intrinsic::sqrt, FTZ_MustBeOn}; +    case Intrinsic::nvvm_trunc_d: +      return {Intrinsic::trunc, FTZ_Any}; +    case Intrinsic::nvvm_trunc_f: +      return {Intrinsic::trunc, FTZ_MustBeOff}; +    case Intrinsic::nvvm_trunc_ftz_f: +      return {Intrinsic::trunc, FTZ_MustBeOn}; + +    // NVVM intrinsics that map to LLVM cast operations. +    // +    // Note that llvm's target-generic conversion operators correspond to the rz +    // (round to zero) versions of the nvvm conversion intrinsics, even though +    // most everything else here uses the rn (round to nearest even) nvvm ops. +    case Intrinsic::nvvm_d2i_rz: +    case Intrinsic::nvvm_f2i_rz: +    case Intrinsic::nvvm_d2ll_rz: +    case Intrinsic::nvvm_f2ll_rz: +      return {Instruction::FPToSI}; +    case Intrinsic::nvvm_d2ui_rz: +    case Intrinsic::nvvm_f2ui_rz: +    case Intrinsic::nvvm_d2ull_rz: +    case Intrinsic::nvvm_f2ull_rz: +      return {Instruction::FPToUI}; +    case Intrinsic::nvvm_i2d_rz: +    case Intrinsic::nvvm_i2f_rz: +    case Intrinsic::nvvm_ll2d_rz: +    case Intrinsic::nvvm_ll2f_rz: +      return {Instruction::SIToFP}; +    case Intrinsic::nvvm_ui2d_rz: +    case Intrinsic::nvvm_ui2f_rz: +    case Intrinsic::nvvm_ull2d_rz: +    case Intrinsic::nvvm_ull2f_rz: +      return {Instruction::UIToFP}; + +    // NVVM intrinsics that map to LLVM binary ops. +    case Intrinsic::nvvm_add_rn_d: +      return {Instruction::FAdd, FTZ_Any}; +    case Intrinsic::nvvm_add_rn_f: +      return {Instruction::FAdd, FTZ_MustBeOff}; +    case Intrinsic::nvvm_add_rn_ftz_f: +      return {Instruction::FAdd, FTZ_MustBeOn}; +    case Intrinsic::nvvm_mul_rn_d: +      return {Instruction::FMul, FTZ_Any}; +    case Intrinsic::nvvm_mul_rn_f: +      return {Instruction::FMul, FTZ_MustBeOff}; +    case Intrinsic::nvvm_mul_rn_ftz_f: +      return {Instruction::FMul, FTZ_MustBeOn}; +    case Intrinsic::nvvm_div_rn_d: +      return {Instruction::FDiv, FTZ_Any}; +    case Intrinsic::nvvm_div_rn_f: +      return {Instruction::FDiv, FTZ_MustBeOff}; +    case Intrinsic::nvvm_div_rn_ftz_f: +      return {Instruction::FDiv, FTZ_MustBeOn}; + +    // The remainder of cases are NVVM intrinsics that map to LLVM idioms, but +    // need special handling. +    // +    // We seem to be mising intrinsics for rcp.approx.{ftz.}f32, which is just +    // as well. +    case Intrinsic::nvvm_rcp_rn_d: +      return {SPC_Reciprocal, FTZ_Any}; +    case Intrinsic::nvvm_rcp_rn_f: +      return {SPC_Reciprocal, FTZ_MustBeOff}; +    case Intrinsic::nvvm_rcp_rn_ftz_f: +      return {SPC_Reciprocal, FTZ_MustBeOn}; + +    // We do not currently simplify intrinsics that give an approximate answer. +    // These include: +    // +    //   - nvvm_cos_approx_{f,ftz_f} +    //   - nvvm_ex2_approx_{d,f,ftz_f} +    //   - nvvm_lg2_approx_{d,f,ftz_f} +    //   - nvvm_sin_approx_{f,ftz_f} +    //   - nvvm_sqrt_approx_{f,ftz_f} +    //   - nvvm_rsqrt_approx_{d,f,ftz_f} +    //   - nvvm_div_approx_{ftz_d,ftz_f,f} +    //   - nvvm_rcp_approx_ftz_d +    // +    // Ideally we'd encode them as e.g. "fast call @llvm.cos", where "fast" +    // means that fastmath is enabled in the intrinsic.  Unfortunately only +    // binary operators (currently) have a fastmath bit in SelectionDAG, so this +    // information gets lost and we can't select on it. +    // +    // TODO: div and rcp are lowered to a binary op, so these we could in theory +    // lower them to "fast fdiv". + +    default: +      return {}; +    } +  }(); + +  // If Action.FtzRequirementTy is not satisfied by the module's ftz state, we +  // can bail out now.  (Notice that in the case that IID is not an NVVM +  // intrinsic, we don't have to look up any module metadata, as +  // FtzRequirementTy will be FTZ_Any.) +  if (Action.FtzRequirement != FTZ_Any) { +    bool FtzEnabled = +        II->getFunction()->getFnAttribute("nvptx-f32ftz").getValueAsString() == +        "true"; + +    if (FtzEnabled != (Action.FtzRequirement == FTZ_MustBeOn)) +      return nullptr; +  } + +  // Simplify to target-generic intrinsic. +  if (Action.IID) { +    SmallVector<Value *, 4> Args(II->arg_operands()); +    // All the target-generic intrinsics currently of interest to us have one +    // type argument, equal to that of the nvvm intrinsic's argument. +    Type *Tys[] = {II->getArgOperand(0)->getType()}; +    return CallInst::Create( +        Intrinsic::getDeclaration(II->getModule(), *Action.IID, Tys), Args); +  } + +  // Simplify to target-generic binary op. +  if (Action.BinaryOp) +    return BinaryOperator::Create(*Action.BinaryOp, II->getArgOperand(0), +                                  II->getArgOperand(1), II->getName()); + +  // Simplify to target-generic cast op. +  if (Action.CastOp) +    return CastInst::Create(*Action.CastOp, II->getArgOperand(0), II->getType(), +                            II->getName()); + +  // All that's left are the special cases. +  if (!Action.Special) +    return nullptr; + +  switch (*Action.Special) { +  case SPC_Reciprocal: +    // Simplify reciprocal. +    return BinaryOperator::Create( +        Instruction::FDiv, ConstantFP::get(II->getArgOperand(0)->getType(), 1), +        II->getArgOperand(0), II->getName()); +  } +  llvm_unreachable("All SpecialCase enumerators should be handled in switch."); +} +  Instruction *InstCombiner::visitVAStartInst(VAStartInst &I) {    removeTriviallyEmptyRange(I, Intrinsic::vastart, Intrinsic::vaend, *this);    return nullptr; @@ -1462,6 +1891,18 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {      if (Changed) return II;    } +  if (auto *AMI = dyn_cast<ElementAtomicMemCpyInst>(II)) { +    if (Constant *C = dyn_cast<Constant>(AMI->getNumElements())) +      if (C->isNullValue()) +        return eraseInstFromFunction(*AMI); + +    if (Instruction *I = SimplifyElementAtomicMemCpy(AMI)) +      return I; +  } + +  if (Instruction *I = SimplifyNVVMIntrinsic(II, *this)) +    return I; +    auto SimplifyDemandedVectorEltsLow = [this](Value *Op, unsigned Width,                                                unsigned DemandedWidth) {      APInt UndefElts(Width, 0); @@ -1581,8 +2022,21 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {        return replaceInstUsesWith(*II, V);      break;    } -  case Intrinsic::fma:    case Intrinsic::fmuladd: { +    // Canonicalize fast fmuladd to the separate fmul + fadd. +    if (II->hasUnsafeAlgebra()) { +      BuilderTy::FastMathFlagGuard Guard(*Builder); +      Builder->setFastMathFlags(II->getFastMathFlags()); +      Value *Mul = Builder->CreateFMul(II->getArgOperand(0), +                                       II->getArgOperand(1)); +      Value *Add = Builder->CreateFAdd(Mul, II->getArgOperand(2)); +      Add->takeName(II); +      return replaceInstUsesWith(*II, Add); +    } + +    LLVM_FALLTHROUGH; +  } +  case Intrinsic::fma: {      Value *Src0 = II->getArgOperand(0);      Value *Src1 = II->getArgOperand(1); @@ -1631,6 +2085,26 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {        return SelectInst::Create(Cond, Call0, Call1);      } +    LLVM_FALLTHROUGH; +  } +  case Intrinsic::ceil: +  case Intrinsic::floor: +  case Intrinsic::round: +  case Intrinsic::nearbyint: +  case Intrinsic::rint: +  case Intrinsic::trunc: { +    Value *ExtSrc; +    if (match(II->getArgOperand(0), m_FPExt(m_Value(ExtSrc))) && +        II->getArgOperand(0)->hasOneUse()) { +      // fabs (fpext x) -> fpext (fabs x) +      Value *F = Intrinsic::getDeclaration(II->getModule(), II->getIntrinsicID(), +                                           { ExtSrc->getType() }); +      CallInst *NewFabs = Builder->CreateCall(F, ExtSrc); +      NewFabs->copyFastMathFlags(II); +      NewFabs->takeName(II); +      return new FPExtInst(NewFabs, II->getType()); +    } +      break;    }    case Intrinsic::cos: @@ -1863,6 +2337,37 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {        return II;      break;    } +  case Intrinsic::x86_avx512_mask_cmp_pd_128: +  case Intrinsic::x86_avx512_mask_cmp_pd_256: +  case Intrinsic::x86_avx512_mask_cmp_pd_512: +  case Intrinsic::x86_avx512_mask_cmp_ps_128: +  case Intrinsic::x86_avx512_mask_cmp_ps_256: +  case Intrinsic::x86_avx512_mask_cmp_ps_512: { +    // Folding cmp(sub(a,b),0) -> cmp(a,b) and cmp(0,sub(a,b)) -> cmp(b,a) +    Value *Arg0 = II->getArgOperand(0); +    Value *Arg1 = II->getArgOperand(1); +    bool Arg0IsZero = match(Arg0, m_Zero()); +    if (Arg0IsZero) +      std::swap(Arg0, Arg1); +    Value *A, *B; +    // This fold requires only the NINF(not +/- inf) since inf minus +    // inf is nan. +    // NSZ(No Signed Zeros) is not needed because zeros of any sign are +    // equal for both compares. +    // NNAN is not needed because nans compare the same for both compares. +    // The compare intrinsic uses the above assumptions and therefore +    // doesn't require additional flags. +    if ((match(Arg0, m_OneUse(m_FSub(m_Value(A), m_Value(B)))) && +         match(Arg1, m_Zero()) && +         cast<Instruction>(Arg0)->getFastMathFlags().noInfs())) { +      if (Arg0IsZero) +        std::swap(A, B); +      II->setArgOperand(0, A); +      II->setArgOperand(1, B); +      return II; +    } +    break; +  }    case Intrinsic::x86_avx512_mask_add_ps_512:    case Intrinsic::x86_avx512_mask_div_ps_512: @@ -2130,6 +2635,9 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {    case Intrinsic::x86_avx2_pmulu_dq:    case Intrinsic::x86_avx512_pmul_dq_512:    case Intrinsic::x86_avx512_pmulu_dq_512: { +    if (Value *V = simplifyX86muldq(*II, *Builder)) +      return replaceInstUsesWith(*II, V); +      unsigned VWidth = II->getType()->getVectorNumElements();      APInt UndefElts(VWidth, 0);      APInt DemandedElts = APInt::getAllOnesValue(VWidth); @@ -2141,6 +2649,64 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {      break;    } +  case Intrinsic::x86_sse2_packssdw_128: +  case Intrinsic::x86_sse2_packsswb_128: +  case Intrinsic::x86_avx2_packssdw: +  case Intrinsic::x86_avx2_packsswb: +  case Intrinsic::x86_avx512_packssdw_512: +  case Intrinsic::x86_avx512_packsswb_512: +    if (Value *V = simplifyX86pack(*II, *this, *Builder, true)) +      return replaceInstUsesWith(*II, V); +    break; + +  case Intrinsic::x86_sse2_packuswb_128: +  case Intrinsic::x86_sse41_packusdw: +  case Intrinsic::x86_avx2_packusdw: +  case Intrinsic::x86_avx2_packuswb: +  case Intrinsic::x86_avx512_packusdw_512: +  case Intrinsic::x86_avx512_packuswb_512: +    if (Value *V = simplifyX86pack(*II, *this, *Builder, false)) +      return replaceInstUsesWith(*II, V); +    break; + +  case Intrinsic::x86_pclmulqdq: { +    if (auto *C = dyn_cast<ConstantInt>(II->getArgOperand(2))) { +      unsigned Imm = C->getZExtValue(); + +      bool MadeChange = false; +      Value *Arg0 = II->getArgOperand(0); +      Value *Arg1 = II->getArgOperand(1); +      unsigned VWidth = Arg0->getType()->getVectorNumElements(); +      APInt DemandedElts(VWidth, 0); + +      APInt UndefElts1(VWidth, 0); +      DemandedElts = (Imm & 0x01) ? 2 : 1; +      if (Value *V = SimplifyDemandedVectorElts(Arg0, DemandedElts, +                                                UndefElts1)) { +        II->setArgOperand(0, V); +        MadeChange = true; +      } + +      APInt UndefElts2(VWidth, 0); +      DemandedElts = (Imm & 0x10) ? 2 : 1; +      if (Value *V = SimplifyDemandedVectorElts(Arg1, DemandedElts, +                                                UndefElts2)) { +        II->setArgOperand(1, V); +        MadeChange = true; +      } + +      // If both input elements are undef, the result is undef. +      if (UndefElts1[(Imm & 0x01) ? 1 : 0] || +          UndefElts2[(Imm & 0x10) ? 1 : 0]) +        return replaceInstUsesWith(*II, +                                   ConstantAggregateZero::get(II->getType())); + +      if (MadeChange) +        return II; +    } +    break; +  } +    case Intrinsic::x86_sse41_insertps:      if (Value *V = simplifyX86insertps(*II, *Builder))        return replaceInstUsesWith(*II, V); @@ -2531,9 +3097,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {      break;    } -    case Intrinsic::amdgcn_rcp: { -    if (const ConstantFP *C = dyn_cast<ConstantFP>(II->getArgOperand(0))) { +    Value *Src = II->getArgOperand(0); + +    // TODO: Move to ConstantFolding/InstSimplify? +    if (isa<UndefValue>(Src)) +      return replaceInstUsesWith(CI, Src); + +    if (const ConstantFP *C = dyn_cast<ConstantFP>(Src)) {        const APFloat &ArgVal = C->getValueAPF();        APFloat Val(ArgVal.getSemantics(), 1.0);        APFloat::opStatus Status = Val.divide(ArgVal, @@ -2546,6 +3117,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {      break;    } +  case Intrinsic::amdgcn_rsq: { +    Value *Src = II->getArgOperand(0); + +    // TODO: Move to ConstantFolding/InstSimplify? +    if (isa<UndefValue>(Src)) +      return replaceInstUsesWith(CI, Src); +    break; +  }    case Intrinsic::amdgcn_frexp_mant:    case Intrinsic::amdgcn_frexp_exp: {      Value *Src = II->getArgOperand(0); @@ -2650,6 +3229,274 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {      return replaceInstUsesWith(*II, ConstantInt::get(II->getType(), Result));    } +  case Intrinsic::amdgcn_cvt_pkrtz: { +    Value *Src0 = II->getArgOperand(0); +    Value *Src1 = II->getArgOperand(1); +    if (const ConstantFP *C0 = dyn_cast<ConstantFP>(Src0)) { +      if (const ConstantFP *C1 = dyn_cast<ConstantFP>(Src1)) { +        const fltSemantics &HalfSem +          = II->getType()->getScalarType()->getFltSemantics(); +        bool LosesInfo; +        APFloat Val0 = C0->getValueAPF(); +        APFloat Val1 = C1->getValueAPF(); +        Val0.convert(HalfSem, APFloat::rmTowardZero, &LosesInfo); +        Val1.convert(HalfSem, APFloat::rmTowardZero, &LosesInfo); + +        Constant *Folded = ConstantVector::get({ +            ConstantFP::get(II->getContext(), Val0), +            ConstantFP::get(II->getContext(), Val1) }); +        return replaceInstUsesWith(*II, Folded); +      } +    } + +    if (isa<UndefValue>(Src0) && isa<UndefValue>(Src1)) +      return replaceInstUsesWith(*II, UndefValue::get(II->getType())); + +    break; +  } +  case Intrinsic::amdgcn_ubfe: +  case Intrinsic::amdgcn_sbfe: { +    // Decompose simple cases into standard shifts. +    Value *Src = II->getArgOperand(0); +    if (isa<UndefValue>(Src)) +      return replaceInstUsesWith(*II, Src); + +    unsigned Width; +    Type *Ty = II->getType(); +    unsigned IntSize = Ty->getIntegerBitWidth(); + +    ConstantInt *CWidth = dyn_cast<ConstantInt>(II->getArgOperand(2)); +    if (CWidth) { +      Width = CWidth->getZExtValue(); +      if ((Width & (IntSize - 1)) == 0) +        return replaceInstUsesWith(*II, ConstantInt::getNullValue(Ty)); + +      if (Width >= IntSize) { +        // Hardware ignores high bits, so remove those. +        II->setArgOperand(2, ConstantInt::get(CWidth->getType(), +                                              Width & (IntSize - 1))); +        return II; +      } +    } + +    unsigned Offset; +    ConstantInt *COffset = dyn_cast<ConstantInt>(II->getArgOperand(1)); +    if (COffset) { +      Offset = COffset->getZExtValue(); +      if (Offset >= IntSize) { +        II->setArgOperand(1, ConstantInt::get(COffset->getType(), +                                              Offset & (IntSize - 1))); +        return II; +      } +    } + +    bool Signed = II->getIntrinsicID() == Intrinsic::amdgcn_sbfe; + +    // TODO: Also emit sub if only width is constant. +    if (!CWidth && COffset && Offset == 0) { +      Constant *KSize = ConstantInt::get(COffset->getType(), IntSize); +      Value *ShiftVal = Builder->CreateSub(KSize, II->getArgOperand(2)); +      ShiftVal = Builder->CreateZExt(ShiftVal, II->getType()); + +      Value *Shl = Builder->CreateShl(Src, ShiftVal); +      Value *RightShift = Signed ? +        Builder->CreateAShr(Shl, ShiftVal) : +        Builder->CreateLShr(Shl, ShiftVal); +      RightShift->takeName(II); +      return replaceInstUsesWith(*II, RightShift); +    } + +    if (!CWidth || !COffset) +      break; + +    // TODO: This allows folding to undef when the hardware has specific +    // behavior? +    if (Offset + Width < IntSize) { +      Value *Shl = Builder->CreateShl(Src, IntSize  - Offset - Width); +      Value *RightShift = Signed ? +        Builder->CreateAShr(Shl, IntSize - Width) : +        Builder->CreateLShr(Shl, IntSize - Width); +      RightShift->takeName(II); +      return replaceInstUsesWith(*II, RightShift); +    } + +    Value *RightShift = Signed ? +      Builder->CreateAShr(Src, Offset) : +      Builder->CreateLShr(Src, Offset); + +    RightShift->takeName(II); +    return replaceInstUsesWith(*II, RightShift); +  } +  case Intrinsic::amdgcn_exp: +  case Intrinsic::amdgcn_exp_compr: { +    ConstantInt *En = dyn_cast<ConstantInt>(II->getArgOperand(1)); +    if (!En) // Illegal. +      break; + +    unsigned EnBits = En->getZExtValue(); +    if (EnBits == 0xf) +      break; // All inputs enabled. + +    bool IsCompr = II->getIntrinsicID() == Intrinsic::amdgcn_exp_compr; +    bool Changed = false; +    for (int I = 0; I < (IsCompr ? 2 : 4); ++I) { +      if ((!IsCompr && (EnBits & (1 << I)) == 0) || +          (IsCompr && ((EnBits & (0x3 << (2 * I))) == 0))) { +        Value *Src = II->getArgOperand(I + 2); +        if (!isa<UndefValue>(Src)) { +          II->setArgOperand(I + 2, UndefValue::get(Src->getType())); +          Changed = true; +        } +      } +    } + +    if (Changed) +      return II; + +    break; + +  } +  case Intrinsic::amdgcn_fmed3: { +    // Note this does not preserve proper sNaN behavior if IEEE-mode is enabled +    // for the shader. + +    Value *Src0 = II->getArgOperand(0); +    Value *Src1 = II->getArgOperand(1); +    Value *Src2 = II->getArgOperand(2); + +    bool Swap = false; +    // Canonicalize constants to RHS operands. +    // +    // fmed3(c0, x, c1) -> fmed3(x, c0, c1) +    if (isa<Constant>(Src0) && !isa<Constant>(Src1)) { +      std::swap(Src0, Src1); +      Swap = true; +    } + +    if (isa<Constant>(Src1) && !isa<Constant>(Src2)) { +      std::swap(Src1, Src2); +      Swap = true; +    } + +    if (isa<Constant>(Src0) && !isa<Constant>(Src1)) { +      std::swap(Src0, Src1); +      Swap = true; +    } + +    if (Swap) { +      II->setArgOperand(0, Src0); +      II->setArgOperand(1, Src1); +      II->setArgOperand(2, Src2); +      return II; +    } + +    if (match(Src2, m_NaN()) || isa<UndefValue>(Src2)) { +      CallInst *NewCall = Builder->CreateMinNum(Src0, Src1); +      NewCall->copyFastMathFlags(II); +      NewCall->takeName(II); +      return replaceInstUsesWith(*II, NewCall); +    } + +    if (const ConstantFP *C0 = dyn_cast<ConstantFP>(Src0)) { +      if (const ConstantFP *C1 = dyn_cast<ConstantFP>(Src1)) { +        if (const ConstantFP *C2 = dyn_cast<ConstantFP>(Src2)) { +          APFloat Result = fmed3AMDGCN(C0->getValueAPF(), C1->getValueAPF(), +                                       C2->getValueAPF()); +          return replaceInstUsesWith(*II, +            ConstantFP::get(Builder->getContext(), Result)); +        } +      } +    } + +    break; +  } +  case Intrinsic::amdgcn_icmp: +  case Intrinsic::amdgcn_fcmp: { +    const ConstantInt *CC = dyn_cast<ConstantInt>(II->getArgOperand(2)); +    if (!CC) +      break; + +    // Guard against invalid arguments. +    int64_t CCVal = CC->getZExtValue(); +    bool IsInteger = II->getIntrinsicID() == Intrinsic::amdgcn_icmp; +    if ((IsInteger && (CCVal < CmpInst::FIRST_ICMP_PREDICATE || +                       CCVal > CmpInst::LAST_ICMP_PREDICATE)) || +        (!IsInteger && (CCVal < CmpInst::FIRST_FCMP_PREDICATE || +                        CCVal > CmpInst::LAST_FCMP_PREDICATE))) +      break; + +    Value *Src0 = II->getArgOperand(0); +    Value *Src1 = II->getArgOperand(1); + +    if (auto *CSrc0 = dyn_cast<Constant>(Src0)) { +      if (auto *CSrc1 = dyn_cast<Constant>(Src1)) { +        Constant *CCmp = ConstantExpr::getCompare(CCVal, CSrc0, CSrc1); +        return replaceInstUsesWith(*II, +                                   ConstantExpr::getSExt(CCmp, II->getType())); +      } + +      // Canonicalize constants to RHS. +      CmpInst::Predicate SwapPred +        = CmpInst::getSwappedPredicate(static_cast<CmpInst::Predicate>(CCVal)); +      II->setArgOperand(0, Src1); +      II->setArgOperand(1, Src0); +      II->setArgOperand(2, ConstantInt::get(CC->getType(), +                                            static_cast<int>(SwapPred))); +      return II; +    } + +    if (CCVal != CmpInst::ICMP_EQ && CCVal != CmpInst::ICMP_NE) +      break; + +    // Canonicalize compare eq with true value to compare != 0 +    // llvm.amdgcn.icmp(zext (i1 x), 1, eq) +    //   -> llvm.amdgcn.icmp(zext (i1 x), 0, ne) +    // llvm.amdgcn.icmp(sext (i1 x), -1, eq) +    //   -> llvm.amdgcn.icmp(sext (i1 x), 0, ne) +    Value *ExtSrc; +    if (CCVal == CmpInst::ICMP_EQ && +        ((match(Src1, m_One()) && match(Src0, m_ZExt(m_Value(ExtSrc)))) || +         (match(Src1, m_AllOnes()) && match(Src0, m_SExt(m_Value(ExtSrc))))) && +        ExtSrc->getType()->isIntegerTy(1)) { +      II->setArgOperand(1, ConstantInt::getNullValue(Src1->getType())); +      II->setArgOperand(2, ConstantInt::get(CC->getType(), CmpInst::ICMP_NE)); +      return II; +    } + +    CmpInst::Predicate SrcPred; +    Value *SrcLHS; +    Value *SrcRHS; + +    // Fold compare eq/ne with 0 from a compare result as the predicate to the +    // intrinsic. The typical use is a wave vote function in the library, which +    // will be fed from a user code condition compared with 0. Fold in the +    // redundant compare. + +    // llvm.amdgcn.icmp([sz]ext ([if]cmp pred a, b), 0, ne) +    //   -> llvm.amdgcn.[if]cmp(a, b, pred) +    // +    // llvm.amdgcn.icmp([sz]ext ([if]cmp pred a, b), 0, eq) +    //   -> llvm.amdgcn.[if]cmp(a, b, inv pred) +    if (match(Src1, m_Zero()) && +        match(Src0, +              m_ZExtOrSExt(m_Cmp(SrcPred, m_Value(SrcLHS), m_Value(SrcRHS))))) { +      if (CCVal == CmpInst::ICMP_EQ) +        SrcPred = CmpInst::getInversePredicate(SrcPred); + +      Intrinsic::ID NewIID = CmpInst::isFPPredicate(SrcPred) ? +        Intrinsic::amdgcn_fcmp : Intrinsic::amdgcn_icmp; + +      Value *NewF = Intrinsic::getDeclaration(II->getModule(), NewIID, +                                              SrcLHS->getType()); +      Value *Args[] = { SrcLHS, SrcRHS, +                        ConstantInt::get(CC->getType(), SrcPred) }; +      CallInst *NewCall = Builder->CreateCall(NewF, Args); +      NewCall->takeName(II); +      return replaceInstUsesWith(*II, NewCall); +    } + +    break; +  }    case Intrinsic::stackrestore: {      // If the save is right next to the restore, remove the restore.  This can      // happen when variable allocas are DCE'd. @@ -2790,7 +3637,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {        // isKnownNonNull -> nonnull attribute        if (isKnownNonNullAt(DerivedPtr, II, &DT)) -        II->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull); +        II->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);      }      // TODO: bitcast(relocate(p)) -> relocate(bitcast(p)) @@ -2799,11 +3646,38 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {      // TODO: relocate((gep p, C, C2, ...)) -> gep(relocate(p), C, C2, ...)      break;    } -  } +  case Intrinsic::experimental_guard: { +    // Is this guard followed by another guard? +    Instruction *NextInst = II->getNextNode(); +    Value *NextCond = nullptr; +    if (match(NextInst, +              m_Intrinsic<Intrinsic::experimental_guard>(m_Value(NextCond)))) { +      Value *CurrCond = II->getArgOperand(0); + +      // Remove a guard that it is immediately preceded by an identical guard. +      if (CurrCond == NextCond) +        return eraseInstFromFunction(*NextInst); + +      // Otherwise canonicalize guard(a); guard(b) -> guard(a & b). +      II->setArgOperand(0, Builder->CreateAnd(CurrCond, NextCond)); +      return eraseInstFromFunction(*NextInst); +    } +    break; +  } +  }    return visitCallSite(II);  } +// Fence instruction simplification +Instruction *InstCombiner::visitFenceInst(FenceInst &FI) { +  // Remove identical consecutive fences. +  if (auto *NFI = dyn_cast<FenceInst>(FI.getNextNode())) +    if (FI.isIdenticalTo(NFI)) +      return eraseInstFromFunction(FI); +  return nullptr; +} +  // InvokeInst simplification  //  Instruction *InstCombiner::visitInvokeInst(InvokeInst &II) { @@ -2950,7 +3824,7 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) {    for (Value *V : CS.args()) {      if (V->getType()->isPointerTy() && -        !CS.paramHasAttr(ArgNo + 1, Attribute::NonNull) && +        !CS.paramHasAttr(ArgNo, Attribute::NonNull) &&          isKnownNonNullAt(V, CS.getInstruction(), &DT))        Indices.push_back(ArgNo + 1);      ArgNo++; @@ -2959,7 +3833,7 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) {    assert(ArgNo == CS.arg_size() && "sanity check");    if (!Indices.empty()) { -    AttributeSet AS = CS.getAttributes(); +    AttributeList AS = CS.getAttributes();      LLVMContext &Ctx = CS.getInstruction()->getContext();      AS = AS.addAttribute(Ctx, Indices,                           Attribute::get(Ctx, Attribute::NonNull)); @@ -3081,7 +3955,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {      return false;    Instruction *Caller = CS.getInstruction(); -  const AttributeSet &CallerPAL = CS.getAttributes(); +  const AttributeList &CallerPAL = CS.getAttributes();    // Okay, this is a cast from a function to a different type.  Unless doing so    // would cause a type conversion of one of our arguments, change this call to @@ -3108,7 +3982,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {      }      if (!CallerPAL.isEmpty() && !Caller->use_empty()) { -      AttrBuilder RAttrs(CallerPAL, AttributeSet::ReturnIndex); +      AttrBuilder RAttrs(CallerPAL, AttributeList::ReturnIndex);        if (RAttrs.overlaps(AttributeFuncs::typeIncompatible(NewRetTy)))          return false;   // Attribute not compatible with transformed value.      } @@ -3149,8 +4023,8 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {      if (!CastInst::isBitOrNoopPointerCastable(ActTy, ParamTy, DL))        return false;   // Cannot transform this parameter value. -    if (AttrBuilder(CallerPAL.getParamAttributes(i + 1), i + 1). -          overlaps(AttributeFuncs::typeIncompatible(ParamTy))) +    if (AttrBuilder(CallerPAL.getParamAttributes(i)) +            .overlaps(AttributeFuncs::typeIncompatible(ParamTy)))        return false;   // Attribute not compatible with transformed value.      if (CS.isInAllocaArgument(i)) @@ -3158,9 +4032,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {      // If the parameter is passed as a byval argument, then we have to have a      // sized type and the sized type has to have the same size as the old type. -    if (ParamTy != ActTy && -        CallerPAL.getParamAttributes(i + 1).hasAttribute(i + 1, -                                                         Attribute::ByVal)) { +    if (ParamTy != ActTy && CallerPAL.hasParamAttribute(i, Attribute::ByVal)) {        PointerType *ParamPTy = dyn_cast<PointerType>(ParamTy);        if (!ParamPTy || !ParamPTy->getElementType()->isSized())          return false; @@ -3205,7 +4077,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {          break;        // Check if it has an attribute that's incompatible with varargs. -      AttributeSet PAttrs = CallerPAL.getSlotAttributes(i - 1); +      AttributeList PAttrs = CallerPAL.getSlotAttributes(i - 1);        if (PAttrs.hasAttribute(Index, Attribute::StructRet))          return false;      } @@ -3213,44 +4085,37 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {    // Okay, we decided that this is a safe thing to do: go ahead and start    // inserting cast instructions as necessary. -  std::vector<Value*> Args; +  SmallVector<Value *, 8> Args; +  SmallVector<AttributeSet, 8> ArgAttrs;    Args.reserve(NumActualArgs); -  SmallVector<AttributeSet, 8> attrVec; -  attrVec.reserve(NumCommonArgs); +  ArgAttrs.reserve(NumActualArgs);    // Get any return attributes. -  AttrBuilder RAttrs(CallerPAL, AttributeSet::ReturnIndex); +  AttrBuilder RAttrs(CallerPAL, AttributeList::ReturnIndex);    // If the return value is not being used, the type may not be compatible    // with the existing attributes.  Wipe out any problematic attributes.    RAttrs.remove(AttributeFuncs::typeIncompatible(NewRetTy)); -  // Add the new return attributes. -  if (RAttrs.hasAttributes()) -    attrVec.push_back(AttributeSet::get(Caller->getContext(), -                                        AttributeSet::ReturnIndex, RAttrs)); -    AI = CS.arg_begin();    for (unsigned i = 0; i != NumCommonArgs; ++i, ++AI) {      Type *ParamTy = FT->getParamType(i); -    if ((*AI)->getType() == ParamTy) { -      Args.push_back(*AI); -    } else { -      Args.push_back(Builder->CreateBitOrPointerCast(*AI, ParamTy)); -    } +    Value *NewArg = *AI; +    if ((*AI)->getType() != ParamTy) +      NewArg = Builder->CreateBitOrPointerCast(*AI, ParamTy); +    Args.push_back(NewArg);      // Add any parameter attributes. -    AttrBuilder PAttrs(CallerPAL.getParamAttributes(i + 1), i + 1); -    if (PAttrs.hasAttributes()) -      attrVec.push_back(AttributeSet::get(Caller->getContext(), i + 1, -                                          PAttrs)); +    ArgAttrs.push_back(CallerPAL.getParamAttributes(i));    }    // If the function takes more arguments than the call was taking, add them    // now. -  for (unsigned i = NumCommonArgs; i != FT->getNumParams(); ++i) +  for (unsigned i = NumCommonArgs; i != FT->getNumParams(); ++i) {      Args.push_back(Constant::getNullValue(FT->getParamType(i))); +    ArgAttrs.push_back(AttributeSet()); +  }    // If we are removing arguments to the function, emit an obnoxious warning.    if (FT->getNumParams() < NumActualArgs) { @@ -3259,54 +4124,56 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {        // Add all of the arguments in their promoted form to the arg list.        for (unsigned i = FT->getNumParams(); i != NumActualArgs; ++i, ++AI) {          Type *PTy = getPromotedType((*AI)->getType()); +        Value *NewArg = *AI;          if (PTy != (*AI)->getType()) {            // Must promote to pass through va_arg area!            Instruction::CastOps opcode =              CastInst::getCastOpcode(*AI, false, PTy, false); -          Args.push_back(Builder->CreateCast(opcode, *AI, PTy)); -        } else { -          Args.push_back(*AI); +          NewArg = Builder->CreateCast(opcode, *AI, PTy);          } +        Args.push_back(NewArg);          // Add any parameter attributes. -        AttrBuilder PAttrs(CallerPAL.getParamAttributes(i + 1), i + 1); -        if (PAttrs.hasAttributes()) -          attrVec.push_back(AttributeSet::get(FT->getContext(), i + 1, -                                              PAttrs)); +        ArgAttrs.push_back(CallerPAL.getParamAttributes(i));        }      }    }    AttributeSet FnAttrs = CallerPAL.getFnAttributes(); -  if (CallerPAL.hasAttributes(AttributeSet::FunctionIndex)) -    attrVec.push_back(AttributeSet::get(Callee->getContext(), FnAttrs));    if (NewRetTy->isVoidTy())      Caller->setName("");   // Void type should not have a name. -  const AttributeSet &NewCallerPAL = AttributeSet::get(Callee->getContext(), -                                                       attrVec); +  assert((ArgAttrs.size() == FT->getNumParams() || FT->isVarArg()) && +         "missing argument attributes"); +  LLVMContext &Ctx = Callee->getContext(); +  AttributeList NewCallerPAL = AttributeList::get( +      Ctx, FnAttrs, AttributeSet::get(Ctx, RAttrs), ArgAttrs);    SmallVector<OperandBundleDef, 1> OpBundles;    CS.getOperandBundlesAsDefs(OpBundles); -  Instruction *NC; +  CallSite NewCS;    if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) { -    NC = Builder->CreateInvoke(Callee, II->getNormalDest(), II->getUnwindDest(), -                               Args, OpBundles); -    NC->takeName(II); -    cast<InvokeInst>(NC)->setCallingConv(II->getCallingConv()); -    cast<InvokeInst>(NC)->setAttributes(NewCallerPAL); +    NewCS = Builder->CreateInvoke(Callee, II->getNormalDest(), +                                  II->getUnwindDest(), Args, OpBundles);    } else { -    CallInst *CI = cast<CallInst>(Caller); -    NC = Builder->CreateCall(Callee, Args, OpBundles); -    NC->takeName(CI); -    cast<CallInst>(NC)->setTailCallKind(CI->getTailCallKind()); -    cast<CallInst>(NC)->setCallingConv(CI->getCallingConv()); -    cast<CallInst>(NC)->setAttributes(NewCallerPAL); +    NewCS = Builder->CreateCall(Callee, Args, OpBundles); +    cast<CallInst>(NewCS.getInstruction()) +        ->setTailCallKind(cast<CallInst>(Caller)->getTailCallKind());    } +  NewCS->takeName(Caller); +  NewCS.setCallingConv(CS.getCallingConv()); +  NewCS.setAttributes(NewCallerPAL); + +  // Preserve the weight metadata for the new call instruction. The metadata +  // is used by SamplePGO to check callsite's hotness. +  uint64_t W; +  if (Caller->extractProfTotalWeight(W)) +    NewCS->setProfWeight(W);    // Insert a cast of the return type as necessary. +  Instruction *NC = NewCS.getInstruction();    Value *NV = NC;    if (OldRetTy != NV->getType() && !Caller->use_empty()) {      if (!NV->getType()->isVoidTy()) { @@ -3351,7 +4218,7 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS,    Value *Callee = CS.getCalledValue();    PointerType *PTy = cast<PointerType>(Callee->getType());    FunctionType *FTy = cast<FunctionType>(PTy->getElementType()); -  const AttributeSet &Attrs = CS.getAttributes(); +  AttributeList Attrs = CS.getAttributes();    // If the call already has the 'nest' attribute somewhere then give up -    // otherwise 'nest' would occur twice after splicing in the chain. @@ -3364,50 +4231,46 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS,    Function *NestF =cast<Function>(Tramp->getArgOperand(1)->stripPointerCasts());    FunctionType *NestFTy = cast<FunctionType>(NestF->getValueType()); -  const AttributeSet &NestAttrs = NestF->getAttributes(); +  AttributeList NestAttrs = NestF->getAttributes();    if (!NestAttrs.isEmpty()) { -    unsigned NestIdx = 1; +    unsigned NestArgNo = 0;      Type *NestTy = nullptr;      AttributeSet NestAttr;      // Look for a parameter marked with the 'nest' attribute.      for (FunctionType::param_iterator I = NestFTy->param_begin(), -         E = NestFTy->param_end(); I != E; ++NestIdx, ++I) -      if (NestAttrs.hasAttribute(NestIdx, Attribute::Nest)) { +                                      E = NestFTy->param_end(); +         I != E; ++NestArgNo, ++I) { +      AttributeSet AS = NestAttrs.getParamAttributes(NestArgNo); +      if (AS.hasAttribute(Attribute::Nest)) {          // Record the parameter type and any other attributes.          NestTy = *I; -        NestAttr = NestAttrs.getParamAttributes(NestIdx); +        NestAttr = AS;          break;        } +    }      if (NestTy) {        Instruction *Caller = CS.getInstruction();        std::vector<Value*> NewArgs; +      std::vector<AttributeSet> NewArgAttrs;        NewArgs.reserve(CS.arg_size() + 1); - -      SmallVector<AttributeSet, 8> NewAttrs; -      NewAttrs.reserve(Attrs.getNumSlots() + 1); +      NewArgAttrs.reserve(CS.arg_size());        // Insert the nest argument into the call argument list, which may        // mean appending it.  Likewise for attributes. -      // Add any result attributes. -      if (Attrs.hasAttributes(AttributeSet::ReturnIndex)) -        NewAttrs.push_back(AttributeSet::get(Caller->getContext(), -                                             Attrs.getRetAttributes())); -        { -        unsigned Idx = 1; +        unsigned ArgNo = 0;          CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end();          do { -          if (Idx == NestIdx) { +          if (ArgNo == NestArgNo) {              // Add the chain argument and attributes.              Value *NestVal = Tramp->getArgOperand(2);              if (NestVal->getType() != NestTy)                NestVal = Builder->CreateBitCast(NestVal, NestTy, "nest");              NewArgs.push_back(NestVal); -            NewAttrs.push_back(AttributeSet::get(Caller->getContext(), -                                                 NestAttr)); +            NewArgAttrs.push_back(NestAttr);            }            if (I == E) @@ -3415,23 +4278,13 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS,            // Add the original argument and attributes.            NewArgs.push_back(*I); -          AttributeSet Attr = Attrs.getParamAttributes(Idx); -          if (Attr.hasAttributes(Idx)) { -            AttrBuilder B(Attr, Idx); -            NewAttrs.push_back(AttributeSet::get(Caller->getContext(), -                                                 Idx + (Idx >= NestIdx), B)); -          } +          NewArgAttrs.push_back(Attrs.getParamAttributes(ArgNo)); -          ++Idx; +          ++ArgNo;            ++I;          } while (true);        } -      // Add any function attributes. -      if (Attrs.hasAttributes(AttributeSet::FunctionIndex)) -        NewAttrs.push_back(AttributeSet::get(FTy->getContext(), -                                             Attrs.getFnAttributes())); -        // The trampoline may have been bitcast to a bogus type (FTy).        // Handle this by synthesizing a new function type, equal to FTy        // with the chain parameter inserted. @@ -3442,12 +4295,12 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS,        // Insert the chain's type into the list of parameter types, which may        // mean appending it.        { -        unsigned Idx = 1; +        unsigned ArgNo = 0;          FunctionType::param_iterator I = FTy->param_begin(),            E = FTy->param_end();          do { -          if (Idx == NestIdx) +          if (ArgNo == NestArgNo)              // Add the chain's type.              NewTypes.push_back(NestTy); @@ -3457,7 +4310,7 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS,            // Add the original type.            NewTypes.push_back(*I); -          ++Idx; +          ++ArgNo;            ++I;          } while (true);        } @@ -3470,8 +4323,9 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS,          NestF->getType() == PointerType::getUnqual(NewFTy) ?          NestF : ConstantExpr::getBitCast(NestF,                                           PointerType::getUnqual(NewFTy)); -      const AttributeSet &NewPAL = -          AttributeSet::get(FTy->getContext(), NewAttrs); +      AttributeList NewPAL = +          AttributeList::get(FTy->getContext(), Attrs.getFnAttributes(), +                             Attrs.getRetAttributes(), NewArgAttrs);        SmallVector<OperandBundleDef, 1> OpBundles;        CS.getOperandBundlesAsDefs(OpBundles); diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp index e74b590e2b7c..25683132c786 100644 --- a/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -274,12 +274,12 @@ Instruction *InstCombiner::commonCastTransforms(CastInst &CI) {        return NV;    // If we are casting a PHI, then fold the cast into the PHI. -  if (isa<PHINode>(Src)) { +  if (auto *PN = dyn_cast<PHINode>(Src)) {      // Don't do this if it would create a PHI node with an illegal type from a      // legal type.      if (!Src->getType()->isIntegerTy() || !CI.getType()->isIntegerTy() || -        ShouldChangeType(CI.getType(), Src->getType())) -      if (Instruction *NV = FoldOpIntoPhi(CI)) +        shouldChangeType(CI.getType(), Src->getType())) +      if (Instruction *NV = foldOpIntoPhi(CI, PN))          return NV;    } @@ -447,7 +447,7 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC,  Instruction *InstCombiner::shrinkBitwiseLogic(TruncInst &Trunc) {    Type *SrcTy = Trunc.getSrcTy();    Type *DestTy = Trunc.getType(); -  if (isa<IntegerType>(SrcTy) && !ShouldChangeType(SrcTy, DestTy)) +  if (isa<IntegerType>(SrcTy) && !shouldChangeType(SrcTy, DestTy))      return nullptr;    BinaryOperator *LogicOp; @@ -463,6 +463,56 @@ Instruction *InstCombiner::shrinkBitwiseLogic(TruncInst &Trunc) {    return BinaryOperator::Create(LogicOp->getOpcode(), NarrowOp0, NarrowC);  } +/// Try to narrow the width of a splat shuffle. This could be generalized to any +/// shuffle with a constant operand, but we limit the transform to avoid +/// creating a shuffle type that targets may not be able to lower effectively. +static Instruction *shrinkSplatShuffle(TruncInst &Trunc, +                                       InstCombiner::BuilderTy &Builder) { +  auto *Shuf = dyn_cast<ShuffleVectorInst>(Trunc.getOperand(0)); +  if (Shuf && Shuf->hasOneUse() && isa<UndefValue>(Shuf->getOperand(1)) && +      Shuf->getMask()->getSplatValue() && +      Shuf->getType() == Shuf->getOperand(0)->getType()) { +    // trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Undef, SplatMask +    Constant *NarrowUndef = UndefValue::get(Trunc.getType()); +    Value *NarrowOp = Builder.CreateTrunc(Shuf->getOperand(0), Trunc.getType()); +    return new ShuffleVectorInst(NarrowOp, NarrowUndef, Shuf->getMask()); +  } + +  return nullptr; +} + +/// Try to narrow the width of an insert element. This could be generalized for +/// any vector constant, but we limit the transform to insertion into undef to +/// avoid potential backend problems from unsupported insertion widths. This +/// could also be extended to handle the case of inserting a scalar constant +/// into a vector variable. +static Instruction *shrinkInsertElt(CastInst &Trunc, +                                    InstCombiner::BuilderTy &Builder) { +  Instruction::CastOps Opcode = Trunc.getOpcode(); +  assert((Opcode == Instruction::Trunc || Opcode == Instruction::FPTrunc) && +         "Unexpected instruction for shrinking"); + +  auto *InsElt = dyn_cast<InsertElementInst>(Trunc.getOperand(0)); +  if (!InsElt || !InsElt->hasOneUse()) +    return nullptr; + +  Type *DestTy = Trunc.getType(); +  Type *DestScalarTy = DestTy->getScalarType(); +  Value *VecOp = InsElt->getOperand(0); +  Value *ScalarOp = InsElt->getOperand(1); +  Value *Index = InsElt->getOperand(2); + +  if (isa<UndefValue>(VecOp)) { +    // trunc   (inselt undef, X, Index) --> inselt undef,   (trunc X), Index +    // fptrunc (inselt undef, X, Index) --> inselt undef, (fptrunc X), Index +    UndefValue *NarrowUndef = UndefValue::get(DestTy); +    Value *NarrowOp = Builder.CreateCast(Opcode, ScalarOp, DestScalarTy); +    return InsertElementInst::Create(NarrowUndef, NarrowOp, Index); +  } + +  return nullptr; +} +  Instruction *InstCombiner::visitTrunc(TruncInst &CI) {    if (Instruction *Result = commonCastTransforms(CI))      return Result; @@ -488,7 +538,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) {    // type.   Only do this if the dest type is a simple type, don't convert the    // expression tree to something weird like i93 unless the source is also    // strange. -  if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) && +  if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) &&        canEvaluateTruncated(Src, DestTy, *this, &CI)) {      // If this cast is a truncate, evaluting in a different type always @@ -554,8 +604,14 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) {    if (Instruction *I = shrinkBitwiseLogic(CI))      return I; +  if (Instruction *I = shrinkSplatShuffle(CI, *Builder)) +    return I; + +  if (Instruction *I = shrinkInsertElt(CI, *Builder)) +    return I; +    if (Src->hasOneUse() && isa<IntegerType>(SrcTy) && -      ShouldChangeType(SrcTy, DestTy)) { +      shouldChangeType(SrcTy, DestTy)) {      // Transform "trunc (shl X, cst)" -> "shl (trunc X), cst" so long as the      // dest type is native and cst < dest size.      if (match(Src, m_Shl(m_Value(A), m_ConstantInt(Cst))) && @@ -838,11 +894,6 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) {    if (Instruction *Result = commonCastTransforms(CI))      return Result; -  // See if we can simplify any instructions used by the input whose sole -  // purpose is to compute bits we don't care about. -  if (SimplifyDemandedInstructionBits(CI)) -    return &CI; -    Value *Src = CI.getOperand(0);    Type *SrcTy = Src->getType(), *DestTy = CI.getType(); @@ -851,10 +902,10 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) {    // expression tree to something weird like i93 unless the source is also    // strange.    unsigned BitsToClear; -  if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) && +  if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) &&        canEvaluateZExtd(Src, DestTy, BitsToClear, *this, &CI)) { -    assert(BitsToClear < SrcTy->getScalarSizeInBits() && -           "Unreasonable BitsToClear"); +    assert(BitsToClear <= SrcTy->getScalarSizeInBits() && +           "Can't clear more bits than in SrcTy");      // Okay, we can transform this!  Insert the new expression now.      DEBUG(dbgs() << "ICE: EvaluateInDifferentType converting expression type" @@ -1124,11 +1175,6 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) {    if (Instruction *I = commonCastTransforms(CI))      return I; -  // See if we can simplify any instructions used by the input whose sole -  // purpose is to compute bits we don't care about. -  if (SimplifyDemandedInstructionBits(CI)) -    return &CI; -    Value *Src = CI.getOperand(0);    Type *SrcTy = Src->getType(), *DestTy = CI.getType(); @@ -1145,7 +1191,7 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) {    // type.   Only do this if the dest type is a simple type, don't convert the    // expression tree to something weird like i93 unless the source is also    // strange. -  if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) && +  if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) &&        canEvaluateSExtd(Src, DestTy)) {      // Okay, we can transform this!  Insert the new expression now.      DEBUG(dbgs() << "ICE: EvaluateInDifferentType converting expression type" @@ -1167,18 +1213,16 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) {                                        ShAmt);    } -  // If this input is a trunc from our destination, then turn sext(trunc(x)) +  // If the input is a trunc from the destination type, then turn sext(trunc(x))    // into shifts. -  if (TruncInst *TI = dyn_cast<TruncInst>(Src)) -    if (TI->hasOneUse() && TI->getOperand(0)->getType() == DestTy) { -      uint32_t SrcBitSize = SrcTy->getScalarSizeInBits(); -      uint32_t DestBitSize = DestTy->getScalarSizeInBits(); - -      // We need to emit a shl + ashr to do the sign extend. -      Value *ShAmt = ConstantInt::get(DestTy, DestBitSize-SrcBitSize); -      Value *Res = Builder->CreateShl(TI->getOperand(0), ShAmt, "sext"); -      return BinaryOperator::CreateAShr(Res, ShAmt); -    } +  Value *X; +  if (match(Src, m_OneUse(m_Trunc(m_Value(X)))) && X->getType() == DestTy) { +    // sext(trunc(X)) --> ashr(shl(X, C), C) +    unsigned SrcBitSize = SrcTy->getScalarSizeInBits(); +    unsigned DestBitSize = DestTy->getScalarSizeInBits(); +    Constant *ShAmt = ConstantInt::get(DestTy, DestBitSize - SrcBitSize); +    return BinaryOperator::CreateAShr(Builder->CreateShl(X, ShAmt), ShAmt); +  }    if (ICmpInst *ICI = dyn_cast<ICmpInst>(Src))      return transformSExtICmp(ICI, CI); @@ -1225,17 +1269,15 @@ static Constant *fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) {    return nullptr;  } -/// If this is a floating-point extension instruction, look -/// through it until we get the source value. +/// Look through floating-point extensions until we get the source value.  static Value *lookThroughFPExtensions(Value *V) { -  if (Instruction *I = dyn_cast<Instruction>(V)) -    if (I->getOpcode() == Instruction::FPExt) -      return lookThroughFPExtensions(I->getOperand(0)); +  while (auto *FPExt = dyn_cast<FPExtInst>(V)) +    V = FPExt->getOperand(0);    // If this value is a constant, return the constant in the smallest FP type    // that can accurately represent it.  This allows us to turn    // (float)((double)X+2.0) into x+2.0f. -  if (ConstantFP *CFP = dyn_cast<ConstantFP>(V)) { +  if (auto *CFP = dyn_cast<ConstantFP>(V)) {      if (CFP->getType() == Type::getPPC_FP128Ty(V->getContext()))        return V;  // No constant folding of this.      // See if the value can be truncated to half and then reextended. @@ -1392,24 +1434,49 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) {    IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI.getOperand(0));    if (II) {      switch (II->getIntrinsicID()) { -      default: break; -      case Intrinsic::fabs: { -        // (fptrunc (fabs x)) -> (fabs (fptrunc x)) -        Value *InnerTrunc = Builder->CreateFPTrunc(II->getArgOperand(0), -                                                   CI.getType()); -        Type *IntrinsicType[] = { CI.getType() }; -        Function *Overload = Intrinsic::getDeclaration( -            CI.getModule(), II->getIntrinsicID(), IntrinsicType); - -        SmallVector<OperandBundleDef, 1> OpBundles; -        II->getOperandBundlesAsDefs(OpBundles); - -        Value *Args[] = { InnerTrunc }; -        return CallInst::Create(Overload, Args, OpBundles, II->getName()); +    default: break; +    case Intrinsic::fabs: +    case Intrinsic::ceil: +    case Intrinsic::floor: +    case Intrinsic::rint: +    case Intrinsic::round: +    case Intrinsic::nearbyint: +    case Intrinsic::trunc: { +      Value *Src = II->getArgOperand(0); +      if (!Src->hasOneUse()) +        break; + +      // Except for fabs, this transformation requires the input of the unary FP +      // operation to be itself an fpext from the type to which we're +      // truncating. +      if (II->getIntrinsicID() != Intrinsic::fabs) { +        FPExtInst *FPExtSrc = dyn_cast<FPExtInst>(Src); +        if (!FPExtSrc || FPExtSrc->getOperand(0)->getType() != CI.getType()) +          break;        } + +      // Do unary FP operation on smaller type. +      // (fptrunc (fabs x)) -> (fabs (fptrunc x)) +      Value *InnerTrunc = Builder->CreateFPTrunc(Src, CI.getType()); +      Type *IntrinsicType[] = { CI.getType() }; +      Function *Overload = Intrinsic::getDeclaration( +        CI.getModule(), II->getIntrinsicID(), IntrinsicType); + +      SmallVector<OperandBundleDef, 1> OpBundles; +      II->getOperandBundlesAsDefs(OpBundles); + +      Value *Args[] = { InnerTrunc }; +      CallInst *NewCI =  CallInst::Create(Overload, Args, +                                          OpBundles, II->getName()); +      NewCI->copyFastMathFlags(II); +      return NewCI; +    }      }    } +  if (Instruction *I = shrinkInsertElt(CI, *Builder)) +    return I; +    return nullptr;  } diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index 428f94bb5e93..bbafa9e9f468 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -230,7 +230,9 @@ Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP,      return nullptr;    uint64_t ArrayElementCount = Init->getType()->getArrayNumElements(); -  if (ArrayElementCount > 1024) return nullptr; // Don't blow up on huge arrays. +  // Don't blow up on huge arrays. +  if (ArrayElementCount > MaxArraySizeForCombine) +    return nullptr;    // There are many forms of this optimization we can handle, for now, just do    // the simple index into a single-dimensional array. @@ -1663,7 +1665,7 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp,        (Cmp.isEquality() || (!C1->isNegative() && !C2->isNegative()))) {      // TODO: Is this a good transform for vectors? Wider types may reduce      // throughput. Should this transform be limited (even for scalars) by using -    // ShouldChangeType()? +    // shouldChangeType()?      if (!Cmp.getType()->isVectorTy()) {        Type *WideType = W->getType();        unsigned WideScalarBits = WideType->getScalarSizeInBits(); @@ -1792,6 +1794,15 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or,                            ConstantInt::get(V->getType(), 1));    } +  // X | C == C --> X <=u C +  // X | C != C --> X  >u C +  //   iff C+1 is a power of 2 (C is a bitmask of the low bits) +  if (Cmp.isEquality() && Cmp.getOperand(1) == Or->getOperand(1) && +      (*C + 1).isPowerOf2()) { +    Pred = (Pred == CmpInst::ICMP_EQ) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT; +    return new ICmpInst(Pred, Or->getOperand(0), Or->getOperand(1)); +  } +    if (!Cmp.isEquality() || *C != 0 || !Or->hasOneUse())      return nullptr; @@ -1914,61 +1925,89 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp,    ICmpInst::Predicate Pred = Cmp.getPredicate();    Value *X = Shl->getOperand(0); -  if (Cmp.isEquality()) { -    // If the shift is NUW, then it is just shifting out zeros, no need for an -    // AND. -    Constant *LShrC = ConstantInt::get(Shl->getType(), C->lshr(*ShiftAmt)); -    if (Shl->hasNoUnsignedWrap()) -      return new ICmpInst(Pred, X, LShrC); - -    // If the shift is NSW and we compare to 0, then it is just shifting out -    // sign bits, no need for an AND either. -    if (Shl->hasNoSignedWrap() && *C == 0) -      return new ICmpInst(Pred, X, LShrC); - -    if (Shl->hasOneUse()) { -      // Otherwise, strength reduce the shift into an and. -      Constant *Mask = ConstantInt::get(Shl->getType(), -          APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt->getZExtValue())); - -      Value *And = Builder->CreateAnd(X, Mask, Shl->getName() + ".mask"); -      return new ICmpInst(Pred, And, LShrC); +  Type *ShType = Shl->getType(); + +  // NSW guarantees that we are only shifting out sign bits from the high bits, +  // so we can ASHR the compare constant without needing a mask and eliminate +  // the shift. +  if (Shl->hasNoSignedWrap()) { +    if (Pred == ICmpInst::ICMP_SGT) { +      // icmp Pred (shl nsw X, ShiftAmt), C --> icmp Pred X, (C >>s ShiftAmt) +      APInt ShiftedC = C->ashr(*ShiftAmt); +      return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); +    } +    if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) { +      // This is the same code as the SGT case, but assert the pre-condition +      // that is needed for this to work with equality predicates. +      assert(C->ashr(*ShiftAmt).shl(*ShiftAmt) == *C && +             "Compare known true or false was not folded"); +      APInt ShiftedC = C->ashr(*ShiftAmt); +      return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); +    } +    if (Pred == ICmpInst::ICMP_SLT) { +      // SLE is the same as above, but SLE is canonicalized to SLT, so convert: +      // (X << S) <=s C is equiv to X <=s (C >> S) for all C +      // (X << S) <s (C + 1) is equiv to X <s (C >> S) + 1 if C <s SMAX +      // (X << S) <s C is equiv to X <s ((C - 1) >> S) + 1 if C >s SMIN +      assert(!C->isMinSignedValue() && "Unexpected icmp slt"); +      APInt ShiftedC = (*C - 1).ashr(*ShiftAmt) + 1; +      return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); +    } +    // If this is a signed comparison to 0 and the shift is sign preserving, +    // use the shift LHS operand instead; isSignTest may change 'Pred', so only +    // do that if we're sure to not continue on in this function. +    if (isSignTest(Pred, *C)) +      return new ICmpInst(Pred, X, Constant::getNullValue(ShType)); +  } + +  // NUW guarantees that we are only shifting out zero bits from the high bits, +  // so we can LSHR the compare constant without needing a mask and eliminate +  // the shift. +  if (Shl->hasNoUnsignedWrap()) { +    if (Pred == ICmpInst::ICMP_UGT) { +      // icmp Pred (shl nuw X, ShiftAmt), C --> icmp Pred X, (C >>u ShiftAmt) +      APInt ShiftedC = C->lshr(*ShiftAmt); +      return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); +    } +    if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) { +      // This is the same code as the UGT case, but assert the pre-condition +      // that is needed for this to work with equality predicates. +      assert(C->lshr(*ShiftAmt).shl(*ShiftAmt) == *C && +             "Compare known true or false was not folded"); +      APInt ShiftedC = C->lshr(*ShiftAmt); +      return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); +    } +    if (Pred == ICmpInst::ICMP_ULT) { +      // ULE is the same as above, but ULE is canonicalized to ULT, so convert: +      // (X << S) <=u C is equiv to X <=u (C >> S) for all C +      // (X << S) <u (C + 1) is equiv to X <u (C >> S) + 1 if C <u ~0u +      // (X << S) <u C is equiv to X <u ((C - 1) >> S) + 1 if C >u 0 +      assert(C->ugt(0) && "ult 0 should have been eliminated"); +      APInt ShiftedC = (*C - 1).lshr(*ShiftAmt) + 1; +      return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));      }    } -  // If this is a signed comparison to 0 and the shift is sign preserving, -  // use the shift LHS operand instead; isSignTest may change 'Pred', so only -  // do that if we're sure to not continue on in this function. -  if (Shl->hasNoSignedWrap() && isSignTest(Pred, *C)) -    return new ICmpInst(Pred, X, Constant::getNullValue(X->getType())); +  if (Cmp.isEquality() && Shl->hasOneUse()) { +    // Strength-reduce the shift into an 'and'. +    Constant *Mask = ConstantInt::get( +        ShType, +        APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt->getZExtValue())); +    Value *And = Builder->CreateAnd(X, Mask, Shl->getName() + ".mask"); +    Constant *LShrC = ConstantInt::get(ShType, C->lshr(*ShiftAmt)); +    return new ICmpInst(Pred, And, LShrC); +  }    // Otherwise, if this is a comparison of the sign bit, simplify to and/test.    bool TrueIfSigned = false;    if (Shl->hasOneUse() && isSignBitCheck(Pred, *C, TrueIfSigned)) {      // (X << 31) <s 0  --> (X & 1) != 0      Constant *Mask = ConstantInt::get( -        X->getType(), +        ShType,          APInt::getOneBitSet(TypeBits, TypeBits - ShiftAmt->getZExtValue() - 1));      Value *And = Builder->CreateAnd(X, Mask, Shl->getName() + ".mask");      return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ, -                        And, Constant::getNullValue(And->getType())); -  } - -  // When the shift is nuw and pred is >u or <=u, comparison only really happens -  // in the pre-shifted bits. Since InstSimplify canonicalizes <=u into <u, the -  // <=u case can be further converted to match <u (see below). -  if (Shl->hasNoUnsignedWrap() && -      (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULT)) { -    // Derivation for the ult case: -    // (X << S) <=u C is equiv to X <=u (C >> S) for all C -    // (X << S) <u (C + 1) is equiv to X <u (C >> S) + 1 if C <u ~0u -    // (X << S) <u C is equiv to X <u ((C - 1) >> S) + 1 if C >u 0 -    assert((Pred != ICmpInst::ICMP_ULT || C->ugt(0)) && -           "Encountered `ult 0` that should have been eliminated by " -           "InstSimplify."); -    APInt ShiftedC = Pred == ICmpInst::ICMP_ULT ? (*C - 1).lshr(*ShiftAmt) + 1 -                                                : C->lshr(*ShiftAmt); -    return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), ShiftedC)); +                        And, Constant::getNullValue(ShType));    }    // Transform (icmp pred iM (shl iM %v, N), C) @@ -1981,8 +2020,8 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp,    if (Shl->hasOneUse() && Amt != 0 && C->countTrailingZeros() >= Amt &&        DL.isLegalInteger(TypeBits - Amt)) {      Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt); -    if (X->getType()->isVectorTy()) -      TruncTy = VectorType::get(TruncTy, X->getType()->getVectorNumElements()); +    if (ShType->isVectorTy()) +      TruncTy = VectorType::get(TruncTy, ShType->getVectorNumElements());      Constant *NewC =          ConstantInt::get(TruncTy, C->ashr(*ShiftAmt).trunc(TypeBits - Amt));      return new ICmpInst(Pred, Builder->CreateTrunc(X, TruncTy), NewC); @@ -2342,8 +2381,24 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp,    // Fold icmp pred (add X, C2), C.    Value *X = Add->getOperand(0);    Type *Ty = Add->getType(); -  auto CR = -      ConstantRange::makeExactICmpRegion(Cmp.getPredicate(), *C).subtract(*C2); +  CmpInst::Predicate Pred = Cmp.getPredicate(); + +  // If the add does not wrap, we can always adjust the compare by subtracting +  // the constants. Equality comparisons are handled elsewhere. SGE/SLE are +  // canonicalized to SGT/SLT. +  if (Add->hasNoSignedWrap() && +      (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT)) { +    bool Overflow; +    APInt NewC = C->ssub_ov(*C2, Overflow); +    // If there is overflow, the result must be true or false. +    // TODO: Can we assert there is no overflow because InstSimplify always +    // handles those cases? +    if (!Overflow) +      // icmp Pred (add nsw X, C2), C --> icmp Pred X, (C - C2) +      return new ICmpInst(Pred, X, ConstantInt::get(Ty, NewC)); +  } + +  auto CR = ConstantRange::makeExactICmpRegion(Pred, *C).subtract(*C2);    const APInt &Upper = CR.getUpper();    const APInt &Lower = CR.getLower();    if (Cmp.isSigned()) { @@ -2364,16 +2419,14 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp,    // X+C <u C2 -> (X & -C2) == C    //   iff C & (C2-1) == 0    //       C2 is a power of 2 -  if (Cmp.getPredicate() == ICmpInst::ICMP_ULT && C->isPowerOf2() && -      (*C2 & (*C - 1)) == 0) +  if (Pred == ICmpInst::ICMP_ULT && C->isPowerOf2() && (*C2 & (*C - 1)) == 0)      return new ICmpInst(ICmpInst::ICMP_EQ, Builder->CreateAnd(X, -(*C)),                          ConstantExpr::getNeg(cast<Constant>(Y)));    // X+C >u C2 -> (X & ~C2) != C    //   iff C & C2 == 0    //       C2+1 is a power of 2 -  if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && -      (*C2 & *C) == 0) +  if (Pred == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && (*C2 & *C) == 0)      return new ICmpInst(ICmpInst::ICMP_NE, Builder->CreateAnd(X, ~(*C)),                          ConstantExpr::getNeg(cast<Constant>(Y))); @@ -2656,7 +2709,7 @@ Instruction *InstCombiner::foldICmpInstWithConstantNotInt(ICmpInst &I) {      // block.  If in the same block, we're encouraging jump threading.  If      // not, we are just pessimizing the code by making an i1 phi.      if (LHSI->getParent() == I.getParent()) -      if (Instruction *NV = FoldOpIntoPhi(I)) +      if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI)))          return NV;      break;    case Instruction::Select: { @@ -2767,12 +2820,6 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) {      D = BO1->getOperand(1);    } -  // icmp (X+cst) < 0 --> X < -cst -  if (NoOp0WrapProblem && ICmpInst::isSigned(Pred) && match(Op1, m_Zero())) -    if (ConstantInt *RHSC = dyn_cast_or_null<ConstantInt>(B)) -      if (!RHSC->isMinValue(/*isSigned=*/true)) -        return new ICmpInst(Pred, A, ConstantExpr::getNeg(RHSC)); -    // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow.    if ((A == Op1 || B == Op1) && NoOp0WrapProblem)      return new ICmpInst(Pred, A == Op1 ? B : A, @@ -2847,6 +2894,31 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) {    if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLT && match(D, m_One()))      return new ICmpInst(CmpInst::ICMP_SLE, Op0, C); +  // TODO: The subtraction-related identities shown below also hold, but +  // canonicalization from (X -nuw 1) to (X + -1) means that the combinations +  // wouldn't happen even if they were implemented. +  // +  // icmp ult (X - 1), Y -> icmp ule X, Y +  // icmp uge (X - 1), Y -> icmp ugt X, Y +  // icmp ugt X, (Y - 1) -> icmp uge X, Y +  // icmp ule X, (Y - 1) -> icmp ult X, Y + +  // icmp ule (X + 1), Y -> icmp ult X, Y +  if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_ULE && match(B, m_One())) +    return new ICmpInst(CmpInst::ICMP_ULT, A, Op1); + +  // icmp ugt (X + 1), Y -> icmp uge X, Y +  if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_UGT && match(B, m_One())) +    return new ICmpInst(CmpInst::ICMP_UGE, A, Op1); + +  // icmp uge X, (Y + 1) -> icmp ugt X, Y +  if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_UGE && match(D, m_One())) +    return new ICmpInst(CmpInst::ICMP_UGT, Op0, C); + +  // icmp ult X, (Y + 1) -> icmp ule X, Y +  if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_ULT && match(D, m_One())) +    return new ICmpInst(CmpInst::ICMP_ULE, Op0, C); +    // if C1 has greater magnitude than C2:    //  icmp (X + C1), (Y + C2) -> icmp (X + C3), Y    //  s.t. C3 = C1 - C2 @@ -3738,16 +3810,14 @@ static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth,    // greater than the RHS must differ in a bit higher than these due to carry.    case ICmpInst::ICMP_UGT: {      unsigned trailingOnes = RHS.countTrailingOnes(); -    APInt lowBitsSet = APInt::getLowBitsSet(BitWidth, trailingOnes); -    return ~lowBitsSet; +    return APInt::getBitsSetFrom(BitWidth, trailingOnes);    }    // Similarly, for a ULT comparison, we don't care about the trailing zeros.    // Any value less than the RHS must differ in a higher bit because of carries.    case ICmpInst::ICMP_ULT: {      unsigned trailingZeros = RHS.countTrailingZeros(); -    APInt lowBitsSet = APInt::getLowBitsSet(BitWidth, trailingZeros); -    return ~lowBitsSet; +    return APInt::getBitsSetFrom(BitWidth, trailingZeros);    }    default: @@ -3887,7 +3957,7 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI,    assert((SIOpd == 1 || SIOpd == 2) && "Invalid select operand!");    if (isChainSelectCmpBranch(SI) && Icmp->getPredicate() == ICmpInst::ICMP_EQ) {      BasicBlock *Succ = SI->getParent()->getTerminator()->getSuccessor(1); -    // The check for the unique predecessor is not the best that can be +    // The check for the single predecessor is not the best that can be      // done. But it protects efficiently against cases like when SI's      // home block has two successors, Succ and Succ1, and Succ1 predecessor      // of Succ. Then SI can't be replaced by SIOpd because the use that gets @@ -3895,8 +3965,10 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI,      // guarantees that the path all uses of SI (outside SI's parent) are on      // is disjoint from all other paths out of SI. But that information      // is more expensive to compute, and the trade-off here is in favor -    // of compile-time. -    if (Succ->getUniquePredecessor() && dominatesAllUses(SI, Icmp, Succ)) { +    // of compile-time. It should also be noticed that we check for a single +    // predecessor and not only uniqueness. This to handle the situation when +    // Succ and Succ1 points to the same basic block. +    if (Succ->getSinglePredecessor() && dominatesAllUses(SI, Icmp, Succ)) {        NumSel++;        SI->replaceUsesOutsideBlock(SI->getOperand(SIOpd), SI->getParent());        return true; @@ -3932,12 +4004,12 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) {    APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0);    APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0); -  if (SimplifyDemandedBits(I.getOperandUse(0), +  if (SimplifyDemandedBits(&I, 0,                             getDemandedBitsLHSMask(I, BitWidth, IsSignBit),                             Op0KnownZero, Op0KnownOne, 0))      return &I; -  if (SimplifyDemandedBits(I.getOperandUse(1), APInt::getAllOnesValue(BitWidth), +  if (SimplifyDemandedBits(&I, 1, APInt::getAllOnesValue(BitWidth),                             Op1KnownZero, Op1KnownOne, 0))      return &I; @@ -4801,7 +4873,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {          // block.  If in the same block, we're encouraging jump threading.  If          // not, we are just pessimizing the code by making an i1 phi.          if (LHSI->getParent() == I.getParent()) -          if (Instruction *NV = FoldOpIntoPhi(I)) +          if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI)))              return NV;          break;        case Instruction::SIToFP: diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h index 2847ce858e79..71000063ab3c 100644 --- a/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/lib/Transforms/InstCombine/InstCombineInternal.h @@ -28,6 +28,9 @@  #include "llvm/IR/PatternMatch.h"  #include "llvm/Pass.h"  #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Support/Dwarf.h" +#include "llvm/IR/DIBuilder.h"  #define DEBUG_TYPE "instcombine" @@ -40,21 +43,29 @@ class DbgDeclareInst;  class MemIntrinsic;  class MemSetInst; -/// \brief Assign a complexity or rank value to LLVM Values. +/// Assign a complexity or rank value to LLVM Values. This is used to reduce +/// the amount of pattern matching needed for compares and commutative +/// instructions. For example, if we have: +///   icmp ugt X, Constant +/// or +///   xor (add X, Constant), cast Z +/// +/// We do not have to consider the commuted variants of these patterns because +/// canonicalization based on complexity guarantees the above ordering.  ///  /// This routine maps IR values to various complexity ranks:  ///   0 -> undef  ///   1 -> Constants  ///   2 -> Other non-instructions  ///   3 -> Arguments -///   3 -> Unary operations -///   4 -> Other instructions +///   4 -> Cast and (f)neg/not instructions +///   5 -> Other instructions  static inline unsigned getComplexity(Value *V) {    if (isa<Instruction>(V)) { -    if (BinaryOperator::isNeg(V) || BinaryOperator::isFNeg(V) || -        BinaryOperator::isNot(V)) -      return 3; -    return 4; +    if (isa<CastInst>(V) || BinaryOperator::isNeg(V) || +        BinaryOperator::isFNeg(V) || BinaryOperator::isNot(V)) +      return 4; +    return 5;    }    if (isa<Argument>(V))      return 3; @@ -289,6 +300,7 @@ public:    Instruction *visitLoadInst(LoadInst &LI);    Instruction *visitStoreInst(StoreInst &SI);    Instruction *visitBranchInst(BranchInst &BI); +  Instruction *visitFenceInst(FenceInst &FI);    Instruction *visitSwitchInst(SwitchInst &SI);    Instruction *visitReturnInst(ReturnInst &RI);    Instruction *visitInsertValueInst(InsertValueInst &IV); @@ -313,9 +325,14 @@ public:    bool replacedSelectWithOperand(SelectInst *SI, const ICmpInst *Icmp,                                   const unsigned SIOpd); +  /// Try to replace instruction \p I with value \p V which are pointers +  /// in different address space. +  /// \return true if successful. +  bool replacePointer(Instruction &I, Value *V); +  private: -  bool ShouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const; -  bool ShouldChangeType(Type *From, Type *To) const; +  bool shouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const; +  bool shouldChangeType(Type *From, Type *To) const;    Value *dyn_castNegVal(Value *V) const;    Value *dyn_castFNegVal(Value *V, bool NoSignedZero = false) const;    Type *FindElementAtOffset(PointerType *PtrTy, int64_t Offset, @@ -456,8 +473,9 @@ public:    /// methods should return the value returned by this function.    Instruction *eraseInstFromFunction(Instruction &I) {      DEBUG(dbgs() << "IC: ERASE " << I << '\n'); -      assert(I.use_empty() && "Cannot erase instruction that is used!"); +    salvageDebugInfo(I); +      // Make sure that we reprocess all operands now that we reduced their      // use counts.      if (I.getNumOperands() < 8) { @@ -499,6 +517,9 @@ public:      return llvm::computeOverflowForUnsignedAdd(LHS, RHS, DL, &AC, CxtI, &DT);    } +  /// Maximum size of array considered when transforming. +  uint64_t MaxArraySizeForCombine; +  private:    /// \brief Performs a few simplifications for operators which are associative    /// or commutative. @@ -518,8 +539,16 @@ private:    Value *SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt &KnownZero,                                   APInt &KnownOne, unsigned Depth,                                   Instruction *CxtI); -  bool SimplifyDemandedBits(Use &U, const APInt &DemandedMask, APInt &KnownZero, +  bool SimplifyDemandedBits(Instruction *I, unsigned Op, +                            const APInt &DemandedMask, APInt &KnownZero,                              APInt &KnownOne, unsigned Depth = 0); +  /// Helper routine of SimplifyDemandedUseBits. It computes KnownZero/KnownOne +  /// bits. It also tries to handle simplifications that can be done based on +  /// DemandedMask, but without modifying the Instruction. +  Value *SimplifyMultipleUseDemandedBits(Instruction *I, +                                         const APInt &DemandedMask, +                                         APInt &KnownZero, APInt &KnownOne, +                                         unsigned Depth, Instruction *CxtI);    /// Helper routine of SimplifyDemandedUseBits. It tries to simplify demanded    /// bit for "r1 = shr x, c1; r2 = shl r1, c2" instruction sequence.    Value *SimplifyShrShlDemandedBits(Instruction *Lsr, Instruction *Sftl, @@ -540,7 +569,7 @@ private:    /// Given a binary operator, cast instruction, or select which has a PHI node    /// as operand #0, see if we can fold the instruction into the PHI (which is    /// only possible if all operands to the PHI are constants). -  Instruction *FoldOpIntoPhi(Instruction &I); +  Instruction *foldOpIntoPhi(Instruction &I, PHINode *PN);    /// Given an instruction with a select as one operand and a constant as the    /// other operand, try to fold the binary operator into the select arguments. @@ -549,7 +578,7 @@ private:    Instruction *FoldOpIntoSelect(Instruction &Op, SelectInst *SI);    /// This is a convenience wrapper function for the above two functions. -  Instruction *foldOpWithConstantIntoOperand(Instruction &I); +  Instruction *foldOpWithConstantIntoOperand(BinaryOperator &I);    /// \brief Try to rotate an operation below a PHI node, using PHI nodes for    /// its operands. @@ -628,16 +657,16 @@ private:                              SelectPatternFlavor SPF2, Value *C);    Instruction *foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI); -  Instruction *OptAndOp(Instruction *Op, ConstantInt *OpRHS, +  Instruction *OptAndOp(BinaryOperator *Op, ConstantInt *OpRHS,                          ConstantInt *AndRHS, BinaryOperator &TheAnd); -  Value *FoldLogicalPlusAnd(Value *LHS, Value *RHS, ConstantInt *Mask, -                            bool isSub, Instruction &I);    Value *insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi,                           bool isSigned, bool Inside);    Instruction *PromoteCastOfAllocation(BitCastInst &CI, AllocaInst &AI);    Instruction *MatchBSwap(BinaryOperator &I);    bool SimplifyStoreAtEndOfBlock(StoreInst &SI); + +  Instruction *SimplifyElementAtomicMemCpy(ElementAtomicMemCpyInst *AMI);    Instruction *SimplifyMemTransfer(MemIntrinsic *MI);    Instruction *SimplifyMemSet(MemSetInst *MI); diff --git a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index 49e516e9c176..6288e054f1bc 100644 --- a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -12,13 +12,15 @@  //===----------------------------------------------------------------------===//  #include "InstCombineInternal.h" +#include "llvm/ADT/MapVector.h"  #include "llvm/ADT/SmallString.h"  #include "llvm/ADT/Statistic.h"  #include "llvm/Analysis/Loads.h"  #include "llvm/IR/ConstantRange.h"  #include "llvm/IR/DataLayout.h" -#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/DebugInfo.h"  #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h"  #include "llvm/IR/MDBuilder.h"  #include "llvm/Transforms/Utils/BasicBlockUtils.h"  #include "llvm/Transforms/Utils/Local.h" @@ -223,6 +225,107 @@ static Instruction *simplifyAllocaArraySize(InstCombiner &IC, AllocaInst &AI) {    return nullptr;  } +namespace { +// If I and V are pointers in different address space, it is not allowed to +// use replaceAllUsesWith since I and V have different types. A +// non-target-specific transformation should not use addrspacecast on V since +// the two address space may be disjoint depending on target. +// +// This class chases down uses of the old pointer until reaching the load +// instructions, then replaces the old pointer in the load instructions with +// the new pointer. If during the chasing it sees bitcast or GEP, it will +// create new bitcast or GEP with the new pointer and use them in the load +// instruction. +class PointerReplacer { +public: +  PointerReplacer(InstCombiner &IC) : IC(IC) {} +  void replacePointer(Instruction &I, Value *V); + +private: +  void findLoadAndReplace(Instruction &I); +  void replace(Instruction *I); +  Value *getReplacement(Value *I); + +  SmallVector<Instruction *, 4> Path; +  MapVector<Value *, Value *> WorkMap; +  InstCombiner &IC; +}; +} // end anonymous namespace + +void PointerReplacer::findLoadAndReplace(Instruction &I) { +  for (auto U : I.users()) { +    auto *Inst = dyn_cast<Instruction>(&*U); +    if (!Inst) +      return; +    DEBUG(dbgs() << "Found pointer user: " << *U << '\n'); +    if (isa<LoadInst>(Inst)) { +      for (auto P : Path) +        replace(P); +      replace(Inst); +    } else if (isa<GetElementPtrInst>(Inst) || isa<BitCastInst>(Inst)) { +      Path.push_back(Inst); +      findLoadAndReplace(*Inst); +      Path.pop_back(); +    } else { +      return; +    } +  } +} + +Value *PointerReplacer::getReplacement(Value *V) { +  auto Loc = WorkMap.find(V); +  if (Loc != WorkMap.end()) +    return Loc->second; +  return nullptr; +} + +void PointerReplacer::replace(Instruction *I) { +  if (getReplacement(I)) +    return; + +  if (auto *LT = dyn_cast<LoadInst>(I)) { +    auto *V = getReplacement(LT->getPointerOperand()); +    assert(V && "Operand not replaced"); +    auto *NewI = new LoadInst(V); +    NewI->takeName(LT); +    IC.InsertNewInstWith(NewI, *LT); +    IC.replaceInstUsesWith(*LT, NewI); +    WorkMap[LT] = NewI; +  } else if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) { +    auto *V = getReplacement(GEP->getPointerOperand()); +    assert(V && "Operand not replaced"); +    SmallVector<Value *, 8> Indices; +    Indices.append(GEP->idx_begin(), GEP->idx_end()); +    auto *NewI = GetElementPtrInst::Create( +        V->getType()->getPointerElementType(), V, Indices); +    IC.InsertNewInstWith(NewI, *GEP); +    NewI->takeName(GEP); +    WorkMap[GEP] = NewI; +  } else if (auto *BC = dyn_cast<BitCastInst>(I)) { +    auto *V = getReplacement(BC->getOperand(0)); +    assert(V && "Operand not replaced"); +    auto *NewT = PointerType::get(BC->getType()->getPointerElementType(), +                                  V->getType()->getPointerAddressSpace()); +    auto *NewI = new BitCastInst(V, NewT); +    IC.InsertNewInstWith(NewI, *BC); +    NewI->takeName(BC); +    WorkMap[BC] = NewI; +  } else { +    llvm_unreachable("should never reach here"); +  } +} + +void PointerReplacer::replacePointer(Instruction &I, Value *V) { +#ifndef NDEBUG +  auto *PT = cast<PointerType>(I.getType()); +  auto *NT = cast<PointerType>(V->getType()); +  assert(PT != NT && PT->getElementType() == NT->getElementType() && +         "Invalid usage"); +#endif +  WorkMap[&I] = V; +  findLoadAndReplace(I); +} +  Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) {    if (auto *I = simplifyAllocaArraySize(*this, AI))      return I; @@ -293,12 +396,22 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) {          for (unsigned i = 0, e = ToDelete.size(); i != e; ++i)            eraseInstFromFunction(*ToDelete[i]);          Constant *TheSrc = cast<Constant>(Copy->getSource()); -        Constant *Cast -          = ConstantExpr::getPointerBitCastOrAddrSpaceCast(TheSrc, AI.getType()); -        Instruction *NewI = replaceInstUsesWith(AI, Cast); -        eraseInstFromFunction(*Copy); -        ++NumGlobalCopies; -        return NewI; +        auto *SrcTy = TheSrc->getType(); +        auto *DestTy = PointerType::get(AI.getType()->getPointerElementType(), +                                        SrcTy->getPointerAddressSpace()); +        Constant *Cast = +            ConstantExpr::getPointerBitCastOrAddrSpaceCast(TheSrc, DestTy); +        if (AI.getType()->getPointerAddressSpace() == +            SrcTy->getPointerAddressSpace()) { +          Instruction *NewI = replaceInstUsesWith(AI, Cast); +          eraseInstFromFunction(*Copy); +          ++NumGlobalCopies; +          return NewI; +        } else { +          PointerReplacer PtrReplacer(*this); +          PtrReplacer.replacePointer(AI, Cast); +          ++NumGlobalCopies; +        }        }      }    } @@ -608,7 +721,7 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) {      // arrays of arbitrary size but this has a terrible impact on compile time.      // The threshold here is chosen arbitrarily, maybe needs a little bit of      // tuning. -    if (NumElements > 1024) +    if (NumElements > IC.MaxArraySizeForCombine)        return nullptr;      const DataLayout &DL = IC.getDataLayout(); @@ -1113,7 +1226,7 @@ static bool unpackStoreToAggregate(InstCombiner &IC, StoreInst &SI) {      // arrays of arbitrary size but this has a terrible impact on compile time.      // The threshold here is chosen arbitrarily, maybe needs a little bit of      // tuning. -    if (NumElements > 1024) +    if (NumElements > IC.MaxArraySizeForCombine)        return false;      const DataLayout &DL = IC.getDataLayout(); @@ -1268,8 +1381,8 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) {        break;      } -    // Don't skip over loads or things that can modify memory. -    if (BBI->mayWriteToMemory() || BBI->mayReadFromMemory()) +    // Don't skip over loads, throws or things that can modify memory. +    if (BBI->mayWriteToMemory() || BBI->mayReadFromMemory() || BBI->mayThrow())        break;    } @@ -1392,8 +1505,8 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) {        }        // If we find something that may be using or overwriting the stored        // value, or if we run out of instructions, we can't do the xform. -      if (BBI->mayReadFromMemory() || BBI->mayWriteToMemory() || -          BBI == OtherBB->begin()) +      if (BBI->mayReadFromMemory() || BBI->mayThrow() || +          BBI->mayWriteToMemory() || BBI == OtherBB->begin())          return false;      } @@ -1402,7 +1515,7 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) {      // StoreBB.      for (BasicBlock::iterator I = StoreBB->begin(); &*I != &SI; ++I) {        // FIXME: This should really be AA driven. -      if (I->mayReadFromMemory() || I->mayWriteToMemory()) +      if (I->mayReadFromMemory() || I->mayThrow() || I->mayWriteToMemory())          return false;      }    } @@ -1425,7 +1538,9 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) {                                     SI.getOrdering(),                                     SI.getSynchScope());    InsertNewInstBefore(NewSI, *BBI); -  NewSI->setDebugLoc(OtherStore->getDebugLoc()); +  // The debug locations of the original instructions might differ; merge them. +  NewSI->setDebugLoc(DILocation::getMergedLocation(SI.getDebugLoc(), +                                                   OtherStore->getDebugLoc()));    // If the two stores had AA tags, merge them.    AAMDNodes AATags; diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 45a19fb0f1f2..f1ac82057e6c 100644 --- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -298,39 +298,33 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {    // (X / Y) *  Y = X - (X % Y)    // (X / Y) * -Y = (X % Y) - X    { -    Value *Op1C = Op1; -    BinaryOperator *BO = dyn_cast<BinaryOperator>(Op0); -    if (!BO || -        (BO->getOpcode() != Instruction::UDiv && -         BO->getOpcode() != Instruction::SDiv)) { -      Op1C = Op0; -      BO = dyn_cast<BinaryOperator>(Op1); +    Value *Y = Op1; +    BinaryOperator *Div = dyn_cast<BinaryOperator>(Op0); +    if (!Div || (Div->getOpcode() != Instruction::UDiv && +                 Div->getOpcode() != Instruction::SDiv)) { +      Y = Op0; +      Div = dyn_cast<BinaryOperator>(Op1);      } -    Value *Neg = dyn_castNegVal(Op1C); -    if (BO && BO->hasOneUse() && -        (BO->getOperand(1) == Op1C || BO->getOperand(1) == Neg) && -        (BO->getOpcode() == Instruction::UDiv || -         BO->getOpcode() == Instruction::SDiv)) { -      Value *Op0BO = BO->getOperand(0), *Op1BO = BO->getOperand(1); +    Value *Neg = dyn_castNegVal(Y); +    if (Div && Div->hasOneUse() && +        (Div->getOperand(1) == Y || Div->getOperand(1) == Neg) && +        (Div->getOpcode() == Instruction::UDiv || +         Div->getOpcode() == Instruction::SDiv)) { +      Value *X = Div->getOperand(0), *DivOp1 = Div->getOperand(1);        // If the division is exact, X % Y is zero, so we end up with X or -X. -      if (PossiblyExactOperator *SDiv = dyn_cast<PossiblyExactOperator>(BO)) -        if (SDiv->isExact()) { -          if (Op1BO == Op1C) -            return replaceInstUsesWith(I, Op0BO); -          return BinaryOperator::CreateNeg(Op0BO); -        } - -      Value *Rem; -      if (BO->getOpcode() == Instruction::UDiv) -        Rem = Builder->CreateURem(Op0BO, Op1BO); -      else -        Rem = Builder->CreateSRem(Op0BO, Op1BO); -      Rem->takeName(BO); +      if (Div->isExact()) { +        if (DivOp1 == Y) +          return replaceInstUsesWith(I, X); +        return BinaryOperator::CreateNeg(X); +      } -      if (Op1BO == Op1C) -        return BinaryOperator::CreateSub(Op0BO, Rem); -      return BinaryOperator::CreateSub(Rem, Op0BO); +      auto RemOpc = Div->getOpcode() == Instruction::UDiv ? Instruction::URem +                                                          : Instruction::SRem; +      Value *Rem = Builder->CreateBinOp(RemOpc, X, DivOp1); +      if (DivOp1 == Y) +        return BinaryOperator::CreateSub(X, Rem); +      return BinaryOperator::CreateSub(Rem, X);      }    } @@ -1461,16 +1455,16 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) {        if (SelectInst *SI = dyn_cast<SelectInst>(Op0I)) {          if (Instruction *R = FoldOpIntoSelect(I, SI))            return R; -      } else if (isa<PHINode>(Op0I)) { +      } else if (auto *PN = dyn_cast<PHINode>(Op0I)) {          using namespace llvm::PatternMatch;          const APInt *Op1Int;          if (match(Op1, m_APInt(Op1Int)) && !Op1Int->isMinValue() &&              (I.getOpcode() == Instruction::URem ||               !Op1Int->isMinSignedValue())) { -          // FoldOpIntoPhi will speculate instructions to the end of the PHI's +          // foldOpIntoPhi will speculate instructions to the end of the PHI's            // predecessor blocks, so do this only if we know the srem or urem            // will not fault. -          if (Instruction *NV = FoldOpIntoPhi(I)) +          if (Instruction *NV = foldOpIntoPhi(I, PN))              return NV;          }        } diff --git a/lib/Transforms/InstCombine/InstCombinePHI.cpp b/lib/Transforms/InstCombine/InstCombinePHI.cpp index 4cbffe9533b7..85e5b6ba2dc2 100644 --- a/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -457,8 +457,8 @@ Instruction *InstCombiner::FoldPHIArgZextsIntoPHI(PHINode &Phi) {    }    // The more common cases of a phi with no constant operands or just one -  // variable operand are handled by FoldPHIArgOpIntoPHI() and FoldOpIntoPhi() -  // respectively. FoldOpIntoPhi() wants to do the opposite transform that is +  // variable operand are handled by FoldPHIArgOpIntoPHI() and foldOpIntoPhi() +  // respectively. foldOpIntoPhi() wants to do the opposite transform that is    // performed here. It tries to replicate a cast in the phi operand's basic    // block to expose other folding opportunities. Thus, InstCombine will    // infinite loop without this check. @@ -507,7 +507,7 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) {      // Be careful about transforming integer PHIs.  We don't want to pessimize      // the code by turning an i32 into an i1293.      if (PN.getType()->isIntegerTy() && CastSrcTy->isIntegerTy()) { -      if (!ShouldChangeType(PN.getType(), CastSrcTy)) +      if (!shouldChangeType(PN.getType(), CastSrcTy))          return nullptr;      }    } else if (isa<BinaryOperator>(FirstInst) || isa<CmpInst>(FirstInst)) { diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index 36644845352e..693b6c95c169 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -120,6 +120,16 @@ static Constant *getSelectFoldableConstant(Instruction *I) {  /// We have (select c, TI, FI), and we know that TI and FI have the same opcode.  Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI,                                            Instruction *FI) { +  // Don't break up min/max patterns. The hasOneUse checks below prevent that +  // for most cases, but vector min/max with bitcasts can be transformed. If the +  // one-use restrictions are eased for other patterns, we still don't want to +  // obfuscate min/max. +  if ((match(&SI, m_SMin(m_Value(), m_Value())) || +       match(&SI, m_SMax(m_Value(), m_Value())) || +       match(&SI, m_UMin(m_Value(), m_Value())) || +       match(&SI, m_UMax(m_Value(), m_Value())))) +    return nullptr; +    // If this is a cast from the same type, merge.    if (TI->getNumOperands() == 1 && TI->isCast()) {      Type *FIOpndTy = FI->getOperand(0)->getType(); @@ -364,7 +374,7 @@ static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal,  /// into:  ///   %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 false)  static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, -                                  InstCombiner::BuilderTy *Builder) { +                                 InstCombiner::BuilderTy *Builder) {    ICmpInst::Predicate Pred = ICI->getPredicate();    Value *CmpLHS = ICI->getOperand(0);    Value *CmpRHS = ICI->getOperand(1); @@ -395,13 +405,12 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal,    if (match(Count, m_Intrinsic<Intrinsic::cttz>(m_Specific(CmpLHS))) ||        match(Count, m_Intrinsic<Intrinsic::ctlz>(m_Specific(CmpLHS)))) {      IntrinsicInst *II = cast<IntrinsicInst>(Count); -    IRBuilder<> Builder(II);      // Explicitly clear the 'undef_on_zero' flag.      IntrinsicInst *NewI = cast<IntrinsicInst>(II->clone());      Type *Ty = NewI->getArgOperand(1)->getType();      NewI->setArgOperand(1, Constant::getNullValue(Ty)); -    Builder.Insert(NewI); -    return Builder.CreateZExtOrTrunc(NewI, ValueOnZero->getType()); +    Builder->Insert(NewI); +    return Builder->CreateZExtOrTrunc(NewI, ValueOnZero->getType());    }    return nullptr; @@ -500,18 +509,16 @@ static bool adjustMinMax(SelectInst &Sel, ICmpInst &Cmp) {    return true;  } -/// If this is an integer min/max where the select's 'true' operand is a -/// constant, canonicalize that constant to the 'false' operand: -/// select (icmp Pred X, C), C, X --> select (icmp Pred' X, C), X, C +/// If this is an integer min/max (icmp + select) with a constant operand, +/// create the canonical icmp for the min/max operation and canonicalize the +/// constant to the 'false' operand of the select: +/// select (icmp Pred X, C1), C2, X --> select (icmp Pred' X, C2), X, C2 +/// Note: if C1 != C2, this will change the icmp constant to the existing +/// constant operand of the select.  static Instruction *  canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp,                                 InstCombiner::BuilderTy &Builder) { -  // TODO: We should also canonicalize min/max when the select has a different -  // constant value than the cmp constant, but we need to fix the backend first. -  if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1)) || -      !isa<Constant>(Sel.getTrueValue()) || -      isa<Constant>(Sel.getFalseValue()) || -      Cmp.getOperand(1) != Sel.getTrueValue()) +  if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1)))      return nullptr;    // Canonicalize the compare predicate based on whether we have min or max. @@ -526,16 +533,25 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp,    default: return nullptr;    } -  // Canonicalize the constant to the right side. -  if (isa<Constant>(LHS)) -    std::swap(LHS, RHS); +  // Is this already canonical? +  if (Cmp.getOperand(0) == LHS && Cmp.getOperand(1) == RHS && +      Cmp.getPredicate() == NewPred) +    return nullptr; + +  // Create the canonical compare and plug it into the select. +  Sel.setCondition(Builder.CreateICmp(NewPred, LHS, RHS)); -  Value *NewCmp = Builder.CreateICmp(NewPred, LHS, RHS); -  SelectInst *NewSel = SelectInst::Create(NewCmp, LHS, RHS, "", nullptr, &Sel); +  // If the select operands did not change, we're done. +  if (Sel.getTrueValue() == LHS && Sel.getFalseValue() == RHS) +    return &Sel; -  // We swapped the select operands, so swap the metadata too. -  NewSel->swapProfMetadata(); -  return NewSel; +  // If we are swapping the select operands, swap the metadata too. +  assert(Sel.getTrueValue() == RHS && Sel.getFalseValue() == LHS && +         "Unexpected results from matchSelectPattern"); +  Sel.setTrueValue(LHS); +  Sel.setFalseValue(RHS); +  Sel.swapProfMetadata(); +  return &Sel;  }  /// Visit a SelectInst that has an ICmpInst as its first operand. @@ -786,7 +802,9 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner,    // This transform is performance neutral if we can elide at least one xor from    // the set of three operands, since we'll be tacking on an xor at the very    // end. -  if (IsFreeOrProfitableToInvert(A, NotA, ElidesXor) && +  if (SelectPatternResult::isMinOrMax(SPF1) && +      SelectPatternResult::isMinOrMax(SPF2) && +      IsFreeOrProfitableToInvert(A, NotA, ElidesXor) &&        IsFreeOrProfitableToInvert(B, NotB, ElidesXor) &&        IsFreeOrProfitableToInvert(C, NotC, ElidesXor) && ElidesXor) {      if (!NotA) @@ -1035,8 +1053,10 @@ static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) {        // If the select condition element is false, choose from the 2nd vector.        Mask.push_back(ConstantInt::get(Int32Ty, i + NumElts));      } else if (isa<UndefValue>(Elt)) { -      // If the select condition element is undef, the shuffle mask is undef. -      Mask.push_back(UndefValue::get(Int32Ty)); +      // Undef in a select condition (choose one of the operands) does not mean +      // the same thing as undef in a shuffle mask (any value is acceptable), so +      // give up. +      return nullptr;      } else {        // Bail out on a constant expression.        return nullptr; @@ -1364,11 +1384,11 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {    }    // See if we can fold the select into a phi node if the condition is a select. -  if (isa<PHINode>(SI.getCondition())) +  if (auto *PN = dyn_cast<PHINode>(SI.getCondition()))      // The true/false values have to be live in the PHI predecessor's blocks.      if (canSelectOperandBeMappingIntoPredBlock(TrueVal, SI) &&          canSelectOperandBeMappingIntoPredBlock(FalseVal, SI)) -      if (Instruction *NV = FoldOpIntoPhi(SI)) +      if (Instruction *NV = foldOpIntoPhi(SI, PN))          return NV;    if (SelectInst *TrueSI = dyn_cast<SelectInst>(TrueVal)) { @@ -1450,6 +1470,20 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {      }    } +  // If we can compute the condition, there's no need for a select. +  // Like the above fold, we are attempting to reduce compile-time cost by +  // putting this fold here with limitations rather than in InstSimplify. +  // The motivation for this call into value tracking is to take advantage of +  // the assumption cache, so make sure that is populated. +  if (!CondVal->getType()->isVectorTy() && !AC.assumptions().empty()) { +    APInt KnownOne(1, 0), KnownZero(1, 0); +    computeKnownBits(CondVal, KnownZero, KnownOne, 0, &SI); +    if (KnownOne == 1) +      return replaceInstUsesWith(SI, TrueVal); +    if (KnownZero == 1) +      return replaceInstUsesWith(SI, FalseVal); +  } +    if (Instruction *BitCastSel = foldSelectCmpBitcasts(SI, *Builder))      return BitCastSel; diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index 4ff9b64ac57c..9aa679c60e47 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -22,8 +22,8 @@ using namespace PatternMatch;  #define DEBUG_TYPE "instcombine"  Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { -  assert(I.getOperand(1)->getType() == I.getOperand(0)->getType());    Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); +  assert(Op0->getType() == Op1->getType());    // See if we can fold away this shift.    if (SimplifyDemandedInstructionBits(I)) @@ -65,63 +65,60 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {  }  /// Return true if we can simplify two logical (either left or right) shifts -/// that have constant shift amounts. -static bool canEvaluateShiftedShift(unsigned FirstShiftAmt, -                                    bool IsFirstShiftLeft, -                                    Instruction *SecondShift, InstCombiner &IC, +/// that have constant shift amounts: OuterShift (InnerShift X, C1), C2. +static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl, +                                    Instruction *InnerShift, InstCombiner &IC,                                      Instruction *CxtI) { -  assert(SecondShift->isLogicalShift() && "Unexpected instruction type"); +  assert(InnerShift->isLogicalShift() && "Unexpected instruction type"); -  // We need constant shifts. -  auto *SecondShiftConst = dyn_cast<ConstantInt>(SecondShift->getOperand(1)); -  if (!SecondShiftConst) +  // We need constant scalar or constant splat shifts. +  const APInt *InnerShiftConst; +  if (!match(InnerShift->getOperand(1), m_APInt(InnerShiftConst)))      return false; -  unsigned SecondShiftAmt = SecondShiftConst->getZExtValue(); -  bool IsSecondShiftLeft = SecondShift->getOpcode() == Instruction::Shl; - -  // We can always fold  shl(c1) +  shl(c2) ->  shl(c1+c2). -  // We can always fold lshr(c1) + lshr(c2) -> lshr(c1+c2). -  if (IsFirstShiftLeft == IsSecondShiftLeft) +  // Two logical shifts in the same direction: +  // shl (shl X, C1), C2 -->  shl X, C1 + C2 +  // lshr (lshr X, C1), C2 --> lshr X, C1 + C2 +  bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl; +  if (IsInnerShl == IsOuterShl)      return true; -  // We can always fold lshr(c) +  shl(c) -> and(c2). -  // We can always fold  shl(c) + lshr(c) -> and(c2). -  if (FirstShiftAmt == SecondShiftAmt) +  // Equal shift amounts in opposite directions become bitwise 'and': +  // lshr (shl X, C), C --> and X, C' +  // shl (lshr X, C), C --> and X, C' +  unsigned InnerShAmt = InnerShiftConst->getZExtValue(); +  if (InnerShAmt == OuterShAmt)      return true; -  unsigned TypeWidth = SecondShift->getType()->getScalarSizeInBits(); -    // If the 2nd shift is bigger than the 1st, we can fold: -  //   lshr(c1) +  shl(c2) ->  shl(c3) + and(c4) or -  //   shl(c1)  + lshr(c2) -> lshr(c3) + and(c4), +  // lshr (shl X, C1), C2 -->  and (shl X, C1 - C2), C3 +  // shl (lshr X, C1), C2 --> and (lshr X, C1 - C2), C3    // but it isn't profitable unless we know the and'd out bits are already zero. -  // Also check that the 2nd shift is valid (less than the type width) or we'll -  // crash trying to produce the bit mask for the 'and'. -  if (SecondShiftAmt > FirstShiftAmt && SecondShiftAmt < TypeWidth) { -    unsigned MaskShift = IsSecondShiftLeft ? TypeWidth - SecondShiftAmt -                                           : SecondShiftAmt - FirstShiftAmt; -    APInt Mask = APInt::getLowBitsSet(TypeWidth, FirstShiftAmt) << MaskShift; -    if (IC.MaskedValueIsZero(SecondShift->getOperand(0), Mask, 0, CxtI)) +  // Also, check that the inner shift is valid (less than the type width) or +  // we'll crash trying to produce the bit mask for the 'and'. +  unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits(); +  if (InnerShAmt > OuterShAmt && InnerShAmt < TypeWidth) { +    unsigned MaskShift = +        IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt; +    APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift; +    if (IC.MaskedValueIsZero(InnerShift->getOperand(0), Mask, 0, CxtI))        return true;    }    return false;  } -/// See if we can compute the specified value, but shifted -/// logically to the left or right by some number of bits.  This should return -/// true if the expression can be computed for the same cost as the current -/// expression tree.  This is used to eliminate extraneous shifting from things -/// like: +/// See if we can compute the specified value, but shifted logically to the left +/// or right by some number of bits. This should return true if the expression +/// can be computed for the same cost as the current expression tree. This is +/// used to eliminate extraneous shifting from things like:  ///      %C = shl i128 %A, 64  ///      %D = shl i128 %B, 96  ///      %E = or i128 %C, %D  ///      %F = lshr i128 %E, 64 -/// where the client will ask if E can be computed shifted right by 64-bits.  If -/// this succeeds, the GetShiftedValue function will be called to produce the -/// value. -static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, +/// where the client will ask if E can be computed shifted right by 64-bits. If +/// this succeeds, getShiftedValue() will be called to produce the value. +static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,                                 InstCombiner &IC, Instruction *CxtI) {    // We can always evaluate constants shifted.    if (isa<Constant>(V)) @@ -165,8 +162,8 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,    case Instruction::Or:    case Instruction::Xor:      // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted. -    return CanEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) && -           CanEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I); +    return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) && +           canEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I);    case Instruction::Shl:    case Instruction::LShr: @@ -176,8 +173,8 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,      SelectInst *SI = cast<SelectInst>(I);      Value *TrueVal = SI->getTrueValue();      Value *FalseVal = SI->getFalseValue(); -    return CanEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) && -           CanEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI); +    return canEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) && +           canEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI);    }    case Instruction::PHI: {      // We can change a phi if we can change all operands.  Note that we never @@ -185,16 +182,79 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,      // instructions with a single use.      PHINode *PN = cast<PHINode>(I);      for (Value *IncValue : PN->incoming_values()) -      if (!CanEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN)) +      if (!canEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN))          return false;      return true;    }    }  } -/// When CanEvaluateShifted returned true for an expression, -/// this value inserts the new computation that produces the shifted value. -static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, +/// Fold OuterShift (InnerShift X, C1), C2. +/// See canEvaluateShiftedShift() for the constraints on these instructions. +static Value *foldShiftedShift(BinaryOperator *InnerShift, unsigned OuterShAmt, +                               bool IsOuterShl, +                               InstCombiner::BuilderTy &Builder) { +  bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl; +  Type *ShType = InnerShift->getType(); +  unsigned TypeWidth = ShType->getScalarSizeInBits(); + +  // We only accept shifts-by-a-constant in canEvaluateShifted(). +  const APInt *C1; +  match(InnerShift->getOperand(1), m_APInt(C1)); +  unsigned InnerShAmt = C1->getZExtValue(); + +  // Change the shift amount and clear the appropriate IR flags. +  auto NewInnerShift = [&](unsigned ShAmt) { +    InnerShift->setOperand(1, ConstantInt::get(ShType, ShAmt)); +    if (IsInnerShl) { +      InnerShift->setHasNoUnsignedWrap(false); +      InnerShift->setHasNoSignedWrap(false); +    } else { +      InnerShift->setIsExact(false); +    } +    return InnerShift; +  }; + +  // Two logical shifts in the same direction: +  // shl (shl X, C1), C2 -->  shl X, C1 + C2 +  // lshr (lshr X, C1), C2 --> lshr X, C1 + C2 +  if (IsInnerShl == IsOuterShl) { +    // If this is an oversized composite shift, then unsigned shifts get 0. +    if (InnerShAmt + OuterShAmt >= TypeWidth) +      return Constant::getNullValue(ShType); + +    return NewInnerShift(InnerShAmt + OuterShAmt); +  } + +  // Equal shift amounts in opposite directions become bitwise 'and': +  // lshr (shl X, C), C --> and X, C' +  // shl (lshr X, C), C --> and X, C' +  if (InnerShAmt == OuterShAmt) { +    APInt Mask = IsInnerShl +                     ? APInt::getLowBitsSet(TypeWidth, TypeWidth - OuterShAmt) +                     : APInt::getHighBitsSet(TypeWidth, TypeWidth - OuterShAmt); +    Value *And = Builder.CreateAnd(InnerShift->getOperand(0), +                                   ConstantInt::get(ShType, Mask)); +    if (auto *AndI = dyn_cast<Instruction>(And)) { +      AndI->moveBefore(InnerShift); +      AndI->takeName(InnerShift); +    } +    return And; +  } + +  assert(InnerShAmt > OuterShAmt && +         "Unexpected opposite direction logical shift pair"); + +  // In general, we would need an 'and' for this transform, but +  // canEvaluateShiftedShift() guarantees that the masked-off bits are not used. +  // lshr (shl X, C1), C2 -->  shl X, C1 - C2 +  // shl (lshr X, C1), C2 --> lshr X, C1 - C2 +  return NewInnerShift(InnerShAmt - OuterShAmt); +} + +/// When canEvaluateShifted() returns true for an expression, this function +/// inserts the new computation that produces the shifted value. +static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,                                InstCombiner &IC, const DataLayout &DL) {    // We can always evaluate constants shifted.    if (Constant *C = dyn_cast<Constant>(V)) { @@ -220,100 +280,21 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,    case Instruction::Xor:      // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.      I->setOperand( -        0, GetShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL)); +        0, getShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL));      I->setOperand( -        1, GetShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL)); +        1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));      return I; -  case Instruction::Shl: { -    BinaryOperator *BO = cast<BinaryOperator>(I); -    unsigned TypeWidth = BO->getType()->getScalarSizeInBits(); - -    // We only accept shifts-by-a-constant in CanEvaluateShifted. -    ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1)); - -    // We can always fold shl(c1)+shl(c2) -> shl(c1+c2). -    if (isLeftShift) { -      // If this is oversized composite shift, then unsigned shifts get 0. -      unsigned NewShAmt = NumBits+CI->getZExtValue(); -      if (NewShAmt >= TypeWidth) -        return Constant::getNullValue(I->getType()); - -      BO->setOperand(1, ConstantInt::get(BO->getType(), NewShAmt)); -      BO->setHasNoUnsignedWrap(false); -      BO->setHasNoSignedWrap(false); -      return I; -    } - -    // We turn shl(c)+lshr(c) -> and(c2) if the input doesn't already have -    // zeros. -    if (CI->getValue() == NumBits) { -      APInt Mask(APInt::getLowBitsSet(TypeWidth, TypeWidth - NumBits)); -      V = IC.Builder->CreateAnd(BO->getOperand(0), -                                ConstantInt::get(BO->getContext(), Mask)); -      if (Instruction *VI = dyn_cast<Instruction>(V)) { -        VI->moveBefore(BO); -        VI->takeName(BO); -      } -      return V; -    } - -    // We turn shl(c1)+shr(c2) -> shl(c3)+and(c4), but only when we know that -    // the and won't be needed. -    assert(CI->getZExtValue() > NumBits); -    BO->setOperand(1, ConstantInt::get(BO->getType(), -                                       CI->getZExtValue() - NumBits)); -    BO->setHasNoUnsignedWrap(false); -    BO->setHasNoSignedWrap(false); -    return BO; -  } -  // FIXME: This is almost identical to the SHL case. Refactor both cases into -  // a helper function. -  case Instruction::LShr: { -    BinaryOperator *BO = cast<BinaryOperator>(I); -    unsigned TypeWidth = BO->getType()->getScalarSizeInBits(); -    // We only accept shifts-by-a-constant in CanEvaluateShifted. -    ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1)); - -    // We can always fold lshr(c1)+lshr(c2) -> lshr(c1+c2). -    if (!isLeftShift) { -      // If this is oversized composite shift, then unsigned shifts get 0. -      unsigned NewShAmt = NumBits+CI->getZExtValue(); -      if (NewShAmt >= TypeWidth) -        return Constant::getNullValue(BO->getType()); - -      BO->setOperand(1, ConstantInt::get(BO->getType(), NewShAmt)); -      BO->setIsExact(false); -      return I; -    } - -    // We turn lshr(c)+shl(c) -> and(c2) if the input doesn't already have -    // zeros. -    if (CI->getValue() == NumBits) { -      APInt Mask(APInt::getHighBitsSet(TypeWidth, TypeWidth - NumBits)); -      V = IC.Builder->CreateAnd(I->getOperand(0), -                                ConstantInt::get(BO->getContext(), Mask)); -      if (Instruction *VI = dyn_cast<Instruction>(V)) { -        VI->moveBefore(I); -        VI->takeName(I); -      } -      return V; -    } - -    // We turn lshr(c1)+shl(c2) -> lshr(c3)+and(c4), but only when we know that -    // the and won't be needed. -    assert(CI->getZExtValue() > NumBits); -    BO->setOperand(1, ConstantInt::get(BO->getType(), -                                       CI->getZExtValue() - NumBits)); -    BO->setIsExact(false); -    return BO; -  } +  case Instruction::Shl: +  case Instruction::LShr: +    return foldShiftedShift(cast<BinaryOperator>(I), NumBits, isLeftShift, +                            *(IC.Builder));    case Instruction::Select:      I->setOperand( -        1, GetShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL)); +        1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));      I->setOperand( -        2, GetShiftedValue(I->getOperand(2), NumBits, isLeftShift, IC, DL)); +        2, getShiftedValue(I->getOperand(2), NumBits, isLeftShift, IC, DL));      return I;    case Instruction::PHI: {      // We can change a phi if we can change all operands.  Note that we never @@ -321,215 +302,39 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,      // instructions with a single use.      PHINode *PN = cast<PHINode>(I);      for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) -      PN->setIncomingValue(i, GetShiftedValue(PN->getIncomingValue(i), NumBits, +      PN->setIncomingValue(i, getShiftedValue(PN->getIncomingValue(i), NumBits,                                                isLeftShift, IC, DL));      return PN;    }    }  } -/// Try to fold (X << C1) << C2, where the shifts are some combination of -/// shl/ashr/lshr. -static Instruction * -foldShiftByConstOfShiftByConst(BinaryOperator &I, ConstantInt *COp1, -                               InstCombiner::BuilderTy *Builder) { -  Value *Op0 = I.getOperand(0); -  uint32_t TypeBits = Op0->getType()->getScalarSizeInBits(); - -  // Find out if this is a shift of a shift by a constant. -  BinaryOperator *ShiftOp = dyn_cast<BinaryOperator>(Op0); -  if (ShiftOp && !ShiftOp->isShift()) -    ShiftOp = nullptr; - -  if (ShiftOp && isa<ConstantInt>(ShiftOp->getOperand(1))) { - -    // This is a constant shift of a constant shift. Be careful about hiding -    // shl instructions behind bit masks. They are used to represent multiplies -    // by a constant, and it is important that simple arithmetic expressions -    // are still recognizable by scalar evolution. -    // -    // The transforms applied to shl are very similar to the transforms applied -    // to mul by constant. We can be more aggressive about optimizing right -    // shifts. -    // -    // Combinations of right and left shifts will still be optimized in -    // DAGCombine where scalar evolution no longer applies. - -    ConstantInt *ShiftAmt1C = cast<ConstantInt>(ShiftOp->getOperand(1)); -    uint32_t ShiftAmt1 = ShiftAmt1C->getLimitedValue(TypeBits); -    uint32_t ShiftAmt2 = COp1->getLimitedValue(TypeBits); -    assert(ShiftAmt2 != 0 && "Should have been simplified earlier"); -    if (ShiftAmt1 == 0) -      return nullptr; // Will be simplified in the future. -    Value *X = ShiftOp->getOperand(0); - -    IntegerType *Ty = cast<IntegerType>(I.getType()); - -    // Check for (X << c1) << c2  and  (X >> c1) >> c2 -    if (I.getOpcode() == ShiftOp->getOpcode()) { -      uint32_t AmtSum = ShiftAmt1 + ShiftAmt2; // Fold into one big shift. -      // If this is an oversized composite shift, then unsigned shifts become -      // zero (handled in InstSimplify) and ashr saturates. -      if (AmtSum >= TypeBits) { -        if (I.getOpcode() != Instruction::AShr) -          return nullptr; -        AmtSum = TypeBits - 1; // Saturate to 31 for i32 ashr. -      } - -      return BinaryOperator::Create(I.getOpcode(), X, -                                    ConstantInt::get(Ty, AmtSum)); -    } - -    if (ShiftAmt1 == ShiftAmt2) { -      // If we have ((X << C) >>u C), turn this into X & (-1 >>u C). -      if (I.getOpcode() == Instruction::LShr && -          ShiftOp->getOpcode() == Instruction::Shl) { -        APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt1)); -        return BinaryOperator::CreateAnd( -            X, ConstantInt::get(I.getContext(), Mask)); -      } -    } else if (ShiftAmt1 < ShiftAmt2) { -      uint32_t ShiftDiff = ShiftAmt2 - ShiftAmt1; - -      // (X >>?,exact C1) << C2 --> X << (C2-C1) -      // The inexact version is deferred to DAGCombine so we don't hide shl -      // behind a bit mask. -      if (I.getOpcode() == Instruction::Shl && -          ShiftOp->getOpcode() != Instruction::Shl && ShiftOp->isExact()) { -        assert(ShiftOp->getOpcode() == Instruction::LShr || -               ShiftOp->getOpcode() == Instruction::AShr); -        ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); -        BinaryOperator *NewShl = -            BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst); -        NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); -        NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); -        return NewShl; -      } - -      // (X << C1) >>u C2  --> X >>u (C2-C1) & (-1 >> C2) -      if (I.getOpcode() == Instruction::LShr && -          ShiftOp->getOpcode() == Instruction::Shl) { -        ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); -        // (X <<nuw C1) >>u C2 --> X >>u (C2-C1) -        if (ShiftOp->hasNoUnsignedWrap()) { -          BinaryOperator *NewLShr = -              BinaryOperator::Create(Instruction::LShr, X, ShiftDiffCst); -          NewLShr->setIsExact(I.isExact()); -          return NewLShr; -        } -        Value *Shift = Builder->CreateLShr(X, ShiftDiffCst); - -        APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2)); -        return BinaryOperator::CreateAnd( -            Shift, ConstantInt::get(I.getContext(), Mask)); -      } - -      // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However, -      // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. -      if (I.getOpcode() == Instruction::AShr && -          ShiftOp->getOpcode() == Instruction::Shl) { -        if (ShiftOp->hasNoSignedWrap()) { -          // (X <<nsw C1) >>s C2 --> X >>s (C2-C1) -          ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); -          BinaryOperator *NewAShr = -              BinaryOperator::Create(Instruction::AShr, X, ShiftDiffCst); -          NewAShr->setIsExact(I.isExact()); -          return NewAShr; -        } -      } -    } else { -      assert(ShiftAmt2 < ShiftAmt1); -      uint32_t ShiftDiff = ShiftAmt1 - ShiftAmt2; - -      // (X >>?exact C1) << C2 --> X >>?exact (C1-C2) -      // The inexact version is deferred to DAGCombine so we don't hide shl -      // behind a bit mask. -      if (I.getOpcode() == Instruction::Shl && -          ShiftOp->getOpcode() != Instruction::Shl && ShiftOp->isExact()) { -        ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); -        BinaryOperator *NewShr = -            BinaryOperator::Create(ShiftOp->getOpcode(), X, ShiftDiffCst); -        NewShr->setIsExact(true); -        return NewShr; -      } - -      // (X << C1) >>u C2  --> X << (C1-C2) & (-1 >> C2) -      if (I.getOpcode() == Instruction::LShr && -          ShiftOp->getOpcode() == Instruction::Shl) { -        ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); -        if (ShiftOp->hasNoUnsignedWrap()) { -          // (X <<nuw C1) >>u C2 --> X <<nuw (C1-C2) -          BinaryOperator *NewShl = -              BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst); -          NewShl->setHasNoUnsignedWrap(true); -          return NewShl; -        } -        Value *Shift = Builder->CreateShl(X, ShiftDiffCst); - -        APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2)); -        return BinaryOperator::CreateAnd( -            Shift, ConstantInt::get(I.getContext(), Mask)); -      } - -      // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However, -      // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. -      if (I.getOpcode() == Instruction::AShr && -          ShiftOp->getOpcode() == Instruction::Shl) { -        if (ShiftOp->hasNoSignedWrap()) { -          // (X <<nsw C1) >>s C2 --> X <<nsw (C1-C2) -          ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); -          BinaryOperator *NewShl = -              BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst); -          NewShl->setHasNoSignedWrap(true); -          return NewShl; -        } -      } -    } -  } - -  return nullptr; -} -  Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,                                                 BinaryOperator &I) {    bool isLeftShift = I.getOpcode() == Instruction::Shl; -  ConstantInt *COp1 = nullptr; -  if (ConstantDataVector *CV = dyn_cast<ConstantDataVector>(Op1)) -    COp1 = dyn_cast_or_null<ConstantInt>(CV->getSplatValue()); -  else if (ConstantVector *CV = dyn_cast<ConstantVector>(Op1)) -    COp1 = dyn_cast_or_null<ConstantInt>(CV->getSplatValue()); -  else -    COp1 = dyn_cast<ConstantInt>(Op1); - -  if (!COp1) +  const APInt *Op1C; +  if (!match(Op1, m_APInt(Op1C)))      return nullptr;    // See if we can propagate this shift into the input, this covers the trivial    // cast of lshr(shl(x,c1),c2) as well as other more complex cases.    if (I.getOpcode() != Instruction::AShr && -      CanEvaluateShifted(Op0, COp1->getZExtValue(), isLeftShift, *this, &I)) { +      canEvaluateShifted(Op0, Op1C->getZExtValue(), isLeftShift, *this, &I)) {      DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression"                " to eliminate shift:\n  IN: " << *Op0 << "\n  SH: " << I <<"\n");      return replaceInstUsesWith( -        I, GetShiftedValue(Op0, COp1->getZExtValue(), isLeftShift, *this, DL)); +        I, getShiftedValue(Op0, Op1C->getZExtValue(), isLeftShift, *this, DL));    }    // See if we can simplify any instructions used by the instruction whose sole    // purpose is to compute bits we don't care about. -  uint32_t TypeBits = Op0->getType()->getScalarSizeInBits(); +  unsigned TypeBits = Op0->getType()->getScalarSizeInBits(); -  assert(!COp1->uge(TypeBits) && +  assert(!Op1C->uge(TypeBits) &&           "Shift over the type width should have been removed already"); -  // ((X*C1) << C2) == (X * (C1 << C2)) -  if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op0)) -    if (BO->getOpcode() == Instruction::Mul && isLeftShift) -      if (Constant *BOOp = dyn_cast<Constant>(BO->getOperand(1))) -        return BinaryOperator::CreateMul(BO->getOperand(0), -                                         ConstantExpr::getShl(BOOp, Op1)); -    if (Instruction *FoldedShift = foldOpWithConstantIntoOperand(I))      return FoldedShift; @@ -544,7 +349,8 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,      if (TrOp && I.isLogicalShift() && TrOp->isShift() &&          isa<ConstantInt>(TrOp->getOperand(1))) {        // Okay, we'll do this xform.  Make the shift of shift. -      Constant *ShAmt = ConstantExpr::getZExt(COp1, TrOp->getType()); +      Constant *ShAmt = +          ConstantExpr::getZExt(cast<Constant>(Op1), TrOp->getType());        // (shift2 (shift1 & 0x00FF), c2)        Value *NSh = Builder->CreateBinOp(I.getOpcode(), TrOp, ShAmt,I.getName()); @@ -561,10 +367,10 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,        // shift.  We know that it is a logical shift by a constant, so adjust the        // mask as appropriate.        if (I.getOpcode() == Instruction::Shl) -        MaskV <<= COp1->getZExtValue(); +        MaskV <<= Op1C->getZExtValue();        else {          assert(I.getOpcode() == Instruction::LShr && "Unknown logical shift"); -        MaskV = MaskV.lshr(COp1->getZExtValue()); +        MaskV = MaskV.lshr(Op1C->getZExtValue());        }        // shift1 & 0x00FF @@ -598,7 +404,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,            // (X + (Y << C))            Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), YS, V1,                                            Op0BO->getOperand(1)->getName()); -          uint32_t Op1Val = COp1->getLimitedValue(TypeBits); +          unsigned Op1Val = Op1C->getLimitedValue(TypeBits);            APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);            Constant *Mask = ConstantInt::get(I.getContext(), Bits); @@ -634,7 +440,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,            // (X + (Y << C))            Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), V1, YS,                                            Op0BO->getOperand(0)->getName()); -          uint32_t Op1Val = COp1->getLimitedValue(TypeBits); +          unsigned Op1Val = Op1C->getLimitedValue(TypeBits);            APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);            Constant *Mask = ConstantInt::get(I.getContext(), Bits); @@ -705,9 +511,6 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,      }    } -  if (Instruction *Folded = foldShiftByConstOfShiftByConst(I, COp1, Builder)) -    return Folded; -    return nullptr;  } @@ -715,59 +518,97 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {    if (Value *V = SimplifyVectorOp(I))      return replaceInstUsesWith(I, V); -  if (Value *V = -          SimplifyShlInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), -                          I.hasNoUnsignedWrap(), DL, &TLI, &DT, &AC)) +  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); +  if (Value *V = SimplifyShlInst(Op0, Op1, I.hasNoSignedWrap(), +                                 I.hasNoUnsignedWrap(), DL, &TLI, &DT, &AC))      return replaceInstUsesWith(I, V);    if (Instruction *V = commonShiftTransforms(I))      return V; -  if (ConstantInt *Op1C = dyn_cast<ConstantInt>(I.getOperand(1))) { -    unsigned ShAmt = Op1C->getZExtValue(); - -    // Turn: -    //  %zext = zext i32 %V to i64 -    //  %res = shl i64 %V, 8 -    // -    // Into: -    //  %shl = shl i32 %V, 8 -    //  %res = zext i32 %shl to i64 -    // -    // This is only valid if %V would have zeros shifted out. -    if (auto *ZI = dyn_cast<ZExtInst>(I.getOperand(0))) { -      unsigned SrcBitWidth = ZI->getSrcTy()->getScalarSizeInBits(); -      if (ShAmt < SrcBitWidth && -          MaskedValueIsZero(ZI->getOperand(0), -                            APInt::getHighBitsSet(SrcBitWidth, ShAmt), 0, &I)) { -        auto *Shl = Builder->CreateShl(ZI->getOperand(0), ShAmt); -        return new ZExtInst(Shl, I.getType()); +  const APInt *ShAmtAPInt; +  if (match(Op1, m_APInt(ShAmtAPInt))) { +    unsigned ShAmt = ShAmtAPInt->getZExtValue(); +    unsigned BitWidth = I.getType()->getScalarSizeInBits(); +    Type *Ty = I.getType(); + +    // shl (zext X), ShAmt --> zext (shl X, ShAmt) +    // This is only valid if X would have zeros shifted out. +    Value *X; +    if (match(Op0, m_ZExt(m_Value(X)))) { +      unsigned SrcWidth = X->getType()->getScalarSizeInBits(); +      if (ShAmt < SrcWidth && +          MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmt), 0, &I)) +        return new ZExtInst(Builder->CreateShl(X, ShAmt), Ty); +    } + +    // (X >>u C) << C --> X & (-1 << C) +    if (match(Op0, m_LShr(m_Value(X), m_Specific(Op1)))) { +      APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt)); +      return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); +    } + +    // Be careful about hiding shl instructions behind bit masks. They are used +    // to represent multiplies by a constant, and it is important that simple +    // arithmetic expressions are still recognizable by scalar evolution. +    // The inexact versions are deferred to DAGCombine, so we don't hide shl +    // behind a bit mask. +    const APInt *ShOp1; +    if (match(Op0, m_CombineOr(m_Exact(m_LShr(m_Value(X), m_APInt(ShOp1))), +                               m_Exact(m_AShr(m_Value(X), m_APInt(ShOp1)))))) { +      unsigned ShrAmt = ShOp1->getZExtValue(); +      if (ShrAmt < ShAmt) { +        // If C1 < C2: (X >>?,exact C1) << C2 --> X << (C2 - C1) +        Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShrAmt); +        auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); +        NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); +        NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); +        return NewShl;        } +      if (ShrAmt > ShAmt) { +        // If C1 > C2: (X >>?exact C1) << C2 --> X >>?exact (C1 - C2) +        Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmt); +        auto *NewShr = BinaryOperator::Create( +            cast<BinaryOperator>(Op0)->getOpcode(), X, ShiftDiff); +        NewShr->setIsExact(true); +        return NewShr; +      } +    } + +    if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1)))) { +      unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); +      // Oversized shifts are simplified to zero in InstSimplify. +      if (AmtSum < BitWidth) +        // (X << C1) << C2 --> X << (C1 + C2) +        return BinaryOperator::CreateShl(X, ConstantInt::get(Ty, AmtSum));      }      // If the shifted-out value is known-zero, then this is a NUW shift.      if (!I.hasNoUnsignedWrap() && -        MaskedValueIsZero(I.getOperand(0), -                          APInt::getHighBitsSet(Op1C->getBitWidth(), ShAmt), 0, -                          &I)) { +        MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmt), 0, &I)) {        I.setHasNoUnsignedWrap();        return &I;      } -    // If the shifted out value is all signbits, this is a NSW shift. -    if (!I.hasNoSignedWrap() && -        ComputeNumSignBits(I.getOperand(0), 0, &I) > ShAmt) { +    // If the shifted-out value is all signbits, then this is a NSW shift. +    if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmt) {        I.setHasNoSignedWrap();        return &I;      }    } -  // (C1 << A) << C2 -> (C1 << C2) << A -  Constant *C1, *C2; -  Value *A; -  if (match(I.getOperand(0), m_OneUse(m_Shl(m_Constant(C1), m_Value(A)))) && -      match(I.getOperand(1), m_Constant(C2))) -    return BinaryOperator::CreateShl(ConstantExpr::getShl(C1, C2), A); +  Constant *C1; +  if (match(Op1, m_Constant(C1))) { +    Constant *C2; +    Value *X; +    // (C2 << X) << C1 --> (C2 << C1) << X +    if (match(Op0, m_OneUse(m_Shl(m_Constant(C2), m_Value(X))))) +      return BinaryOperator::CreateShl(ConstantExpr::getShl(C2, C1), X); + +    // (X * C2) << C1 --> X * (C2 << C1) +    if (match(Op0, m_Mul(m_Value(X), m_Constant(C2)))) +      return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1)); +  }    return nullptr;  } @@ -776,43 +617,83 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) {    if (Value *V = SimplifyVectorOp(I))      return replaceInstUsesWith(I, V); -  if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), -                                  DL, &TLI, &DT, &AC)) +  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); +  if (Value *V = SimplifyLShrInst(Op0, Op1, I.isExact(), DL, &TLI, &DT, &AC))      return replaceInstUsesWith(I, V);    if (Instruction *R = commonShiftTransforms(I))      return R; -  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - -  if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) { -    unsigned ShAmt = Op1C->getZExtValue(); - -    if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0)) { -      unsigned BitWidth = Op0->getType()->getScalarSizeInBits(); +  Type *Ty = I.getType(); +  const APInt *ShAmtAPInt; +  if (match(Op1, m_APInt(ShAmtAPInt))) { +    unsigned ShAmt = ShAmtAPInt->getZExtValue(); +    unsigned BitWidth = Ty->getScalarSizeInBits(); +    auto *II = dyn_cast<IntrinsicInst>(Op0); +    if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt && +        (II->getIntrinsicID() == Intrinsic::ctlz || +         II->getIntrinsicID() == Intrinsic::cttz || +         II->getIntrinsicID() == Intrinsic::ctpop)) {        // ctlz.i32(x)>>5  --> zext(x == 0)        // cttz.i32(x)>>5  --> zext(x == 0)        // ctpop.i32(x)>>5 --> zext(x == -1) -      if ((II->getIntrinsicID() == Intrinsic::ctlz || -           II->getIntrinsicID() == Intrinsic::cttz || -           II->getIntrinsicID() == Intrinsic::ctpop) && -          isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt) { -        bool isCtPop = II->getIntrinsicID() == Intrinsic::ctpop; -        Constant *RHS = ConstantInt::getSigned(Op0->getType(), isCtPop ? -1:0); -        Value *Cmp = Builder->CreateICmpEQ(II->getArgOperand(0), RHS); -        return new ZExtInst(Cmp, II->getType()); +      bool IsPop = II->getIntrinsicID() == Intrinsic::ctpop; +      Constant *RHS = ConstantInt::getSigned(Ty, IsPop ? -1 : 0); +      Value *Cmp = Builder->CreateICmpEQ(II->getArgOperand(0), RHS); +      return new ZExtInst(Cmp, Ty); +    } + +    Value *X; +    const APInt *ShOp1; +    if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1)))) { +      unsigned ShlAmt = ShOp1->getZExtValue(); +      if (ShlAmt < ShAmt) { +        Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt); +        if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) { +          // (X <<nuw C1) >>u C2 --> X >>u (C2 - C1) +          auto *NewLShr = BinaryOperator::CreateLShr(X, ShiftDiff); +          NewLShr->setIsExact(I.isExact()); +          return NewLShr; +        } +        // (X << C1) >>u C2  --> (X >>u (C2 - C1)) & (-1 >> C2) +        Value *NewLShr = Builder->CreateLShr(X, ShiftDiff, "", I.isExact()); +        APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); +        return BinaryOperator::CreateAnd(NewLShr, ConstantInt::get(Ty, Mask));        } +      if (ShlAmt > ShAmt) { +        Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt); +        if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) { +          // (X <<nuw C1) >>u C2 --> X <<nuw (C1 - C2) +          auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); +          NewShl->setHasNoUnsignedWrap(true); +          return NewShl; +        } +        // (X << C1) >>u C2  --> X << (C1 - C2) & (-1 >> C2) +        Value *NewShl = Builder->CreateShl(X, ShiftDiff); +        APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); +        return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask)); +      } +      assert(ShlAmt == ShAmt); +      // (X << C) >>u C --> X & (-1 >>u C) +      APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); +      return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); +    } + +    if (match(Op0, m_LShr(m_Value(X), m_APInt(ShOp1)))) { +      unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); +      // Oversized shifts are simplified to zero in InstSimplify. +      if (AmtSum < BitWidth) +        // (X >>u C1) >>u C2 --> X >>u (C1 + C2) +        return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum));      }      // If the shifted-out value is known-zero, then this is an exact shift.      if (!I.isExact() && -        MaskedValueIsZero(Op0, APInt::getLowBitsSet(Op1C->getBitWidth(), ShAmt), -                          0, &I)){ +        MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) {        I.setIsExact();        return &I;      }    } -    return nullptr;  } @@ -820,48 +701,66 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) {    if (Value *V = SimplifyVectorOp(I))      return replaceInstUsesWith(I, V); -  if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), -                                  DL, &TLI, &DT, &AC)) +  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); +  if (Value *V = SimplifyAShrInst(Op0, Op1, I.isExact(), DL, &TLI, &DT, &AC))      return replaceInstUsesWith(I, V);    if (Instruction *R = commonShiftTransforms(I))      return R; -  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - -  if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) { -    unsigned ShAmt = Op1C->getZExtValue(); +  Type *Ty = I.getType(); +  unsigned BitWidth = Ty->getScalarSizeInBits(); +  const APInt *ShAmtAPInt; +  if (match(Op1, m_APInt(ShAmtAPInt))) { +    unsigned ShAmt = ShAmtAPInt->getZExtValue(); -    // If the input is a SHL by the same constant (ashr (shl X, C), C), then we -    // have a sign-extend idiom. +    // If the shift amount equals the difference in width of the destination +    // and source scalar types: +    // ashr (shl (zext X), C), C --> sext X      Value *X; -    if (match(Op0, m_Shl(m_Value(X), m_Specific(Op1)))) { -      // If the input is an extension from the shifted amount value, e.g. -      //   %x = zext i8 %A to i32 -      //   %y = shl i32 %x, 24 -      //   %z = ashr %y, 24 -      // then turn this into "z = sext i8 A to i32". -      if (ZExtInst *ZI = dyn_cast<ZExtInst>(X)) { -        uint32_t SrcBits = ZI->getOperand(0)->getType()->getScalarSizeInBits(); -        uint32_t DestBits = ZI->getType()->getScalarSizeInBits(); -        if (Op1C->getZExtValue() == DestBits-SrcBits) -          return new SExtInst(ZI->getOperand(0), ZI->getType()); +    if (match(Op0, m_Shl(m_ZExt(m_Value(X)), m_Specific(Op1))) && +        ShAmt == BitWidth - X->getType()->getScalarSizeInBits()) +      return new SExtInst(X, Ty); + +    // We can't handle (X << C1) >>s C2. It shifts arbitrary bits in. However, +    // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. +    const APInt *ShOp1; +    if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1)))) { +      unsigned ShlAmt = ShOp1->getZExtValue(); +      if (ShlAmt < ShAmt) { +        // (X <<nsw C1) >>s C2 --> X >>s (C2 - C1) +        Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt); +        auto *NewAShr = BinaryOperator::CreateAShr(X, ShiftDiff); +        NewAShr->setIsExact(I.isExact()); +        return NewAShr;        } +      if (ShlAmt > ShAmt) { +        // (X <<nsw C1) >>s C2 --> X <<nsw (C1 - C2) +        Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt); +        auto *NewShl = BinaryOperator::Create(Instruction::Shl, X, ShiftDiff); +        NewShl->setHasNoSignedWrap(true); +        return NewShl; +      } +    } + +    if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1)))) { +      unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); +      // Oversized arithmetic shifts replicate the sign bit. +      AmtSum = std::min(AmtSum, BitWidth - 1); +      // (X >>s C1) >>s C2 --> X >>s (C1 + C2) +      return BinaryOperator::CreateAShr(X, ConstantInt::get(Ty, AmtSum));      }      // If the shifted-out value is known-zero, then this is an exact shift.      if (!I.isExact() && -        MaskedValueIsZero(Op0, APInt::getLowBitsSet(Op1C->getBitWidth(), ShAmt), -                          0, &I)) { +        MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) {        I.setIsExact();        return &I;      }    }    // See if we can turn a signed shr into an unsigned shr. -  if (MaskedValueIsZero(Op0, -                        APInt::getSignBit(I.getType()->getScalarSizeInBits()), -                        0, &I)) +  if (MaskedValueIsZero(Op0, APInt::getSignBit(BitWidth), 0, &I))      return BinaryOperator::CreateLShr(Op0, Op1);    return nullptr; diff --git a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 8b930bd95dfe..4e6f02058d83 100644 --- a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -30,18 +30,20 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,    assert(I && "No instruction?");    assert(OpNo < I->getNumOperands() && "Operand index too large"); -  // If the operand is not a constant integer, nothing to do. -  ConstantInt *OpC = dyn_cast<ConstantInt>(I->getOperand(OpNo)); -  if (!OpC) return false; +  // The operand must be a constant integer or splat integer. +  Value *Op = I->getOperand(OpNo); +  const APInt *C; +  if (!match(Op, m_APInt(C))) +    return false;    // If there are no bits set that aren't demanded, nothing to do. -  Demanded = Demanded.zextOrTrunc(OpC->getValue().getBitWidth()); -  if ((~Demanded & OpC->getValue()) == 0) +  Demanded = Demanded.zextOrTrunc(C->getBitWidth()); +  if ((~Demanded & *C) == 0)      return false;    // This instruction is producing bits that are not demanded. Shrink the RHS. -  Demanded &= OpC->getValue(); -  I->setOperand(OpNo, ConstantInt::get(OpC->getType(), Demanded)); +  Demanded &= *C; +  I->setOperand(OpNo, ConstantInt::get(Op->getType(), Demanded));    return true;  } @@ -66,12 +68,13 @@ bool InstCombiner::SimplifyDemandedInstructionBits(Instruction &Inst) {  /// This form of SimplifyDemandedBits simplifies the specified instruction  /// operand if possible, updating it in place. It returns true if it made any  /// change and false otherwise. -bool InstCombiner::SimplifyDemandedBits(Use &U, const APInt &DemandedMask, +bool InstCombiner::SimplifyDemandedBits(Instruction *I, unsigned OpNo, +                                        const APInt &DemandedMask,                                          APInt &KnownZero, APInt &KnownOne,                                          unsigned Depth) { -  auto *UserI = dyn_cast<Instruction>(U.getUser()); +  Use &U = I->getOperandUse(OpNo);    Value *NewVal = SimplifyDemandedUseBits(U.get(), DemandedMask, KnownZero, -                                          KnownOne, Depth, UserI); +                                          KnownOne, Depth, I);    if (!NewVal) return false;    U = NewVal;    return true; @@ -114,9 +117,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,        KnownOne.getBitWidth() == BitWidth &&        "Value *V, DemandedMask, KnownZero and KnownOne "        "must have same BitWidth"); -  if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) { -    // We know all of the bits for a constant! -    KnownOne = CI->getValue() & DemandedMask; +  const APInt *C; +  if (match(V, m_APInt(C))) { +    // We know all of the bits for a scalar constant or a splat vector constant! +    KnownOne = *C & DemandedMask;      KnownZero = ~KnownOne & DemandedMask;      return nullptr;    } @@ -138,9 +142,6 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,    if (Depth == 6)        // Limit search depth.      return nullptr; -  APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); -  APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); -    Instruction *I = dyn_cast<Instruction>(V);    if (!I) {      computeKnownBits(V, KnownZero, KnownOne, Depth, CxtI); @@ -151,107 +152,43 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,    // we can't do any simplifications of the operands, because DemandedMask    // only reflects the bits demanded by *one* of the users.    if (Depth != 0 && !I->hasOneUse()) { -    // Despite the fact that we can't simplify this instruction in all User's -    // context, we can at least compute the knownzero/knownone bits, and we can -    // do simplifications that apply to *just* the one user if we know that -    // this instruction has a simpler value in that context. -    if (I->getOpcode() == Instruction::And) { -      // If either the LHS or the RHS are Zero, the result is zero. -      computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1, -                       CxtI); -      computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1, -                       CxtI); - -      // If all of the demanded bits are known 1 on one side, return the other. -      // These bits cannot contribute to the result of the 'and' in this -      // context. -      if ((DemandedMask & ~LHSKnownZero & RHSKnownOne) == -          (DemandedMask & ~LHSKnownZero)) -        return I->getOperand(0); -      if ((DemandedMask & ~RHSKnownZero & LHSKnownOne) == -          (DemandedMask & ~RHSKnownZero)) -        return I->getOperand(1); - -      // If all of the demanded bits in the inputs are known zeros, return zero. -      if ((DemandedMask & (RHSKnownZero|LHSKnownZero)) == DemandedMask) -        return Constant::getNullValue(VTy); - -    } else if (I->getOpcode() == Instruction::Or) { -      // We can simplify (X|Y) -> X or Y in the user's context if we know that -      // only bits from X or Y are demanded. - -      // If either the LHS or the RHS are One, the result is One. -      computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1, -                       CxtI); -      computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1, -                       CxtI); - -      // If all of the demanded bits are known zero on one side, return the -      // other.  These bits cannot contribute to the result of the 'or' in this -      // context. -      if ((DemandedMask & ~LHSKnownOne & RHSKnownZero) == -          (DemandedMask & ~LHSKnownOne)) -        return I->getOperand(0); -      if ((DemandedMask & ~RHSKnownOne & LHSKnownZero) == -          (DemandedMask & ~RHSKnownOne)) -        return I->getOperand(1); - -      // If all of the potentially set bits on one side are known to be set on -      // the other side, just use the 'other' side. -      if ((DemandedMask & (~RHSKnownZero) & LHSKnownOne) == -          (DemandedMask & (~RHSKnownZero))) -        return I->getOperand(0); -      if ((DemandedMask & (~LHSKnownZero) & RHSKnownOne) == -          (DemandedMask & (~LHSKnownZero))) -        return I->getOperand(1); -    } else if (I->getOpcode() == Instruction::Xor) { -      // We can simplify (X^Y) -> X or Y in the user's context if we know that -      // only bits from X or Y are demanded. - -      computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1, -                       CxtI); -      computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1, -                       CxtI); - -      // If all of the demanded bits are known zero on one side, return the -      // other. -      if ((DemandedMask & RHSKnownZero) == DemandedMask) -        return I->getOperand(0); -      if ((DemandedMask & LHSKnownZero) == DemandedMask) -        return I->getOperand(1); -    } - -    // Compute the KnownZero/KnownOne bits to simplify things downstream. -    computeKnownBits(I, KnownZero, KnownOne, Depth, CxtI); -    return nullptr; +    return SimplifyMultipleUseDemandedBits(I, DemandedMask, KnownZero, KnownOne, +                                           Depth, CxtI);    } +  APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); +  APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); +    // If this is the root being simplified, allow it to have multiple uses,    // just set the DemandedMask to all bits so that we can try to simplify the    // operands.  This allows visitTruncInst (for example) to simplify the    // operand of a trunc without duplicating all the logic below.    if (Depth == 0 && !V->hasOneUse()) -    DemandedMask = APInt::getAllOnesValue(BitWidth); +    DemandedMask.setAllBits();    switch (I->getOpcode()) {    default:      computeKnownBits(I, KnownZero, KnownOne, Depth, CxtI);      break; -  case Instruction::And: +  case Instruction::And: {      // If either the LHS or the RHS are Zero, the result is zero. -    if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, RHSKnownZero, -                             RHSKnownOne, Depth + 1) || -        SimplifyDemandedBits(I->getOperandUse(0), DemandedMask & ~RHSKnownZero, -                             LHSKnownZero, LHSKnownOne, Depth + 1)) +    if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnownZero, RHSKnownOne, +                             Depth + 1) || +        SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnownZero, LHSKnownZero, +                             LHSKnownOne, Depth + 1))        return I;      assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?");      assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?"); +    // Output known-0 are known to be clear if zero in either the LHS | RHS. +    APInt IKnownZero = RHSKnownZero | LHSKnownZero; +    // Output known-1 bits are only known if set in both the LHS & RHS. +    APInt IKnownOne = RHSKnownOne & LHSKnownOne; +      // If the client is only demanding bits that we know, return the known      // constant. -    if ((DemandedMask & ((RHSKnownZero | LHSKnownZero)| -                         (RHSKnownOne & LHSKnownOne))) == DemandedMask) -      return Constant::getIntegerValue(VTy, RHSKnownOne & LHSKnownOne); +    if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask) +      return Constant::getIntegerValue(VTy, IKnownOne);      // If all of the demanded bits are known 1 on one side, return the other.      // These bits cannot contribute to the result of the 'and'. @@ -262,34 +199,33 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,          (DemandedMask & ~RHSKnownZero))        return I->getOperand(1); -    // If all of the demanded bits in the inputs are known zeros, return zero. -    if ((DemandedMask & (RHSKnownZero|LHSKnownZero)) == DemandedMask) -      return Constant::getNullValue(VTy); -      // If the RHS is a constant, see if we can simplify it.      if (ShrinkDemandedConstant(I, 1, DemandedMask & ~LHSKnownZero))        return I; -    // Output known-1 bits are only known if set in both the LHS & RHS. -    KnownOne = RHSKnownOne & LHSKnownOne; -    // Output known-0 are known to be clear if zero in either the LHS | RHS. -    KnownZero = RHSKnownZero | LHSKnownZero; +    KnownZero = std::move(IKnownZero); +    KnownOne  = std::move(IKnownOne);      break; -  case Instruction::Or: +  } +  case Instruction::Or: {      // If either the LHS or the RHS are One, the result is One. -    if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, RHSKnownZero, -                             RHSKnownOne, Depth + 1) || -        SimplifyDemandedBits(I->getOperandUse(0), DemandedMask & ~RHSKnownOne, -                             LHSKnownZero, LHSKnownOne, Depth + 1)) +    if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnownZero, RHSKnownOne, +                             Depth + 1) || +        SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnownOne, LHSKnownZero, +                             LHSKnownOne, Depth + 1))        return I;      assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?");      assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?"); +    // Output known-0 bits are only known if clear in both the LHS & RHS. +    APInt IKnownZero = RHSKnownZero & LHSKnownZero; +    // Output known-1 are known to be set if set in either the LHS | RHS. +    APInt IKnownOne = RHSKnownOne | LHSKnownOne; +      // If the client is only demanding bits that we know, return the known      // constant. -    if ((DemandedMask & ((RHSKnownZero & LHSKnownZero)| -                         (RHSKnownOne | LHSKnownOne))) == DemandedMask) -      return Constant::getIntegerValue(VTy, RHSKnownOne | LHSKnownOne); +    if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask) +      return Constant::getIntegerValue(VTy, IKnownOne);      // If all of the demanded bits are known zero on one side, return the other.      // These bits cannot contribute to the result of the 'or'. @@ -313,16 +249,15 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,      if (ShrinkDemandedConstant(I, 1, DemandedMask))        return I; -    // Output known-0 bits are only known if clear in both the LHS & RHS. -    KnownZero = RHSKnownZero & LHSKnownZero; -    // Output known-1 are known to be set if set in either the LHS | RHS. -    KnownOne = RHSKnownOne | LHSKnownOne; +    KnownZero = std::move(IKnownZero); +    KnownOne  = std::move(IKnownOne);      break; +  }    case Instruction::Xor: { -    if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, RHSKnownZero, -                             RHSKnownOne, Depth + 1) || -        SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, LHSKnownZero, -                             LHSKnownOne, Depth + 1)) +    if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnownZero, RHSKnownOne, +                             Depth + 1) || +        SimplifyDemandedBits(I, 0, DemandedMask, LHSKnownZero, LHSKnownOne, +                             Depth + 1))        return I;      assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?");      assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?"); @@ -400,9 +335,9 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,        }      // Output known-0 bits are known if clear or set in both the LHS & RHS. -    KnownZero= (RHSKnownZero & LHSKnownZero) | (RHSKnownOne & LHSKnownOne); +    KnownZero = std::move(IKnownZero);      // Output known-1 are known to be set if set in only one of the LHS, RHS. -    KnownOne = (RHSKnownZero & LHSKnownOne) | (RHSKnownOne & LHSKnownZero); +    KnownOne  = std::move(IKnownOne);      break;    }    case Instruction::Select: @@ -412,10 +347,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,      if (matchSelectPattern(I, LHS, RHS).Flavor != SPF_UNKNOWN)        return nullptr; -    if (SimplifyDemandedBits(I->getOperandUse(2), DemandedMask, RHSKnownZero, -                             RHSKnownOne, Depth + 1) || -        SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, LHSKnownZero, -                             LHSKnownOne, Depth + 1)) +    if (SimplifyDemandedBits(I, 2, DemandedMask, RHSKnownZero, RHSKnownOne, +                             Depth + 1) || +        SimplifyDemandedBits(I, 1, DemandedMask, LHSKnownZero, LHSKnownOne, +                             Depth + 1))        return I;      assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?");      assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?"); @@ -434,8 +369,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,      DemandedMask = DemandedMask.zext(truncBf);      KnownZero = KnownZero.zext(truncBf);      KnownOne = KnownOne.zext(truncBf); -    if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, KnownZero, -                             KnownOne, Depth + 1)) +    if (SimplifyDemandedBits(I, 0, DemandedMask, KnownZero, KnownOne, +                             Depth + 1))        return I;      DemandedMask = DemandedMask.trunc(BitWidth);      KnownZero = KnownZero.trunc(BitWidth); @@ -460,8 +395,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,        // Don't touch a vector-to-scalar bitcast.        return nullptr; -    if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, KnownZero, -                             KnownOne, Depth + 1)) +    if (SimplifyDemandedBits(I, 0, DemandedMask, KnownZero, KnownOne, +                             Depth + 1))        return I;      assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");      break; @@ -472,15 +407,15 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,      DemandedMask = DemandedMask.trunc(SrcBitWidth);      KnownZero = KnownZero.trunc(SrcBitWidth);      KnownOne = KnownOne.trunc(SrcBitWidth); -    if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, KnownZero, -                             KnownOne, Depth + 1)) +    if (SimplifyDemandedBits(I, 0, DemandedMask, KnownZero, KnownOne, +                             Depth + 1))        return I;      DemandedMask = DemandedMask.zext(BitWidth);      KnownZero = KnownZero.zext(BitWidth);      KnownOne = KnownOne.zext(BitWidth);      assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");      // The top bits are known to be zero. -    KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth); +    KnownZero.setBitsFrom(SrcBitWidth);      break;    }    case Instruction::SExt: { @@ -490,7 +425,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,      APInt InputDemandedBits = DemandedMask &                                APInt::getLowBitsSet(BitWidth, SrcBitWidth); -    APInt NewBits(APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth)); +    APInt NewBits(APInt::getBitsSetFrom(BitWidth, SrcBitWidth));      // If any of the sign extended bits are demanded, we know that the sign      // bit is demanded.      if ((NewBits & DemandedMask) != 0) @@ -499,8 +434,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,      InputDemandedBits = InputDemandedBits.trunc(SrcBitWidth);      KnownZero = KnownZero.trunc(SrcBitWidth);      KnownOne = KnownOne.trunc(SrcBitWidth); -    if (SimplifyDemandedBits(I->getOperandUse(0), InputDemandedBits, KnownZero, -                             KnownOne, Depth + 1)) +    if (SimplifyDemandedBits(I, 0, InputDemandedBits, KnownZero, KnownOne, +                             Depth + 1))        return I;      InputDemandedBits = InputDemandedBits.zext(BitWidth);      KnownZero = KnownZero.zext(BitWidth); @@ -530,11 +465,12 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,        // Right fill the mask of bits for this ADD/SUB to demand the most        // significant bit and all those below it.        APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ)); -      if (SimplifyDemandedBits(I->getOperandUse(0), DemandedFromOps, -                               LHSKnownZero, LHSKnownOne, Depth + 1) || +      if (ShrinkDemandedConstant(I, 0, DemandedFromOps) || +          SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnownZero, LHSKnownOne, +                               Depth + 1) ||            ShrinkDemandedConstant(I, 1, DemandedFromOps) || -          SimplifyDemandedBits(I->getOperandUse(1), DemandedFromOps, -                               LHSKnownZero, LHSKnownOne, Depth + 1)) { +          SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnownZero, RHSKnownOne, +                               Depth + 1)) {          // Disable the nsw and nuw flags here: We can no longer guarantee that          // we won't wrap after simplification. Removing the nsw/nuw flags is          // legal here because the top bit is not demanded. @@ -543,6 +479,15 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,          BinOP.setHasNoUnsignedWrap(false);          return I;        } + +      // If we are known to be adding/subtracting zeros to every bit below +      // the highest demanded bit, we just return the other side. +      if ((DemandedFromOps & RHSKnownZero) == DemandedFromOps) +        return I->getOperand(0); +      // We can't do this with the LHS for subtraction. +      if (I->getOpcode() == Instruction::Add && +          (DemandedFromOps & LHSKnownZero) == DemandedFromOps) +        return I->getOperand(1);      }      // Otherwise just hand the add/sub off to computeKnownBits to fill in @@ -569,19 +514,19 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,        // If the shift is NUW/NSW, then it does demand the high bits.        ShlOperator *IOp = cast<ShlOperator>(I);        if (IOp->hasNoSignedWrap()) -        DemandedMaskIn |= APInt::getHighBitsSet(BitWidth, ShiftAmt+1); +        DemandedMaskIn.setHighBits(ShiftAmt+1);        else if (IOp->hasNoUnsignedWrap()) -        DemandedMaskIn |= APInt::getHighBitsSet(BitWidth, ShiftAmt); +        DemandedMaskIn.setHighBits(ShiftAmt); -      if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn, KnownZero, -                               KnownOne, Depth + 1)) +      if (SimplifyDemandedBits(I, 0, DemandedMaskIn, KnownZero, KnownOne, +                               Depth + 1))          return I;        assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");        KnownZero <<= ShiftAmt;        KnownOne  <<= ShiftAmt;        // low bits known zero.        if (ShiftAmt) -        KnownZero |= APInt::getLowBitsSet(BitWidth, ShiftAmt); +        KnownZero.setLowBits(ShiftAmt);      }      break;    case Instruction::LShr: @@ -595,19 +540,16 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,        // If the shift is exact, then it does demand the low bits (and knows that        // they are zero).        if (cast<LShrOperator>(I)->isExact()) -        DemandedMaskIn |= APInt::getLowBitsSet(BitWidth, ShiftAmt); +        DemandedMaskIn.setLowBits(ShiftAmt); -      if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn, KnownZero, -                               KnownOne, Depth + 1)) +      if (SimplifyDemandedBits(I, 0, DemandedMaskIn, KnownZero, KnownOne, +                               Depth + 1))          return I;        assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?"); -      KnownZero = APIntOps::lshr(KnownZero, ShiftAmt); -      KnownOne  = APIntOps::lshr(KnownOne, ShiftAmt); -      if (ShiftAmt) { -        // Compute the new bits that are at the top now. -        APInt HighBits(APInt::getHighBitsSet(BitWidth, ShiftAmt)); -        KnownZero |= HighBits;  // high bits known zero. -      } +      KnownZero = KnownZero.lshr(ShiftAmt); +      KnownOne  = KnownOne.lshr(ShiftAmt); +      if (ShiftAmt) +        KnownZero.setHighBits(ShiftAmt);  // high bits known zero.      }      break;    case Instruction::AShr: @@ -635,26 +577,26 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,        // If any of the "high bits" are demanded, we should set the sign bit as        // demanded.        if (DemandedMask.countLeadingZeros() <= ShiftAmt) -        DemandedMaskIn.setBit(BitWidth-1); +        DemandedMaskIn.setSignBit();        // If the shift is exact, then it does demand the low bits (and knows that        // they are zero).        if (cast<AShrOperator>(I)->isExact()) -        DemandedMaskIn |= APInt::getLowBitsSet(BitWidth, ShiftAmt); +        DemandedMaskIn.setLowBits(ShiftAmt); -      if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn, KnownZero, -                               KnownOne, Depth + 1)) +      if (SimplifyDemandedBits(I, 0, DemandedMaskIn, KnownZero, KnownOne, +                               Depth + 1))          return I;        assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");        // Compute the new bits that are at the top now.        APInt HighBits(APInt::getHighBitsSet(BitWidth, ShiftAmt)); -      KnownZero = APIntOps::lshr(KnownZero, ShiftAmt); -      KnownOne  = APIntOps::lshr(KnownOne, ShiftAmt); +      KnownZero = KnownZero.lshr(ShiftAmt); +      KnownOne  = KnownOne.lshr(ShiftAmt);        // Handle the sign bits.        APInt SignBit(APInt::getSignBit(BitWidth));        // Adjust to where it is now in the mask. -      SignBit = APIntOps::lshr(SignBit, ShiftAmt); +      SignBit = SignBit.lshr(ShiftAmt);        // If the input sign bit is known to be zero, or if none of the top bits        // are demanded, turn this into an unsigned shift right. @@ -683,8 +625,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,          APInt LowBits = RA - 1;          APInt Mask2 = LowBits | APInt::getSignBit(BitWidth); -        if (SimplifyDemandedBits(I->getOperandUse(0), Mask2, LHSKnownZero, -                                 LHSKnownOne, Depth + 1)) +        if (SimplifyDemandedBits(I, 0, Mask2, LHSKnownZero, LHSKnownOne, +                                 Depth + 1))            return I;          // The low bits of LHS are unchanged by the srem. @@ -693,12 +635,12 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,          // If LHS is non-negative or has all low bits zero, then the upper bits          // are all zero. -        if (LHSKnownZero[BitWidth-1] || ((LHSKnownZero & LowBits) == LowBits)) +        if (LHSKnownZero.isNegative() || ((LHSKnownZero & LowBits) == LowBits))            KnownZero |= ~LowBits;          // If LHS is negative and not all low bits are zero, then the upper bits          // are all one. -        if (LHSKnownOne[BitWidth-1] && ((LHSKnownOne & LowBits) != 0)) +        if (LHSKnownOne.isNegative() && ((LHSKnownOne & LowBits) != 0))            KnownOne |= ~LowBits;          assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?"); @@ -713,21 +655,17 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,                         CxtI);        // If it's known zero, our sign bit is also zero.        if (LHSKnownZero.isNegative()) -        KnownZero.setBit(KnownZero.getBitWidth() - 1); +        KnownZero.setSignBit();      }      break;    case Instruction::URem: {      APInt KnownZero2(BitWidth, 0), KnownOne2(BitWidth, 0);      APInt AllOnes = APInt::getAllOnesValue(BitWidth); -    if (SimplifyDemandedBits(I->getOperandUse(0), AllOnes, KnownZero2, -                             KnownOne2, Depth + 1) || -        SimplifyDemandedBits(I->getOperandUse(1), AllOnes, KnownZero2, -                             KnownOne2, Depth + 1)) +    if (SimplifyDemandedBits(I, 0, AllOnes, KnownZero2, KnownOne2, Depth + 1) || +        SimplifyDemandedBits(I, 1, AllOnes, KnownZero2, KnownOne2, Depth + 1))        return I;      unsigned Leaders = KnownZero2.countLeadingOnes(); -    Leaders = std::max(Leaders, -                       KnownZero2.countLeadingOnes());      KnownZero = APInt::getHighBitsSet(BitWidth, Leaders) & DemandedMask;      break;    } @@ -792,11 +730,11 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,            return ConstantInt::getNullValue(VTy);          // We know that the upper bits are set to zero. -        KnownZero = APInt::getHighBitsSet(BitWidth, BitWidth - ArgWidth); +        KnownZero.setBitsFrom(ArgWidth);          return nullptr;        }        case Intrinsic::x86_sse42_crc32_64_64: -        KnownZero = APInt::getHighBitsSet(64, 32); +        KnownZero.setBitsFrom(32);          return nullptr;        }      } @@ -811,6 +749,150 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,    return nullptr;  } +/// Helper routine of SimplifyDemandedUseBits. It computes KnownZero/KnownOne +/// bits. It also tries to handle simplifications that can be done based on +/// DemandedMask, but without modifying the Instruction. +Value *InstCombiner::SimplifyMultipleUseDemandedBits(Instruction *I, +                                                     const APInt &DemandedMask, +                                                     APInt &KnownZero, +                                                     APInt &KnownOne, +                                                     unsigned Depth, +                                                     Instruction *CxtI) { +  unsigned BitWidth = DemandedMask.getBitWidth(); +  Type *ITy = I->getType(); + +  APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); +  APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + +  // Despite the fact that we can't simplify this instruction in all User's +  // context, we can at least compute the knownzero/knownone bits, and we can +  // do simplifications that apply to *just* the one user if we know that +  // this instruction has a simpler value in that context. +  switch (I->getOpcode()) { +  case Instruction::And: { +    // If either the LHS or the RHS are Zero, the result is zero. +    computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1, +                     CxtI); +    computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1, +                     CxtI); + +    // Output known-0 are known to be clear if zero in either the LHS | RHS. +    APInt IKnownZero = RHSKnownZero | LHSKnownZero; +    // Output known-1 bits are only known if set in both the LHS & RHS. +    APInt IKnownOne = RHSKnownOne & LHSKnownOne; + +    // If the client is only demanding bits that we know, return the known +    // constant. +    if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask) +      return Constant::getIntegerValue(ITy, IKnownOne); + +    // If all of the demanded bits are known 1 on one side, return the other. +    // These bits cannot contribute to the result of the 'and' in this +    // context. +    if ((DemandedMask & ~LHSKnownZero & RHSKnownOne) == +        (DemandedMask & ~LHSKnownZero)) +      return I->getOperand(0); +    if ((DemandedMask & ~RHSKnownZero & LHSKnownOne) == +        (DemandedMask & ~RHSKnownZero)) +      return I->getOperand(1); + +    KnownZero = std::move(IKnownZero); +    KnownOne  = std::move(IKnownOne); +    break; +  } +  case Instruction::Or: { +    // We can simplify (X|Y) -> X or Y in the user's context if we know that +    // only bits from X or Y are demanded. + +    // If either the LHS or the RHS are One, the result is One. +    computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1, +                     CxtI); +    computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1, +                     CxtI); + +    // Output known-0 bits are only known if clear in both the LHS & RHS. +    APInt IKnownZero = RHSKnownZero & LHSKnownZero; +    // Output known-1 are known to be set if set in either the LHS | RHS. +    APInt IKnownOne = RHSKnownOne | LHSKnownOne; + +    // If the client is only demanding bits that we know, return the known +    // constant. +    if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask) +      return Constant::getIntegerValue(ITy, IKnownOne); + +    // If all of the demanded bits are known zero on one side, return the +    // other.  These bits cannot contribute to the result of the 'or' in this +    // context. +    if ((DemandedMask & ~LHSKnownOne & RHSKnownZero) == +        (DemandedMask & ~LHSKnownOne)) +      return I->getOperand(0); +    if ((DemandedMask & ~RHSKnownOne & LHSKnownZero) == +        (DemandedMask & ~RHSKnownOne)) +      return I->getOperand(1); + +    // If all of the potentially set bits on one side are known to be set on +    // the other side, just use the 'other' side. +    if ((DemandedMask & (~RHSKnownZero) & LHSKnownOne) == +        (DemandedMask & (~RHSKnownZero))) +      return I->getOperand(0); +    if ((DemandedMask & (~LHSKnownZero) & RHSKnownOne) == +        (DemandedMask & (~LHSKnownZero))) +      return I->getOperand(1); + +    KnownZero = std::move(IKnownZero); +    KnownOne  = std::move(IKnownOne); +    break; +  } +  case Instruction::Xor: { +    // We can simplify (X^Y) -> X or Y in the user's context if we know that +    // only bits from X or Y are demanded. + +    computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1, +                     CxtI); +    computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1, +                     CxtI); + +    // Output known-0 bits are known if clear or set in both the LHS & RHS. +    APInt IKnownZero = (RHSKnownZero & LHSKnownZero) | +                       (RHSKnownOne & LHSKnownOne); +    // Output known-1 are known to be set if set in only one of the LHS, RHS. +    APInt IKnownOne =  (RHSKnownZero & LHSKnownOne) | +                       (RHSKnownOne & LHSKnownZero); + +    // If the client is only demanding bits that we know, return the known +    // constant. +    if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask) +      return Constant::getIntegerValue(ITy, IKnownOne); + +    // If all of the demanded bits are known zero on one side, return the +    // other. +    if ((DemandedMask & RHSKnownZero) == DemandedMask) +      return I->getOperand(0); +    if ((DemandedMask & LHSKnownZero) == DemandedMask) +      return I->getOperand(1); + +    // Output known-0 bits are known if clear or set in both the LHS & RHS. +    KnownZero = std::move(IKnownZero); +    // Output known-1 are known to be set if set in only one of the LHS, RHS. +    KnownOne  = std::move(IKnownOne); +    break; +  } +  default: +    // Compute the KnownZero/KnownOne bits to simplify things downstream. +    computeKnownBits(I, KnownZero, KnownOne, Depth, CxtI); + +    // If this user is only demanding bits that we know, return the known +    // constant. +    if ((DemandedMask & (KnownZero|KnownOne)) == DemandedMask) +      return Constant::getIntegerValue(ITy, KnownOne); + +    break; +  } + +  return nullptr; +} + +  /// Helper routine of SimplifyDemandedUseBits. It tries to simplify  /// "E1 = (X lsr C1) << C2", where the C1 and C2 are constant, into  /// "E2 = X << (C2 - C1)" or "E2 = X >> (C1 - C2)", depending on the sign @@ -849,7 +931,7 @@ Value *InstCombiner::SimplifyShrShlDemandedBits(Instruction *Shr,    unsigned ShrAmt = ShrOp1.getZExtValue();    KnownOne.clearAllBits(); -  KnownZero = APInt::getBitsSet(KnownZero.getBitWidth(), 0, ShlAmt-1); +  KnownZero.setLowBits(ShlAmt - 1);    KnownZero &= DemandedMask;    APInt BitMask1(APInt::getAllOnesValue(BitWidth)); @@ -1472,14 +1554,136 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,        break;      } +    case Intrinsic::x86_sse2_packssdw_128: +    case Intrinsic::x86_sse2_packsswb_128: +    case Intrinsic::x86_sse2_packuswb_128: +    case Intrinsic::x86_sse41_packusdw: +    case Intrinsic::x86_avx2_packssdw: +    case Intrinsic::x86_avx2_packsswb: +    case Intrinsic::x86_avx2_packusdw: +    case Intrinsic::x86_avx2_packuswb: +    case Intrinsic::x86_avx512_packssdw_512: +    case Intrinsic::x86_avx512_packsswb_512: +    case Intrinsic::x86_avx512_packusdw_512: +    case Intrinsic::x86_avx512_packuswb_512: { +      auto *Ty0 = II->getArgOperand(0)->getType(); +      unsigned InnerVWidth = Ty0->getVectorNumElements(); +      assert(VWidth == (InnerVWidth * 2) && "Unexpected input size"); + +      unsigned NumLanes = Ty0->getPrimitiveSizeInBits() / 128; +      unsigned VWidthPerLane = VWidth / NumLanes; +      unsigned InnerVWidthPerLane = InnerVWidth / NumLanes; + +      // Per lane, pack the elements of the first input and then the second. +      // e.g. +      // v8i16 PACK(v4i32 X, v4i32 Y) - (X[0..3],Y[0..3]) +      // v32i8 PACK(v16i16 X, v16i16 Y) - (X[0..7],Y[0..7]),(X[8..15],Y[8..15]) +      for (int OpNum = 0; OpNum != 2; ++OpNum) { +        APInt OpDemandedElts(InnerVWidth, 0); +        for (unsigned Lane = 0; Lane != NumLanes; ++Lane) { +          unsigned LaneIdx = Lane * VWidthPerLane; +          for (unsigned Elt = 0; Elt != InnerVWidthPerLane; ++Elt) { +            unsigned Idx = LaneIdx + Elt + InnerVWidthPerLane * OpNum; +            if (DemandedElts[Idx]) +              OpDemandedElts.setBit((Lane * InnerVWidthPerLane) + Elt); +          } +        } + +        // Demand elements from the operand. +        auto *Op = II->getArgOperand(OpNum); +        APInt OpUndefElts(InnerVWidth, 0); +        TmpV = SimplifyDemandedVectorElts(Op, OpDemandedElts, OpUndefElts, +                                          Depth + 1); +        if (TmpV) { +          II->setArgOperand(OpNum, TmpV); +          MadeChange = true; +        } + +        // Pack the operand's UNDEF elements, one lane at a time. +        OpUndefElts = OpUndefElts.zext(VWidth); +        for (unsigned Lane = 0; Lane != NumLanes; ++Lane) { +          APInt LaneElts = OpUndefElts.lshr(InnerVWidthPerLane * Lane); +          LaneElts = LaneElts.getLoBits(InnerVWidthPerLane); +          LaneElts = LaneElts.shl(InnerVWidthPerLane * (2 * Lane + OpNum)); +          UndefElts |= LaneElts; +        } +      } +      break; +    } + +    // PSHUFB +    case Intrinsic::x86_ssse3_pshuf_b_128: +    case Intrinsic::x86_avx2_pshuf_b: +    case Intrinsic::x86_avx512_pshuf_b_512: +    // PERMILVAR +    case Intrinsic::x86_avx_vpermilvar_ps: +    case Intrinsic::x86_avx_vpermilvar_ps_256: +    case Intrinsic::x86_avx512_vpermilvar_ps_512: +    case Intrinsic::x86_avx_vpermilvar_pd: +    case Intrinsic::x86_avx_vpermilvar_pd_256: +    case Intrinsic::x86_avx512_vpermilvar_pd_512: +    // PERMV +    case Intrinsic::x86_avx2_permd: +    case Intrinsic::x86_avx2_permps: { +      Value *Op1 = II->getArgOperand(1); +      TmpV = SimplifyDemandedVectorElts(Op1, DemandedElts, UndefElts, +                                        Depth + 1); +      if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; } +      break; +    } +      // SSE4A instructions leave the upper 64-bits of the 128-bit result      // in an undefined state.      case Intrinsic::x86_sse4a_extrq:      case Intrinsic::x86_sse4a_extrqi:      case Intrinsic::x86_sse4a_insertq:      case Intrinsic::x86_sse4a_insertqi: -      UndefElts |= APInt::getHighBitsSet(VWidth, VWidth / 2); +      UndefElts.setHighBits(VWidth / 2);        break; +    case Intrinsic::amdgcn_buffer_load: +    case Intrinsic::amdgcn_buffer_load_format: { +      if (VWidth == 1 || !DemandedElts.isMask()) +        return nullptr; + +      // TODO: Handle 3 vectors when supported in code gen. +      unsigned NewNumElts = PowerOf2Ceil(DemandedElts.countTrailingOnes()); +      if (NewNumElts == VWidth) +        return nullptr; + +      Module *M = II->getParent()->getParent()->getParent(); +      Type *EltTy = V->getType()->getVectorElementType(); + +      Type *NewTy = (NewNumElts == 1) ? EltTy : +        VectorType::get(EltTy, NewNumElts); + +      Function *NewIntrin = Intrinsic::getDeclaration(M, II->getIntrinsicID(), +                                                      NewTy); + +      SmallVector<Value *, 5> Args; +      for (unsigned I = 0, E = II->getNumArgOperands(); I != E; ++I) +        Args.push_back(II->getArgOperand(I)); + +      IRBuilderBase::InsertPointGuard Guard(*Builder); +      Builder->SetInsertPoint(II); + +      CallInst *NewCall = Builder->CreateCall(NewIntrin, Args); +      NewCall->takeName(II); +      NewCall->copyMetadata(*II); +      if (NewNumElts == 1) { +        return Builder->CreateInsertElement(UndefValue::get(V->getType()), +                                            NewCall, static_cast<uint64_t>(0)); +      } + +      SmallVector<uint32_t, 8> EltMask; +      for (unsigned I = 0; I < VWidth; ++I) +        EltMask.push_back(I); + +      Value *Shuffle = Builder->CreateShuffleVector( +        NewCall, UndefValue::get(NewTy), EltMask); + +      MadeChange = true; +      return Shuffle; +    }      }      break;    } diff --git a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index b2477f6c8633..e89b400a4afc 100644 --- a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -645,6 +645,36 @@ static Instruction *foldInsSequenceIntoBroadcast(InsertElementInst &InsElt) {    return new ShuffleVectorInst(InsertFirst, UndefValue::get(VT), ZeroMask);  } +/// If we have an insertelement instruction feeding into another insertelement +/// and the 2nd is inserting a constant into the vector, canonicalize that +/// constant insertion before the insertion of a variable: +/// +/// insertelement (insertelement X, Y, IdxC1), ScalarC, IdxC2 --> +/// insertelement (insertelement X, ScalarC, IdxC2), Y, IdxC1 +/// +/// This has the potential of eliminating the 2nd insertelement instruction +/// via constant folding of the scalar constant into a vector constant. +static Instruction *hoistInsEltConst(InsertElementInst &InsElt2, +                                     InstCombiner::BuilderTy &Builder) { +  auto *InsElt1 = dyn_cast<InsertElementInst>(InsElt2.getOperand(0)); +  if (!InsElt1 || !InsElt1->hasOneUse()) +    return nullptr; + +  Value *X, *Y; +  Constant *ScalarC; +  ConstantInt *IdxC1, *IdxC2; +  if (match(InsElt1->getOperand(0), m_Value(X)) && +      match(InsElt1->getOperand(1), m_Value(Y)) && !isa<Constant>(Y) && +      match(InsElt1->getOperand(2), m_ConstantInt(IdxC1)) && +      match(InsElt2.getOperand(1), m_Constant(ScalarC)) && +      match(InsElt2.getOperand(2), m_ConstantInt(IdxC2)) && IdxC1 != IdxC2) { +    Value *NewInsElt1 = Builder.CreateInsertElement(X, ScalarC, IdxC2); +    return InsertElementInst::Create(NewInsElt1, Y, IdxC1); +  } + +  return nullptr; +} +  /// insertelt (shufflevector X, CVec, Mask|insertelt X, C1, CIndex1), C, CIndex  /// --> shufflevector X, CVec', Mask'  static Instruction *foldConstantInsEltIntoShuffle(InsertElementInst &InsElt) { @@ -806,6 +836,9 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) {    if (Instruction *Shuf = foldConstantInsEltIntoShuffle(IE))      return Shuf; +  if (Instruction *NewInsElt = hoistInsEltConst(IE, *Builder)) +    return NewInsElt; +    // Turn a sequence of inserts that broadcasts a scalar into a single    // insert + shufflevector.    if (Instruction *Broadcast = foldInsSequenceIntoBroadcast(IE)) @@ -1107,12 +1140,11 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {    SmallVector<int, 16> Mask = SVI.getShuffleMask();    Type *Int32Ty = Type::getInt32Ty(SVI.getContext()); -  bool MadeChange = false; - -  // Undefined shuffle mask -> undefined value. -  if (isa<UndefValue>(SVI.getOperand(2))) -    return replaceInstUsesWith(SVI, UndefValue::get(SVI.getType())); +  if (auto *V = SimplifyShuffleVectorInst(LHS, RHS, SVI.getMask(), +                                          SVI.getType(), DL, &TLI, &DT, &AC)) +    return replaceInstUsesWith(SVI, V); +  bool MadeChange = false;    unsigned VWidth = SVI.getType()->getVectorNumElements();    APInt UndefElts(VWidth, 0); @@ -1209,7 +1241,6 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {    if (isShuffleExtractingFromLHS(SVI, Mask)) {      Value *V = LHS;      unsigned MaskElems = Mask.size(); -    unsigned BegIdx = Mask.front();      VectorType *SrcTy = cast<VectorType>(V->getType());      unsigned VecBitWidth = SrcTy->getBitWidth();      unsigned SrcElemBitWidth = DL.getTypeSizeInBits(SrcTy->getElementType()); @@ -1223,6 +1254,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {            // Only visit bitcasts that weren't previously handled.            BCs.push_back(BC);      for (BitCastInst *BC : BCs) { +      unsigned BegIdx = Mask.front();        Type *TgtTy = BC->getDestTy();        unsigned TgtElemBitWidth = DL.getTypeSizeInBits(TgtTy);        if (!TgtElemBitWidth) diff --git a/lib/Transforms/InstCombine/InstructionCombining.cpp b/lib/Transforms/InstCombine/InstructionCombining.cpp index 27fc34d23175..88ef17bbc8fa 100644 --- a/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -82,18 +82,24 @@ static cl::opt<bool>  EnableExpensiveCombines("expensive-combines",                          cl::desc("Enable expensive instruction combines")); +static cl::opt<unsigned> +MaxArraySize("instcombine-maxarray-size", cl::init(1024), +             cl::desc("Maximum array size considered when doing a combine")); +  Value *InstCombiner::EmitGEPOffset(User *GEP) {    return llvm::EmitGEPOffset(Builder, DL, GEP);  }  /// Return true if it is desirable to convert an integer computation from a  /// given bit width to a new bit width. -/// We don't want to convert from a legal to an illegal type for example or from -/// a smaller to a larger illegal type. -bool InstCombiner::ShouldChangeType(unsigned FromWidth, +/// We don't want to convert from a legal to an illegal type or from a smaller +/// to a larger illegal type. A width of '1' is always treated as a legal type +/// because i1 is a fundamental type in IR, and there are many specialized +/// optimizations for i1 types. +bool InstCombiner::shouldChangeType(unsigned FromWidth,                                      unsigned ToWidth) const { -  bool FromLegal = DL.isLegalInteger(FromWidth); -  bool ToLegal = DL.isLegalInteger(ToWidth); +  bool FromLegal = FromWidth == 1 || DL.isLegalInteger(FromWidth); +  bool ToLegal = ToWidth == 1 || DL.isLegalInteger(ToWidth);    // If this is a legal integer from type, and the result would be an illegal    // type, don't do the transformation. @@ -109,14 +115,16 @@ bool InstCombiner::ShouldChangeType(unsigned FromWidth,  }  /// Return true if it is desirable to convert a computation from 'From' to 'To'. -/// We don't want to convert from a legal to an illegal type for example or from -/// a smaller to a larger illegal type. -bool InstCombiner::ShouldChangeType(Type *From, Type *To) const { +/// We don't want to convert from a legal to an illegal type or from a smaller +/// to a larger illegal type. i1 is always treated as a legal type because it is +/// a fundamental type in IR, and there are many specialized optimizations for +/// i1 types. +bool InstCombiner::shouldChangeType(Type *From, Type *To) const {    assert(From->isIntegerTy() && To->isIntegerTy());    unsigned FromWidth = From->getPrimitiveSizeInBits();    unsigned ToWidth = To->getPrimitiveSizeInBits(); -  return ShouldChangeType(FromWidth, ToWidth); +  return shouldChangeType(FromWidth, ToWidth);  }  // Return true, if No Signed Wrap should be maintained for I. @@ -447,16 +455,11 @@ static bool RightDistributesOverLeft(Instruction::BinaryOps LOp,  /// This function returns identity value for given opcode, which can be used to  /// factor patterns like (X * 2) + X ==> (X * 2) + (X * 1) ==> X * (2 + 1). -static Value *getIdentityValue(Instruction::BinaryOps OpCode, Value *V) { +static Value *getIdentityValue(Instruction::BinaryOps Opcode, Value *V) {    if (isa<Constant>(V))      return nullptr; -  if (OpCode == Instruction::Mul) -    return ConstantInt::get(V->getType(), 1); - -  // TODO: We can handle other cases e.g. Instruction::And, Instruction::Or etc. - -  return nullptr; +  return ConstantExpr::getBinOpIdentity(Opcode, V->getType());  }  /// This function factors binary ops which can be combined using distributive @@ -468,8 +471,7 @@ static Value *getIdentityValue(Instruction::BinaryOps OpCode, Value *V) {  static Instruction::BinaryOps  getBinOpsForFactorization(Instruction::BinaryOps TopLevelOpcode,                            BinaryOperator *Op, Value *&LHS, Value *&RHS) { -  if (!Op) -    return Instruction::BinaryOpsEnd; +  assert(Op && "Expected a binary operator");    LHS = Op->getOperand(0);    RHS = Op->getOperand(1); @@ -499,11 +501,7 @@ static Value *tryFactorization(InstCombiner::BuilderTy *Builder,                                 const DataLayout &DL, BinaryOperator &I,                                 Instruction::BinaryOps InnerOpcode, Value *A,                                 Value *B, Value *C, Value *D) { - -  // If any of A, B, C, D are null, we can not factor I, return early. -  // Checking A and C should be enough. -  if (!A || !C || !B || !D) -    return nullptr; +  assert(A && B && C && D && "All values must be provided");    Value *V = nullptr;    Value *SimplifiedInst = nullptr; @@ -564,13 +562,11 @@ static Value *tryFactorization(InstCombiner::BuilderTy *Builder,          if (isa<OverflowingBinaryOperator>(&I))            HasNSW = I.hasNoSignedWrap(); -        if (BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS)) -          if (isa<OverflowingBinaryOperator>(Op0)) -            HasNSW &= Op0->hasNoSignedWrap(); +        if (auto *LOBO = dyn_cast<OverflowingBinaryOperator>(LHS)) +          HasNSW &= LOBO->hasNoSignedWrap(); -        if (BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS)) -          if (isa<OverflowingBinaryOperator>(Op1)) -            HasNSW &= Op1->hasNoSignedWrap(); +        if (auto *ROBO = dyn_cast<OverflowingBinaryOperator>(RHS)) +          HasNSW &= ROBO->hasNoSignedWrap();          // We can propagate 'nsw' if we know that          //  %Y = mul nsw i16 %X, C @@ -599,31 +595,39 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) {    Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);    BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS);    BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS); +  Instruction::BinaryOps TopLevelOpcode = I.getOpcode(); -  // Factorization. -  Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; -  auto TopLevelOpcode = I.getOpcode(); -  auto LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B); -  auto RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D); - -  // The instruction has the form "(A op' B) op (C op' D)".  Try to factorize -  // a common term. -  if (LHSOpcode == RHSOpcode) { -    if (Value *V = tryFactorization(Builder, DL, I, LHSOpcode, A, B, C, D)) -      return V; -  } - -  // The instruction has the form "(A op' B) op (C)".  Try to factorize common -  // term. -  if (Value *V = tryFactorization(Builder, DL, I, LHSOpcode, A, B, RHS, -                                  getIdentityValue(LHSOpcode, RHS))) -    return V; +  { +    // Factorization. +    Value *A, *B, *C, *D; +    Instruction::BinaryOps LHSOpcode, RHSOpcode; +    if (Op0) +      LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B); +    if (Op1) +      RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D); + +    // The instruction has the form "(A op' B) op (C op' D)".  Try to factorize +    // a common term. +    if (Op0 && Op1 && LHSOpcode == RHSOpcode) +      if (Value *V = tryFactorization(Builder, DL, I, LHSOpcode, A, B, C, D)) +        return V; + +    // The instruction has the form "(A op' B) op (C)".  Try to factorize common +    // term. +    if (Op0) +      if (Value *Ident = getIdentityValue(LHSOpcode, RHS)) +        if (Value *V = tryFactorization(Builder, DL, I, LHSOpcode, A, B, RHS, +                                        Ident)) +          return V; -  // The instruction has the form "(B) op (C op' D)".  Try to factorize common -  // term. -  if (Value *V = tryFactorization(Builder, DL, I, RHSOpcode, LHS, -                                  getIdentityValue(RHSOpcode, LHS), C, D)) -    return V; +    // The instruction has the form "(B) op (C op' D)".  Try to factorize common +    // term. +    if (Op1) +      if (Value *Ident = getIdentityValue(RHSOpcode, LHS)) +        if (Value *V = tryFactorization(Builder, DL, I, RHSOpcode, LHS, Ident, +                                        C, D)) +          return V; +  }    // Expansion.    if (Op0 && RightDistributesOverLeft(Op0->getOpcode(), TopLevelOpcode)) { @@ -720,6 +724,21 @@ Value *InstCombiner::dyn_castNegVal(Value *V) const {      if (C->getType()->getElementType()->isIntegerTy())        return ConstantExpr::getNeg(C); +  if (ConstantVector *CV = dyn_cast<ConstantVector>(V)) { +    for (unsigned i = 0, e = CV->getNumOperands(); i != e; ++i) { +      Constant *Elt = CV->getAggregateElement(i); +      if (!Elt) +        return nullptr; + +      if (isa<UndefValue>(Elt)) +        continue; + +      if (!isa<ConstantInt>(Elt)) +        return nullptr; +    } +    return ConstantExpr::getNeg(CV); +  } +    return nullptr;  } @@ -820,8 +839,29 @@ Instruction *InstCombiner::FoldOpIntoSelect(Instruction &Op, SelectInst *SI) {    return SelectInst::Create(SI->getCondition(), NewTV, NewFV, "", nullptr, SI);  } -Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) { -  PHINode *PN = cast<PHINode>(I.getOperand(0)); +static Value *foldOperationIntoPhiValue(BinaryOperator *I, Value *InV, +                                        InstCombiner *IC) { +  bool ConstIsRHS = isa<Constant>(I->getOperand(1)); +  Constant *C = cast<Constant>(I->getOperand(ConstIsRHS)); + +  if (auto *InC = dyn_cast<Constant>(InV)) { +    if (ConstIsRHS) +      return ConstantExpr::get(I->getOpcode(), InC, C); +    return ConstantExpr::get(I->getOpcode(), C, InC); +  } + +  Value *Op0 = InV, *Op1 = C; +  if (!ConstIsRHS) +    std::swap(Op0, Op1); + +  Value *RI = IC->Builder->CreateBinOp(I->getOpcode(), Op0, Op1, "phitmp"); +  auto *FPInst = dyn_cast<Instruction>(RI); +  if (FPInst && isa<FPMathOperator>(FPInst)) +    FPInst->copyFastMathFlags(I); +  return RI; +} + +Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) {    unsigned NumPHIValues = PN->getNumIncomingValues();    if (NumPHIValues == 0)      return nullptr; @@ -902,7 +942,11 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) {        // Beware of ConstantExpr:  it may eventually evaluate to getNullValue,        // even if currently isNullValue gives false.        Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i)); -      if (InC && !isa<ConstantExpr>(InC)) +      // For vector constants, we cannot use isNullValue to fold into +      // FalseVInPred versus TrueVInPred. When we have individual nonzero +      // elements in the vector, we will incorrectly fold InC to +      // `TrueVInPred`. +      if (InC && !isa<ConstantExpr>(InC) && isa<ConstantInt>(InC))          InV = InC->isNullValue() ? FalseVInPred : TrueVInPred;        else          InV = Builder->CreateSelect(PN->getIncomingValue(i), @@ -923,15 +967,9 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) {                                    C, "phitmp");        NewPN->addIncoming(InV, PN->getIncomingBlock(i));      } -  } else if (I.getNumOperands() == 2) { -    Constant *C = cast<Constant>(I.getOperand(1)); +  } else if (auto *BO = dyn_cast<BinaryOperator>(&I)) {      for (unsigned i = 0; i != NumPHIValues; ++i) { -      Value *InV = nullptr; -      if (Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i))) -        InV = ConstantExpr::get(I.getOpcode(), InC, C); -      else -        InV = Builder->CreateBinOp(cast<BinaryOperator>(I).getOpcode(), -                                   PN->getIncomingValue(i), C, "phitmp"); +      Value *InV = foldOperationIntoPhiValue(BO, PN->getIncomingValue(i), this);        NewPN->addIncoming(InV, PN->getIncomingBlock(i));      }    } else { @@ -957,14 +995,14 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) {    return replaceInstUsesWith(I, NewPN);  } -Instruction *InstCombiner::foldOpWithConstantIntoOperand(Instruction &I) { +Instruction *InstCombiner::foldOpWithConstantIntoOperand(BinaryOperator &I) {    assert(isa<Constant>(I.getOperand(1)) && "Unexpected operand type");    if (auto *Sel = dyn_cast<SelectInst>(I.getOperand(0))) {      if (Instruction *NewSel = FoldOpIntoSelect(I, Sel))        return NewSel; -  } else if (isa<PHINode>(I.getOperand(0))) { -    if (Instruction *NewPhi = FoldOpIntoPhi(I)) +  } else if (auto *PN = dyn_cast<PHINode>(I.getOperand(0))) { +    if (Instruction *NewPhi = foldOpIntoPhi(I, PN))        return NewPhi;    }    return nullptr; @@ -1315,22 +1353,19 @@ Value *InstCombiner::SimplifyVectorOp(BinaryOperator &Inst) {    assert(cast<VectorType>(LHS->getType())->getNumElements() == VWidth);    assert(cast<VectorType>(RHS->getType())->getNumElements() == VWidth); -  // If both arguments of binary operation are shuffles, which use the same -  // mask and shuffle within a single vector, it is worthwhile to move the -  // shuffle after binary operation: +  // If both arguments of the binary operation are shuffles that use the same +  // mask and shuffle within a single vector, move the shuffle after the binop:    //   Op(shuffle(v1, m), shuffle(v2, m)) -> shuffle(Op(v1, v2), m) -  if (isa<ShuffleVectorInst>(LHS) && isa<ShuffleVectorInst>(RHS)) { -    ShuffleVectorInst *LShuf = cast<ShuffleVectorInst>(LHS); -    ShuffleVectorInst *RShuf = cast<ShuffleVectorInst>(RHS); -    if (isa<UndefValue>(LShuf->getOperand(1)) && -        isa<UndefValue>(RShuf->getOperand(1)) && -        LShuf->getOperand(0)->getType() == RShuf->getOperand(0)->getType() && -        LShuf->getMask() == RShuf->getMask()) { -      Value *NewBO = CreateBinOpAsGiven(Inst, LShuf->getOperand(0), -          RShuf->getOperand(0), Builder); -      return Builder->CreateShuffleVector(NewBO, -          UndefValue::get(NewBO->getType()), LShuf->getMask()); -    } +  auto *LShuf = dyn_cast<ShuffleVectorInst>(LHS); +  auto *RShuf = dyn_cast<ShuffleVectorInst>(RHS); +  if (LShuf && RShuf && LShuf->getMask() == RShuf->getMask() && +      isa<UndefValue>(LShuf->getOperand(1)) && +      isa<UndefValue>(RShuf->getOperand(1)) && +      LShuf->getOperand(0)->getType() == RShuf->getOperand(0)->getType()) { +    Value *NewBO = CreateBinOpAsGiven(Inst, LShuf->getOperand(0), +                                      RShuf->getOperand(0), Builder); +    return Builder->CreateShuffleVector( +        NewBO, UndefValue::get(NewBO->getType()), LShuf->getMask());    }    // If one argument is a shuffle within one vector, the other is a constant, @@ -1559,27 +1594,21 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {        // Replace: gep (gep %P, long B), long A, ...        // With:    T = long A+B; gep %P, T, ...        // -      Value *Sum;        Value *SO1 = Src->getOperand(Src->getNumOperands()-1);        Value *GO1 = GEP.getOperand(1); -      if (SO1 == Constant::getNullValue(SO1->getType())) { -        Sum = GO1; -      } else if (GO1 == Constant::getNullValue(GO1->getType())) { -        Sum = SO1; -      } else { -        // If they aren't the same type, then the input hasn't been processed -        // by the loop above yet (which canonicalizes sequential index types to -        // intptr_t).  Just avoid transforming this until the input has been -        // normalized. -        if (SO1->getType() != GO1->getType()) -          return nullptr; -        // Only do the combine when GO1 and SO1 are both constants. Only in -        // this case, we are sure the cost after the merge is never more than -        // that before the merge. -        if (!isa<Constant>(GO1) || !isa<Constant>(SO1)) -          return nullptr; -        Sum = Builder->CreateAdd(SO1, GO1, PtrOp->getName()+".sum"); -      } + +      // If they aren't the same type, then the input hasn't been processed +      // by the loop above yet (which canonicalizes sequential index types to +      // intptr_t).  Just avoid transforming this until the input has been +      // normalized. +      if (SO1->getType() != GO1->getType()) +        return nullptr; + +      Value* Sum = SimplifyAddInst(GO1, SO1, false, false, DL, &TLI, &DT, &AC); +      // Only do the combine when we are sure the cost after the +      // merge is never more than that before the merge. +      if (Sum == nullptr) +        return nullptr;        // Update the GEP in place if possible.        if (Src->getNumOperands() == 2) { @@ -1654,14 +1683,14 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {      }    } -  // Handle gep(bitcast x) and gep(gep x, 0, 0, 0). -  Value *StrippedPtr = PtrOp->stripPointerCasts(); -  PointerType *StrippedPtrTy = dyn_cast<PointerType>(StrippedPtr->getType()); -    // We do not handle pointer-vector geps here. -  if (!StrippedPtrTy) +  if (GEP.getType()->isVectorTy())      return nullptr; +  // Handle gep(bitcast x) and gep(gep x, 0, 0, 0). +  Value *StrippedPtr = PtrOp->stripPointerCasts(); +  PointerType *StrippedPtrTy = cast<PointerType>(StrippedPtr->getType()); +    if (StrippedPtr != PtrOp) {      bool HasZeroPointerIndex = false;      if (ConstantInt *C = dyn_cast<ConstantInt>(GEP.getOperand(1))) @@ -2239,11 +2268,11 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) {    ConstantInt *AddRHS;    if (match(Cond, m_Add(m_Value(Op0), m_ConstantInt(AddRHS)))) {      // Change 'switch (X+4) case 1:' into 'switch (X) case -3'. -    for (SwitchInst::CaseIt CaseIter : SI.cases()) { -      Constant *NewCase = ConstantExpr::getSub(CaseIter.getCaseValue(), AddRHS); +    for (auto Case : SI.cases()) { +      Constant *NewCase = ConstantExpr::getSub(Case.getCaseValue(), AddRHS);        assert(isa<ConstantInt>(NewCase) &&               "Result of expression should be constant"); -      CaseIter.setValue(cast<ConstantInt>(NewCase)); +      Case.setValue(cast<ConstantInt>(NewCase));      }      SI.setCondition(Op0);      return &SI; @@ -2275,9 +2304,9 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) {      Value *NewCond = Builder->CreateTrunc(Cond, Ty, "trunc");      SI.setCondition(NewCond); -    for (SwitchInst::CaseIt CaseIter : SI.cases()) { -      APInt TruncatedCase = CaseIter.getCaseValue()->getValue().trunc(NewWidth); -      CaseIter.setValue(ConstantInt::get(SI.getContext(), TruncatedCase)); +    for (auto Case : SI.cases()) { +      APInt TruncatedCase = Case.getCaseValue()->getValue().trunc(NewWidth); +      Case.setValue(ConstantInt::get(SI.getContext(), TruncatedCase));      }      return &SI;    } @@ -2934,8 +2963,8 @@ bool InstCombiner::run() {          Result->takeName(I);          // Push the new instruction and any users onto the worklist. -        Worklist.Add(Result);          Worklist.AddUsersToWorkList(*Result); +        Worklist.Add(Result);          // Insert the new instruction into the basic block...          BasicBlock *InstParent = I->getParent(); @@ -2958,8 +2987,8 @@ bool InstCombiner::run() {          if (isInstructionTriviallyDead(I, &TLI)) {            eraseInstFromFunction(*I);          } else { -          Worklist.Add(I);            Worklist.AddUsersToWorkList(*I); +          Worklist.Add(I);          }        }        MadeIRChange = true; @@ -3022,12 +3051,11 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL,          }        // See if we can constant fold its operands. -      for (User::op_iterator i = Inst->op_begin(), e = Inst->op_end(); i != e; -           ++i) { -        if (!isa<ConstantVector>(i) && !isa<ConstantExpr>(i)) +      for (Use &U : Inst->operands()) { +        if (!isa<ConstantVector>(U) && !isa<ConstantExpr>(U))            continue; -        auto *C = cast<Constant>(i); +        auto *C = cast<Constant>(U);          Constant *&FoldRes = FoldedConstants[C];          if (!FoldRes)            FoldRes = ConstantFoldConstant(C, DL, TLI); @@ -3035,7 +3063,10 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL,            FoldRes = C;          if (FoldRes != C) { -          *i = FoldRes; +          DEBUG(dbgs() << "IC: ConstFold operand of: " << *Inst +                       << "\n    Old = " << *C +                       << "\n    New = " << *FoldRes << '\n'); +          U = FoldRes;            MadeIRChange = true;          }        } @@ -3055,17 +3086,7 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL,        }      } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) {        if (ConstantInt *Cond = dyn_cast<ConstantInt>(SI->getCondition())) { -        // See if this is an explicit destination. -        for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); -             i != e; ++i) -          if (i.getCaseValue() == Cond) { -            BasicBlock *ReachableBB = i.getCaseSuccessor(); -            Worklist.push_back(ReachableBB); -            continue; -          } - -        // Otherwise it is the default destination. -        Worklist.push_back(SI->getDefaultDest()); +        Worklist.push_back(SI->findCaseValue(Cond)->getCaseSuccessor());          continue;        }      } @@ -3152,6 +3173,7 @@ combineInstructionsOverFunction(Function &F, InstCombineWorklist &Worklist,      InstCombiner IC(Worklist, &Builder, F.optForMinSize(), ExpensiveCombines,                      AA, AC, TLI, DT, DL, LI); +    IC.MaxArraySizeForCombine = MaxArraySize;      Changed |= IC.run();      if (!Changed) @@ -3176,9 +3198,10 @@ PreservedAnalyses InstCombinePass::run(Function &F,      return PreservedAnalyses::all();    // Mark all the analyses that instcombine updates as preserved. -  // FIXME: This should also 'preserve the CFG'.    PreservedAnalyses PA; -  PA.preserve<DominatorTreeAnalysis>(); +  PA.preserveSet<CFGAnalyses>(); +  PA.preserve<AAManager>(); +  PA.preserve<GlobalsAA>();    return PA;  } diff --git a/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/lib/Transforms/Instrumentation/AddressSanitizer.cpp index f5e9e7dd5a93..94cfc69ed555 100644 --- a/lib/Transforms/Instrumentation/AddressSanitizer.cpp +++ b/lib/Transforms/Instrumentation/AddressSanitizer.cpp @@ -80,6 +80,7 @@ static const uint64_t kMIPS64_ShadowOffset64 = 1ULL << 37;  static const uint64_t kAArch64_ShadowOffset64 = 1ULL << 36;  static const uint64_t kFreeBSD_ShadowOffset32 = 1ULL << 30;  static const uint64_t kFreeBSD_ShadowOffset64 = 1ULL << 46; +static const uint64_t kPS4CPU_ShadowOffset64 = 1ULL << 40;  static const uint64_t kWindowsShadowOffset32 = 3ULL << 28;  // The shadow memory space is dynamically allocated.  static const uint64_t kWindowsShadowOffset64 = kDynamicShadowSentinel; @@ -380,6 +381,7 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize,    bool IsAndroid = TargetTriple.isAndroid();    bool IsIOS = TargetTriple.isiOS() || TargetTriple.isWatchOS();    bool IsFreeBSD = TargetTriple.isOSFreeBSD(); +  bool IsPS4CPU = TargetTriple.isPS4CPU();    bool IsLinux = TargetTriple.isOSLinux();    bool IsPPC64 = TargetTriple.getArch() == llvm::Triple::ppc64 ||                   TargetTriple.getArch() == llvm::Triple::ppc64le; @@ -392,6 +394,7 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize,                    TargetTriple.getArch() == llvm::Triple::mips64el;    bool IsAArch64 = TargetTriple.getArch() == llvm::Triple::aarch64;    bool IsWindows = TargetTriple.isOSWindows(); +  bool IsFuchsia = TargetTriple.isOSFuchsia();    ShadowMapping Mapping; @@ -412,12 +415,18 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize,      else        Mapping.Offset = kDefaultShadowOffset32;    } else {  // LongSize == 64 -    if (IsPPC64) +    // Fuchsia is always PIE, which means that the beginning of the address +    // space is always available. +    if (IsFuchsia) +      Mapping.Offset = 0; +    else if (IsPPC64)        Mapping.Offset = kPPC64_ShadowOffset64;      else if (IsSystemZ)        Mapping.Offset = kSystemZ_ShadowOffset64;      else if (IsFreeBSD)        Mapping.Offset = kFreeBSD_ShadowOffset64; +    else if (IsPS4CPU) +      Mapping.Offset = kPS4CPU_ShadowOffset64;      else if (IsLinux && IsX86_64) {        if (IsKasan)          Mapping.Offset = kLinuxKasan_ShadowOffset64; @@ -456,9 +465,9 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize,    // offset is not necessary 1/8-th of the address space.  On SystemZ,    // we could OR the constant in a single instruction, but it's more    // efficient to load it once and use indexed addressing. -  Mapping.OrShadowOffset = !IsAArch64 && !IsPPC64 && !IsSystemZ -                           && !(Mapping.Offset & (Mapping.Offset - 1)) -                           && Mapping.Offset != kDynamicShadowSentinel; +  Mapping.OrShadowOffset = !IsAArch64 && !IsPPC64 && !IsSystemZ && !IsPS4CPU && +                           !(Mapping.Offset & (Mapping.Offset - 1)) && +                           Mapping.Offset != kDynamicShadowSentinel;    return Mapping;  } @@ -567,8 +576,6 @@ struct AddressSanitizer : public FunctionPass {    Type *IntptrTy;    ShadowMapping Mapping;    DominatorTree *DT; -  Function *AsanCtorFunction = nullptr; -  Function *AsanInitFunction = nullptr;    Function *AsanHandleNoReturnFunc;    Function *AsanPtrCmpFunction, *AsanPtrSubFunction;    // This array is indexed by AccessIsWrite, Experiment and log2(AccessSize). @@ -1561,31 +1568,31 @@ void AddressSanitizerModule::initializeCallbacks(Module &M) {    // Declare our poisoning and unpoisoning functions.    AsanPoisonGlobals = checkSanitizerInterfaceFunction(M.getOrInsertFunction( -      kAsanPoisonGlobalsName, IRB.getVoidTy(), IntptrTy, nullptr)); +      kAsanPoisonGlobalsName, IRB.getVoidTy(), IntptrTy));    AsanPoisonGlobals->setLinkage(Function::ExternalLinkage);    AsanUnpoisonGlobals = checkSanitizerInterfaceFunction(M.getOrInsertFunction( -      kAsanUnpoisonGlobalsName, IRB.getVoidTy(), nullptr)); +      kAsanUnpoisonGlobalsName, IRB.getVoidTy()));    AsanUnpoisonGlobals->setLinkage(Function::ExternalLinkage);    // Declare functions that register/unregister globals.    AsanRegisterGlobals = checkSanitizerInterfaceFunction(M.getOrInsertFunction( -      kAsanRegisterGlobalsName, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); +      kAsanRegisterGlobalsName, IRB.getVoidTy(), IntptrTy, IntptrTy));    AsanRegisterGlobals->setLinkage(Function::ExternalLinkage);    AsanUnregisterGlobals = checkSanitizerInterfaceFunction(        M.getOrInsertFunction(kAsanUnregisterGlobalsName, IRB.getVoidTy(), -                            IntptrTy, IntptrTy, nullptr)); +                            IntptrTy, IntptrTy));    AsanUnregisterGlobals->setLinkage(Function::ExternalLinkage);    // Declare the functions that find globals in a shared object and then invoke    // the (un)register function on them.    AsanRegisterImageGlobals =        checkSanitizerInterfaceFunction(M.getOrInsertFunction( -          kAsanRegisterImageGlobalsName, IRB.getVoidTy(), IntptrTy, nullptr)); +          kAsanRegisterImageGlobalsName, IRB.getVoidTy(), IntptrTy));    AsanRegisterImageGlobals->setLinkage(Function::ExternalLinkage);    AsanUnregisterImageGlobals =        checkSanitizerInterfaceFunction(M.getOrInsertFunction( -          kAsanUnregisterImageGlobalsName, IRB.getVoidTy(), IntptrTy, nullptr)); +          kAsanUnregisterImageGlobalsName, IRB.getVoidTy(), IntptrTy));    AsanUnregisterImageGlobals->setLinkage(Function::ExternalLinkage);  } @@ -1618,11 +1625,12 @@ void AddressSanitizerModule::SetComdatForGlobalMetadata(  GlobalVariable *  AddressSanitizerModule::CreateMetadataGlobal(Module &M, Constant *Initializer,                                               StringRef OriginalName) { -  GlobalVariable *Metadata = -      new GlobalVariable(M, Initializer->getType(), false, -                         GlobalVariable::InternalLinkage, Initializer, -                         Twine("__asan_global_") + -                             GlobalValue::getRealLinkageName(OriginalName)); +  auto Linkage = TargetTriple.isOSBinFormatMachO() +                     ? GlobalVariable::InternalLinkage +                     : GlobalVariable::PrivateLinkage; +  GlobalVariable *Metadata = new GlobalVariable( +      M, Initializer->getType(), false, Linkage, Initializer, +      Twine("__asan_global_") + GlobalValue::getRealLinkageName(OriginalName));    Metadata->setSection(getGlobalMetadataSection());    return Metadata;  } @@ -1862,7 +1870,8 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) {      GlobalValue *InstrumentedGlobal = NewGlobal;      bool CanUsePrivateAliases = -        TargetTriple.isOSBinFormatELF() || TargetTriple.isOSBinFormatMachO(); +        TargetTriple.isOSBinFormatELF() || TargetTriple.isOSBinFormatMachO() || +        TargetTriple.isOSBinFormatWasm();      if (CanUsePrivateAliases && ClUsePrivateAliasForGlobals) {        // Create local alias for NewGlobal to avoid crash on ODR between        // instrumented and non-instrumented libraries. @@ -1926,13 +1935,19 @@ bool AddressSanitizerModule::runOnModule(Module &M) {    Mapping = getShadowMapping(TargetTriple, LongSize, CompileKernel);    initializeCallbacks(M); -  bool Changed = false; +  if (CompileKernel) +    return false; + +  Function *AsanCtorFunction; +  std::tie(AsanCtorFunction, std::ignore) = createSanitizerCtorAndInitFunctions( +      M, kAsanModuleCtorName, kAsanInitName, /*InitArgTypes=*/{}, +      /*InitArgs=*/{}, kAsanVersionCheckName); +  appendToGlobalCtors(M, AsanCtorFunction, kAsanCtorAndDtorPriority); +  bool Changed = false;    // TODO(glider): temporarily disabled globals instrumentation for KASan. -  if (ClGlobals && !CompileKernel) { -    Function *CtorFunc = M.getFunction(kAsanModuleCtorName); -    assert(CtorFunc); -    IRBuilder<> IRB(CtorFunc->getEntryBlock().getTerminator()); +  if (ClGlobals) { +    IRBuilder<> IRB(AsanCtorFunction->getEntryBlock().getTerminator());      Changed |= InstrumentGlobals(IRB, M);    } @@ -1949,49 +1964,60 @@ void AddressSanitizer::initializeCallbacks(Module &M) {        const std::string ExpStr = Exp ? "exp_" : "";        const std::string SuffixStr = CompileKernel ? "N" : "_n";        const std::string EndingStr = Recover ? "_noabort" : ""; -      Type *ExpType = Exp ? Type::getInt32Ty(*C) : nullptr; -      AsanErrorCallbackSized[AccessIsWrite][Exp] = -          checkSanitizerInterfaceFunction(M.getOrInsertFunction( -              kAsanReportErrorTemplate + ExpStr + TypeStr + SuffixStr + EndingStr, -              IRB.getVoidTy(), IntptrTy, IntptrTy, ExpType, nullptr)); -      AsanMemoryAccessCallbackSized[AccessIsWrite][Exp] = -          checkSanitizerInterfaceFunction(M.getOrInsertFunction( -              ClMemoryAccessCallbackPrefix + ExpStr + TypeStr + "N" + EndingStr, -              IRB.getVoidTy(), IntptrTy, IntptrTy, ExpType, nullptr)); -      for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes; -           AccessSizeIndex++) { -        const std::string Suffix = TypeStr + itostr(1ULL << AccessSizeIndex); -        AsanErrorCallback[AccessIsWrite][Exp][AccessSizeIndex] = -            checkSanitizerInterfaceFunction(M.getOrInsertFunction( -                kAsanReportErrorTemplate + ExpStr + Suffix + EndingStr, -                IRB.getVoidTy(), IntptrTy, ExpType, nullptr)); -        AsanMemoryAccessCallback[AccessIsWrite][Exp][AccessSizeIndex] = -            checkSanitizerInterfaceFunction(M.getOrInsertFunction( -                ClMemoryAccessCallbackPrefix + ExpStr + Suffix + EndingStr, -                IRB.getVoidTy(), IntptrTy, ExpType, nullptr)); + +      SmallVector<Type *, 3> Args2 = {IntptrTy, IntptrTy}; +      SmallVector<Type *, 2> Args1{1, IntptrTy}; +      if (Exp) { +        Type *ExpType = Type::getInt32Ty(*C); +        Args2.push_back(ExpType); +        Args1.push_back(ExpType);        } -    } +	    AsanErrorCallbackSized[AccessIsWrite][Exp] = +	        checkSanitizerInterfaceFunction(M.getOrInsertFunction( +	            kAsanReportErrorTemplate + ExpStr + TypeStr + SuffixStr + +	                EndingStr, +	            FunctionType::get(IRB.getVoidTy(), Args2, false))); + +	    AsanMemoryAccessCallbackSized[AccessIsWrite][Exp] = +	        checkSanitizerInterfaceFunction(M.getOrInsertFunction( +	            ClMemoryAccessCallbackPrefix + ExpStr + TypeStr + "N" + EndingStr, +	            FunctionType::get(IRB.getVoidTy(), Args2, false))); + +	    for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes; +	         AccessSizeIndex++) { +	      const std::string Suffix = TypeStr + itostr(1ULL << AccessSizeIndex); +	      AsanErrorCallback[AccessIsWrite][Exp][AccessSizeIndex] = +	          checkSanitizerInterfaceFunction(M.getOrInsertFunction( +	              kAsanReportErrorTemplate + ExpStr + Suffix + EndingStr, +	              FunctionType::get(IRB.getVoidTy(), Args1, false))); + +	      AsanMemoryAccessCallback[AccessIsWrite][Exp][AccessSizeIndex] = +	          checkSanitizerInterfaceFunction(M.getOrInsertFunction( +	              ClMemoryAccessCallbackPrefix + ExpStr + Suffix + EndingStr, +	              FunctionType::get(IRB.getVoidTy(), Args1, false))); +	    } +	  }    }    const std::string MemIntrinCallbackPrefix =        CompileKernel ? std::string("") : ClMemoryAccessCallbackPrefix;    AsanMemmove = checkSanitizerInterfaceFunction(M.getOrInsertFunction(        MemIntrinCallbackPrefix + "memmove", IRB.getInt8PtrTy(), -      IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy, nullptr)); +      IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy));    AsanMemcpy = checkSanitizerInterfaceFunction(M.getOrInsertFunction(        MemIntrinCallbackPrefix + "memcpy", IRB.getInt8PtrTy(), -      IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy, nullptr)); +      IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy));    AsanMemset = checkSanitizerInterfaceFunction(M.getOrInsertFunction(        MemIntrinCallbackPrefix + "memset", IRB.getInt8PtrTy(), -      IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy, nullptr)); +      IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy));    AsanHandleNoReturnFunc = checkSanitizerInterfaceFunction( -      M.getOrInsertFunction(kAsanHandleNoReturnName, IRB.getVoidTy(), nullptr)); +      M.getOrInsertFunction(kAsanHandleNoReturnName, IRB.getVoidTy()));    AsanPtrCmpFunction = checkSanitizerInterfaceFunction(M.getOrInsertFunction( -      kAsanPtrCmp, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); +      kAsanPtrCmp, IRB.getVoidTy(), IntptrTy, IntptrTy));    AsanPtrSubFunction = checkSanitizerInterfaceFunction(M.getOrInsertFunction( -      kAsanPtrSub, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); +      kAsanPtrSub, IRB.getVoidTy(), IntptrTy, IntptrTy));    // We insert an empty inline asm after __asan_report* to avoid callback merge.    EmptyAsm = InlineAsm::get(FunctionType::get(IRB.getVoidTy(), false),                              StringRef(""), StringRef(""), @@ -2001,7 +2027,6 @@ void AddressSanitizer::initializeCallbacks(Module &M) {  // virtual  bool AddressSanitizer::doInitialization(Module &M) {    // Initialize the private fields. No one has accessed them before. -    GlobalsMD.init(M);    C = &(M.getContext()); @@ -2009,13 +2034,6 @@ bool AddressSanitizer::doInitialization(Module &M) {    IntptrTy = Type::getIntNTy(*C, LongSize);    TargetTriple = Triple(M.getTargetTriple()); -  if (!CompileKernel) { -    std::tie(AsanCtorFunction, AsanInitFunction) = -        createSanitizerCtorAndInitFunctions( -            M, kAsanModuleCtorName, kAsanInitName, -            /*InitArgTypes=*/{}, /*InitArgs=*/{}, kAsanVersionCheckName); -    appendToGlobalCtors(M, AsanCtorFunction, kAsanCtorAndDtorPriority); -  }    Mapping = getShadowMapping(TargetTriple, LongSize, CompileKernel);    return true;  } @@ -2034,6 +2052,8 @@ bool AddressSanitizer::maybeInsertAsanInitAtFunctionEntry(Function &F) {    // We cannot just ignore these methods, because they may call other    // instrumented functions.    if (F.getName().find(" load]") != std::string::npos) { +    Function *AsanInitFunction = +        declareSanitizerInitFunction(*F.getParent(), kAsanInitName, {});      IRBuilder<> IRB(&F.front(), F.front().begin());      IRB.CreateCall(AsanInitFunction, {});      return true; @@ -2081,7 +2101,6 @@ void AddressSanitizer::markEscapedLocalAllocas(Function &F) {  }  bool AddressSanitizer::runOnFunction(Function &F) { -  if (&F == AsanCtorFunction) return false;    if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage) return false;    if (!ClDebugFunc.empty() && ClDebugFunc == F.getName()) return false;    if (F.getName().startswith("__asan_")) return false; @@ -2175,8 +2194,9 @@ bool AddressSanitizer::runOnFunction(Function &F) {        (ClInstrumentationWithCallsThreshold >= 0 &&         ToInstrument.size() > (unsigned)ClInstrumentationWithCallsThreshold);    const DataLayout &DL = F.getParent()->getDataLayout(); -  ObjectSizeOffsetVisitor ObjSizeVis(DL, TLI, F.getContext(), -                                     /*RoundToAlign=*/true); +  ObjectSizeOpts ObjSizeOpts; +  ObjSizeOpts.RoundToAlign = true; +  ObjectSizeOffsetVisitor ObjSizeVis(DL, TLI, F.getContext(), ObjSizeOpts);    // Instrument.    int NumInstrumented = 0; @@ -2234,18 +2254,18 @@ void FunctionStackPoisoner::initializeCallbacks(Module &M) {      std::string Suffix = itostr(i);      AsanStackMallocFunc[i] = checkSanitizerInterfaceFunction(          M.getOrInsertFunction(kAsanStackMallocNameTemplate + Suffix, IntptrTy, -                              IntptrTy, nullptr)); +                              IntptrTy));      AsanStackFreeFunc[i] = checkSanitizerInterfaceFunction(          M.getOrInsertFunction(kAsanStackFreeNameTemplate + Suffix, -                              IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); +                              IRB.getVoidTy(), IntptrTy, IntptrTy));    }    if (ASan.UseAfterScope) {      AsanPoisonStackMemoryFunc = checkSanitizerInterfaceFunction(          M.getOrInsertFunction(kAsanPoisonStackMemoryName, IRB.getVoidTy(), -                              IntptrTy, IntptrTy, nullptr)); +                              IntptrTy, IntptrTy));      AsanUnpoisonStackMemoryFunc = checkSanitizerInterfaceFunction(          M.getOrInsertFunction(kAsanUnpoisonStackMemoryName, IRB.getVoidTy(), -                              IntptrTy, IntptrTy, nullptr)); +                              IntptrTy, IntptrTy));    }    for (size_t Val : {0x00, 0xf1, 0xf2, 0xf3, 0xf5, 0xf8}) { @@ -2254,14 +2274,14 @@ void FunctionStackPoisoner::initializeCallbacks(Module &M) {      Name << std::setw(2) << std::setfill('0') << std::hex << Val;      AsanSetShadowFunc[Val] =          checkSanitizerInterfaceFunction(M.getOrInsertFunction( -            Name.str(), IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); +            Name.str(), IRB.getVoidTy(), IntptrTy, IntptrTy));    }    AsanAllocaPoisonFunc = checkSanitizerInterfaceFunction(M.getOrInsertFunction( -      kAsanAllocaPoison, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); +      kAsanAllocaPoison, IRB.getVoidTy(), IntptrTy, IntptrTy));    AsanAllocasUnpoisonFunc =        checkSanitizerInterfaceFunction(M.getOrInsertFunction( -          kAsanAllocasUnpoison, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); +          kAsanAllocasUnpoison, IRB.getVoidTy(), IntptrTy, IntptrTy));  }  void FunctionStackPoisoner::copyToShadowInline(ArrayRef<uint8_t> ShadowMask, diff --git a/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp index b34d5b8c45a7..4e454f0c95b6 100644 --- a/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp +++ b/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp @@ -254,7 +254,7 @@ class DataFlowSanitizer : public ModulePass {    MDNode *ColdCallWeights;    DFSanABIList ABIList;    DenseMap<Value *, Function *> UnwrappedFnMap; -  AttributeSet ReadOnlyNoneAttrs; +  AttributeList ReadOnlyNoneAttrs;    bool DFSanRuntimeShadowMask;    Value *getShadowAddress(Value *Addr, Instruction *Pos); @@ -331,6 +331,10 @@ class DFSanVisitor : public InstVisitor<DFSanVisitor> {    DFSanFunction &DFSF;    DFSanVisitor(DFSanFunction &DFSF) : DFSF(DFSF) {} +  const DataLayout &getDataLayout() const { +    return DFSF.F->getParent()->getDataLayout(); +  } +    void visitOperandShadowInst(Instruction &I);    void visitBinaryOperator(BinaryOperator &BO); @@ -539,16 +543,17 @@ DataFlowSanitizer::buildWrapperFunction(Function *F, StringRef NewFName,                                      F->getParent());    NewF->copyAttributesFrom(F);    NewF->removeAttributes( -    AttributeSet::ReturnIndex, -    AttributeSet::get(F->getContext(), AttributeSet::ReturnIndex, -                    AttributeFuncs::typeIncompatible(NewFT->getReturnType()))); +      AttributeList::ReturnIndex, +      AttributeList::get( +          F->getContext(), AttributeList::ReturnIndex, +          AttributeFuncs::typeIncompatible(NewFT->getReturnType())));    BasicBlock *BB = BasicBlock::Create(*Ctx, "entry", NewF);    if (F->isVarArg()) {      NewF->removeAttributes( -        AttributeSet::FunctionIndex, -        AttributeSet().addAttribute(*Ctx, AttributeSet::FunctionIndex, -                                    "split-stack")); +        AttributeList::FunctionIndex, +        AttributeList().addAttribute(*Ctx, AttributeList::FunctionIndex, +                                     "split-stack"));      CallInst::Create(DFSanVarargWrapperFn,                       IRBuilder<>(BB).CreateGlobalStringPtr(F->getName()), "",                       BB); @@ -580,8 +585,7 @@ Constant *DataFlowSanitizer::getOrBuildTrampolineFunction(FunctionType *FT,      Function::arg_iterator AI = F->arg_begin(); ++AI;      for (unsigned N = FT->getNumParams(); N != 0; ++AI, --N)        Args.push_back(&*AI); -    CallInst *CI = -        CallInst::Create(&F->getArgumentList().front(), Args, "", BB); +    CallInst *CI = CallInst::Create(&*F->arg_begin(), Args, "", BB);      ReturnInst *RI;      if (FT->getReturnType()->isVoidTy())        RI = ReturnInst::Create(*Ctx, BB); @@ -595,7 +599,7 @@ Constant *DataFlowSanitizer::getOrBuildTrampolineFunction(FunctionType *FT,      DFSanVisitor(DFSF).visitCallInst(*CI);      if (!FT->getReturnType()->isVoidTy())        new StoreInst(DFSF.getShadow(RI->getReturnValue()), -                    &F->getArgumentList().back(), RI); +                    &*std::prev(F->arg_end()), RI);    }    return C; @@ -622,26 +626,26 @@ bool DataFlowSanitizer::runOnModule(Module &M) {    DFSanUnionFn = Mod->getOrInsertFunction("__dfsan_union", DFSanUnionFnTy);    if (Function *F = dyn_cast<Function>(DFSanUnionFn)) { -    F->addAttribute(AttributeSet::FunctionIndex, Attribute::NoUnwind); -    F->addAttribute(AttributeSet::FunctionIndex, Attribute::ReadNone); -    F->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt); +    F->addAttribute(AttributeList::FunctionIndex, Attribute::NoUnwind); +    F->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone); +    F->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt);      F->addAttribute(1, Attribute::ZExt);      F->addAttribute(2, Attribute::ZExt);    }    DFSanCheckedUnionFn = Mod->getOrInsertFunction("dfsan_union", DFSanUnionFnTy);    if (Function *F = dyn_cast<Function>(DFSanCheckedUnionFn)) { -    F->addAttribute(AttributeSet::FunctionIndex, Attribute::NoUnwind); -    F->addAttribute(AttributeSet::FunctionIndex, Attribute::ReadNone); -    F->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt); +    F->addAttribute(AttributeList::FunctionIndex, Attribute::NoUnwind); +    F->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone); +    F->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt);      F->addAttribute(1, Attribute::ZExt);      F->addAttribute(2, Attribute::ZExt);    }    DFSanUnionLoadFn =        Mod->getOrInsertFunction("__dfsan_union_load", DFSanUnionLoadFnTy);    if (Function *F = dyn_cast<Function>(DFSanUnionLoadFn)) { -    F->addAttribute(AttributeSet::FunctionIndex, Attribute::NoUnwind); -    F->addAttribute(AttributeSet::FunctionIndex, Attribute::ReadOnly); -    F->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt); +    F->addAttribute(AttributeList::FunctionIndex, Attribute::NoUnwind); +    F->addAttribute(AttributeList::FunctionIndex, Attribute::ReadOnly); +    F->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt);    }    DFSanUnimplementedFn =        Mod->getOrInsertFunction("__dfsan_unimplemented", DFSanUnimplementedFnTy); @@ -696,7 +700,7 @@ bool DataFlowSanitizer::runOnModule(Module &M) {    AttrBuilder B;    B.addAttribute(Attribute::ReadOnly).addAttribute(Attribute::ReadNone); -  ReadOnlyNoneAttrs = AttributeSet::get(*Ctx, AttributeSet::FunctionIndex, B); +  ReadOnlyNoneAttrs = AttributeList::get(*Ctx, AttributeList::FunctionIndex, B);    // First, change the ABI of every function in the module.  ABI-listed    // functions keep their original ABI and get a wrapper function. @@ -717,9 +721,10 @@ bool DataFlowSanitizer::runOnModule(Module &M) {          Function *NewF = Function::Create(NewFT, F.getLinkage(), "", &M);          NewF->copyAttributesFrom(&F);          NewF->removeAttributes( -          AttributeSet::ReturnIndex, -          AttributeSet::get(NewF->getContext(), AttributeSet::ReturnIndex, -                    AttributeFuncs::typeIncompatible(NewFT->getReturnType()))); +            AttributeList::ReturnIndex, +            AttributeList::get( +                NewF->getContext(), AttributeList::ReturnIndex, +                AttributeFuncs::typeIncompatible(NewFT->getReturnType())));          for (Function::arg_iterator FArg = F.arg_begin(),                                      NewFArg = NewF->arg_begin(),                                      FArgEnd = F.arg_end(); @@ -758,7 +763,7 @@ bool DataFlowSanitizer::runOnModule(Module &M) {            &F, std::string("dfsw$") + std::string(F.getName()),            GlobalValue::LinkOnceODRLinkage, NewFT);        if (getInstrumentedABI() == IA_TLS) -        NewF->removeAttributes(AttributeSet::FunctionIndex, ReadOnlyNoneAttrs); +        NewF->removeAttributes(AttributeList::FunctionIndex, ReadOnlyNoneAttrs);        Value *WrappedFnCst =            ConstantExpr::getBitCast(NewF, PointerType::getUnqual(FT)); @@ -906,7 +911,7 @@ Value *DFSanFunction::getShadow(Value *V) {          break;        }        case DataFlowSanitizer::IA_Args: { -        unsigned ArgIdx = A->getArgNo() + F->getArgumentList().size() / 2; +        unsigned ArgIdx = A->getArgNo() + F->arg_size() / 2;          Function::arg_iterator i = F->arg_begin();          while (ArgIdx--)            ++i; @@ -983,7 +988,7 @@ Value *DFSanFunction::combineShadows(Value *V1, Value *V2, Instruction *Pos) {    IRBuilder<> IRB(Pos);    if (AvoidNewBlocks) {      CallInst *Call = IRB.CreateCall(DFS.DFSanCheckedUnionFn, {V1, V2}); -    Call->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt); +    Call->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt);      Call->addAttribute(1, Attribute::ZExt);      Call->addAttribute(2, Attribute::ZExt); @@ -996,7 +1001,7 @@ Value *DFSanFunction::combineShadows(Value *V1, Value *V2, Instruction *Pos) {          Ne, Pos, /*Unreachable=*/false, DFS.ColdCallWeights, &DT));      IRBuilder<> ThenIRB(BI);      CallInst *Call = ThenIRB.CreateCall(DFS.DFSanUnionFn, {V1, V2}); -    Call->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt); +    Call->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt);      Call->addAttribute(1, Attribute::ZExt);      Call->addAttribute(2, Attribute::ZExt); @@ -1099,7 +1104,7 @@ Value *DFSanFunction::loadShadow(Value *Addr, uint64_t Size, uint64_t Align,      CallInst *FallbackCall = FallbackIRB.CreateCall(          DFS.DFSanUnionLoadFn,          {ShadowAddr, ConstantInt::get(DFS.IntptrTy, Size)}); -    FallbackCall->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt); +    FallbackCall->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt);      // Compare each of the shadows stored in the loaded 64 bits to each other,      // by computing (WideShadow rotl ShadowWidth) == WideShadow. @@ -1156,7 +1161,7 @@ Value *DFSanFunction::loadShadow(Value *Addr, uint64_t Size, uint64_t Align,    IRBuilder<> IRB(Pos);    CallInst *FallbackCall = IRB.CreateCall(        DFS.DFSanUnionLoadFn, {ShadowAddr, ConstantInt::get(DFS.IntptrTy, Size)}); -  FallbackCall->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt); +  FallbackCall->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt);    return FallbackCall;  } @@ -1446,7 +1451,7 @@ void DFSanVisitor::visitCallSite(CallSite CS) {            // Custom functions returning non-void will write to the return label.            if (!FT->getReturnType()->isVoidTy()) { -            CustomFn->removeAttributes(AttributeSet::FunctionIndex, +            CustomFn->removeAttributes(AttributeList::FunctionIndex,                                         DFSF.DFS.ReadOnlyNoneAttrs);            }          } @@ -1481,7 +1486,8 @@ void DFSanVisitor::visitCallSite(CallSite CS) {            auto *LabelVATy = ArrayType::get(DFSF.DFS.ShadowTy,                                             CS.arg_size() - FT->getNumParams());            auto *LabelVAAlloca = new AllocaInst( -              LabelVATy, "labelva", &DFSF.F->getEntryBlock().front()); +              LabelVATy, getDataLayout().getAllocaAddrSpace(), +              "labelva", &DFSF.F->getEntryBlock().front());            for (unsigned n = 0; i != CS.arg_end(); ++i, ++n) {              auto LabelVAPtr = IRB.CreateStructGEP(LabelVATy, LabelVAAlloca, n); @@ -1494,8 +1500,9 @@ void DFSanVisitor::visitCallSite(CallSite CS) {          if (!FT->getReturnType()->isVoidTy()) {            if (!DFSF.LabelReturnAlloca) {              DFSF.LabelReturnAlloca = -                new AllocaInst(DFSF.DFS.ShadowTy, "labelreturn", -                               &DFSF.F->getEntryBlock().front()); +              new AllocaInst(DFSF.DFS.ShadowTy, +                             getDataLayout().getAllocaAddrSpace(), +                             "labelreturn", &DFSF.F->getEntryBlock().front());            }            Args.push_back(DFSF.LabelReturnAlloca);          } @@ -1574,7 +1581,8 @@ void DFSanVisitor::visitCallSite(CallSite CS) {        unsigned VarArgSize = CS.arg_size() - FT->getNumParams();        ArrayType *VarArgArrayTy = ArrayType::get(DFSF.DFS.ShadowTy, VarArgSize);        AllocaInst *VarArgShadow = -          new AllocaInst(VarArgArrayTy, "", &DFSF.F->getEntryBlock().front()); +        new AllocaInst(VarArgArrayTy, getDataLayout().getAllocaAddrSpace(), +                       "", &DFSF.F->getEntryBlock().front());        Args.push_back(IRB.CreateConstGEP2_32(VarArgArrayTy, VarArgShadow, 0, 0));        for (unsigned n = 0; i != e; ++i, ++n) {          IRB.CreateStore( @@ -1593,7 +1601,7 @@ void DFSanVisitor::visitCallSite(CallSite CS) {      }      NewCS.setCallingConv(CS.getCallingConv());      NewCS.setAttributes(CS.getAttributes().removeAttributes( -        *DFSF.DFS.Ctx, AttributeSet::ReturnIndex, +        *DFSF.DFS.Ctx, AttributeList::ReturnIndex,          AttributeFuncs::typeIncompatible(NewCS.getInstruction()->getType())));      if (Next) { diff --git a/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp b/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp index 05eba6c4dc69..7dea1dee756a 100644 --- a/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp +++ b/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp @@ -267,35 +267,35 @@ void EfficiencySanitizer::initializeCallbacks(Module &M) {      SmallString<32> AlignedLoadName("__esan_aligned_load" + ByteSizeStr);      EsanAlignedLoad[Idx] =          checkSanitizerInterfaceFunction(M.getOrInsertFunction( -            AlignedLoadName, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); +            AlignedLoadName, IRB.getVoidTy(), IRB.getInt8PtrTy()));      SmallString<32> AlignedStoreName("__esan_aligned_store" + ByteSizeStr);      EsanAlignedStore[Idx] =          checkSanitizerInterfaceFunction(M.getOrInsertFunction( -            AlignedStoreName, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); +            AlignedStoreName, IRB.getVoidTy(), IRB.getInt8PtrTy()));      SmallString<32> UnalignedLoadName("__esan_unaligned_load" + ByteSizeStr);      EsanUnalignedLoad[Idx] =          checkSanitizerInterfaceFunction(M.getOrInsertFunction( -            UnalignedLoadName, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); +            UnalignedLoadName, IRB.getVoidTy(), IRB.getInt8PtrTy()));      SmallString<32> UnalignedStoreName("__esan_unaligned_store" + ByteSizeStr);      EsanUnalignedStore[Idx] =          checkSanitizerInterfaceFunction(M.getOrInsertFunction( -            UnalignedStoreName, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); +            UnalignedStoreName, IRB.getVoidTy(), IRB.getInt8PtrTy()));    }    EsanUnalignedLoadN = checkSanitizerInterfaceFunction(        M.getOrInsertFunction("__esan_unaligned_loadN", IRB.getVoidTy(), -                            IRB.getInt8PtrTy(), IntptrTy, nullptr)); +                            IRB.getInt8PtrTy(), IntptrTy));    EsanUnalignedStoreN = checkSanitizerInterfaceFunction(        M.getOrInsertFunction("__esan_unaligned_storeN", IRB.getVoidTy(), -                            IRB.getInt8PtrTy(), IntptrTy, nullptr)); +                            IRB.getInt8PtrTy(), IntptrTy));    MemmoveFn = checkSanitizerInterfaceFunction(        M.getOrInsertFunction("memmove", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), -                            IRB.getInt8PtrTy(), IntptrTy, nullptr)); +                            IRB.getInt8PtrTy(), IntptrTy));    MemcpyFn = checkSanitizerInterfaceFunction(        M.getOrInsertFunction("memcpy", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), -                            IRB.getInt8PtrTy(), IntptrTy, nullptr)); +                            IRB.getInt8PtrTy(), IntptrTy));    MemsetFn = checkSanitizerInterfaceFunction(        M.getOrInsertFunction("memset", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), -                            IRB.getInt32Ty(), IntptrTy, nullptr)); +                            IRB.getInt32Ty(), IntptrTy));  }  bool EfficiencySanitizer::shouldIgnoreStructType(StructType *StructTy) { @@ -533,7 +533,7 @@ void EfficiencySanitizer::createDestructor(Module &M, Constant *ToolInfoArg) {    IRBuilder<> IRB_Dtor(EsanDtorFunction->getEntryBlock().getTerminator());    Function *EsanExit = checkSanitizerInterfaceFunction(        M.getOrInsertFunction(EsanExitName, IRB_Dtor.getVoidTy(), -                            Int8PtrTy, nullptr)); +                            Int8PtrTy));    EsanExit->setLinkage(Function::ExternalLinkage);    IRB_Dtor.CreateCall(EsanExit, {ToolInfoArg});    appendToGlobalDtors(M, EsanDtorFunction, EsanCtorAndDtorPriority); @@ -757,7 +757,7 @@ bool EfficiencySanitizer::instrumentGetElementPtr(Instruction *I, Module &M) {      return false;    }    Type *SourceTy = GepInst->getSourceElementType(); -  StructType *StructTy; +  StructType *StructTy = nullptr;    ConstantInt *Idx;    // Check if GEP calculates address from a struct array.    if (isa<StructType>(SourceTy)) { diff --git a/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp index 1ba13bdfe05a..61d627673c90 100644 --- a/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp +++ b/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp @@ -1,4 +1,4 @@ -//===-- IndirectCallPromotion.cpp - Promote indirect calls to direct calls ===// +//===-- IndirectCallPromotion.cpp - Optimizations based on value profiling ===//  //  //                      The LLVM Compiler Infrastructure  // @@ -17,6 +17,8 @@  #include "llvm/ADT/Statistic.h"  #include "llvm/ADT/StringRef.h"  #include "llvm/ADT/Twine.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/GlobalsModRef.h"  #include "llvm/Analysis/IndirectCallPromotionAnalysis.h"  #include "llvm/Analysis/IndirectCallSiteVisitor.h"  #include "llvm/IR/BasicBlock.h" @@ -40,6 +42,7 @@  #include "llvm/Support/CommandLine.h"  #include "llvm/Support/Debug.h"  #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h"  #include "llvm/Transforms/Instrumentation.h"  #include "llvm/Transforms/PGOInstrumentation.h"  #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -53,6 +56,8 @@ using namespace llvm;  STATISTIC(NumOfPGOICallPromotion, "Number of indirect call promotions.");  STATISTIC(NumOfPGOICallsites, "Number of indirect call candidate sites."); +STATISTIC(NumOfPGOMemOPOpt, "Number of memop intrinsics optimized."); +STATISTIC(NumOfPGOMemOPAnnotate, "Number of memop intrinsics annotated.");  // Command line option to disable indirect-call promotion with the default as  // false. This is for debug purpose. @@ -80,6 +85,12 @@ static cl::opt<bool> ICPLTOMode("icp-lto", cl::init(false), cl::Hidden,                                  cl::desc("Run indirect-call promotion in LTO "                                           "mode")); +// Set if the pass is called in SamplePGO mode. The difference for SamplePGO +// mode is it will add prof metadatato the created direct call. +static cl::opt<bool> +    ICPSamplePGOMode("icp-samplepgo", cl::init(false), cl::Hidden, +                     cl::desc("Run indirect-call promotion in SamplePGO mode")); +  // If the option is set to true, only call instructions will be considered for  // transformation -- invoke instructions will be ignored.  static cl::opt<bool> @@ -100,13 +111,51 @@ static cl::opt<bool>      ICPDUMPAFTER("icp-dumpafter", cl::init(false), cl::Hidden,                   cl::desc("Dump IR after transformation happens")); +// The minimum call count to optimize memory intrinsic calls. +static cl::opt<unsigned> +    MemOPCountThreshold("pgo-memop-count-threshold", cl::Hidden, cl::ZeroOrMore, +                        cl::init(1000), +                        cl::desc("The minimum count to optimize memory " +                                 "intrinsic calls")); + +// Command line option to disable memory intrinsic optimization. The default is +// false. This is for debug purpose. +static cl::opt<bool> DisableMemOPOPT("disable-memop-opt", cl::init(false), +                                     cl::Hidden, cl::desc("Disable optimize")); + +// The percent threshold to optimize memory intrinsic calls. +static cl::opt<unsigned> +    MemOPPercentThreshold("pgo-memop-percent-threshold", cl::init(40), +                          cl::Hidden, cl::ZeroOrMore, +                          cl::desc("The percentage threshold for the " +                                   "memory intrinsic calls optimization")); + +// Maximum number of versions for optimizing memory intrinsic call. +static cl::opt<unsigned> +    MemOPMaxVersion("pgo-memop-max-version", cl::init(3), cl::Hidden, +                    cl::ZeroOrMore, +                    cl::desc("The max version for the optimized memory " +                             " intrinsic calls")); + +// Scale the counts from the annotation using the BB count value. +static cl::opt<bool> +    MemOPScaleCount("pgo-memop-scale-count", cl::init(true), cl::Hidden, +                    cl::desc("Scale the memop size counts using the basic " +                             " block count value")); + +// This option sets the rangge of precise profile memop sizes. +extern cl::opt<std::string> MemOPSizeRange; + +// This option sets the value that groups large memop sizes +extern cl::opt<unsigned> MemOPSizeLarge; +  namespace {  class PGOIndirectCallPromotionLegacyPass : public ModulePass {  public:    static char ID; -  PGOIndirectCallPromotionLegacyPass(bool InLTO = false) -      : ModulePass(ID), InLTO(InLTO) { +  PGOIndirectCallPromotionLegacyPass(bool InLTO = false, bool SamplePGO = false) +      : ModulePass(ID), InLTO(InLTO), SamplePGO(SamplePGO) {      initializePGOIndirectCallPromotionLegacyPassPass(          *PassRegistry::getPassRegistry());    } @@ -119,6 +168,28 @@ private:    // If this pass is called in LTO. We need to special handling the PGOFuncName    // for the static variables due to LTO's internalization.    bool InLTO; + +  // If this pass is called in SamplePGO. We need to add the prof metadata to +  // the promoted direct call. +  bool SamplePGO; +}; + +class PGOMemOPSizeOptLegacyPass : public FunctionPass { +public: +  static char ID; + +  PGOMemOPSizeOptLegacyPass() : FunctionPass(ID) { +    initializePGOMemOPSizeOptLegacyPassPass(*PassRegistry::getPassRegistry()); +  } + +  StringRef getPassName() const override { return "PGOMemOPSize"; } + +private: +  bool runOnFunction(Function &F) override; +  void getAnalysisUsage(AnalysisUsage &AU) const override { +    AU.addRequired<BlockFrequencyInfoWrapperPass>(); +    AU.addPreserved<GlobalsAAWrapperPass>(); +  }  };  } // end anonymous namespace @@ -128,8 +199,22 @@ INITIALIZE_PASS(PGOIndirectCallPromotionLegacyPass, "pgo-icall-prom",                  "direct calls.",                  false, false) -ModulePass *llvm::createPGOIndirectCallPromotionLegacyPass(bool InLTO) { -  return new PGOIndirectCallPromotionLegacyPass(InLTO); +ModulePass *llvm::createPGOIndirectCallPromotionLegacyPass(bool InLTO, +                                                           bool SamplePGO) { +  return new PGOIndirectCallPromotionLegacyPass(InLTO, SamplePGO); +} + +char PGOMemOPSizeOptLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(PGOMemOPSizeOptLegacyPass, "pgo-memop-opt", +                      "Optimize memory intrinsic using its size value profile", +                      false, false) +INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) +INITIALIZE_PASS_END(PGOMemOPSizeOptLegacyPass, "pgo-memop-opt", +                    "Optimize memory intrinsic using its size value profile", +                    false, false) + +FunctionPass *llvm::createPGOMemOPSizeOptLegacyPass() { +  return new PGOMemOPSizeOptLegacyPass();  }  namespace { @@ -144,17 +229,11 @@ private:    // defines.    InstrProfSymtab *Symtab; -  enum TargetStatus { -    OK,                   // Should be able to promote. -    NotAvailableInModule, // Cannot find the target in current module. -    ReturnTypeMismatch,   // Return type mismatch b/w target and indirect-call. -    NumArgsMismatch,      // Number of arguments does not match. -    ArgTypeMismatch       // Type mismatch in the arguments (cannot bitcast). -  }; +  bool SamplePGO;    // Test if we can legally promote this direct-call of Target. -  TargetStatus isPromotionLegal(Instruction *Inst, uint64_t Target, -                                Function *&F); +  bool isPromotionLegal(Instruction *Inst, uint64_t Target, Function *&F, +                        const char **Reason = nullptr);    // A struct that records the direct target and it's call count.    struct PromotionCandidate { @@ -172,91 +251,77 @@ private:        Instruction *Inst, const ArrayRef<InstrProfValueData> &ValueDataRef,        uint64_t TotalCount, uint32_t NumCandidates); -  // Main function that transforms Inst (either a indirect-call instruction, or -  // an invoke instruction , to a conditional call to F. This is like: -  //     if (Inst.CalledValue == F) -  //        F(...); -  //     else -  //        Inst(...); -  //     end -  // TotalCount is the profile count value that the instruction executes. -  // Count is the profile count value that F is the target function. -  // These two values are being used to update the branch weight. -  void promote(Instruction *Inst, Function *F, uint64_t Count, -               uint64_t TotalCount); -    // Promote a list of targets for one indirect-call callsite. Return    // the number of promotions.    uint32_t tryToPromote(Instruction *Inst,                          const std::vector<PromotionCandidate> &Candidates,                          uint64_t &TotalCount); -  static const char *StatusToString(const TargetStatus S) { -    switch (S) { -    case OK: -      return "OK to promote"; -    case NotAvailableInModule: -      return "Cannot find the target"; -    case ReturnTypeMismatch: -      return "Return type mismatch"; -    case NumArgsMismatch: -      return "The number of arguments mismatch"; -    case ArgTypeMismatch: -      return "Argument Type mismatch"; -    } -    llvm_unreachable("Should not reach here"); -  } -    // Noncopyable    ICallPromotionFunc(const ICallPromotionFunc &other) = delete;    ICallPromotionFunc &operator=(const ICallPromotionFunc &other) = delete;  public: -  ICallPromotionFunc(Function &Func, Module *Modu, InstrProfSymtab *Symtab) -      : F(Func), M(Modu), Symtab(Symtab) { -  } +  ICallPromotionFunc(Function &Func, Module *Modu, InstrProfSymtab *Symtab, +                     bool SamplePGO) +      : F(Func), M(Modu), Symtab(Symtab), SamplePGO(SamplePGO) {}    bool processFunction();  };  } // end anonymous namespace -ICallPromotionFunc::TargetStatus -ICallPromotionFunc::isPromotionLegal(Instruction *Inst, uint64_t Target, -                                     Function *&TargetFunction) { -  Function *DirectCallee = Symtab->getFunction(Target); -  if (DirectCallee == nullptr) -    return NotAvailableInModule; +bool llvm::isLegalToPromote(Instruction *Inst, Function *F, +                            const char **Reason) {    // Check the return type.    Type *CallRetType = Inst->getType();    if (!CallRetType->isVoidTy()) { -    Type *FuncRetType = DirectCallee->getReturnType(); +    Type *FuncRetType = F->getReturnType();      if (FuncRetType != CallRetType && -        !CastInst::isBitCastable(FuncRetType, CallRetType)) -      return ReturnTypeMismatch; +        !CastInst::isBitCastable(FuncRetType, CallRetType)) { +      if (Reason) +        *Reason = "Return type mismatch"; +      return false; +    }    }    // Check if the arguments are compatible with the parameters -  FunctionType *DirectCalleeType = DirectCallee->getFunctionType(); +  FunctionType *DirectCalleeType = F->getFunctionType();    unsigned ParamNum = DirectCalleeType->getFunctionNumParams();    CallSite CS(Inst);    unsigned ArgNum = CS.arg_size(); -  if (ParamNum != ArgNum && !DirectCalleeType->isVarArg()) -    return NumArgsMismatch; +  if (ParamNum != ArgNum && !DirectCalleeType->isVarArg()) { +    if (Reason) +      *Reason = "The number of arguments mismatch"; +    return false; +  }    for (unsigned I = 0; I < ParamNum; ++I) {      Type *PTy = DirectCalleeType->getFunctionParamType(I);      Type *ATy = CS.getArgument(I)->getType();      if (PTy == ATy)        continue; -    if (!CastInst::castIsValid(Instruction::BitCast, CS.getArgument(I), PTy)) -      return ArgTypeMismatch; +    if (!CastInst::castIsValid(Instruction::BitCast, CS.getArgument(I), PTy)) { +      if (Reason) +        *Reason = "Argument type mismatch"; +      return false; +    }    }    DEBUG(dbgs() << " #" << NumOfPGOICallPromotion << " Promote the icall to " -               << Symtab->getFuncName(Target) << "\n"); -  TargetFunction = DirectCallee; -  return OK; +               << F->getName() << "\n"); +  return true; +} + +bool ICallPromotionFunc::isPromotionLegal(Instruction *Inst, uint64_t Target, +                                          Function *&TargetFunction, +                                          const char **Reason) { +  TargetFunction = Symtab->getFunction(Target); +  if (TargetFunction == nullptr) { +    *Reason = "Cannot find the target"; +    return false; +  } +  return isLegalToPromote(Inst, TargetFunction, Reason);  }  // Indirect-call promotion heuristic. The direct targets are sorted based on @@ -296,10 +361,9 @@ ICallPromotionFunc::getPromotionCandidatesForCallSite(        break;      }      Function *TargetFunction = nullptr; -    TargetStatus Status = isPromotionLegal(Inst, Target, TargetFunction); -    if (Status != OK) { +    const char *Reason = nullptr; +    if (!isPromotionLegal(Inst, Target, TargetFunction, &Reason)) {        StringRef TargetFuncName = Symtab->getFuncName(Target); -      const char *Reason = StatusToString(Status);        DEBUG(dbgs() << " Not promote: " << Reason << "\n");        emitOptimizationRemarkMissed(            F.getContext(), "pgo-icall-prom", F, Inst->getDebugLoc(), @@ -532,8 +596,14 @@ static void insertCallRetPHI(Instruction *Inst, Instruction *CallResult,  //     Ret = phi(Ret1, Ret2);  // It adds type casts for the args do not match the parameters and the return  // value. Branch weights metadata also updated. -void ICallPromotionFunc::promote(Instruction *Inst, Function *DirectCallee, -                                 uint64_t Count, uint64_t TotalCount) { +// If \p AttachProfToDirectCall is true, a prof metadata is attached to the +// new direct call to contain \p Count. This is used by SamplePGO inliner to +// check callsite hotness. +// Returns the promoted direct call instruction. +Instruction *llvm::promoteIndirectCall(Instruction *Inst, +                                       Function *DirectCallee, uint64_t Count, +                                       uint64_t TotalCount, +                                       bool AttachProfToDirectCall) {    assert(DirectCallee != nullptr);    BasicBlock *BB = Inst->getParent();    // Just to suppress the non-debug build warning. @@ -548,6 +618,14 @@ void ICallPromotionFunc::promote(Instruction *Inst, Function *DirectCallee,    Instruction *NewInst =        createDirectCallInst(Inst, DirectCallee, DirectCallBB, MergeBB); +  if (AttachProfToDirectCall) { +    SmallVector<uint32_t, 1> Weights; +    Weights.push_back(Count); +    MDBuilder MDB(NewInst->getContext()); +    dyn_cast<Instruction>(NewInst->stripPointerCasts()) +        ->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); +  } +    // Move Inst from MergeBB to IndirectCallBB.    Inst->removeFromParent();    IndirectCallBB->getInstList().insert(IndirectCallBB->getFirstInsertionPt(), @@ -576,9 +654,10 @@ void ICallPromotionFunc::promote(Instruction *Inst, Function *DirectCallee,    DEBUG(dbgs() << *BB << *DirectCallBB << *IndirectCallBB << *MergeBB << "\n");    emitOptimizationRemark( -      F.getContext(), "pgo-icall-prom", F, Inst->getDebugLoc(), +      BB->getContext(), "pgo-icall-prom", *BB->getParent(), Inst->getDebugLoc(),        Twine("Promote indirect call to ") + DirectCallee->getName() +            " with count " + Twine(Count) + " out of " + Twine(TotalCount)); +  return NewInst;  }  // Promote indirect-call to conditional direct-call for one callsite. @@ -589,7 +668,7 @@ uint32_t ICallPromotionFunc::tryToPromote(    for (auto &C : Candidates) {      uint64_t Count = C.Count; -    promote(Inst, C.TargetFunction, Count, TotalCount); +    promoteIndirectCall(Inst, C.TargetFunction, Count, TotalCount, SamplePGO);      assert(TotalCount >= Count);      TotalCount -= Count;      NumOfPGOICallPromotion++; @@ -630,7 +709,7 @@ bool ICallPromotionFunc::processFunction() {  }  // A wrapper function that does the actual work. -static bool promoteIndirectCalls(Module &M, bool InLTO) { +static bool promoteIndirectCalls(Module &M, bool InLTO, bool SamplePGO) {    if (DisableICP)      return false;    InstrProfSymtab Symtab; @@ -641,7 +720,7 @@ static bool promoteIndirectCalls(Module &M, bool InLTO) {        continue;      if (F.hasFnAttribute(Attribute::OptimizeNone))        continue; -    ICallPromotionFunc ICallPromotion(F, &M, &Symtab); +    ICallPromotionFunc ICallPromotion(F, &M, &Symtab, SamplePGO);      bool FuncChanged = ICallPromotion.processFunction();      if (ICPDUMPAFTER && FuncChanged) {        DEBUG(dbgs() << "\n== IR Dump After =="; F.print(dbgs())); @@ -658,12 +737,289 @@ static bool promoteIndirectCalls(Module &M, bool InLTO) {  bool PGOIndirectCallPromotionLegacyPass::runOnModule(Module &M) {    // Command-line option has the priority for InLTO. -  return promoteIndirectCalls(M, InLTO | ICPLTOMode); +  return promoteIndirectCalls(M, InLTO | ICPLTOMode, +                              SamplePGO | ICPSamplePGOMode);  } -PreservedAnalyses PGOIndirectCallPromotion::run(Module &M, ModuleAnalysisManager &AM) { -  if (!promoteIndirectCalls(M, InLTO | ICPLTOMode)) +PreservedAnalyses PGOIndirectCallPromotion::run(Module &M, +                                                ModuleAnalysisManager &AM) { +  if (!promoteIndirectCalls(M, InLTO | ICPLTOMode, +                            SamplePGO | ICPSamplePGOMode))      return PreservedAnalyses::all();    return PreservedAnalyses::none();  } + +namespace { +class MemOPSizeOpt : public InstVisitor<MemOPSizeOpt> { +public: +  MemOPSizeOpt(Function &Func, BlockFrequencyInfo &BFI) +      : Func(Func), BFI(BFI), Changed(false) { +    ValueDataArray = +        llvm::make_unique<InstrProfValueData[]>(MemOPMaxVersion + 2); +    // Get the MemOPSize range information from option MemOPSizeRange, +    getMemOPSizeRangeFromOption(MemOPSizeRange, PreciseRangeStart, +                                PreciseRangeLast); +  } +  bool isChanged() const { return Changed; } +  void perform() { +    WorkList.clear(); +    visit(Func); + +    for (auto &MI : WorkList) { +      ++NumOfPGOMemOPAnnotate; +      if (perform(MI)) { +        Changed = true; +        ++NumOfPGOMemOPOpt; +        DEBUG(dbgs() << "MemOP calls: " << MI->getCalledFunction()->getName() +                     << "is Transformed.\n"); +      } +    } +  } + +  void visitMemIntrinsic(MemIntrinsic &MI) { +    Value *Length = MI.getLength(); +    // Not perform on constant length calls. +    if (dyn_cast<ConstantInt>(Length)) +      return; +    WorkList.push_back(&MI); +  } + +private: +  Function &Func; +  BlockFrequencyInfo &BFI; +  bool Changed; +  std::vector<MemIntrinsic *> WorkList; +  // Start of the previse range. +  int64_t PreciseRangeStart; +  // Last value of the previse range. +  int64_t PreciseRangeLast; +  // The space to read the profile annotation. +  std::unique_ptr<InstrProfValueData[]> ValueDataArray; +  bool perform(MemIntrinsic *MI); + +  // This kind shows which group the value falls in. For PreciseValue, we have +  // the profile count for that value. LargeGroup groups the values that are in +  // range [LargeValue, +inf). NonLargeGroup groups the rest of values. +  enum MemOPSizeKind { PreciseValue, NonLargeGroup, LargeGroup }; + +  MemOPSizeKind getMemOPSizeKind(int64_t Value) const { +    if (Value == MemOPSizeLarge && MemOPSizeLarge != 0) +      return LargeGroup; +    if (Value == PreciseRangeLast + 1) +      return NonLargeGroup; +    return PreciseValue; +  } +}; + +static const char *getMIName(const MemIntrinsic *MI) { +  switch (MI->getIntrinsicID()) { +  case Intrinsic::memcpy: +    return "memcpy"; +  case Intrinsic::memmove: +    return "memmove"; +  case Intrinsic::memset: +    return "memset"; +  default: +    return "unknown"; +  } +} + +static bool isProfitable(uint64_t Count, uint64_t TotalCount) { +  assert(Count <= TotalCount); +  if (Count < MemOPCountThreshold) +    return false; +  if (Count < TotalCount * MemOPPercentThreshold / 100) +    return false; +  return true; +} + +static inline uint64_t getScaledCount(uint64_t Count, uint64_t Num, +                                      uint64_t Denom) { +  if (!MemOPScaleCount) +    return Count; +  bool Overflowed; +  uint64_t ScaleCount = SaturatingMultiply(Count, Num, &Overflowed); +  return ScaleCount / Denom; +} + +bool MemOPSizeOpt::perform(MemIntrinsic *MI) { +  assert(MI); +  if (MI->getIntrinsicID() == Intrinsic::memmove) +    return false; + +  uint32_t NumVals, MaxNumPromotions = MemOPMaxVersion + 2; +  uint64_t TotalCount; +  if (!getValueProfDataFromInst(*MI, IPVK_MemOPSize, MaxNumPromotions, +                                ValueDataArray.get(), NumVals, TotalCount)) +    return false; + +  uint64_t ActualCount = TotalCount; +  uint64_t SavedTotalCount = TotalCount; +  if (MemOPScaleCount) { +    auto BBEdgeCount = BFI.getBlockProfileCount(MI->getParent()); +    if (!BBEdgeCount) +      return false; +    ActualCount = *BBEdgeCount; +  } + +  if (ActualCount < MemOPCountThreshold) +    return false; + +  ArrayRef<InstrProfValueData> VDs(ValueDataArray.get(), NumVals); +  TotalCount = ActualCount; +  if (MemOPScaleCount) +    DEBUG(dbgs() << "Scale counts: numberator = " << ActualCount +                 << " denominator = " << SavedTotalCount << "\n"); + +  // Keeping track of the count of the default case: +  uint64_t RemainCount = TotalCount; +  SmallVector<uint64_t, 16> SizeIds; +  SmallVector<uint64_t, 16> CaseCounts; +  uint64_t MaxCount = 0; +  unsigned Version = 0; +  // Default case is in the front -- save the slot here. +  CaseCounts.push_back(0); +  for (auto &VD : VDs) { +    int64_t V = VD.Value; +    uint64_t C = VD.Count; +    if (MemOPScaleCount) +      C = getScaledCount(C, ActualCount, SavedTotalCount); + +    // Only care precise value here. +    if (getMemOPSizeKind(V) != PreciseValue) +      continue; + +    // ValueCounts are sorted on the count. Break at the first un-profitable +    // value. +    if (!isProfitable(C, RemainCount)) +      break; + +    SizeIds.push_back(V); +    CaseCounts.push_back(C); +    if (C > MaxCount) +      MaxCount = C; + +    assert(RemainCount >= C); +    RemainCount -= C; + +    if (++Version > MemOPMaxVersion && MemOPMaxVersion != 0) +      break; +  } + +  if (Version == 0) +    return false; + +  CaseCounts[0] = RemainCount; +  if (RemainCount > MaxCount) +    MaxCount = RemainCount; + +  uint64_t SumForOpt = TotalCount - RemainCount; +  DEBUG(dbgs() << "Read one memory intrinsic profile: " << SumForOpt << " vs " +               << TotalCount << "\n"); +  DEBUG( +      for (auto &VD +           : VDs) { dbgs() << "  (" << VD.Value << "," << VD.Count << ")\n"; }); + +  DEBUG(dbgs() << "Optimize one memory intrinsic call to " << Version +               << " Versions\n"); + +  // mem_op(..., size) +  // ==> +  // switch (size) { +  //   case s1: +  //      mem_op(..., s1); +  //      goto merge_bb; +  //   case s2: +  //      mem_op(..., s2); +  //      goto merge_bb; +  //   ... +  //   default: +  //      mem_op(..., size); +  //      goto merge_bb; +  // } +  // merge_bb: + +  BasicBlock *BB = MI->getParent(); +  DEBUG(dbgs() << "\n\n== Basic Block Before ==\n"); +  DEBUG(dbgs() << *BB << "\n"); + +  BasicBlock *DefaultBB = SplitBlock(BB, MI); +  BasicBlock::iterator It(*MI); +  ++It; +  assert(It != DefaultBB->end()); +  BasicBlock *MergeBB = SplitBlock(DefaultBB, &(*It)); +  DefaultBB->setName("MemOP.Default"); +  MergeBB->setName("MemOP.Merge"); + +  auto &Ctx = Func.getContext(); +  IRBuilder<> IRB(BB); +  BB->getTerminator()->eraseFromParent(); +  Value *SizeVar = MI->getLength(); +  SwitchInst *SI = IRB.CreateSwitch(SizeVar, DefaultBB, SizeIds.size()); + +  // Clear the value profile data. +  MI->setMetadata(LLVMContext::MD_prof, nullptr); + +  DEBUG(dbgs() << "\n\n== Basic Block After==\n"); + +  for (uint64_t SizeId : SizeIds) { +    ConstantInt *CaseSizeId = ConstantInt::get(Type::getInt64Ty(Ctx), SizeId); +    BasicBlock *CaseBB = BasicBlock::Create( +        Ctx, Twine("MemOP.Case.") + Twine(SizeId), &Func, DefaultBB); +    Instruction *NewInst = MI->clone(); +    // Fix the argument. +    dyn_cast<MemIntrinsic>(NewInst)->setLength(CaseSizeId); +    CaseBB->getInstList().push_back(NewInst); +    IRBuilder<> IRBCase(CaseBB); +    IRBCase.CreateBr(MergeBB); +    SI->addCase(CaseSizeId, CaseBB); +    DEBUG(dbgs() << *CaseBB << "\n"); +  } +  setProfMetadata(Func.getParent(), SI, CaseCounts, MaxCount); + +  DEBUG(dbgs() << *BB << "\n"); +  DEBUG(dbgs() << *DefaultBB << "\n"); +  DEBUG(dbgs() << *MergeBB << "\n"); + +  emitOptimizationRemark(Func.getContext(), "memop-opt", Func, +                         MI->getDebugLoc(), +                         Twine("optimize ") + getMIName(MI) + " with count " + +                             Twine(SumForOpt) + " out of " + Twine(TotalCount) + +                             " for " + Twine(Version) + " versions"); + +  return true; +} +} // namespace + +static bool PGOMemOPSizeOptImpl(Function &F, BlockFrequencyInfo &BFI) { +  if (DisableMemOPOPT) +    return false; + +  if (F.hasFnAttribute(Attribute::OptimizeForSize)) +    return false; +  MemOPSizeOpt MemOPSizeOpt(F, BFI); +  MemOPSizeOpt.perform(); +  return MemOPSizeOpt.isChanged(); +} + +bool PGOMemOPSizeOptLegacyPass::runOnFunction(Function &F) { +  BlockFrequencyInfo &BFI = +      getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI(); +  return PGOMemOPSizeOptImpl(F, BFI); +} + +namespace llvm { +char &PGOMemOPSizeOptID = PGOMemOPSizeOptLegacyPass::ID; + +PreservedAnalyses PGOMemOPSizeOpt::run(Function &F, +                                       FunctionAnalysisManager &FAM) { +  auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F); +  bool Changed = PGOMemOPSizeOptImpl(F, BFI); +  if (!Changed) +    return PreservedAnalyses::all(); +  auto  PA = PreservedAnalyses(); +  PA.preserve<GlobalsAA>(); +  return PA; +} +} // namespace llvm diff --git a/lib/Transforms/Instrumentation/InstrProfiling.cpp b/lib/Transforms/Instrumentation/InstrProfiling.cpp index adea7e772447..d91ac6ac7883 100644 --- a/lib/Transforms/Instrumentation/InstrProfiling.cpp +++ b/lib/Transforms/Instrumentation/InstrProfiling.cpp @@ -14,18 +14,58 @@  //===----------------------------------------------------------------------===//  #include "llvm/Transforms/InstrProfiling.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h"  #include "llvm/ADT/Triple.h" +#include "llvm/ADT/Twine.h"  #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h"  #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IRBuilder.h"  #include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/Pass.h"  #include "llvm/ProfileData/InstrProf.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/ErrorHandling.h"  #include "llvm/Transforms/Utils/ModuleUtils.h" +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <string>  using namespace llvm;  #define DEBUG_TYPE "instrprof" +// The start and end values of precise value profile range for memory +// intrinsic sizes +cl::opt<std::string> MemOPSizeRange( +    "memop-size-range", +    cl::desc("Set the range of size in memory intrinsic calls to be profiled " +             "precisely, in a format of <start_val>:<end_val>"), +    cl::init("")); + +// The value that considered to be large value in  memory intrinsic. +cl::opt<unsigned> MemOPSizeLarge( +    "memop-size-large", +    cl::desc("Set large value thresthold in memory intrinsic size profiling. " +             "Value of 0 disables the large value profiling."), +    cl::init(8192)); +  namespace {  cl::opt<bool> DoNameCompression("enable-name-compression", @@ -41,6 +81,7 @@ cl::opt<bool> ValueProfileStaticAlloc(      "vp-static-alloc",      cl::desc("Do static counter allocation for value profiler"),      cl::init(true)); +  cl::opt<double> NumCountersPerValueSite(      "vp-counters-per-site",      cl::desc("The average number of profile counters allocated " @@ -56,9 +97,11 @@ class InstrProfilingLegacyPass : public ModulePass {  public:    static char ID; -  InstrProfilingLegacyPass() : ModulePass(ID), InstrProf() {} + +  InstrProfilingLegacyPass() : ModulePass(ID) {}    InstrProfilingLegacyPass(const InstrProfOptions &Options)        : ModulePass(ID), InstrProf(Options) {} +    StringRef getPassName() const override {      return "Frontend instrumentation-based coverage lowering";    } @@ -73,7 +116,7 @@ public:    }  }; -} // anonymous namespace +} // end anonymous namespace  PreservedAnalyses InstrProfiling::run(Module &M, ModuleAnalysisManager &AM) {    auto &TLI = AM.getResult<TargetLibraryAnalysis>(M); @@ -97,30 +140,6 @@ llvm::createInstrProfilingLegacyPass(const InstrProfOptions &Options) {    return new InstrProfilingLegacyPass(Options);  } -bool InstrProfiling::isMachO() const { -  return Triple(M->getTargetTriple()).isOSBinFormatMachO(); -} - -/// Get the section name for the counter variables. -StringRef InstrProfiling::getCountersSection() const { -  return getInstrProfCountersSectionName(isMachO()); -} - -/// Get the section name for the name variables. -StringRef InstrProfiling::getNameSection() const { -  return getInstrProfNameSectionName(isMachO()); -} - -/// Get the section name for the profile data variables. -StringRef InstrProfiling::getDataSection() const { -  return getInstrProfDataSectionName(isMachO()); -} - -/// Get the section name for the coverage mapping data. -StringRef InstrProfiling::getCoverageSection() const { -  return getInstrProfCoverageSectionName(isMachO()); -} -  static InstrProfIncrementInst *castToIncrementInst(Instruction *Instr) {    InstrProfIncrementInst *Inc = dyn_cast<InstrProfIncrementInstStep>(Instr);    if (Inc) @@ -137,6 +156,9 @@ bool InstrProfiling::run(Module &M, const TargetLibraryInfo &TLI) {    NamesSize = 0;    ProfileDataMap.clear();    UsedVars.clear(); +  getMemOPSizeRangeFromOption(MemOPSizeRange, MemOPSizeRangeStart, +                              MemOPSizeRangeLast); +  TT = Triple(M.getTargetTriple());    // We did not know how many value sites there would be inside    // the instrumented function. This is counting the number of instrumented @@ -189,17 +211,34 @@ bool InstrProfiling::run(Module &M, const TargetLibraryInfo &TLI) {  }  static Constant *getOrInsertValueProfilingCall(Module &M, -                                               const TargetLibraryInfo &TLI) { +                                               const TargetLibraryInfo &TLI, +                                               bool IsRange = false) {    LLVMContext &Ctx = M.getContext();    auto *ReturnTy = Type::getVoidTy(M.getContext()); -  Type *ParamTypes[] = { + +  Constant *Res; +  if (!IsRange) { +    Type *ParamTypes[] = {  #define VALUE_PROF_FUNC_PARAM(ParamType, ParamName, ParamLLVMType) ParamLLVMType  #include "llvm/ProfileData/InstrProfData.inc" -  }; -  auto *ValueProfilingCallTy = -      FunctionType::get(ReturnTy, makeArrayRef(ParamTypes), false); -  Constant *Res = M.getOrInsertFunction(getInstrProfValueProfFuncName(), -                                        ValueProfilingCallTy); +    }; +    auto *ValueProfilingCallTy = +        FunctionType::get(ReturnTy, makeArrayRef(ParamTypes), false); +    Res = M.getOrInsertFunction(getInstrProfValueProfFuncName(), +                                ValueProfilingCallTy); +  } else { +    Type *RangeParamTypes[] = { +#define VALUE_RANGE_PROF 1 +#define VALUE_PROF_FUNC_PARAM(ParamType, ParamName, ParamLLVMType) ParamLLVMType +#include "llvm/ProfileData/InstrProfData.inc" +#undef VALUE_RANGE_PROF +    }; +    auto *ValueRangeProfilingCallTy = +        FunctionType::get(ReturnTy, makeArrayRef(RangeParamTypes), false); +    Res = M.getOrInsertFunction(getInstrProfValueRangeProfFuncName(), +                                ValueRangeProfilingCallTy); +  } +    if (Function *FunRes = dyn_cast<Function>(Res)) {      if (auto AK = TLI.getExtAttrForI32Param(false))        FunRes->addAttribute(3, AK); @@ -208,7 +247,6 @@ static Constant *getOrInsertValueProfilingCall(Module &M,  }  void InstrProfiling::computeNumValueSiteCounts(InstrProfValueProfileInst *Ind) { -    GlobalVariable *Name = Ind->getName();    uint64_t ValueKind = Ind->getValueKind()->getZExtValue();    uint64_t Index = Ind->getIndex()->getZExtValue(); @@ -222,7 +260,6 @@ void InstrProfiling::computeNumValueSiteCounts(InstrProfValueProfileInst *Ind) {  }  void InstrProfiling::lowerValueProfileInst(InstrProfValueProfileInst *Ind) { -    GlobalVariable *Name = Ind->getName();    auto It = ProfileDataMap.find(Name);    assert(It != ProfileDataMap.end() && It->second.DataVar && @@ -235,11 +272,25 @@ void InstrProfiling::lowerValueProfileInst(InstrProfValueProfileInst *Ind) {      Index += It->second.NumValueSites[Kind];    IRBuilder<> Builder(Ind); -  Value *Args[3] = {Ind->getTargetValue(), -                    Builder.CreateBitCast(DataVar, Builder.getInt8PtrTy()), -                    Builder.getInt32(Index)}; -  CallInst *Call = Builder.CreateCall(getOrInsertValueProfilingCall(*M, *TLI), -                                      Args); +  bool IsRange = (Ind->getValueKind()->getZExtValue() == +                  llvm::InstrProfValueKind::IPVK_MemOPSize); +  CallInst *Call = nullptr; +  if (!IsRange) { +    Value *Args[3] = {Ind->getTargetValue(), +                      Builder.CreateBitCast(DataVar, Builder.getInt8PtrTy()), +                      Builder.getInt32(Index)}; +    Call = Builder.CreateCall(getOrInsertValueProfilingCall(*M, *TLI), Args); +  } else { +    Value *Args[6] = { +        Ind->getTargetValue(), +        Builder.CreateBitCast(DataVar, Builder.getInt8PtrTy()), +        Builder.getInt32(Index), +        Builder.getInt64(MemOPSizeRangeStart), +        Builder.getInt64(MemOPSizeRangeLast), +        Builder.getInt64(MemOPSizeLarge == 0 ? INT64_MIN : MemOPSizeLarge)}; +    Call = +        Builder.CreateCall(getOrInsertValueProfilingCall(*M, *TLI, true), Args); +  }    if (auto AK = TLI->getExtAttrForI32Param(false))      Call->addAttribute(3, AK);    Ind->replaceAllUsesWith(Call); @@ -259,7 +310,6 @@ void InstrProfiling::lowerIncrement(InstrProfIncrementInst *Inc) {  }  void InstrProfiling::lowerCoverageData(GlobalVariable *CoverageNamesVar) { -    ConstantArray *Names =        cast<ConstantArray>(CoverageNamesVar->getInitializer());    for (unsigned I = 0, E = Names->getNumOperands(); I < E; ++I) { @@ -270,7 +320,9 @@ void InstrProfiling::lowerCoverageData(GlobalVariable *CoverageNamesVar) {      Name->setLinkage(GlobalValue::PrivateLinkage);      ReferencedNames.push_back(Name); +    NC->dropAllReferences();    } +  CoverageNamesVar->eraseFromParent();  }  /// Get the name of a profiling variable for a particular function. @@ -367,7 +419,8 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) {                           Constant::getNullValue(CounterTy),                           getVarName(Inc, getInstrProfCountersVarPrefix()));    CounterPtr->setVisibility(NamePtr->getVisibility()); -  CounterPtr->setSection(getCountersSection()); +  CounterPtr->setSection( +      getInstrProfSectionName(IPSK_cnts, TT.getObjectFormat()));    CounterPtr->setAlignment(8);    CounterPtr->setComdat(ProfileVarsComdat); @@ -376,7 +429,6 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) {    // the current function.    Constant *ValuesPtrExpr = ConstantPointerNull::get(Int8PtrTy);    if (ValueProfileStaticAlloc && !needsRuntimeRegistrationOfSectionRange(*M)) { -      uint64_t NS = 0;      for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind)        NS += PD.NumValueSites[Kind]; @@ -388,11 +440,12 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) {                               Constant::getNullValue(ValuesTy),                               getVarName(Inc, getInstrProfValuesVarPrefix()));        ValuesVar->setVisibility(NamePtr->getVisibility()); -      ValuesVar->setSection(getInstrProfValuesSectionName(isMachO())); +      ValuesVar->setSection( +          getInstrProfSectionName(IPSK_vals, TT.getObjectFormat()));        ValuesVar->setAlignment(8);        ValuesVar->setComdat(ProfileVarsComdat);        ValuesPtrExpr = -          ConstantExpr::getBitCast(ValuesVar, llvm::Type::getInt8PtrTy(Ctx)); +          ConstantExpr::getBitCast(ValuesVar, Type::getInt8PtrTy(Ctx));      }    } @@ -421,7 +474,7 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) {                                    ConstantStruct::get(DataTy, DataVals),                                    getVarName(Inc, getInstrProfDataVarPrefix()));    Data->setVisibility(NamePtr->getVisibility()); -  Data->setSection(getDataSection()); +  Data->setSection(getInstrProfSectionName(IPSK_data, TT.getObjectFormat()));    Data->setAlignment(INSTR_PROF_DATA_ALIGNMENT);    Data->setComdat(ProfileVarsComdat); @@ -481,9 +534,10 @@ void InstrProfiling::emitVNodes() {    ArrayType *VNodesTy = ArrayType::get(VNodeTy, NumCounters);    auto *VNodesVar = new GlobalVariable( -      *M, VNodesTy, false, llvm::GlobalValue::PrivateLinkage, +      *M, VNodesTy, false, GlobalValue::PrivateLinkage,        Constant::getNullValue(VNodesTy), getInstrProfVNodesVarName()); -  VNodesVar->setSection(getInstrProfVNodesSectionName(isMachO())); +  VNodesVar->setSection( +      getInstrProfSectionName(IPSK_vnodes, TT.getObjectFormat()));    UsedVars.push_back(VNodesVar);  } @@ -496,18 +550,22 @@ void InstrProfiling::emitNameData() {    std::string CompressedNameStr;    if (Error E = collectPGOFuncNameStrings(ReferencedNames, CompressedNameStr,                                            DoNameCompression)) { -    llvm::report_fatal_error(toString(std::move(E)), false); +    report_fatal_error(toString(std::move(E)), false);    }    auto &Ctx = M->getContext(); -  auto *NamesVal = llvm::ConstantDataArray::getString( +  auto *NamesVal = ConstantDataArray::getString(        Ctx, StringRef(CompressedNameStr), false); -  NamesVar = new llvm::GlobalVariable(*M, NamesVal->getType(), true, -                                      llvm::GlobalValue::PrivateLinkage, -                                      NamesVal, getInstrProfNamesVarName()); +  NamesVar = new GlobalVariable(*M, NamesVal->getType(), true, +                                GlobalValue::PrivateLinkage, NamesVal, +                                getInstrProfNamesVarName());    NamesSize = CompressedNameStr.size(); -  NamesVar->setSection(getNameSection()); +  NamesVar->setSection( +      getInstrProfSectionName(IPSK_name, TT.getObjectFormat()));    UsedVars.push_back(NamesVar); + +  for (auto *NamePtr : ReferencedNames) +    NamePtr->eraseFromParent();  }  void InstrProfiling::emitRegistration() { @@ -550,7 +608,6 @@ void InstrProfiling::emitRegistration() {  }  void InstrProfiling::emitRuntimeHook() { -    // We expect the linker to be invoked with -u<hook_var> flag for linux,    // for which case there is no need to emit the user function.    if (Triple(M->getTargetTriple()).isOSLinux()) @@ -600,7 +657,6 @@ void InstrProfiling::emitInitialization() {      GlobalVariable *ProfileNameVar = new GlobalVariable(          *M, ProfileNameConst->getType(), true, GlobalValue::WeakAnyLinkage,          ProfileNameConst, INSTR_PROF_QUOTE(INSTR_PROF_PROFILE_NAME_VAR)); -    Triple TT(M->getTargetTriple());      if (TT.supportsCOMDAT()) {        ProfileNameVar->setLinkage(GlobalValue::ExternalLinkage);        ProfileNameVar->setComdat(M->getOrInsertComdat( diff --git a/lib/Transforms/Instrumentation/Instrumentation.cpp b/lib/Transforms/Instrumentation/Instrumentation.cpp index 2963d08752c4..7bb62d2c8455 100644 --- a/lib/Transforms/Instrumentation/Instrumentation.cpp +++ b/lib/Transforms/Instrumentation/Instrumentation.cpp @@ -63,6 +63,7 @@ void llvm::initializeInstrumentation(PassRegistry &Registry) {    initializePGOInstrumentationGenLegacyPassPass(Registry);    initializePGOInstrumentationUseLegacyPassPass(Registry);    initializePGOIndirectCallPromotionLegacyPassPass(Registry); +  initializePGOMemOPSizeOptLegacyPassPass(Registry);    initializeInstrProfilingLegacyPassPass(Registry);    initializeMemorySanitizerPass(Registry);    initializeThreadSanitizerPass(Registry); diff --git a/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/lib/Transforms/Instrumentation/MemorySanitizer.cpp index fafb0fcbd017..190f05db4b0c 100644 --- a/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -425,7 +425,7 @@ void MemorySanitizer::initializeCallbacks(Module &M) {    // which is not yet implemented.    StringRef WarningFnName = Recover ? "__msan_warning"                                      : "__msan_warning_noreturn"; -  WarningFn = M.getOrInsertFunction(WarningFnName, IRB.getVoidTy(), nullptr); +  WarningFn = M.getOrInsertFunction(WarningFnName, IRB.getVoidTy());    for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes;         AccessSizeIndex++) { @@ -433,31 +433,31 @@ void MemorySanitizer::initializeCallbacks(Module &M) {      std::string FunctionName = "__msan_maybe_warning_" + itostr(AccessSize);      MaybeWarningFn[AccessSizeIndex] = M.getOrInsertFunction(          FunctionName, IRB.getVoidTy(), IRB.getIntNTy(AccessSize * 8), -        IRB.getInt32Ty(), nullptr); +        IRB.getInt32Ty());      FunctionName = "__msan_maybe_store_origin_" + itostr(AccessSize);      MaybeStoreOriginFn[AccessSizeIndex] = M.getOrInsertFunction(          FunctionName, IRB.getVoidTy(), IRB.getIntNTy(AccessSize * 8), -        IRB.getInt8PtrTy(), IRB.getInt32Ty(), nullptr); +        IRB.getInt8PtrTy(), IRB.getInt32Ty());    }    MsanSetAllocaOrigin4Fn = M.getOrInsertFunction(      "__msan_set_alloca_origin4", IRB.getVoidTy(), IRB.getInt8PtrTy(), IntptrTy, -    IRB.getInt8PtrTy(), IntptrTy, nullptr); +    IRB.getInt8PtrTy(), IntptrTy);    MsanPoisonStackFn =        M.getOrInsertFunction("__msan_poison_stack", IRB.getVoidTy(), -                            IRB.getInt8PtrTy(), IntptrTy, nullptr); +                            IRB.getInt8PtrTy(), IntptrTy);    MsanChainOriginFn = M.getOrInsertFunction( -    "__msan_chain_origin", IRB.getInt32Ty(), IRB.getInt32Ty(), nullptr); +    "__msan_chain_origin", IRB.getInt32Ty(), IRB.getInt32Ty());    MemmoveFn = M.getOrInsertFunction(      "__msan_memmove", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), -    IRB.getInt8PtrTy(), IntptrTy, nullptr); +    IRB.getInt8PtrTy(), IntptrTy);    MemcpyFn = M.getOrInsertFunction(      "__msan_memcpy", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), -    IntptrTy, nullptr); +    IntptrTy);    MemsetFn = M.getOrInsertFunction(      "__msan_memset", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt32Ty(), -    IntptrTy, nullptr); +    IntptrTy);    // Create globals.    RetvalTLS = new GlobalVariable( @@ -1037,15 +1037,19 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {      OriginMap[V] = Origin;    } +  Constant *getCleanShadow(Type *OrigTy) { +    Type *ShadowTy = getShadowTy(OrigTy); +    if (!ShadowTy) +      return nullptr; +    return Constant::getNullValue(ShadowTy); +  } +    /// \brief Create a clean shadow value for a given value.    ///    /// Clean shadow (all zeroes) means all bits of the value are defined    /// (initialized).    Constant *getCleanShadow(Value *V) { -    Type *ShadowTy = getShadowTy(V); -    if (!ShadowTy) -      return nullptr; -    return Constant::getNullValue(ShadowTy); +    return getCleanShadow(V->getType());    }    /// \brief Create a dirty shadow of a given shadow type. @@ -1942,7 +1946,6 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {      if (ClCheckAccessAddress)        insertShadowCheck(Addr, &I); -    // FIXME: use ClStoreCleanOrigin      // FIXME: factor out common code from materializeStores      if (MS.TrackOrigins)        IRB.CreateStore(getOrigin(&I, 1), getOriginPtr(Addr, IRB, 1)); @@ -2325,11 +2328,49 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {      setOriginForNaryOp(I);    } +  void handleStmxcsr(IntrinsicInst &I) { +    IRBuilder<> IRB(&I); +    Value* Addr = I.getArgOperand(0); +    Type *Ty = IRB.getInt32Ty(); +    Value *ShadowPtr = getShadowPtr(Addr, Ty, IRB); + +    IRB.CreateStore(getCleanShadow(Ty), +                    IRB.CreatePointerCast(ShadowPtr, Ty->getPointerTo())); + +    if (ClCheckAccessAddress) +      insertShadowCheck(Addr, &I); +  } + +  void handleLdmxcsr(IntrinsicInst &I) { +    if (!InsertChecks) return; + +    IRBuilder<> IRB(&I); +    Value *Addr = I.getArgOperand(0); +    Type *Ty = IRB.getInt32Ty(); +    unsigned Alignment = 1; + +    if (ClCheckAccessAddress) +      insertShadowCheck(Addr, &I); + +    Value *Shadow = IRB.CreateAlignedLoad(getShadowPtr(Addr, Ty, IRB), +                                          Alignment, "_ldmxcsr"); +    Value *Origin = MS.TrackOrigins +                        ? IRB.CreateLoad(getOriginPtr(Addr, IRB, Alignment)) +                        : getCleanOrigin(); +    insertShadowCheck(Shadow, Origin, &I); +  } +    void visitIntrinsicInst(IntrinsicInst &I) {      switch (I.getIntrinsicID()) {      case llvm::Intrinsic::bswap:        handleBswap(I);        break; +    case llvm::Intrinsic::x86_sse_stmxcsr: +      handleStmxcsr(I); +      break; +    case llvm::Intrinsic::x86_sse_ldmxcsr: +      handleLdmxcsr(I); +      break;      case llvm::Intrinsic::x86_avx512_vcvtsd2usi64:      case llvm::Intrinsic::x86_avx512_vcvtsd2usi32:      case llvm::Intrinsic::x86_avx512_vcvtss2usi64: @@ -2566,10 +2607,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {          AttrBuilder B;          B.addAttribute(Attribute::ReadOnly)            .addAttribute(Attribute::ReadNone); -        Func->removeAttributes(AttributeSet::FunctionIndex, -                               AttributeSet::get(Func->getContext(), -                                                 AttributeSet::FunctionIndex, -                                                 B)); +        Func->removeAttributes(AttributeList::FunctionIndex, +                               AttributeList::get(Func->getContext(), +                                                  AttributeList::FunctionIndex, +                                                  B));        }        maybeMarkSanitizerLibraryCallNoBuiltin(Call, TLI); @@ -2597,7 +2638,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {              " Shadow: " << *ArgShadow << "\n");        bool ArgIsInitialized = false;        const DataLayout &DL = F.getParent()->getDataLayout(); -      if (CS.paramHasAttr(i + 1, Attribute::ByVal)) { +      if (CS.paramHasAttr(i, Attribute::ByVal)) {          assert(A->getType()->isPointerTy() &&                 "ByVal argument is not a pointer!");          Size = DL.getTypeAllocSize(A->getType()->getPointerElementType()); @@ -2690,7 +2731,6 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {      } else {        Value *Shadow = getShadow(RetVal);        IRB.CreateAlignedStore(Shadow, ShadowPtr, kShadowTLSAlignment); -      // FIXME: make it conditional if ClStoreCleanOrigin==0        if (MS.TrackOrigins)          IRB.CreateStore(getOrigin(RetVal), getOriginPtrForRetval(IRB));      } @@ -2717,15 +2757,17 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {      setOrigin(&I, getCleanOrigin());      IRBuilder<> IRB(I.getNextNode());      const DataLayout &DL = F.getParent()->getDataLayout(); -    uint64_t Size = DL.getTypeAllocSize(I.getAllocatedType()); +    uint64_t TypeSize = DL.getTypeAllocSize(I.getAllocatedType()); +    Value *Len = ConstantInt::get(MS.IntptrTy, TypeSize); +    if (I.isArrayAllocation()) +      Len = IRB.CreateMul(Len, I.getArraySize());      if (PoisonStack && ClPoisonStackWithCall) {        IRB.CreateCall(MS.MsanPoisonStackFn, -                     {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), -                      ConstantInt::get(MS.IntptrTy, Size)}); +                     {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len});      } else {        Value *ShadowBase = getShadowPtr(&I, Type::getInt8PtrTy(*MS.C), IRB);        Value *PoisonValue = IRB.getInt8(PoisonStack ? ClPoisonStackPattern : 0); -      IRB.CreateMemSet(ShadowBase, PoisonValue, Size, I.getAlignment()); +      IRB.CreateMemSet(ShadowBase, PoisonValue, Len, I.getAlignment());      }      if (PoisonStack && MS.TrackOrigins) { @@ -2742,8 +2784,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {                                                 StackDescription.str());        IRB.CreateCall(MS.MsanSetAllocaOrigin4Fn, -                     {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), -                      ConstantInt::get(MS.IntptrTy, Size), +                     {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len,                        IRB.CreatePointerCast(Descr, IRB.getInt8PtrTy()),                        IRB.CreatePointerCast(&F, MS.IntptrTy)});      } @@ -2935,7 +2976,7 @@ struct VarArgAMD64Helper : public VarArgHelper {        Value *A = *ArgIt;        unsigned ArgNo = CS.getArgumentNo(ArgIt);        bool IsFixed = ArgNo < CS.getFunctionType()->getNumParams(); -      bool IsByVal = CS.paramHasAttr(ArgNo + 1, Attribute::ByVal); +      bool IsByVal = CS.paramHasAttr(ArgNo, Attribute::ByVal);        if (IsByVal) {          // ByVal arguments always go to the overflow area.          // Fixed arguments passed through the overflow area will be stepped @@ -3456,7 +3497,7 @@ struct VarArgPowerPC64Helper : public VarArgHelper {        Value *A = *ArgIt;        unsigned ArgNo = CS.getArgumentNo(ArgIt);        bool IsFixed = ArgNo < CS.getFunctionType()->getNumParams(); -      bool IsByVal = CS.paramHasAttr(ArgNo + 1, Attribute::ByVal); +      bool IsByVal = CS.paramHasAttr(ArgNo, Attribute::ByVal);        if (IsByVal) {          assert(A->getType()->isPointerTy());          Type *RealTy = A->getType()->getPointerElementType(); @@ -3618,9 +3659,9 @@ bool MemorySanitizer::runOnFunction(Function &F) {    AttrBuilder B;    B.addAttribute(Attribute::ReadOnly)      .addAttribute(Attribute::ReadNone); -  F.removeAttributes(AttributeSet::FunctionIndex, -                     AttributeSet::get(F.getContext(), -                                       AttributeSet::FunctionIndex, B)); +  F.removeAttributes( +      AttributeList::FunctionIndex, +      AttributeList::get(F.getContext(), AttributeList::FunctionIndex, B));    return Visitor.runOnFunction();  } diff --git a/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/lib/Transforms/Instrumentation/PGOInstrumentation.cpp index 04f9a64bef9f..990bcec109de 100644 --- a/lib/Transforms/Instrumentation/PGOInstrumentation.cpp +++ b/lib/Transforms/Instrumentation/PGOInstrumentation.cpp @@ -58,8 +58,10 @@  #include "llvm/Analysis/BranchProbabilityInfo.h"  #include "llvm/Analysis/CFG.h"  #include "llvm/Analysis/IndirectCallSiteVisitor.h" +#include "llvm/Analysis/LoopInfo.h"  #include "llvm/IR/CallSite.h"  #include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Dominators.h"  #include "llvm/IR/GlobalValue.h"  #include "llvm/IR/IRBuilder.h"  #include "llvm/IR/InstIterator.h" @@ -71,7 +73,9 @@  #include "llvm/ProfileData/InstrProfReader.h"  #include "llvm/ProfileData/ProfileCommon.h"  #include "llvm/Support/BranchProbability.h" +#include "llvm/Support/DOTGraphTraits.h"  #include "llvm/Support/Debug.h" +#include "llvm/Support/GraphWriter.h"  #include "llvm/Support/JamCRC.h"  #include "llvm/Transforms/Instrumentation.h"  #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -87,6 +91,7 @@ using namespace llvm;  STATISTIC(NumOfPGOInstrument, "Number of edges instrumented.");  STATISTIC(NumOfPGOSelectInsts, "Number of select instruction instrumented."); +STATISTIC(NumOfPGOMemIntrinsics, "Number of mem intrinsics instrumented.");  STATISTIC(NumOfPGOEdge, "Number of edges.");  STATISTIC(NumOfPGOBB, "Number of basic-blocks.");  STATISTIC(NumOfPGOSplit, "Number of critical edge splits."); @@ -116,6 +121,13 @@ static cl::opt<unsigned> MaxNumAnnotations(      cl::desc("Max number of annotations for a single indirect "               "call callsite")); +// Command line option to set the maximum number of value annotations +// to write to the metadata for a single memop intrinsic. +static cl::opt<unsigned> MaxNumMemOPAnnotations( +    "memop-max-annotations", cl::init(4), cl::Hidden, cl::ZeroOrMore, +    cl::desc("Max number of preicise value annotations for a single memop" +             "intrinsic")); +  // Command line option to control appending FunctionHash to the name of a COMDAT  // function. This is to avoid the hash mismatch caused by the preinliner.  static cl::opt<bool> DoComdatRenaming( @@ -125,24 +137,59 @@ static cl::opt<bool> DoComdatRenaming(  // Command line option to enable/disable the warning about missing profile  // information. -static cl::opt<bool> PGOWarnMissing("pgo-warn-missing-function", -                                     cl::init(false), -                                     cl::Hidden); +static cl::opt<bool> +    PGOWarnMissing("pgo-warn-missing-function", cl::init(false), cl::Hidden, +                   cl::desc("Use this option to turn on/off " +                            "warnings about missing profile data for " +                            "functions."));  // Command line option to enable/disable the warning about a hash mismatch in  // the profile data. -static cl::opt<bool> NoPGOWarnMismatch("no-pgo-warn-mismatch", cl::init(false), -                                       cl::Hidden); +static cl::opt<bool> +    NoPGOWarnMismatch("no-pgo-warn-mismatch", cl::init(false), cl::Hidden, +                      cl::desc("Use this option to turn off/on " +                               "warnings about profile cfg mismatch."));  // Command line option to enable/disable the warning about a hash mismatch in  // the profile data for Comdat functions, which often turns out to be false  // positive due to the pre-instrumentation inline. -static cl::opt<bool> NoPGOWarnMismatchComdat("no-pgo-warn-mismatch-comdat", -                                             cl::init(true), cl::Hidden); +static cl::opt<bool> +    NoPGOWarnMismatchComdat("no-pgo-warn-mismatch-comdat", cl::init(true), +                            cl::Hidden, +                            cl::desc("The option is used to turn on/off " +                                     "warnings about hash mismatch for comdat " +                                     "functions."));  // Command line option to enable/disable select instruction instrumentation. -static cl::opt<bool> PGOInstrSelect("pgo-instr-select", cl::init(true), -                                    cl::Hidden); +static cl::opt<bool> +    PGOInstrSelect("pgo-instr-select", cl::init(true), cl::Hidden, +                   cl::desc("Use this option to turn on/off SELECT " +                            "instruction instrumentation. ")); + +// Command line option to turn on CFG dot dump of raw profile counts +static cl::opt<bool> +    PGOViewRawCounts("pgo-view-raw-counts", cl::init(false), cl::Hidden, +                     cl::desc("A boolean option to show CFG dag " +                              "with raw profile counts from " +                              "profile data. See also option " +                              "-pgo-view-counts. To limit graph " +                              "display to only one function, use " +                              "filtering option -view-bfi-func-name.")); + +// Command line option to enable/disable memop intrinsic call.size profiling. +static cl::opt<bool> +    PGOInstrMemOP("pgo-instr-memop", cl::init(true), cl::Hidden, +                  cl::desc("Use this option to turn on/off " +                           "memory instrinsic size profiling.")); + +// Command line option to turn on CFG dot dump after profile annotation. +// Defined in Analysis/BlockFrequencyInfo.cpp:  -pgo-view-counts +extern cl::opt<bool> PGOViewCounts; + +// Command line option to specify the name of the function for CFG dump +// Defined in Analysis/BlockFrequencyInfo.cpp:  -view-bfi-func-name= +extern cl::opt<std::string> ViewBlockFreqFuncName; +  namespace {  /// The select instruction visitor plays three roles specified @@ -167,6 +214,7 @@ struct SelectInstVisitor : public InstVisitor<SelectInstVisitor> {    SelectInstVisitor(Function &Func) : F(Func) {}    void countSelects(Function &Func) { +    NSIs = 0;      Mode = VM_counting;      visit(Func);    } @@ -196,9 +244,54 @@ struct SelectInstVisitor : public InstVisitor<SelectInstVisitor> {    void annotateOneSelectInst(SelectInst &SI);    // Visit \p SI instruction and perform tasks according to visit mode.    void visitSelectInst(SelectInst &SI); +  // Return the number of select instructions. This needs be called after +  // countSelects().    unsigned getNumOfSelectInsts() const { return NSIs; }  }; +/// Instruction Visitor class to visit memory intrinsic calls. +struct MemIntrinsicVisitor : public InstVisitor<MemIntrinsicVisitor> { +  Function &F; +  unsigned NMemIs = 0;          // Number of memIntrinsics instrumented. +  VisitMode Mode = VM_counting; // Visiting mode. +  unsigned CurCtrId = 0;        // Current counter index. +  unsigned TotalNumCtrs = 0;    // Total number of counters +  GlobalVariable *FuncNameVar = nullptr; +  uint64_t FuncHash = 0; +  PGOUseFunc *UseFunc = nullptr; +  std::vector<Instruction *> Candidates; + +  MemIntrinsicVisitor(Function &Func) : F(Func) {} + +  void countMemIntrinsics(Function &Func) { +    NMemIs = 0; +    Mode = VM_counting; +    visit(Func); +  } + +  void instrumentMemIntrinsics(Function &Func, unsigned TotalNC, +                               GlobalVariable *FNV, uint64_t FHash) { +    Mode = VM_instrument; +    TotalNumCtrs = TotalNC; +    FuncHash = FHash; +    FuncNameVar = FNV; +    visit(Func); +  } + +  std::vector<Instruction *> findMemIntrinsics(Function &Func) { +    Candidates.clear(); +    Mode = VM_annotate; +    visit(Func); +    return Candidates; +  } + +  // Visit the IR stream and annotate all mem intrinsic call instructions. +  void instrumentOneMemIntrinsic(MemIntrinsic &MI); +  // Visit \p MI instruction and perform tasks according to visit mode. +  void visitMemIntrinsic(MemIntrinsic &SI); +  unsigned getNumOfMemIntrinsics() const { return NMemIs; } +}; +  class PGOInstrumentationGenLegacyPass : public ModulePass {  public:    static char ID; @@ -316,8 +409,9 @@ private:    std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers;  public: -  std::vector<Instruction *> IndirectCallSites; +  std::vector<std::vector<Instruction *>> ValueSites;    SelectInstVisitor SIVisitor; +  MemIntrinsicVisitor MIVisitor;    std::string FuncName;    GlobalVariable *FuncNameVar;    // CFG hash value for this function. @@ -347,13 +441,16 @@ public:        std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers,        bool CreateGlobalVar = false, BranchProbabilityInfo *BPI = nullptr,        BlockFrequencyInfo *BFI = nullptr) -      : F(Func), ComdatMembers(ComdatMembers), SIVisitor(Func), FunctionHash(0), -        MST(F, BPI, BFI) { +      : F(Func), ComdatMembers(ComdatMembers), ValueSites(IPVK_Last + 1), +        SIVisitor(Func), MIVisitor(Func), FunctionHash(0), MST(F, BPI, BFI) {      // This should be done before CFG hash computation.      SIVisitor.countSelects(Func); +    MIVisitor.countMemIntrinsics(Func);      NumOfPGOSelectInsts += SIVisitor.getNumOfSelectInsts(); -    IndirectCallSites = findIndirectCallSites(Func); +    NumOfPGOMemIntrinsics += MIVisitor.getNumOfMemIntrinsics(); +    ValueSites[IPVK_IndirectCallTarget] = findIndirectCallSites(Func); +    ValueSites[IPVK_MemOPSize] = MIVisitor.findMemIntrinsics(Func);      FuncName = getPGOFuncName(F);      computeCFGHash(); @@ -405,7 +502,7 @@ void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() {    }    JC.update(Indexes);    FunctionHash = (uint64_t)SIVisitor.getNumOfSelectInsts() << 56 | -                 (uint64_t)IndirectCallSites.size() << 48 | +                 (uint64_t)ValueSites[IPVK_IndirectCallTarget].size() << 48 |                   (uint64_t)MST.AllEdges.size() << 32 | JC.getCRC();  } @@ -552,7 +649,7 @@ static void instrumentOneFunc(      return;    unsigned NumIndirectCallSites = 0; -  for (auto &I : FuncInfo.IndirectCallSites) { +  for (auto &I : FuncInfo.ValueSites[IPVK_IndirectCallTarget]) {      CallSite CS(I);      Value *Callee = CS.getCalledValue();      DEBUG(dbgs() << "Instrument one indirect call: CallSite Index = " @@ -565,10 +662,14 @@ static void instrumentOneFunc(          {llvm::ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy),           Builder.getInt64(FuncInfo.FunctionHash),           Builder.CreatePtrToInt(Callee, Builder.getInt64Ty()), -         Builder.getInt32(llvm::InstrProfValueKind::IPVK_IndirectCallTarget), +         Builder.getInt32(IPVK_IndirectCallTarget),           Builder.getInt32(NumIndirectCallSites++)});    }    NumOfPGOICall += NumIndirectCallSites; + +  // Now instrument memop intrinsic calls. +  FuncInfo.MIVisitor.instrumentMemIntrinsics( +      F, NumCounters, FuncInfo.FuncNameVar, FuncInfo.FunctionHash);  }  // This class represents a CFG edge in profile use compilation. @@ -653,8 +754,11 @@ public:    // Set the branch weights based on the count values.    void setBranchWeights(); -  // Annotate the indirect call sites. -  void annotateIndirectCallSites(); +  // Annotate the value profile call sites all all value kind. +  void annotateValueSites(); + +  // Annotate the value profile call sites for one value kind. +  void annotateValueSites(uint32_t Kind);    // The hotness of the function from the profile count.    enum FuncFreqAttr { FFA_Normal, FFA_Cold, FFA_Hot }; @@ -677,6 +781,8 @@ public:      return FuncInfo.findBBInfo(BB);    } +  Function &getFunc() const { return F; } +  private:    Function &F;    Module *M; @@ -761,7 +867,7 @@ void PGOUseFunc::setInstrumentedCounts(      NewEdge1.InMST = true;      getBBInfo(InstrBB).setBBInfoCount(CountValue);    } -  ProfileCountSize =  CountFromProfile.size(); +  ProfileCountSize = CountFromProfile.size();    CountPosition = I;  } @@ -932,21 +1038,6 @@ void PGOUseFunc::populateCounters() {    DEBUG(FuncInfo.dumpInfo("after reading profile."));  } -static void setProfMetadata(Module *M, Instruction *TI, -                            ArrayRef<uint64_t> EdgeCounts, uint64_t MaxCount) { -  MDBuilder MDB(M->getContext()); -  assert(MaxCount > 0 && "Bad max count"); -  uint64_t Scale = calculateCountScale(MaxCount); -  SmallVector<unsigned, 4> Weights; -  for (const auto &ECI : EdgeCounts) -    Weights.push_back(scaleBranchCount(ECI, Scale)); - -  DEBUG(dbgs() << "Weight is: "; -        for (const auto &W : Weights) { dbgs() << W << " "; }  -        dbgs() << "\n";); -  TI->setMetadata(llvm::LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); -} -  // Assign the scaled count values to the BB with multiple out edges.  void PGOUseFunc::setBranchWeights() {    // Generate MD_prof metadata for every branch instruction. @@ -990,8 +1081,8 @@ void SelectInstVisitor::instrumentOneSelectInst(SelectInst &SI) {    Builder.CreateCall(        Intrinsic::getDeclaration(M, Intrinsic::instrprof_increment_step),        {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), -       Builder.getInt64(FuncHash), -       Builder.getInt32(TotalNumCtrs), Builder.getInt32(*CurCtrIdx), Step}); +       Builder.getInt64(FuncHash), Builder.getInt32(TotalNumCtrs), +       Builder.getInt32(*CurCtrIdx), Step});    ++(*CurCtrIdx);  } @@ -1020,9 +1111,9 @@ void SelectInstVisitor::visitSelectInst(SelectInst &SI) {    if (SI.getCondition()->getType()->isVectorTy())      return; -  NSIs++;    switch (Mode) {    case VM_counting: +    NSIs++;      return;    case VM_instrument:      instrumentOneSelectInst(SI); @@ -1035,35 +1126,79 @@ void SelectInstVisitor::visitSelectInst(SelectInst &SI) {    llvm_unreachable("Unknown visiting mode");  } -// Traverse all the indirect callsites and annotate the instructions. -void PGOUseFunc::annotateIndirectCallSites() { +void MemIntrinsicVisitor::instrumentOneMemIntrinsic(MemIntrinsic &MI) { +  Module *M = F.getParent(); +  IRBuilder<> Builder(&MI); +  Type *Int64Ty = Builder.getInt64Ty(); +  Type *I8PtrTy = Builder.getInt8PtrTy(); +  Value *Length = MI.getLength(); +  assert(!dyn_cast<ConstantInt>(Length)); +  Builder.CreateCall( +      Intrinsic::getDeclaration(M, Intrinsic::instrprof_value_profile), +      {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), +       Builder.getInt64(FuncHash), Builder.CreatePtrToInt(Length, Int64Ty), +       Builder.getInt32(IPVK_MemOPSize), Builder.getInt32(CurCtrId)}); +  ++CurCtrId; +} + +void MemIntrinsicVisitor::visitMemIntrinsic(MemIntrinsic &MI) { +  if (!PGOInstrMemOP) +    return; +  Value *Length = MI.getLength(); +  // Not instrument constant length calls. +  if (dyn_cast<ConstantInt>(Length)) +    return; + +  switch (Mode) { +  case VM_counting: +    NMemIs++; +    return; +  case VM_instrument: +    instrumentOneMemIntrinsic(MI); +    return; +  case VM_annotate: +    Candidates.push_back(&MI); +    return; +  } +  llvm_unreachable("Unknown visiting mode"); +} + +// Traverse all valuesites and annotate the instructions for all value kind. +void PGOUseFunc::annotateValueSites() {    if (DisableValueProfiling)      return;    // Create the PGOFuncName meta data.    createPGOFuncNameMetadata(F, FuncInfo.FuncName); -  unsigned IndirectCallSiteIndex = 0; -  auto &IndirectCallSites = FuncInfo.IndirectCallSites; -  unsigned NumValueSites = -      ProfileRecord.getNumValueSites(IPVK_IndirectCallTarget); -  if (NumValueSites != IndirectCallSites.size()) { -    std::string Msg = -        std::string("Inconsistent number of indirect call sites: ") + -        F.getName().str(); +  for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind) +    annotateValueSites(Kind); +} + +// Annotate the instructions for a specific value kind. +void PGOUseFunc::annotateValueSites(uint32_t Kind) { +  unsigned ValueSiteIndex = 0; +  auto &ValueSites = FuncInfo.ValueSites[Kind]; +  unsigned NumValueSites = ProfileRecord.getNumValueSites(Kind); +  if (NumValueSites != ValueSites.size()) {      auto &Ctx = M->getContext(); -    Ctx.diagnose( -        DiagnosticInfoPGOProfile(M->getName().data(), Msg, DS_Warning)); +    Ctx.diagnose(DiagnosticInfoPGOProfile( +        M->getName().data(), +        Twine("Inconsistent number of value sites for kind = ") + Twine(Kind) + +            " in " + F.getName().str(), +        DS_Warning));      return;    } -  for (auto &I : IndirectCallSites) { -    DEBUG(dbgs() << "Read one indirect call instrumentation: Index=" -                 << IndirectCallSiteIndex << " out of " << NumValueSites -                 << "\n"); -    annotateValueSite(*M, *I, ProfileRecord, IPVK_IndirectCallTarget, -                      IndirectCallSiteIndex, MaxNumAnnotations); -    IndirectCallSiteIndex++; +  for (auto &I : ValueSites) { +    DEBUG(dbgs() << "Read one value site profile (kind = " << Kind +                 << "): Index = " << ValueSiteIndex << " out of " +                 << NumValueSites << "\n"); +    annotateValueSite(*M, *I, ProfileRecord, +                      static_cast<InstrProfValueKind>(Kind), ValueSiteIndex, +                      Kind == IPVK_MemOPSize ? MaxNumMemOPAnnotations +                                             : MaxNumAnnotations); +    ValueSiteIndex++;    }  }  } // end anonymous namespace @@ -1196,12 +1331,29 @@ static bool annotateAllFunctions(        continue;      Func.populateCounters();      Func.setBranchWeights(); -    Func.annotateIndirectCallSites(); +    Func.annotateValueSites();      PGOUseFunc::FuncFreqAttr FreqAttr = Func.getFuncFreqAttr();      if (FreqAttr == PGOUseFunc::FFA_Cold)        ColdFunctions.push_back(&F);      else if (FreqAttr == PGOUseFunc::FFA_Hot)        HotFunctions.push_back(&F); +    if (PGOViewCounts && (ViewBlockFreqFuncName.empty() || +                          F.getName().equals(ViewBlockFreqFuncName))) { +      LoopInfo LI{DominatorTree(F)}; +      std::unique_ptr<BranchProbabilityInfo> NewBPI = +          llvm::make_unique<BranchProbabilityInfo>(F, LI); +      std::unique_ptr<BlockFrequencyInfo> NewBFI = +          llvm::make_unique<BlockFrequencyInfo>(F, *NewBPI, LI); + +      NewBFI->view(); +    } +    if (PGOViewRawCounts && (ViewBlockFreqFuncName.empty() || +                             F.getName().equals(ViewBlockFreqFuncName))) { +      if (ViewBlockFreqFuncName.empty()) +        WriteGraph(&Func, Twine("PGORawCounts_") + Func.getFunc().getName()); +      else +        ViewGraph(&Func, Twine("PGORawCounts_") + Func.getFunc().getName()); +    }    }    M.setProfileSummary(PGOReader->getSummary().getMD(M.getContext()));    // Set function hotness attribute from the profile. @@ -1257,3 +1409,90 @@ bool PGOInstrumentationUseLegacyPass::runOnModule(Module &M) {    return annotateAllFunctions(M, ProfileFileName, LookupBPI, LookupBFI);  } + +namespace llvm { +void setProfMetadata(Module *M, Instruction *TI, ArrayRef<uint64_t> EdgeCounts, +                     uint64_t MaxCount) { +  MDBuilder MDB(M->getContext()); +  assert(MaxCount > 0 && "Bad max count"); +  uint64_t Scale = calculateCountScale(MaxCount); +  SmallVector<unsigned, 4> Weights; +  for (const auto &ECI : EdgeCounts) +    Weights.push_back(scaleBranchCount(ECI, Scale)); + +  DEBUG(dbgs() << "Weight is: "; +        for (const auto &W : Weights) { dbgs() << W << " "; } +        dbgs() << "\n";); +  TI->setMetadata(llvm::LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); +} + +template <> struct GraphTraits<PGOUseFunc *> { +  typedef const BasicBlock *NodeRef; +  typedef succ_const_iterator ChildIteratorType; +  typedef pointer_iterator<Function::const_iterator> nodes_iterator; + +  static NodeRef getEntryNode(const PGOUseFunc *G) { +    return &G->getFunc().front(); +  } +  static ChildIteratorType child_begin(const NodeRef N) { +    return succ_begin(N); +  } +  static ChildIteratorType child_end(const NodeRef N) { return succ_end(N); } +  static nodes_iterator nodes_begin(const PGOUseFunc *G) { +    return nodes_iterator(G->getFunc().begin()); +  } +  static nodes_iterator nodes_end(const PGOUseFunc *G) { +    return nodes_iterator(G->getFunc().end()); +  } +}; + +static std::string getSimpleNodeName(const BasicBlock *Node) { +  if (!Node->getName().empty()) +    return Node->getName(); + +  std::string SimpleNodeName; +  raw_string_ostream OS(SimpleNodeName); +  Node->printAsOperand(OS, false); +  return OS.str(); +} + +template <> struct DOTGraphTraits<PGOUseFunc *> : DefaultDOTGraphTraits { +  explicit DOTGraphTraits(bool isSimple = false) +      : DefaultDOTGraphTraits(isSimple) {} + +  static std::string getGraphName(const PGOUseFunc *G) { +    return G->getFunc().getName(); +  } + +  std::string getNodeLabel(const BasicBlock *Node, const PGOUseFunc *Graph) { +    std::string Result; +    raw_string_ostream OS(Result); + +    OS << getSimpleNodeName(Node) << ":\\l"; +    UseBBInfo *BI = Graph->findBBInfo(Node); +    OS << "Count : "; +    if (BI && BI->CountValid) +      OS << BI->CountValue << "\\l"; +    else +      OS << "Unknown\\l"; + +    if (!PGOInstrSelect) +      return Result; + +    for (auto BI = Node->begin(); BI != Node->end(); ++BI) { +      auto *I = &*BI; +      if (!isa<SelectInst>(I)) +        continue; +      // Display scaled counts for SELECT instruction: +      OS << "SELECT : { T = "; +      uint64_t TC, FC; +      bool HasProf = I->extractProfMetadata(TC, FC); +      if (!HasProf) +        OS << "Unknown, F = Unknown }\\l"; +      else +        OS << TC << ", F = " << FC << " }\\l"; +    } +    return Result; +  } +}; +} // namespace llvm diff --git a/lib/Transforms/Instrumentation/SanitizerCoverage.cpp b/lib/Transforms/Instrumentation/SanitizerCoverage.cpp index 5b4b1fb77134..fa0c7cc5a4c5 100644 --- a/lib/Transforms/Instrumentation/SanitizerCoverage.cpp +++ b/lib/Transforms/Instrumentation/SanitizerCoverage.cpp @@ -78,7 +78,6 @@ static const char *const SanCovTraceSwitchName = "__sanitizer_cov_trace_switch";  static const char *const SanCovModuleCtorName = "sancov.module_ctor";  static const uint64_t SanCtorAndDtorPriority = 2; -static const char *const SanCovTracePCGuardSection = "__sancov_guards";  static const char *const SanCovTracePCGuardName =      "__sanitizer_cov_trace_pc_guard";  static const char *const SanCovTracePCGuardInitName = @@ -95,7 +94,7 @@ static cl::opt<unsigned> ClCoverageBlockThreshold(      "sanitizer-coverage-block-threshold",      cl::desc("Use a callback with a guard check inside it if there are"               " more than this number of blocks."), -    cl::Hidden, cl::init(500)); +    cl::Hidden, cl::init(0));  static cl::opt<bool>      ClExperimentalTracing("sanitizer-coverage-experimental-tracing", @@ -216,6 +215,9 @@ private:             SanCovWithCheckFunction->getNumUses() + SanCovTraceBB->getNumUses() +             SanCovTraceEnter->getNumUses();    } +  StringRef getSanCovTracePCGuardSection() const; +  StringRef getSanCovTracePCGuardSectionStart() const; +  StringRef getSanCovTracePCGuardSectionEnd() const;    Function *SanCovFunction;    Function *SanCovWithCheckFunction;    Function *SanCovIndirCallFunction, *SanCovTracePCIndir; @@ -227,6 +229,7 @@ private:    InlineAsm *EmptyAsm;    Type *IntptrTy, *IntptrPtrTy, *Int64Ty, *Int64PtrTy, *Int32Ty, *Int32PtrTy;    Module *CurModule; +  Triple TargetTriple;    LLVMContext *C;    const DataLayout *DL; @@ -246,6 +249,7 @@ bool SanitizerCoverageModule::runOnModule(Module &M) {    C = &(M.getContext());    DL = &M.getDataLayout();    CurModule = &M; +  TargetTriple = Triple(M.getTargetTriple());    HasSancovGuardsSection = false;    IntptrTy = Type::getIntNTy(*C, DL->getPointerSizeInBits());    IntptrPtrTy = PointerType::getUnqual(IntptrTy); @@ -258,39 +262,39 @@ bool SanitizerCoverageModule::runOnModule(Module &M) {    Int32Ty = IRB.getInt32Ty();    SanCovFunction = checkSanitizerInterfaceFunction( -      M.getOrInsertFunction(SanCovName, VoidTy, Int32PtrTy, nullptr)); +      M.getOrInsertFunction(SanCovName, VoidTy, Int32PtrTy));    SanCovWithCheckFunction = checkSanitizerInterfaceFunction( -      M.getOrInsertFunction(SanCovWithCheckName, VoidTy, Int32PtrTy, nullptr)); +      M.getOrInsertFunction(SanCovWithCheckName, VoidTy, Int32PtrTy));    SanCovTracePCIndir = checkSanitizerInterfaceFunction( -      M.getOrInsertFunction(SanCovTracePCIndirName, VoidTy, IntptrTy, nullptr)); +      M.getOrInsertFunction(SanCovTracePCIndirName, VoidTy, IntptrTy));    SanCovIndirCallFunction =        checkSanitizerInterfaceFunction(M.getOrInsertFunction( -          SanCovIndirCallName, VoidTy, IntptrTy, IntptrTy, nullptr)); +          SanCovIndirCallName, VoidTy, IntptrTy, IntptrTy));    SanCovTraceCmpFunction[0] =        checkSanitizerInterfaceFunction(M.getOrInsertFunction( -          SanCovTraceCmp1, VoidTy, IRB.getInt8Ty(), IRB.getInt8Ty(), nullptr)); +          SanCovTraceCmp1, VoidTy, IRB.getInt8Ty(), IRB.getInt8Ty()));    SanCovTraceCmpFunction[1] = checkSanitizerInterfaceFunction(        M.getOrInsertFunction(SanCovTraceCmp2, VoidTy, IRB.getInt16Ty(), -                            IRB.getInt16Ty(), nullptr)); +                            IRB.getInt16Ty()));    SanCovTraceCmpFunction[2] = checkSanitizerInterfaceFunction(        M.getOrInsertFunction(SanCovTraceCmp4, VoidTy, IRB.getInt32Ty(), -                            IRB.getInt32Ty(), nullptr)); +                            IRB.getInt32Ty()));    SanCovTraceCmpFunction[3] =        checkSanitizerInterfaceFunction(M.getOrInsertFunction( -          SanCovTraceCmp8, VoidTy, Int64Ty, Int64Ty, nullptr)); +          SanCovTraceCmp8, VoidTy, Int64Ty, Int64Ty));    SanCovTraceDivFunction[0] =        checkSanitizerInterfaceFunction(M.getOrInsertFunction( -          SanCovTraceDiv4, VoidTy, IRB.getInt32Ty(), nullptr)); +          SanCovTraceDiv4, VoidTy, IRB.getInt32Ty()));    SanCovTraceDivFunction[1] =        checkSanitizerInterfaceFunction(M.getOrInsertFunction( -          SanCovTraceDiv8, VoidTy, Int64Ty, nullptr)); +          SanCovTraceDiv8, VoidTy, Int64Ty));    SanCovTraceGepFunction =        checkSanitizerInterfaceFunction(M.getOrInsertFunction( -          SanCovTraceGep, VoidTy, IntptrTy, nullptr)); +          SanCovTraceGep, VoidTy, IntptrTy));    SanCovTraceSwitchFunction =        checkSanitizerInterfaceFunction(M.getOrInsertFunction( -          SanCovTraceSwitchName, VoidTy, Int64Ty, Int64PtrTy, nullptr)); +          SanCovTraceSwitchName, VoidTy, Int64Ty, Int64PtrTy));    // We insert an empty inline asm after cov callbacks to avoid callback merge.    EmptyAsm = InlineAsm::get(FunctionType::get(IRB.getVoidTy(), false), @@ -298,13 +302,13 @@ bool SanitizerCoverageModule::runOnModule(Module &M) {                              /*hasSideEffects=*/true);    SanCovTracePC = checkSanitizerInterfaceFunction( -      M.getOrInsertFunction(SanCovTracePCName, VoidTy, nullptr)); +      M.getOrInsertFunction(SanCovTracePCName, VoidTy));    SanCovTracePCGuard = checkSanitizerInterfaceFunction(M.getOrInsertFunction( -      SanCovTracePCGuardName, VoidTy, Int32PtrTy, nullptr)); +      SanCovTracePCGuardName, VoidTy, Int32PtrTy));    SanCovTraceEnter = checkSanitizerInterfaceFunction( -      M.getOrInsertFunction(SanCovTraceEnterName, VoidTy, Int32PtrTy, nullptr)); +      M.getOrInsertFunction(SanCovTraceEnterName, VoidTy, Int32PtrTy));    SanCovTraceBB = checkSanitizerInterfaceFunction( -      M.getOrInsertFunction(SanCovTraceBBName, VoidTy, Int32PtrTy, nullptr)); +      M.getOrInsertFunction(SanCovTraceBBName, VoidTy, Int32PtrTy));    // At this point we create a dummy array of guards because we don't    // know how many elements we will need. @@ -363,22 +367,28 @@ bool SanitizerCoverageModule::runOnModule(Module &M) {    if (Options.TracePCGuard) {      if (HasSancovGuardsSection) {        Function *CtorFunc; -      std::string SectionName(SanCovTracePCGuardSection); -      GlobalVariable *Bounds[2]; -      const char *Prefix[2] = {"__start_", "__stop_"}; -      for (int i = 0; i < 2; i++) { -        Bounds[i] = new GlobalVariable(M, Int32PtrTy, false, -                                       GlobalVariable::ExternalLinkage, nullptr, -                                       Prefix[i] + SectionName); -        Bounds[i]->setVisibility(GlobalValue::HiddenVisibility); -      } +      GlobalVariable *SecStart = new GlobalVariable( +          M, Int32PtrTy, false, GlobalVariable::ExternalLinkage, nullptr, +          getSanCovTracePCGuardSectionStart()); +      SecStart->setVisibility(GlobalValue::HiddenVisibility); +      GlobalVariable *SecEnd = new GlobalVariable( +          M, Int32PtrTy, false, GlobalVariable::ExternalLinkage, nullptr, +          getSanCovTracePCGuardSectionEnd()); +      SecEnd->setVisibility(GlobalValue::HiddenVisibility); +        std::tie(CtorFunc, std::ignore) = createSanitizerCtorAndInitFunctions(            M, SanCovModuleCtorName, SanCovTracePCGuardInitName,            {Int32PtrTy, Int32PtrTy}, -          {IRB.CreatePointerCast(Bounds[0], Int32PtrTy), -            IRB.CreatePointerCast(Bounds[1], Int32PtrTy)}); - -      appendToGlobalCtors(M, CtorFunc, SanCtorAndDtorPriority); +          {IRB.CreatePointerCast(SecStart, Int32PtrTy), +            IRB.CreatePointerCast(SecEnd, Int32PtrTy)}); + +      if (TargetTriple.supportsCOMDAT()) { +        // Use comdat to dedup CtorFunc. +        CtorFunc->setComdat(M.getOrInsertComdat(SanCovModuleCtorName)); +        appendToGlobalCtors(M, CtorFunc, SanCtorAndDtorPriority, CtorFunc); +      } else { +        appendToGlobalCtors(M, CtorFunc, SanCtorAndDtorPriority); +      }      }    } else if (!Options.TracePC) {      Function *CtorFunc; @@ -435,6 +445,11 @@ static bool shouldInstrumentBlock(const Function& F, const BasicBlock *BB, const    if (isa<UnreachableInst>(BB->getTerminator()))      return false; +  // Don't insert coverage into blocks without a valid insertion point +  // (catchswitch blocks). +  if (BB->getFirstInsertionPt() == BB->end()) +    return false; +    if (!ClPruneBlocks || &F.getEntryBlock() == BB)      return true; @@ -517,7 +532,7 @@ void SanitizerCoverageModule::CreateFunctionGuardArray(size_t NumGuards,        Constant::getNullValue(ArrayOfInt32Ty), "__sancov_gen_");    if (auto Comdat = F.getComdat())      FunctionGuardArray->setComdat(Comdat); -  FunctionGuardArray->setSection(SanCovTracePCGuardSection); +  FunctionGuardArray->setSection(getSanCovTracePCGuardSection());  }  bool SanitizerCoverageModule::InjectCoverage(Function &F, @@ -755,6 +770,27 @@ void SanitizerCoverageModule::InjectCoverageAtBlock(Function &F, BasicBlock &BB,    }  } +StringRef SanitizerCoverageModule::getSanCovTracePCGuardSection() const { +  if (TargetTriple.getObjectFormat() == Triple::COFF) +    return ".SCOV$M"; +  if (TargetTriple.isOSBinFormatMachO()) +    return "__DATA,__sancov_guards"; +  return "__sancov_guards"; +} + +StringRef SanitizerCoverageModule::getSanCovTracePCGuardSectionStart() const { +  if (TargetTriple.isOSBinFormatMachO()) +    return "\1section$start$__DATA$__sancov_guards"; +  return "__start___sancov_guards"; +} + +StringRef SanitizerCoverageModule::getSanCovTracePCGuardSectionEnd() const { +  if (TargetTriple.isOSBinFormatMachO()) +    return "\1section$end$__DATA$__sancov_guards"; +  return "__stop___sancov_guards"; +} + +  char SanitizerCoverageModule::ID = 0;  INITIALIZE_PASS_BEGIN(SanitizerCoverageModule, "sancov",                        "SanitizerCoverage: TODO." diff --git a/lib/Transforms/Instrumentation/ThreadSanitizer.cpp b/lib/Transforms/Instrumentation/ThreadSanitizer.cpp index 52035c79a4a3..9260217bd5e6 100644 --- a/lib/Transforms/Instrumentation/ThreadSanitizer.cpp +++ b/lib/Transforms/Instrumentation/ThreadSanitizer.cpp @@ -155,17 +155,18 @@ FunctionPass *llvm::createThreadSanitizerPass() {  void ThreadSanitizer::initializeCallbacks(Module &M) {    IRBuilder<> IRB(M.getContext()); -  AttributeSet Attr; -  Attr = Attr.addAttribute(M.getContext(), AttributeSet::FunctionIndex, Attribute::NoUnwind); +  AttributeList Attr; +  Attr = Attr.addAttribute(M.getContext(), AttributeList::FunctionIndex, +                           Attribute::NoUnwind);    // Initialize the callbacks.    TsanFuncEntry = checkSanitizerInterfaceFunction(M.getOrInsertFunction( -      "__tsan_func_entry", Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); +      "__tsan_func_entry", Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()));    TsanFuncExit = checkSanitizerInterfaceFunction( -      M.getOrInsertFunction("__tsan_func_exit", Attr, IRB.getVoidTy(), nullptr)); +      M.getOrInsertFunction("__tsan_func_exit", Attr, IRB.getVoidTy()));    TsanIgnoreBegin = checkSanitizerInterfaceFunction(M.getOrInsertFunction( -      "__tsan_ignore_thread_begin", Attr, IRB.getVoidTy(), nullptr)); +      "__tsan_ignore_thread_begin", Attr, IRB.getVoidTy()));    TsanIgnoreEnd = checkSanitizerInterfaceFunction(M.getOrInsertFunction( -      "__tsan_ignore_thread_end", Attr, IRB.getVoidTy(), nullptr)); +      "__tsan_ignore_thread_end", Attr, IRB.getVoidTy()));    OrdTy = IRB.getInt32Ty();    for (size_t i = 0; i < kNumberOfAccessSizes; ++i) {      const unsigned ByteSize = 1U << i; @@ -174,31 +175,31 @@ void ThreadSanitizer::initializeCallbacks(Module &M) {      std::string BitSizeStr = utostr(BitSize);      SmallString<32> ReadName("__tsan_read" + ByteSizeStr);      TsanRead[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( -        ReadName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); +        ReadName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()));      SmallString<32> WriteName("__tsan_write" + ByteSizeStr);      TsanWrite[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( -        WriteName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); +        WriteName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()));      SmallString<64> UnalignedReadName("__tsan_unaligned_read" + ByteSizeStr);      TsanUnalignedRead[i] =          checkSanitizerInterfaceFunction(M.getOrInsertFunction( -            UnalignedReadName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); +            UnalignedReadName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()));      SmallString<64> UnalignedWriteName("__tsan_unaligned_write" + ByteSizeStr);      TsanUnalignedWrite[i] =          checkSanitizerInterfaceFunction(M.getOrInsertFunction( -            UnalignedWriteName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); +            UnalignedWriteName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()));      Type *Ty = Type::getIntNTy(M.getContext(), BitSize);      Type *PtrTy = Ty->getPointerTo();      SmallString<32> AtomicLoadName("__tsan_atomic" + BitSizeStr + "_load");      TsanAtomicLoad[i] = checkSanitizerInterfaceFunction( -        M.getOrInsertFunction(AtomicLoadName, Attr, Ty, PtrTy, OrdTy, nullptr)); +        M.getOrInsertFunction(AtomicLoadName, Attr, Ty, PtrTy, OrdTy));      SmallString<32> AtomicStoreName("__tsan_atomic" + BitSizeStr + "_store");      TsanAtomicStore[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( -        AtomicStoreName, Attr, IRB.getVoidTy(), PtrTy, Ty, OrdTy, nullptr)); +        AtomicStoreName, Attr, IRB.getVoidTy(), PtrTy, Ty, OrdTy));      for (int op = AtomicRMWInst::FIRST_BINOP;          op <= AtomicRMWInst::LAST_BINOP; ++op) { @@ -222,33 +223,33 @@ void ThreadSanitizer::initializeCallbacks(Module &M) {          continue;        SmallString<32> RMWName("__tsan_atomic" + itostr(BitSize) + NamePart);        TsanAtomicRMW[op][i] = checkSanitizerInterfaceFunction( -          M.getOrInsertFunction(RMWName, Attr, Ty, PtrTy, Ty, OrdTy, nullptr)); +          M.getOrInsertFunction(RMWName, Attr, Ty, PtrTy, Ty, OrdTy));      }      SmallString<32> AtomicCASName("__tsan_atomic" + BitSizeStr +                                    "_compare_exchange_val");      TsanAtomicCAS[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( -        AtomicCASName, Attr, Ty, PtrTy, Ty, Ty, OrdTy, OrdTy, nullptr)); +        AtomicCASName, Attr, Ty, PtrTy, Ty, Ty, OrdTy, OrdTy));    }    TsanVptrUpdate = checkSanitizerInterfaceFunction(        M.getOrInsertFunction("__tsan_vptr_update", Attr, IRB.getVoidTy(), -                            IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), nullptr)); +                            IRB.getInt8PtrTy(), IRB.getInt8PtrTy()));    TsanVptrLoad = checkSanitizerInterfaceFunction(M.getOrInsertFunction( -      "__tsan_vptr_read", Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); +      "__tsan_vptr_read", Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()));    TsanAtomicThreadFence = checkSanitizerInterfaceFunction(M.getOrInsertFunction( -      "__tsan_atomic_thread_fence", Attr, IRB.getVoidTy(), OrdTy, nullptr)); +      "__tsan_atomic_thread_fence", Attr, IRB.getVoidTy(), OrdTy));    TsanAtomicSignalFence = checkSanitizerInterfaceFunction(M.getOrInsertFunction( -      "__tsan_atomic_signal_fence", Attr, IRB.getVoidTy(), OrdTy, nullptr)); +      "__tsan_atomic_signal_fence", Attr, IRB.getVoidTy(), OrdTy));    MemmoveFn = checkSanitizerInterfaceFunction(        M.getOrInsertFunction("memmove", Attr, IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), -                            IRB.getInt8PtrTy(), IntptrTy, nullptr)); +                            IRB.getInt8PtrTy(), IntptrTy));    MemcpyFn = checkSanitizerInterfaceFunction(        M.getOrInsertFunction("memcpy", Attr, IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), -                            IRB.getInt8PtrTy(), IntptrTy, nullptr)); +                            IRB.getInt8PtrTy(), IntptrTy));    MemsetFn = checkSanitizerInterfaceFunction(        M.getOrInsertFunction("memset", Attr, IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), -                            IRB.getInt32Ty(), IntptrTy, nullptr)); +                            IRB.getInt32Ty(), IntptrTy));  }  bool ThreadSanitizer::doInitialization(Module &M) { @@ -271,7 +272,7 @@ static bool isVtableAccess(Instruction *I) {  // Do not instrument known races/"benign races" that come from compiler  // instrumentatin. The user has no way of suppressing them. -static bool shouldInstrumentReadWriteFromAddress(Value *Addr) { +static bool shouldInstrumentReadWriteFromAddress(const Module *M, Value *Addr) {    // Peel off GEPs and BitCasts.    Addr = Addr->stripInBoundsOffsets(); @@ -279,8 +280,9 @@ static bool shouldInstrumentReadWriteFromAddress(Value *Addr) {      if (GV->hasSection()) {        StringRef SectionName = GV->getSection();        // Check if the global is in the PGO counters section. -      if (SectionName.endswith(getInstrProfCountersSectionName( -            /*AddSegment=*/false))) +      auto OF = Triple(M->getTargetTriple()).getObjectFormat(); +      if (SectionName.endswith( +              getInstrProfSectionName(IPSK_cnts, OF, /*AddSegmentInfo=*/false)))          return false;      } @@ -342,13 +344,13 @@ void ThreadSanitizer::chooseInstructionsToInstrument(    for (Instruction *I : reverse(Local)) {      if (StoreInst *Store = dyn_cast<StoreInst>(I)) {        Value *Addr = Store->getPointerOperand(); -      if (!shouldInstrumentReadWriteFromAddress(Addr)) +      if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))          continue;        WriteTargets.insert(Addr);      } else {        LoadInst *Load = cast<LoadInst>(I);        Value *Addr = Load->getPointerOperand(); -      if (!shouldInstrumentReadWriteFromAddress(Addr)) +      if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))          continue;        if (WriteTargets.count(Addr)) {          // We will write to this temp, so no reason to analyze the read. diff --git a/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h b/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h index c74827210364..c541fa4c8bee 100644 --- a/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h +++ b/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h @@ -127,9 +127,8 @@ private:      LLVMContext &C = TheModule->getContext();      Type *Params[] = { PointerType::getUnqual(Type::getInt8Ty(C)) }; -    AttributeSet Attr = -      AttributeSet().addAttribute(C, AttributeSet::FunctionIndex, -                                  Attribute::NoUnwind); +    AttributeList Attr = AttributeList().addAttribute( +        C, AttributeList::FunctionIndex, Attribute::NoUnwind);      FunctionType *Fty = FunctionType::get(Type::getVoidTy(C), Params,                                            /*isVarArg=*/false);      return Decl = TheModule->getOrInsertFunction(Name, Fty, Attr); @@ -144,10 +143,10 @@ private:      Type *I8X = PointerType::getUnqual(Type::getInt8Ty(C));      Type *Params[] = { I8X };      FunctionType *Fty = FunctionType::get(I8X, Params, /*isVarArg=*/false); -    AttributeSet Attr = AttributeSet(); +    AttributeList Attr = AttributeList();      if (NoUnwind) -      Attr = Attr.addAttribute(C, AttributeSet::FunctionIndex, +      Attr = Attr.addAttribute(C, AttributeList::FunctionIndex,                                 Attribute::NoUnwind);      return Decl = TheModule->getOrInsertFunction(Name, Fty, Attr); @@ -162,9 +161,8 @@ private:      Type *I8XX = PointerType::getUnqual(I8X);      Type *Params[] = { I8XX, I8X }; -    AttributeSet Attr = -      AttributeSet().addAttribute(C, AttributeSet::FunctionIndex, -                                  Attribute::NoUnwind); +    AttributeList Attr = AttributeList().addAttribute( +        C, AttributeList::FunctionIndex, Attribute::NoUnwind);      Attr = Attr.addAttribute(C, 1, Attribute::NoCapture);      FunctionType *Fty = FunctionType::get(Type::getVoidTy(C), Params, diff --git a/lib/Transforms/ObjCARC/ObjCARCContract.cpp b/lib/Transforms/ObjCARC/ObjCARCContract.cpp index 23c1f5990ba5..a86eaaec7641 100644 --- a/lib/Transforms/ObjCARC/ObjCARCContract.cpp +++ b/lib/Transforms/ObjCARC/ObjCARCContract.cpp @@ -394,6 +394,7 @@ void ObjCARCContract::tryToContractReleaseIntoStoreStrong(Instruction *Release,    DEBUG(llvm::dbgs() << "        New Store Strong: " << *StoreStrong << "\n"); +  if (&*Iter == Retain) ++Iter;    if (&*Iter == Store) ++Iter;    Store->eraseFromParent();    Release->eraseFromParent(); diff --git a/lib/Transforms/ObjCARC/ObjCARCOpts.cpp b/lib/Transforms/ObjCARC/ObjCARCOpts.cpp index 136d54a6cb75..3c73376c9906 100644 --- a/lib/Transforms/ObjCARC/ObjCARCOpts.cpp +++ b/lib/Transforms/ObjCARC/ObjCARCOpts.cpp @@ -85,41 +85,6 @@ static const Value *FindSingleUseIdentifiedObject(const Value *Arg) {    return nullptr;  } -/// This is a wrapper around getUnderlyingObjCPtr along the lines of -/// GetUnderlyingObjects except that it returns early when it sees the first -/// alloca. -static inline bool AreAnyUnderlyingObjectsAnAlloca(const Value *V, -                                                   const DataLayout &DL) { -  SmallPtrSet<const Value *, 4> Visited; -  SmallVector<const Value *, 4> Worklist; -  Worklist.push_back(V); -  do { -    const Value *P = Worklist.pop_back_val(); -    P = GetUnderlyingObjCPtr(P, DL); - -    if (isa<AllocaInst>(P)) -      return true; - -    if (!Visited.insert(P).second) -      continue; - -    if (const SelectInst *SI = dyn_cast<const SelectInst>(P)) { -      Worklist.push_back(SI->getTrueValue()); -      Worklist.push_back(SI->getFalseValue()); -      continue; -    } - -    if (const PHINode *PN = dyn_cast<const PHINode>(P)) { -      for (Value *IncValue : PN->incoming_values()) -        Worklist.push_back(IncValue); -      continue; -    } -  } while (!Worklist.empty()); - -  return false; -} - -  /// @}  ///  /// \defgroup ARCOpt ARC Optimization. @@ -481,9 +446,6 @@ namespace {      /// MDKind identifiers.      ARCMDKindCache MDKindCache; -    // This is used to track if a pointer is stored into an alloca. -    DenseSet<const Value *> MultiOwnersSet; -      /// A flag indicating whether this optimization pass should run.      bool Run; @@ -524,8 +486,7 @@ namespace {      PairUpRetainsAndReleases(DenseMap<const BasicBlock *, BBState> &BBStates,                               BlotMapVector<Value *, RRInfo> &Retains,                               DenseMap<Value *, RRInfo> &Releases, Module *M, -                             SmallVectorImpl<Instruction *> &NewRetains, -                             SmallVectorImpl<Instruction *> &NewReleases, +                             Instruction * Retain,                               SmallVectorImpl<Instruction *> &DeadInsts,                               RRInfo &RetainsToMove, RRInfo &ReleasesToMove,                               Value *Arg, bool KnownSafe, @@ -1155,29 +1116,6 @@ bool ObjCARCOpt::VisitInstructionBottomUp(    case ARCInstKind::None:      // These are irrelevant.      return NestingDetected; -  case ARCInstKind::User: -    // If we have a store into an alloca of a pointer we are tracking, the -    // pointer has multiple owners implying that we must be more conservative. -    // -    // This comes up in the context of a pointer being ``KnownSafe''. In the -    // presence of a block being initialized, the frontend will emit the -    // objc_retain on the original pointer and the release on the pointer loaded -    // from the alloca. The optimizer will through the provenance analysis -    // realize that the two are related, but since we only require KnownSafe in -    // one direction, will match the inner retain on the original pointer with -    // the guard release on the original pointer. This is fixed by ensuring that -    // in the presence of allocas we only unconditionally remove pointers if -    // both our retain and our release are KnownSafe. -    if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { -      const DataLayout &DL = BB->getModule()->getDataLayout(); -      if (AreAnyUnderlyingObjectsAnAlloca(SI->getPointerOperand(), DL)) { -        auto I = MyStates.findPtrBottomUpState( -            GetRCIdentityRoot(SI->getValueOperand())); -        if (I != MyStates.bottom_up_ptr_end()) -          MultiOwnersSet.insert(I->first); -      } -    } -    break;    default:      break;    } @@ -1540,8 +1478,7 @@ bool ObjCARCOpt::PairUpRetainsAndReleases(      DenseMap<const BasicBlock *, BBState> &BBStates,      BlotMapVector<Value *, RRInfo> &Retains,      DenseMap<Value *, RRInfo> &Releases, Module *M, -    SmallVectorImpl<Instruction *> &NewRetains, -    SmallVectorImpl<Instruction *> &NewReleases, +    Instruction *Retain,      SmallVectorImpl<Instruction *> &DeadInsts, RRInfo &RetainsToMove,      RRInfo &ReleasesToMove, Value *Arg, bool KnownSafe,      bool &AnyPairsCompletelyEliminated) { @@ -1549,7 +1486,6 @@ bool ObjCARCOpt::PairUpRetainsAndReleases(    // is already incremented, we can similarly ignore possible decrements unless    // we are dealing with a retainable object with multiple provenance sources.    bool KnownSafeTD = true, KnownSafeBU = true; -  bool MultipleOwners = false;    bool CFGHazardAfflicted = false;    // Connect the dots between the top-down-collected RetainsToMove and @@ -1561,14 +1497,13 @@ bool ObjCARCOpt::PairUpRetainsAndReleases(    unsigned OldCount = 0;    unsigned NewCount = 0;    bool FirstRelease = true; -  for (;;) { +  for (SmallVector<Instruction *, 4> NewRetains{Retain};;) { +    SmallVector<Instruction *, 4> NewReleases;      for (Instruction *NewRetain : NewRetains) {        auto It = Retains.find(NewRetain);        assert(It != Retains.end());        const RRInfo &NewRetainRRI = It->second;        KnownSafeTD &= NewRetainRRI.KnownSafe; -      MultipleOwners = -        MultipleOwners || MultiOwnersSet.count(GetArgRCIdentityRoot(NewRetain));        for (Instruction *NewRetainRelease : NewRetainRRI.Calls) {          auto Jt = Releases.find(NewRetainRelease);          if (Jt == Releases.end()) @@ -1691,7 +1626,6 @@ bool ObjCARCOpt::PairUpRetainsAndReleases(          }        }      } -    NewReleases.clear();      if (NewRetains.empty()) break;    } @@ -1745,10 +1679,6 @@ bool ObjCARCOpt::PerformCodePlacement(    DEBUG(dbgs() << "\n== ObjCARCOpt::PerformCodePlacement ==\n");    bool AnyPairsCompletelyEliminated = false; -  RRInfo RetainsToMove; -  RRInfo ReleasesToMove; -  SmallVector<Instruction *, 4> NewRetains; -  SmallVector<Instruction *, 4> NewReleases;    SmallVector<Instruction *, 8> DeadInsts;    // Visit each retain. @@ -1780,9 +1710,10 @@ bool ObjCARCOpt::PerformCodePlacement(      // Connect the dots between the top-down-collected RetainsToMove and      // bottom-up-collected ReleasesToMove to form sets of related calls. -    NewRetains.push_back(Retain); +    RRInfo RetainsToMove, ReleasesToMove; +      bool PerformMoveCalls = PairUpRetainsAndReleases( -        BBStates, Retains, Releases, M, NewRetains, NewReleases, DeadInsts, +        BBStates, Retains, Releases, M, Retain, DeadInsts,          RetainsToMove, ReleasesToMove, Arg, KnownSafe,          AnyPairsCompletelyEliminated); @@ -1792,12 +1723,6 @@ bool ObjCARCOpt::PerformCodePlacement(        MoveCalls(Arg, RetainsToMove, ReleasesToMove,                  Retains, Releases, DeadInsts, M);      } - -    // Clean up state for next retain. -    NewReleases.clear(); -    NewRetains.clear(); -    RetainsToMove.clear(); -    ReleasesToMove.clear();    }    // Now that we're done moving everything, we can delete the newly dead @@ -1987,9 +1912,6 @@ bool ObjCARCOpt::OptimizeSequences(Function &F) {                                                             Releases,                                                             F.getParent()); -  // Cleanup. -  MultiOwnersSet.clear(); -    return AnyPairsCompletelyEliminated && NestingDetected;  } diff --git a/lib/Transforms/Scalar/ADCE.cpp b/lib/Transforms/Scalar/ADCE.cpp index adc903cab31b..5b467dc9fe12 100644 --- a/lib/Transforms/Scalar/ADCE.cpp +++ b/lib/Transforms/Scalar/ADCE.cpp @@ -41,8 +41,8 @@ using namespace llvm;  STATISTIC(NumRemoved, "Number of instructions removed");  STATISTIC(NumBranchesRemoved, "Number of branch instructions removed"); -// This is a tempoary option until we change the interface -// to this pass based on optimization level. +// This is a temporary option until we change the interface to this pass based +// on optimization level.  static cl::opt<bool> RemoveControlFlowFlag("adce-remove-control-flow",                                             cl::init(true), cl::Hidden); @@ -110,7 +110,7 @@ class AggressiveDeadCodeElimination {    /// The set of blocks which we have determined whose control    /// dependence sources must be live and which have not had -  /// those dependences analyized. +  /// those dependences analyzed.    SmallPtrSet<BasicBlock *, 16> NewLiveBlocks;    /// Set up auxiliary data structures for Instructions and BasicBlocks and @@ -145,7 +145,7 @@ class AggressiveDeadCodeElimination {    /// was removed.    bool removeDeadInstructions(); -  /// Identify connected sections of the control flow grap which have +  /// Identify connected sections of the control flow graph which have    /// dead terminators and rewrite the control flow graph to remove them.    void updateDeadRegions(); @@ -234,7 +234,7 @@ void AggressiveDeadCodeElimination::initialize() {          return Iter != end() && Iter->second;        }      } State; -     +      State.reserve(F.size());      // Iterate over blocks in depth-first pre-order and      // treat all edges to a block already seen as loop back edges @@ -262,25 +262,6 @@ void AggressiveDeadCodeElimination::initialize() {        continue;      auto *BB = BBInfo.BB;      if (!PDT.getNode(BB)) { -      markLive(BBInfo.Terminator); -      continue; -    } -    for (auto *Succ : successors(BB)) -      if (!PDT.getNode(Succ)) { -        markLive(BBInfo.Terminator); -        break; -      } -  } - -  // Mark blocks live if there is no path from the block to the -  // return of the function or a successor for which this is true. -  // This protects IDFCalculator which cannot handle such blocks. -  for (auto &BBInfoPair : BlockInfo) { -    auto &BBInfo = BBInfoPair.second; -    if (BBInfo.terminatorIsLive()) -      continue; -    auto *BB = BBInfo.BB; -    if (!PDT.getNode(BB)) {        DEBUG(dbgs() << "Not post-dominated by return: " << BB->getName()                     << '\n';);        markLive(BBInfo.Terminator); @@ -579,7 +560,7 @@ void AggressiveDeadCodeElimination::updateDeadRegions() {          PreferredSucc = Info;      }      assert((PreferredSucc && PreferredSucc->PostOrder > 0) && -           "Failed to find safe successor for dead branc"); +           "Failed to find safe successor for dead branch");      bool First = true;      for (auto *Succ : successors(BB)) {        if (!First || Succ != PreferredSucc->BB) @@ -594,13 +575,13 @@ void AggressiveDeadCodeElimination::updateDeadRegions() {  // reverse top-sort order  void AggressiveDeadCodeElimination::computeReversePostOrder() { -   -  // This provides a post-order numbering of the reverse conrtol flow graph + +  // This provides a post-order numbering of the reverse control flow graph    // Note that it is incomplete in the presence of infinite loops but we don't    // need numbers blocks which don't reach the end of the functions since    // all branches in those blocks are forced live. -   -  // For each block without successors, extend the DFS from the bloack + +  // For each block without successors, extend the DFS from the block    // backward through the graph    SmallPtrSet<BasicBlock*, 16> Visited;    unsigned PostOrder = 0; @@ -644,8 +625,8 @@ PreservedAnalyses ADCEPass::run(Function &F, FunctionAnalysisManager &FAM) {    if (!AggressiveDeadCodeElimination(F, PDT).performDeadCodeElimination())      return PreservedAnalyses::all(); -  // FIXME: This should also 'preserve the CFG'. -  auto PA = PreservedAnalyses(); +  PreservedAnalyses PA; +  PA.preserveSet<CFGAnalyses>();    PA.preserve<GlobalsAA>();    return PA;  } diff --git a/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp index c1df3173c0fc..fd931c521c8f 100644 --- a/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp +++ b/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp @@ -438,19 +438,13 @@ AlignmentFromAssumptionsPass::run(Function &F, FunctionAnalysisManager &AM) {    AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);    ScalarEvolution &SE = AM.getResult<ScalarEvolutionAnalysis>(F);    DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); -  bool Changed = runImpl(F, AC, &SE, &DT); - -  // FIXME: We need to invalidate this to avoid PR28400. Is there a better -  // solution? -  AM.invalidate<ScalarEvolutionAnalysis>(F); - -  if (!Changed) +  if (!runImpl(F, AC, &SE, &DT))      return PreservedAnalyses::all(); +    PreservedAnalyses PA; +  PA.preserveSet<CFGAnalyses>();    PA.preserve<AAManager>();    PA.preserve<ScalarEvolutionAnalysis>();    PA.preserve<GlobalsAA>(); -  PA.preserve<LoopAnalysis>(); -  PA.preserve<DominatorTreeAnalysis>();    return PA;  } diff --git a/lib/Transforms/Scalar/BDCE.cpp b/lib/Transforms/Scalar/BDCE.cpp index 251b38707769..61e8700f1cd6 100644 --- a/lib/Transforms/Scalar/BDCE.cpp +++ b/lib/Transforms/Scalar/BDCE.cpp @@ -80,8 +80,8 @@ PreservedAnalyses BDCEPass::run(Function &F, FunctionAnalysisManager &AM) {    if (!bitTrackingDCE(F, DB))      return PreservedAnalyses::all(); -  // FIXME: This should also 'preserve the CFG'. -  auto PA = PreservedAnalyses(); +  PreservedAnalyses PA; +  PA.preserveSet<CFGAnalyses>();    PA.preserve<GlobalsAA>();    return PA;  } diff --git a/lib/Transforms/Scalar/CMakeLists.txt b/lib/Transforms/Scalar/CMakeLists.txt index 06d3d6a73954..b323ab3bd443 100644 --- a/lib/Transforms/Scalar/CMakeLists.txt +++ b/lib/Transforms/Scalar/CMakeLists.txt @@ -16,6 +16,7 @@ add_llvm_library(LLVMScalarOpts    IVUsersPrinter.cpp    InductiveRangeCheckElimination.cpp    IndVarSimplify.cpp +  InferAddressSpaces.cpp    JumpThreading.cpp    LICM.cpp    LoopAccessAnalysisPrinter.cpp @@ -29,6 +30,7 @@ add_llvm_library(LLVMScalarOpts    LoopInterchange.cpp    LoopLoadElimination.cpp    LoopPassManager.cpp +  LoopPredication.cpp    LoopRerollPass.cpp    LoopRotation.cpp    LoopSimplifyCFG.cpp diff --git a/lib/Transforms/Scalar/ConstantHoisting.cpp b/lib/Transforms/Scalar/ConstantHoisting.cpp index 38262514c9ec..ee6333e88716 100644 --- a/lib/Transforms/Scalar/ConstantHoisting.cpp +++ b/lib/Transforms/Scalar/ConstantHoisting.cpp @@ -136,8 +136,16 @@ Instruction *ConstantHoistingPass::findMatInsertPt(Instruction *Inst,    if (Idx != ~0U && isa<PHINode>(Inst))      return cast<PHINode>(Inst)->getIncomingBlock(Idx)->getTerminator(); -  BasicBlock *IDom = DT->getNode(Inst->getParent())->getIDom()->getBlock(); -  return IDom->getTerminator(); +  // This must be an EH pad. Iterate over immediate dominators until we find a +  // non-EH pad. We need to skip over catchswitch blocks, which are both EH pads +  // and terminators. +  auto IDom = DT->getNode(Inst->getParent())->getIDom(); +  while (IDom->getBlock()->isEHPad()) { +    assert(Entry != IDom->getBlock() && "eh pad in entry block"); +    IDom = IDom->getIDom(); +  } + +  return IDom->getBlock()->getTerminator();  }  /// \brief Find an insertion point that dominates all uses. @@ -289,8 +297,8 @@ void ConstantHoistingPass::collectConstantCandidates(Function &Fn) {  // bit widths (APInt Operator- does not like that). If the value cannot be  // represented in uint64 we return an "empty" APInt. This is then interpreted  // as the value is not in range. -static llvm::Optional<APInt> calculateOffsetDiff(APInt V1, APInt V2) -{ +static llvm::Optional<APInt> calculateOffsetDiff(const APInt &V1, +                                                 const APInt &V2) {    llvm::Optional<APInt> Res = None;    unsigned BW = V1.getBitWidth() > V2.getBitWidth() ?                  V1.getBitWidth() : V2.getBitWidth(); @@ -623,6 +631,7 @@ PreservedAnalyses ConstantHoistingPass::run(Function &F,    if (!runImpl(F, TTI, DT, F.getEntryBlock()))      return PreservedAnalyses::all(); -  // FIXME: This should also 'preserve the CFG'. -  return PreservedAnalyses::none(); +  PreservedAnalyses PA; +  PA.preserveSet<CFGAnalyses>(); +  return PA;  } diff --git a/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp index 84f9373ae914..c843c61ea94e 100644 --- a/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -235,9 +235,8 @@ static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI) {    // Analyse each switch case in turn.  This is done in reverse order so that    // removing a case doesn't cause trouble for the iteration.    bool Changed = false; -  for (SwitchInst::CaseIt CI = SI->case_end(), CE = SI->case_begin(); CI-- != CE; -       ) { -    ConstantInt *Case = CI.getCaseValue(); +  for (auto CI = SI->case_begin(), CE = SI->case_end(); CI != CE;) { +    ConstantInt *Case = CI->getCaseValue();      // Check to see if the switch condition is equal to/not equal to the case      // value on every incoming edge, equal/not equal being the same each time. @@ -270,8 +269,9 @@ static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI) {      if (State == LazyValueInfo::False) {        // This case never fires - remove it. -      CI.getCaseSuccessor()->removePredecessor(BB); -      SI->removeCase(CI); // Does not invalidate the iterator. +      CI->getCaseSuccessor()->removePredecessor(BB); +      CI = SI->removeCase(CI); +      CE = SI->case_end();        // The condition can be modified by removePredecessor's PHI simplification        // logic. @@ -279,7 +279,9 @@ static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI) {        ++NumDeadCases;        Changed = true; -    } else if (State == LazyValueInfo::True) { +      continue; +    } +    if (State == LazyValueInfo::True) {        // This case always fires.  Arrange for the switch to be turned into an        // unconditional branch by replacing the switch condition with the case        // value. @@ -288,6 +290,9 @@ static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI) {        Changed = true;        break;      } + +    // Increment the case iterator sense we didn't delete it. +    ++CI;    }    if (Changed) @@ -308,7 +313,7 @@ static bool processCallSite(CallSite CS, LazyValueInfo *LVI) {      // Try to mark pointer typed parameters as non-null.  We skip the      // relatively expensive analysis for constants which are obviously either      // null or non-null to start with. -    if (Type && !CS.paramHasAttr(ArgNo + 1, Attribute::NonNull) && +    if (Type && !CS.paramHasAttr(ArgNo, Attribute::NonNull) &&          !isa<Constant>(V) &&           LVI->getPredicateAt(ICmpInst::ICMP_EQ, V,                              ConstantPointerNull::get(Type), @@ -322,7 +327,7 @@ static bool processCallSite(CallSite CS, LazyValueInfo *LVI) {    if (Indices.empty())      return false; -  AttributeSet AS = CS.getAttributes(); +  AttributeList AS = CS.getAttributes();    LLVMContext &Ctx = CS.getInstruction()->getContext();    AS = AS.addAttribute(Ctx, Indices, Attribute::get(Ctx, Attribute::NonNull));    CS.setAttributes(AS); @@ -570,10 +575,6 @@ CorrelatedValuePropagationPass::run(Function &F, FunctionAnalysisManager &AM) {    LazyValueInfo *LVI = &AM.getResult<LazyValueAnalysis>(F);    bool Changed = runImpl(F, LVI); -  // FIXME: We need to invalidate LVI to avoid PR28400. Is there a better -  // solution? -  AM.invalidate<LazyValueAnalysis>(F); -    if (!Changed)      return PreservedAnalyses::all();    PreservedAnalyses PA; diff --git a/lib/Transforms/Scalar/DCE.cpp b/lib/Transforms/Scalar/DCE.cpp index cc2a3cfaf9d1..07a0ba9b1222 100644 --- a/lib/Transforms/Scalar/DCE.cpp +++ b/lib/Transforms/Scalar/DCE.cpp @@ -124,9 +124,12 @@ static bool eliminateDeadCode(Function &F, TargetLibraryInfo *TLI) {  }  PreservedAnalyses DCEPass::run(Function &F, FunctionAnalysisManager &AM) { -  if (eliminateDeadCode(F, AM.getCachedResult<TargetLibraryAnalysis>(F))) -    return PreservedAnalyses::none(); -  return PreservedAnalyses::all(); +  if (!eliminateDeadCode(F, AM.getCachedResult<TargetLibraryAnalysis>(F))) +    return PreservedAnalyses::all(); + +  PreservedAnalyses PA; +  PA.preserveSet<CFGAnalyses>(); +  return PA;  }  namespace { diff --git a/lib/Transforms/Scalar/DeadStoreElimination.cpp b/lib/Transforms/Scalar/DeadStoreElimination.cpp index 4d4c3baef3f5..1ec38e56aa4c 100644 --- a/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -135,13 +135,13 @@ static bool hasMemoryWrite(Instruction *I, const TargetLibraryInfo &TLI) {    if (auto CS = CallSite(I)) {      if (Function *F = CS.getCalledFunction()) {        StringRef FnName = F->getName(); -      if (TLI.has(LibFunc::strcpy) && FnName == TLI.getName(LibFunc::strcpy)) +      if (TLI.has(LibFunc_strcpy) && FnName == TLI.getName(LibFunc_strcpy))          return true; -      if (TLI.has(LibFunc::strncpy) && FnName == TLI.getName(LibFunc::strncpy)) +      if (TLI.has(LibFunc_strncpy) && FnName == TLI.getName(LibFunc_strncpy))          return true; -      if (TLI.has(LibFunc::strcat) && FnName == TLI.getName(LibFunc::strcat)) +      if (TLI.has(LibFunc_strcat) && FnName == TLI.getName(LibFunc_strcat))          return true; -      if (TLI.has(LibFunc::strncat) && FnName == TLI.getName(LibFunc::strncat)) +      if (TLI.has(LibFunc_strncat) && FnName == TLI.getName(LibFunc_strncat))          return true;      }    } @@ -287,19 +287,14 @@ static uint64_t getPointerSize(const Value *V, const DataLayout &DL,  }  namespace { -enum OverwriteResult { -  OverwriteBegin, -  OverwriteComplete, -  OverwriteEnd, -  OverwriteUnknown -}; +enum OverwriteResult { OW_Begin, OW_Complete, OW_End, OW_Unknown };  } -/// Return 'OverwriteComplete' if a store to the 'Later' location completely -/// overwrites a store to the 'Earlier' location, 'OverwriteEnd' if the end of -/// the 'Earlier' location is completely overwritten by 'Later', -/// 'OverwriteBegin' if the beginning of the 'Earlier' location is overwritten -/// by 'Later', or 'OverwriteUnknown' if nothing can be determined. +/// Return 'OW_Complete' if a store to the 'Later' location completely +/// overwrites a store to the 'Earlier' location, 'OW_End' if the end of the +/// 'Earlier' location is completely overwritten by 'Later', 'OW_Begin' if the +/// beginning of the 'Earlier' location is overwritten by 'Later', or +/// 'OW_Unknown' if nothing can be determined.  static OverwriteResult isOverwrite(const MemoryLocation &Later,                                     const MemoryLocation &Earlier,                                     const DataLayout &DL, @@ -310,7 +305,7 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later,    // If we don't know the sizes of either access, then we can't do a comparison.    if (Later.Size == MemoryLocation::UnknownSize ||        Earlier.Size == MemoryLocation::UnknownSize) -    return OverwriteUnknown; +    return OW_Unknown;    const Value *P1 = Earlier.Ptr->stripPointerCasts();    const Value *P2 = Later.Ptr->stripPointerCasts(); @@ -320,7 +315,7 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later,    if (P1 == P2) {      // Make sure that the Later size is >= the Earlier size.      if (Later.Size >= Earlier.Size) -      return OverwriteComplete; +      return OW_Complete;    }    // Check to see if the later store is to the entire object (either a global, @@ -332,13 +327,13 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later,    // If we can't resolve the same pointers to the same object, then we can't    // analyze them at all.    if (UO1 != UO2) -    return OverwriteUnknown; +    return OW_Unknown;    // If the "Later" store is to a recognizable object, get its size.    uint64_t ObjectSize = getPointerSize(UO2, DL, TLI);    if (ObjectSize != MemoryLocation::UnknownSize)      if (ObjectSize == Later.Size && ObjectSize >= Earlier.Size) -      return OverwriteComplete; +      return OW_Complete;    // Okay, we have stores to two completely different pointers.  Try to    // decompose the pointer into a "base + constant_offset" form.  If the base @@ -350,7 +345,7 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later,    // If the base pointers still differ, we have two completely different stores.    if (BP1 != BP2) -    return OverwriteUnknown; +    return OW_Unknown;    // The later store completely overlaps the earlier store if:    // @@ -370,7 +365,7 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later,    if (EarlierOff >= LaterOff &&        Later.Size >= Earlier.Size &&        uint64_t(EarlierOff - LaterOff) + Earlier.Size <= Later.Size) -    return OverwriteComplete; +    return OW_Complete;    // We may now overlap, although the overlap is not complete. There might also    // be other incomplete overlaps, and together, they might cover the complete @@ -428,7 +423,7 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later,                        ") Composite Later [" <<                        ILI->second << ", " << ILI->first << ")\n");        ++NumCompletePartials; -      return OverwriteComplete; +      return OW_Complete;      }    } @@ -443,7 +438,7 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later,    if (!EnablePartialOverwriteTracking &&        (LaterOff > EarlierOff && LaterOff < int64_t(EarlierOff + Earlier.Size) &&         int64_t(LaterOff + Later.Size) >= int64_t(EarlierOff + Earlier.Size))) -    return OverwriteEnd; +    return OW_End;    // Finally, we also need to check if the later store overwrites the beginning    // of the earlier store. @@ -458,11 +453,11 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later,        (LaterOff <= EarlierOff && int64_t(LaterOff + Later.Size) > EarlierOff)) {      assert(int64_t(LaterOff + Later.Size) <                 int64_t(EarlierOff + Earlier.Size) && -           "Expect to be handled as OverwriteComplete"); -    return OverwriteBegin; +           "Expect to be handled as OW_Complete"); +    return OW_Begin;    }    // Otherwise, they don't completely overlap. -  return OverwriteUnknown; +  return OW_Unknown;  }  /// If 'Inst' might be a self read (i.e. a noop copy of a @@ -551,7 +546,7 @@ static bool memoryIsNotModifiedBetween(Instruction *FirstI,        Instruction *I = &*BI;        if (I->mayWriteToMemory() && I != SecondI) {          auto Res = AA->getModRefInfo(I, MemLoc); -        if (Res != MRI_NoModRef) +        if (Res & MRI_Mod)            return false;        }      } @@ -909,7 +904,7 @@ static bool tryToShortenBegin(Instruction *EarlierWrite,    if (LaterStart <= EarlierStart && LaterStart + LaterSize > EarlierStart) {      assert(LaterStart + LaterSize < EarlierStart + EarlierSize && -           "Should have been handled as OverwriteComplete"); +           "Should have been handled as OW_Complete");      if (tryToShorten(EarlierWrite, EarlierStart, EarlierSize, LaterStart,                       LaterSize, false)) {        IntervalMap.erase(OII); @@ -1105,7 +1100,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,          OverwriteResult OR =              isOverwrite(Loc, DepLoc, DL, *TLI, DepWriteOffset, InstWriteOffset,                          DepWrite, IOL); -        if (OR == OverwriteComplete) { +        if (OR == OW_Complete) {            DEBUG(dbgs() << "DSE: Remove Dead Store:\n  DEAD: "                  << *DepWrite << "\n  KILLER: " << *Inst << '\n'); @@ -1117,15 +1112,15 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,            // We erased DepWrite; start over.            InstDep = MD->getDependency(Inst);            continue; -        } else if ((OR == OverwriteEnd && isShortenableAtTheEnd(DepWrite)) || -                   ((OR == OverwriteBegin && +        } else if ((OR == OW_End && isShortenableAtTheEnd(DepWrite)) || +                   ((OR == OW_Begin &&                       isShortenableAtTheBeginning(DepWrite)))) {            assert(!EnablePartialOverwriteTracking && "Do not expect to perform "                                                      "when partial-overwrite "                                                      "tracking is enabled");            int64_t EarlierSize = DepLoc.Size;            int64_t LaterSize = Loc.Size; -          bool IsOverwriteEnd = (OR == OverwriteEnd); +          bool IsOverwriteEnd = (OR == OW_End);            MadeChange |= tryToShorten(DepWrite, DepWriteOffset, EarlierSize,                                      InstWriteOffset, LaterSize, IsOverwriteEnd);          } @@ -1186,8 +1181,9 @@ PreservedAnalyses DSEPass::run(Function &F, FunctionAnalysisManager &AM) {    if (!eliminateDeadStores(F, AA, MD, DT, TLI))      return PreservedAnalyses::all(); +    PreservedAnalyses PA; -  PA.preserve<DominatorTreeAnalysis>(); +  PA.preserveSet<CFGAnalyses>();    PA.preserve<GlobalsAA>();    PA.preserve<MemoryDependenceAnalysis>();    return PA; diff --git a/lib/Transforms/Scalar/EarlyCSE.cpp b/lib/Transforms/Scalar/EarlyCSE.cpp index 16e08ee58fbe..04479b6e49ac 100644 --- a/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/lib/Transforms/Scalar/EarlyCSE.cpp @@ -19,6 +19,8 @@  #include "llvm/Analysis/AssumptionCache.h"  #include "llvm/Analysis/GlobalsModRef.h"  #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h"  #include "llvm/Analysis/TargetLibraryInfo.h"  #include "llvm/Analysis/TargetTransformInfo.h"  #include "llvm/IR/DataLayout.h" @@ -32,7 +34,6 @@  #include "llvm/Support/raw_ostream.h"  #include "llvm/Transforms/Scalar.h"  #include "llvm/Transforms/Utils/Local.h" -#include "llvm/Transforms/Utils/MemorySSA.h"  #include <deque>  using namespace llvm;  using namespace llvm::PatternMatch; @@ -253,6 +254,7 @@ public:    DominatorTree &DT;    AssumptionCache &AC;    MemorySSA *MSSA; +  std::unique_ptr<MemorySSAUpdater> MSSAUpdater;    typedef RecyclingAllocator<        BumpPtrAllocator, ScopedHashTableVal<SimpleValue, Value *>> AllocatorTy;    typedef ScopedHashTable<SimpleValue, Value *, DenseMapInfo<SimpleValue>, @@ -315,7 +317,9 @@ public:    /// \brief Set up the EarlyCSE runner for a particular function.    EarlyCSE(const TargetLibraryInfo &TLI, const TargetTransformInfo &TTI,             DominatorTree &DT, AssumptionCache &AC, MemorySSA *MSSA) -      : TLI(TLI), TTI(TTI), DT(DT), AC(AC), MSSA(MSSA), CurrentGeneration(0) {} +      : TLI(TLI), TTI(TTI), DT(DT), AC(AC), MSSA(MSSA), +        MSSAUpdater(make_unique<MemorySSAUpdater>(MSSA)), CurrentGeneration(0) { +  }    bool run(); @@ -388,7 +392,7 @@ private:      ParseMemoryInst(Instruction *Inst, const TargetTransformInfo &TTI)        : IsTargetMemInst(false), Inst(Inst) {        if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) -        if (TTI.getTgtMemIntrinsic(II, Info) && Info.NumMemRefs == 1) +        if (TTI.getTgtMemIntrinsic(II, Info))            IsTargetMemInst = true;      }      bool isLoad() const { @@ -400,17 +404,14 @@ private:        return isa<StoreInst>(Inst);      }      bool isAtomic() const { -      if (IsTargetMemInst) { -        assert(Info.IsSimple && "need to refine IsSimple in TTI"); -        return false; -      } +      if (IsTargetMemInst) +        return Info.Ordering != AtomicOrdering::NotAtomic;        return Inst->isAtomic();      }      bool isUnordered() const { -      if (IsTargetMemInst) { -        assert(Info.IsSimple && "need to refine IsSimple in TTI"); -        return true; -      } +      if (IsTargetMemInst) +        return Info.isUnordered(); +        if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) {          return LI->isUnordered();        } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { @@ -421,10 +422,9 @@ private:      }      bool isVolatile() const { -      if (IsTargetMemInst) { -        assert(Info.IsSimple && "need to refine IsSimple in TTI"); -        return false; -      } +      if (IsTargetMemInst) +        return Info.IsVolatile; +        if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) {          return LI->isVolatile();        } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { @@ -517,7 +517,7 @@ private:            if (MemoryPhi *MP = dyn_cast<MemoryPhi>(U))              PhisToCheck.push_back(MP); -        MSSA->removeMemoryAccess(WI); +        MSSAUpdater->removeMemoryAccess(WI);          for (MemoryPhi *MP : PhisToCheck) {            MemoryAccess *FirstIn = MP->getIncomingValue(0); @@ -587,27 +587,28 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {    // which reaches this block where the condition might hold a different    // value.  Since we're adding this to the scoped hash table (like any other    // def), it will have been popped if we encounter a future merge block. -  if (BasicBlock *Pred = BB->getSinglePredecessor()) -    if (auto *BI = dyn_cast<BranchInst>(Pred->getTerminator())) -      if (BI->isConditional()) -        if (auto *CondInst = dyn_cast<Instruction>(BI->getCondition())) -          if (SimpleValue::canHandle(CondInst)) { -            assert(BI->getSuccessor(0) == BB || BI->getSuccessor(1) == BB); -            auto *ConditionalConstant = (BI->getSuccessor(0) == BB) ? -              ConstantInt::getTrue(BB->getContext()) : -              ConstantInt::getFalse(BB->getContext()); -            AvailableValues.insert(CondInst, ConditionalConstant); -            DEBUG(dbgs() << "EarlyCSE CVP: Add conditional value for '" -                  << CondInst->getName() << "' as " << *ConditionalConstant -                  << " in " << BB->getName() << "\n"); -            // Replace all dominated uses with the known value. -            if (unsigned Count = -                    replaceDominatedUsesWith(CondInst, ConditionalConstant, DT, -                                             BasicBlockEdge(Pred, BB))) { -              Changed = true; -              NumCSECVP = NumCSECVP + Count; -            } -          } +  if (BasicBlock *Pred = BB->getSinglePredecessor()) { +    auto *BI = dyn_cast<BranchInst>(Pred->getTerminator()); +    if (BI && BI->isConditional()) { +      auto *CondInst = dyn_cast<Instruction>(BI->getCondition()); +      if (CondInst && SimpleValue::canHandle(CondInst)) { +        assert(BI->getSuccessor(0) == BB || BI->getSuccessor(1) == BB); +        auto *TorF = (BI->getSuccessor(0) == BB) +                         ? ConstantInt::getTrue(BB->getContext()) +                         : ConstantInt::getFalse(BB->getContext()); +        AvailableValues.insert(CondInst, TorF); +        DEBUG(dbgs() << "EarlyCSE CVP: Add conditional value for '" +                     << CondInst->getName() << "' as " << *TorF << " in " +                     << BB->getName() << "\n"); +        // Replace all dominated uses with the known value. +        if (unsigned Count = replaceDominatedUsesWith( +                CondInst, TorF, DT, BasicBlockEdge(Pred, BB))) { +          Changed = true; +          NumCSECVP = NumCSECVP + Count; +        } +      } +    } +  }    /// LastStore - Keep track of the last non-volatile store that we saw... for    /// as long as there in no instruction that reads memory.  If we see a store @@ -761,12 +762,13 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {        continue;      } -    // If this instruction may read from memory, forget LastStore. -    // Load/store intrinsics will indicate both a read and a write to -    // memory.  The target may override this (e.g. so that a store intrinsic -    // does not read  from memory, and thus will be treated the same as a -    // regular store for commoning purposes). -    if (Inst->mayReadFromMemory() && +    // If this instruction may read from memory or throw (and potentially read +    // from memory in the exception handler), forget LastStore.  Load/store +    // intrinsics will indicate both a read and a write to memory.  The target +    // may override this (e.g. so that a store intrinsic does not read from +    // memory, and thus will be treated the same as a regular store for +    // commoning purposes). +    if ((Inst->mayReadFromMemory() || Inst->mayThrow()) &&          !(MemInst.isValid() && !MemInst.mayReadFromMemory()))        LastStore = nullptr; @@ -967,10 +969,8 @@ PreservedAnalyses EarlyCSEPass::run(Function &F,    if (!CSE.run())      return PreservedAnalyses::all(); -  // CSE preserves the dominator tree because it doesn't mutate the CFG. -  // FIXME: Bundle this with other CFG-preservation.    PreservedAnalyses PA; -  PA.preserve<DominatorTreeAnalysis>(); +  PA.preserveSet<CFGAnalyses>();    PA.preserve<GlobalsAA>();    if (UseMemorySSA)      PA.preserve<MemorySSAAnalysis>(); diff --git a/lib/Transforms/Scalar/Float2Int.cpp b/lib/Transforms/Scalar/Float2Int.cpp index 545036d724ef..8a5af6195f1b 100644 --- a/lib/Transforms/Scalar/Float2Int.cpp +++ b/lib/Transforms/Scalar/Float2Int.cpp @@ -516,11 +516,10 @@ FunctionPass *createFloat2IntPass() { return new Float2IntLegacyPass(); }  PreservedAnalyses Float2IntPass::run(Function &F, FunctionAnalysisManager &) {    if (!runImpl(F))      return PreservedAnalyses::all(); -  else { -    // FIXME: This should also 'preserve the CFG'. -    PreservedAnalyses PA; -    PA.preserve<GlobalsAA>(); -    return PA; -  } + +  PreservedAnalyses PA; +  PA.preserveSet<CFGAnalyses>(); +  PA.preserve<GlobalsAA>(); +  return PA;  }  } // End namespace llvm diff --git a/lib/Transforms/Scalar/GVN.cpp b/lib/Transforms/Scalar/GVN.cpp index 0137378b828b..be696df548d5 100644 --- a/lib/Transforms/Scalar/GVN.cpp +++ b/lib/Transforms/Scalar/GVN.cpp @@ -36,7 +36,6 @@  #include "llvm/Analysis/OptimizationDiagnosticInfo.h"  #include "llvm/Analysis/PHITransAddr.h"  #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Analysis/ValueTracking.h"  #include "llvm/IR/DataLayout.h"  #include "llvm/IR/Dominators.h"  #include "llvm/IR/GlobalVariable.h" @@ -51,9 +50,12 @@  #include "llvm/Transforms/Utils/BasicBlockUtils.h"  #include "llvm/Transforms/Utils/Local.h"  #include "llvm/Transforms/Utils/SSAUpdater.h" +#include "llvm/Transforms/Utils/VNCoercion.h" +  #include <vector>  using namespace llvm;  using namespace llvm::gvn; +using namespace llvm::VNCoercion;  using namespace PatternMatch;  #define DEBUG_TYPE "gvn" @@ -595,11 +597,12 @@ PreservedAnalyses GVN::run(Function &F, FunctionAnalysisManager &AM) {    PreservedAnalyses PA;    PA.preserve<DominatorTreeAnalysis>();    PA.preserve<GlobalsAA>(); +  PA.preserve<TargetLibraryAnalysis>();    return PA;  } -LLVM_DUMP_METHOD -void GVN::dump(DenseMap<uint32_t, Value*>& d) { +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void GVN::dump(DenseMap<uint32_t, Value*>& d) {    errs() << "{\n";    for (DenseMap<uint32_t, Value*>::iterator I = d.begin(),         E = d.end(); I != E; ++I) { @@ -608,6 +611,7 @@ void GVN::dump(DenseMap<uint32_t, Value*>& d) {    }    errs() << "}\n";  } +#endif  /// Return true if we can prove that the value  /// we're analyzing is fully available in the specified block.  As we go, keep @@ -690,442 +694,6 @@ SpeculationFailure:  } -/// Return true if CoerceAvailableValueToLoadType will succeed. -static bool CanCoerceMustAliasedValueToLoad(Value *StoredVal, -                                            Type *LoadTy, -                                            const DataLayout &DL) { -  // If the loaded or stored value is an first class array or struct, don't try -  // to transform them.  We need to be able to bitcast to integer. -  if (LoadTy->isStructTy() || LoadTy->isArrayTy() || -      StoredVal->getType()->isStructTy() || -      StoredVal->getType()->isArrayTy()) -    return false; - -  // The store has to be at least as big as the load. -  if (DL.getTypeSizeInBits(StoredVal->getType()) < -        DL.getTypeSizeInBits(LoadTy)) -    return false; - -  return true; -} - -/// If we saw a store of a value to memory, and -/// then a load from a must-aliased pointer of a different type, try to coerce -/// the stored value.  LoadedTy is the type of the load we want to replace. -/// IRB is IRBuilder used to insert new instructions. -/// -/// If we can't do it, return null. -static Value *CoerceAvailableValueToLoadType(Value *StoredVal, Type *LoadedTy, -                                             IRBuilder<> &IRB, -                                             const DataLayout &DL) { -  assert(CanCoerceMustAliasedValueToLoad(StoredVal, LoadedTy, DL) && -         "precondition violation - materialization can't fail"); - -  if (auto *C = dyn_cast<Constant>(StoredVal)) -    if (auto *FoldedStoredVal = ConstantFoldConstant(C, DL)) -      StoredVal = FoldedStoredVal; - -  // If this is already the right type, just return it. -  Type *StoredValTy = StoredVal->getType(); - -  uint64_t StoredValSize = DL.getTypeSizeInBits(StoredValTy); -  uint64_t LoadedValSize = DL.getTypeSizeInBits(LoadedTy); - -  // If the store and reload are the same size, we can always reuse it. -  if (StoredValSize == LoadedValSize) { -    // Pointer to Pointer -> use bitcast. -    if (StoredValTy->getScalarType()->isPointerTy() && -        LoadedTy->getScalarType()->isPointerTy()) { -      StoredVal = IRB.CreateBitCast(StoredVal, LoadedTy); -    } else { -      // Convert source pointers to integers, which can be bitcast. -      if (StoredValTy->getScalarType()->isPointerTy()) { -        StoredValTy = DL.getIntPtrType(StoredValTy); -        StoredVal = IRB.CreatePtrToInt(StoredVal, StoredValTy); -      } - -      Type *TypeToCastTo = LoadedTy; -      if (TypeToCastTo->getScalarType()->isPointerTy()) -        TypeToCastTo = DL.getIntPtrType(TypeToCastTo); - -      if (StoredValTy != TypeToCastTo) -        StoredVal = IRB.CreateBitCast(StoredVal, TypeToCastTo); - -      // Cast to pointer if the load needs a pointer type. -      if (LoadedTy->getScalarType()->isPointerTy()) -        StoredVal = IRB.CreateIntToPtr(StoredVal, LoadedTy); -    } - -    if (auto *C = dyn_cast<ConstantExpr>(StoredVal)) -      if (auto *FoldedStoredVal = ConstantFoldConstant(C, DL)) -        StoredVal = FoldedStoredVal; - -    return StoredVal; -  } - -  // If the loaded value is smaller than the available value, then we can -  // extract out a piece from it.  If the available value is too small, then we -  // can't do anything. -  assert(StoredValSize >= LoadedValSize && -         "CanCoerceMustAliasedValueToLoad fail"); - -  // Convert source pointers to integers, which can be manipulated. -  if (StoredValTy->getScalarType()->isPointerTy()) { -    StoredValTy = DL.getIntPtrType(StoredValTy); -    StoredVal = IRB.CreatePtrToInt(StoredVal, StoredValTy); -  } - -  // Convert vectors and fp to integer, which can be manipulated. -  if (!StoredValTy->isIntegerTy()) { -    StoredValTy = IntegerType::get(StoredValTy->getContext(), StoredValSize); -    StoredVal = IRB.CreateBitCast(StoredVal, StoredValTy); -  } - -  // If this is a big-endian system, we need to shift the value down to the low -  // bits so that a truncate will work. -  if (DL.isBigEndian()) { -    uint64_t ShiftAmt = DL.getTypeStoreSizeInBits(StoredValTy) - -                        DL.getTypeStoreSizeInBits(LoadedTy); -    StoredVal = IRB.CreateLShr(StoredVal, ShiftAmt, "tmp"); -  } - -  // Truncate the integer to the right size now. -  Type *NewIntTy = IntegerType::get(StoredValTy->getContext(), LoadedValSize); -  StoredVal  = IRB.CreateTrunc(StoredVal, NewIntTy, "trunc"); - -  if (LoadedTy != NewIntTy) { -    // If the result is a pointer, inttoptr. -    if (LoadedTy->getScalarType()->isPointerTy()) -      StoredVal = IRB.CreateIntToPtr(StoredVal, LoadedTy, "inttoptr"); -    else -      // Otherwise, bitcast. -      StoredVal = IRB.CreateBitCast(StoredVal, LoadedTy, "bitcast"); -  } - -  if (auto *C = dyn_cast<Constant>(StoredVal)) -    if (auto *FoldedStoredVal = ConstantFoldConstant(C, DL)) -      StoredVal = FoldedStoredVal; - -  return StoredVal; -} - -/// This function is called when we have a -/// memdep query of a load that ends up being a clobbering memory write (store, -/// memset, memcpy, memmove).  This means that the write *may* provide bits used -/// by the load but we can't be sure because the pointers don't mustalias. -/// -/// Check this case to see if there is anything more we can do before we give -/// up.  This returns -1 if we have to give up, or a byte number in the stored -/// value of the piece that feeds the load. -static int AnalyzeLoadFromClobberingWrite(Type *LoadTy, Value *LoadPtr, -                                          Value *WritePtr, -                                          uint64_t WriteSizeInBits, -                                          const DataLayout &DL) { -  // If the loaded or stored value is a first class array or struct, don't try -  // to transform them.  We need to be able to bitcast to integer. -  if (LoadTy->isStructTy() || LoadTy->isArrayTy()) -    return -1; - -  int64_t StoreOffset = 0, LoadOffset = 0; -  Value *StoreBase = -      GetPointerBaseWithConstantOffset(WritePtr, StoreOffset, DL); -  Value *LoadBase = GetPointerBaseWithConstantOffset(LoadPtr, LoadOffset, DL); -  if (StoreBase != LoadBase) -    return -1; - -  // If the load and store are to the exact same address, they should have been -  // a must alias.  AA must have gotten confused. -  // FIXME: Study to see if/when this happens.  One case is forwarding a memset -  // to a load from the base of the memset. - -  // If the load and store don't overlap at all, the store doesn't provide -  // anything to the load.  In this case, they really don't alias at all, AA -  // must have gotten confused. -  uint64_t LoadSize = DL.getTypeSizeInBits(LoadTy); - -  if ((WriteSizeInBits & 7) | (LoadSize & 7)) -    return -1; -  uint64_t StoreSize = WriteSizeInBits / 8;  // Convert to bytes. -  LoadSize /= 8; - - -  bool isAAFailure = false; -  if (StoreOffset < LoadOffset) -    isAAFailure = StoreOffset+int64_t(StoreSize) <= LoadOffset; -  else -    isAAFailure = LoadOffset+int64_t(LoadSize) <= StoreOffset; - -  if (isAAFailure) -    return -1; - -  // If the Load isn't completely contained within the stored bits, we don't -  // have all the bits to feed it.  We could do something crazy in the future -  // (issue a smaller load then merge the bits in) but this seems unlikely to be -  // valuable. -  if (StoreOffset > LoadOffset || -      StoreOffset+StoreSize < LoadOffset+LoadSize) -    return -1; - -  // Okay, we can do this transformation.  Return the number of bytes into the -  // store that the load is. -  return LoadOffset-StoreOffset; -} - -/// This function is called when we have a -/// memdep query of a load that ends up being a clobbering store. -static int AnalyzeLoadFromClobberingStore(Type *LoadTy, Value *LoadPtr, -                                          StoreInst *DepSI) { -  // Cannot handle reading from store of first-class aggregate yet. -  if (DepSI->getValueOperand()->getType()->isStructTy() || -      DepSI->getValueOperand()->getType()->isArrayTy()) -    return -1; - -  const DataLayout &DL = DepSI->getModule()->getDataLayout(); -  Value *StorePtr = DepSI->getPointerOperand(); -  uint64_t StoreSize =DL.getTypeSizeInBits(DepSI->getValueOperand()->getType()); -  return AnalyzeLoadFromClobberingWrite(LoadTy, LoadPtr, -                                        StorePtr, StoreSize, DL); -} - -/// This function is called when we have a -/// memdep query of a load that ends up being clobbered by another load.  See if -/// the other load can feed into the second load. -static int AnalyzeLoadFromClobberingLoad(Type *LoadTy, Value *LoadPtr, -                                         LoadInst *DepLI, const DataLayout &DL){ -  // Cannot handle reading from store of first-class aggregate yet. -  if (DepLI->getType()->isStructTy() || DepLI->getType()->isArrayTy()) -    return -1; - -  Value *DepPtr = DepLI->getPointerOperand(); -  uint64_t DepSize = DL.getTypeSizeInBits(DepLI->getType()); -  int R = AnalyzeLoadFromClobberingWrite(LoadTy, LoadPtr, DepPtr, DepSize, DL); -  if (R != -1) return R; - -  // If we have a load/load clobber an DepLI can be widened to cover this load, -  // then we should widen it! -  int64_t LoadOffs = 0; -  const Value *LoadBase = -      GetPointerBaseWithConstantOffset(LoadPtr, LoadOffs, DL); -  unsigned LoadSize = DL.getTypeStoreSize(LoadTy); - -  unsigned Size = MemoryDependenceResults::getLoadLoadClobberFullWidthSize( -      LoadBase, LoadOffs, LoadSize, DepLI); -  if (Size == 0) return -1; - -  // Check non-obvious conditions enforced by MDA which we rely on for being -  // able to materialize this potentially available value -  assert(DepLI->isSimple() && "Cannot widen volatile/atomic load!"); -  assert(DepLI->getType()->isIntegerTy() && "Can't widen non-integer load"); - -  return AnalyzeLoadFromClobberingWrite(LoadTy, LoadPtr, DepPtr, Size*8, DL); -} - - - -static int AnalyzeLoadFromClobberingMemInst(Type *LoadTy, Value *LoadPtr, -                                            MemIntrinsic *MI, -                                            const DataLayout &DL) { -  // If the mem operation is a non-constant size, we can't handle it. -  ConstantInt *SizeCst = dyn_cast<ConstantInt>(MI->getLength()); -  if (!SizeCst) return -1; -  uint64_t MemSizeInBits = SizeCst->getZExtValue()*8; - -  // If this is memset, we just need to see if the offset is valid in the size -  // of the memset.. -  if (MI->getIntrinsicID() == Intrinsic::memset) -    return AnalyzeLoadFromClobberingWrite(LoadTy, LoadPtr, MI->getDest(), -                                          MemSizeInBits, DL); - -  // If we have a memcpy/memmove, the only case we can handle is if this is a -  // copy from constant memory.  In that case, we can read directly from the -  // constant memory. -  MemTransferInst *MTI = cast<MemTransferInst>(MI); - -  Constant *Src = dyn_cast<Constant>(MTI->getSource()); -  if (!Src) return -1; - -  GlobalVariable *GV = dyn_cast<GlobalVariable>(GetUnderlyingObject(Src, DL)); -  if (!GV || !GV->isConstant()) return -1; - -  // See if the access is within the bounds of the transfer. -  int Offset = AnalyzeLoadFromClobberingWrite(LoadTy, LoadPtr, -                                              MI->getDest(), MemSizeInBits, DL); -  if (Offset == -1) -    return Offset; - -  unsigned AS = Src->getType()->getPointerAddressSpace(); -  // Otherwise, see if we can constant fold a load from the constant with the -  // offset applied as appropriate. -  Src = ConstantExpr::getBitCast(Src, -                                 Type::getInt8PtrTy(Src->getContext(), AS)); -  Constant *OffsetCst = -    ConstantInt::get(Type::getInt64Ty(Src->getContext()), (unsigned)Offset); -  Src = ConstantExpr::getGetElementPtr(Type::getInt8Ty(Src->getContext()), Src, -                                       OffsetCst); -  Src = ConstantExpr::getBitCast(Src, PointerType::get(LoadTy, AS)); -  if (ConstantFoldLoadFromConstPtr(Src, LoadTy, DL)) -    return Offset; -  return -1; -} - - -/// This function is called when we have a -/// memdep query of a load that ends up being a clobbering store.  This means -/// that the store provides bits used by the load but we the pointers don't -/// mustalias.  Check this case to see if there is anything more we can do -/// before we give up. -static Value *GetStoreValueForLoad(Value *SrcVal, unsigned Offset, -                                   Type *LoadTy, -                                   Instruction *InsertPt, const DataLayout &DL){ -  LLVMContext &Ctx = SrcVal->getType()->getContext(); - -  uint64_t StoreSize = (DL.getTypeSizeInBits(SrcVal->getType()) + 7) / 8; -  uint64_t LoadSize = (DL.getTypeSizeInBits(LoadTy) + 7) / 8; - -  IRBuilder<> Builder(InsertPt); - -  // Compute which bits of the stored value are being used by the load.  Convert -  // to an integer type to start with. -  if (SrcVal->getType()->getScalarType()->isPointerTy()) -    SrcVal = Builder.CreatePtrToInt(SrcVal, -        DL.getIntPtrType(SrcVal->getType())); -  if (!SrcVal->getType()->isIntegerTy()) -    SrcVal = Builder.CreateBitCast(SrcVal, IntegerType::get(Ctx, StoreSize*8)); - -  // Shift the bits to the least significant depending on endianness. -  unsigned ShiftAmt; -  if (DL.isLittleEndian()) -    ShiftAmt = Offset*8; -  else -    ShiftAmt = (StoreSize-LoadSize-Offset)*8; - -  if (ShiftAmt) -    SrcVal = Builder.CreateLShr(SrcVal, ShiftAmt); - -  if (LoadSize != StoreSize) -    SrcVal = Builder.CreateTrunc(SrcVal, IntegerType::get(Ctx, LoadSize*8)); - -  return CoerceAvailableValueToLoadType(SrcVal, LoadTy, Builder, DL); -} - -/// This function is called when we have a -/// memdep query of a load that ends up being a clobbering load.  This means -/// that the load *may* provide bits used by the load but we can't be sure -/// because the pointers don't mustalias.  Check this case to see if there is -/// anything more we can do before we give up. -static Value *GetLoadValueForLoad(LoadInst *SrcVal, unsigned Offset, -                                  Type *LoadTy, Instruction *InsertPt, -                                  GVN &gvn) { -  const DataLayout &DL = SrcVal->getModule()->getDataLayout(); -  // If Offset+LoadTy exceeds the size of SrcVal, then we must be wanting to -  // widen SrcVal out to a larger load. -  unsigned SrcValStoreSize = DL.getTypeStoreSize(SrcVal->getType()); -  unsigned LoadSize = DL.getTypeStoreSize(LoadTy); -  if (Offset+LoadSize > SrcValStoreSize) { -    assert(SrcVal->isSimple() && "Cannot widen volatile/atomic load!"); -    assert(SrcVal->getType()->isIntegerTy() && "Can't widen non-integer load"); -    // If we have a load/load clobber an DepLI can be widened to cover this -    // load, then we should widen it to the next power of 2 size big enough! -    unsigned NewLoadSize = Offset+LoadSize; -    if (!isPowerOf2_32(NewLoadSize)) -      NewLoadSize = NextPowerOf2(NewLoadSize); - -    Value *PtrVal = SrcVal->getPointerOperand(); - -    // Insert the new load after the old load.  This ensures that subsequent -    // memdep queries will find the new load.  We can't easily remove the old -    // load completely because it is already in the value numbering table. -    IRBuilder<> Builder(SrcVal->getParent(), ++BasicBlock::iterator(SrcVal)); -    Type *DestPTy = -      IntegerType::get(LoadTy->getContext(), NewLoadSize*8); -    DestPTy = PointerType::get(DestPTy, -                               PtrVal->getType()->getPointerAddressSpace()); -    Builder.SetCurrentDebugLocation(SrcVal->getDebugLoc()); -    PtrVal = Builder.CreateBitCast(PtrVal, DestPTy); -    LoadInst *NewLoad = Builder.CreateLoad(PtrVal); -    NewLoad->takeName(SrcVal); -    NewLoad->setAlignment(SrcVal->getAlignment()); - -    DEBUG(dbgs() << "GVN WIDENED LOAD: " << *SrcVal << "\n"); -    DEBUG(dbgs() << "TO: " << *NewLoad << "\n"); - -    // Replace uses of the original load with the wider load.  On a big endian -    // system, we need to shift down to get the relevant bits. -    Value *RV = NewLoad; -    if (DL.isBigEndian()) -      RV = Builder.CreateLShr(RV, (NewLoadSize - SrcValStoreSize) * 8); -    RV = Builder.CreateTrunc(RV, SrcVal->getType()); -    SrcVal->replaceAllUsesWith(RV); - -    // We would like to use gvn.markInstructionForDeletion here, but we can't -    // because the load is already memoized into the leader map table that GVN -    // tracks.  It is potentially possible to remove the load from the table, -    // but then there all of the operations based on it would need to be -    // rehashed.  Just leave the dead load around. -    gvn.getMemDep().removeInstruction(SrcVal); -    SrcVal = NewLoad; -  } - -  return GetStoreValueForLoad(SrcVal, Offset, LoadTy, InsertPt, DL); -} - - -/// This function is called when we have a -/// memdep query of a load that ends up being a clobbering mem intrinsic. -static Value *GetMemInstValueForLoad(MemIntrinsic *SrcInst, unsigned Offset, -                                     Type *LoadTy, Instruction *InsertPt, -                                     const DataLayout &DL){ -  LLVMContext &Ctx = LoadTy->getContext(); -  uint64_t LoadSize = DL.getTypeSizeInBits(LoadTy)/8; - -  IRBuilder<> Builder(InsertPt); - -  // We know that this method is only called when the mem transfer fully -  // provides the bits for the load. -  if (MemSetInst *MSI = dyn_cast<MemSetInst>(SrcInst)) { -    // memset(P, 'x', 1234) -> splat('x'), even if x is a variable, and -    // independently of what the offset is. -    Value *Val = MSI->getValue(); -    if (LoadSize != 1) -      Val = Builder.CreateZExt(Val, IntegerType::get(Ctx, LoadSize*8)); - -    Value *OneElt = Val; - -    // Splat the value out to the right number of bits. -    for (unsigned NumBytesSet = 1; NumBytesSet != LoadSize; ) { -      // If we can double the number of bytes set, do it. -      if (NumBytesSet*2 <= LoadSize) { -        Value *ShVal = Builder.CreateShl(Val, NumBytesSet*8); -        Val = Builder.CreateOr(Val, ShVal); -        NumBytesSet <<= 1; -        continue; -      } - -      // Otherwise insert one byte at a time. -      Value *ShVal = Builder.CreateShl(Val, 1*8); -      Val = Builder.CreateOr(OneElt, ShVal); -      ++NumBytesSet; -    } - -    return CoerceAvailableValueToLoadType(Val, LoadTy, Builder, DL); -  } - -  // Otherwise, this is a memcpy/memmove from a constant global. -  MemTransferInst *MTI = cast<MemTransferInst>(SrcInst); -  Constant *Src = cast<Constant>(MTI->getSource()); -  unsigned AS = Src->getType()->getPointerAddressSpace(); - -  // Otherwise, see if we can constant fold a load from the constant with the -  // offset applied as appropriate. -  Src = ConstantExpr::getBitCast(Src, -                                 Type::getInt8PtrTy(Src->getContext(), AS)); -  Constant *OffsetCst = -    ConstantInt::get(Type::getInt64Ty(Src->getContext()), (unsigned)Offset); -  Src = ConstantExpr::getGetElementPtr(Type::getInt8Ty(Src->getContext()), Src, -                                       OffsetCst); -  Src = ConstantExpr::getBitCast(Src, PointerType::get(LoadTy, AS)); -  return ConstantFoldLoadFromConstPtr(Src, LoadTy, DL); -}  /// Given a set of loads specified by ValuesPerBlock, @@ -1171,7 +739,7 @@ Value *AvailableValue::MaterializeAdjustedValue(LoadInst *LI,    if (isSimpleValue()) {      Res = getSimpleValue();      if (Res->getType() != LoadTy) { -      Res = GetStoreValueForLoad(Res, Offset, LoadTy, InsertPt, DL); +      Res = getStoreValueForLoad(Res, Offset, LoadTy, InsertPt, DL);        DEBUG(dbgs() << "GVN COERCED NONLOCAL VAL:\nOffset: " << Offset << "  "                     << *getSimpleValue() << '\n' @@ -1182,14 +750,20 @@ Value *AvailableValue::MaterializeAdjustedValue(LoadInst *LI,      if (Load->getType() == LoadTy && Offset == 0) {        Res = Load;      } else { -      Res = GetLoadValueForLoad(Load, Offset, LoadTy, InsertPt, gvn); - +      Res = getLoadValueForLoad(Load, Offset, LoadTy, InsertPt, DL); +      // We would like to use gvn.markInstructionForDeletion here, but we can't +      // because the load is already memoized into the leader map table that GVN +      // tracks.  It is potentially possible to remove the load from the table, +      // but then there all of the operations based on it would need to be +      // rehashed.  Just leave the dead load around. +      gvn.getMemDep().removeInstruction(Load);        DEBUG(dbgs() << "GVN COERCED NONLOCAL LOAD:\nOffset: " << Offset << "  "                     << *getCoercedLoadValue() << '\n' -                   << *Res << '\n' << "\n\n\n"); +                   << *Res << '\n' +                   << "\n\n\n");      }    } else if (isMemIntrinValue()) { -    Res = GetMemInstValueForLoad(getMemIntrinValue(), Offset, LoadTy, +    Res = getMemInstValueForLoad(getMemIntrinValue(), Offset, LoadTy,                                   InsertPt, DL);      DEBUG(dbgs() << "GVN COERCED NONLOCAL MEM INTRIN:\nOffset: " << Offset                   << "  " << *getMemIntrinValue() << '\n' @@ -1258,7 +832,7 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo,        // Can't forward from non-atomic to atomic without violating memory model.        if (Address && LI->isAtomic() <= DepSI->isAtomic()) {          int Offset = -          AnalyzeLoadFromClobberingStore(LI->getType(), Address, DepSI); +          analyzeLoadFromClobberingStore(LI->getType(), Address, DepSI, DL);          if (Offset != -1) {            Res = AvailableValue::get(DepSI->getValueOperand(), Offset);            return true; @@ -1276,7 +850,7 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo,        // Can't forward from non-atomic to atomic without violating memory model.        if (DepLI != LI && Address && LI->isAtomic() <= DepLI->isAtomic()) {          int Offset = -          AnalyzeLoadFromClobberingLoad(LI->getType(), Address, DepLI, DL); +          analyzeLoadFromClobberingLoad(LI->getType(), Address, DepLI, DL);          if (Offset != -1) {            Res = AvailableValue::getLoad(DepLI, Offset); @@ -1289,7 +863,7 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo,      // forward a value on from it.      if (MemIntrinsic *DepMI = dyn_cast<MemIntrinsic>(DepInfo.getInst())) {        if (Address && !LI->isAtomic()) { -        int Offset = AnalyzeLoadFromClobberingMemInst(LI->getType(), Address, +        int Offset = analyzeLoadFromClobberingMemInst(LI->getType(), Address,                                                        DepMI, DL);          if (Offset != -1) {            Res = AvailableValue::getMI(DepMI, Offset); @@ -1334,7 +908,7 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo,      // different types if we have to. If the stored value is larger or equal to      // the loaded value, we can reuse it.      if (S->getValueOperand()->getType() != LI->getType() && -        !CanCoerceMustAliasedValueToLoad(S->getValueOperand(), +        !canCoerceMustAliasedValueToLoad(S->getValueOperand(),                                           LI->getType(), DL))        return false; @@ -1351,7 +925,7 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo,      // If the stored value is larger or equal to the loaded value, we can reuse      // it.      if (LD->getType() != LI->getType() && -        !CanCoerceMustAliasedValueToLoad(LD, LI->getType(), DL)) +        !canCoerceMustAliasedValueToLoad(LD, LI->getType(), DL))        return false;      // Can't forward from non-atomic to atomic without violating memory model. @@ -1713,7 +1287,7 @@ bool GVN::processNonLocalLoad(LoadInst *LI) {        // If instruction I has debug info, then we should not update it.        // Also, if I has a null DebugLoc, then it is still potentially incorrect        // to propagate LI's DebugLoc because LI may not post-dominate I. -      if (LI->getDebugLoc() && ValuesPerBlock.size() != 1) +      if (LI->getDebugLoc() && LI->getParent() == I->getParent())          I->setDebugLoc(LI->getDebugLoc());      if (V->getType()->getScalarType()->isPointerTy())        MD->invalidateCachedPointerInfo(V); @@ -1795,7 +1369,7 @@ static void patchReplacementInstruction(Instruction *I, Value *Repl) {    // Patch the replacement so that it is not more restrictive than the value    // being replaced. -  // Note that if 'I' is a load being replaced by some operation,  +  // Note that if 'I' is a load being replaced by some operation,    // for example, by an arithmetic operation, then andIRFlags()    // would just erase all math flags from the original arithmetic    // operation, which is clearly not wanted and not needed. @@ -2187,11 +1761,11 @@ bool GVN::processInstruction(Instruction *I) {      for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end();           i != e; ++i) { -      BasicBlock *Dst = i.getCaseSuccessor(); +      BasicBlock *Dst = i->getCaseSuccessor();        // If there is only a single edge, propagate the case value into it.        if (SwitchEdges.lookup(Dst) == 1) {          BasicBlockEdge E(Parent, Dst); -        Changed |= propagateEquality(SwitchCond, i.getCaseValue(), E, true); +        Changed |= propagateEquality(SwitchCond, i->getCaseValue(), E, true);        }      }      return Changed; @@ -2581,21 +2155,12 @@ bool GVN::iterateOnFunction(Function &F) {    // Top-down walk of the dominator tree    bool Changed = false; -  // Save the blocks this function have before transformation begins. GVN may -  // split critical edge, and hence may invalidate the RPO/DT iterator. -  // -  std::vector<BasicBlock *> BBVect; -  BBVect.reserve(256);    // Needed for value numbering with phi construction to work. +  // RPOT walks the graph in its constructor and will not be invalidated during +  // processBlock.    ReversePostOrderTraversal<Function *> RPOT(&F); -  for (ReversePostOrderTraversal<Function *>::rpo_iterator RI = RPOT.begin(), -                                                           RE = RPOT.end(); -       RI != RE; ++RI) -    BBVect.push_back(*RI); - -  for (std::vector<BasicBlock *>::iterator I = BBVect.begin(), E = BBVect.end(); -       I != E; I++) -    Changed |= processBlock(*I); +  for (BasicBlock *BB : RPOT) +    Changed |= processBlock(BB);    return Changed;  } @@ -2783,6 +2348,7 @@ public:      AU.addPreserved<DominatorTreeWrapperPass>();      AU.addPreserved<GlobalsAAWrapperPass>(); +    AU.addPreserved<TargetLibraryInfoWrapperPass>();      AU.addRequired<OptimizationRemarkEmitterWrapperPass>();    } diff --git a/lib/Transforms/Scalar/GVNHoist.cpp b/lib/Transforms/Scalar/GVNHoist.cpp index f8e1d2e1a08a..6adfe130d148 100644 --- a/lib/Transforms/Scalar/GVNHoist.cpp +++ b/lib/Transforms/Scalar/GVNHoist.cpp @@ -17,16 +17,39 @@  // is disabled in the following cases.  // 1. Scalars across calls.  // 2. geps when corresponding load/store cannot be hoisted. +// +// TODO: Hoist from >2 successors. Currently GVNHoist will not hoist stores +// in this case because it works on two instructions at a time. +// entry: +//   switch i32 %c1, label %exit1 [ +//     i32 0, label %sw0 +//     i32 1, label %sw1 +//   ] +// +// sw0: +//   store i32 1, i32* @G +//   br label %exit +// +// sw1: +//   store i32 1, i32* @G +//   br label %exit +// +// exit1: +//   store i32 1, i32* @G +//   ret void +// exit: +//   ret void  //===----------------------------------------------------------------------===//  #include "llvm/Transforms/Scalar/GVN.h"  #include "llvm/ADT/DenseMap.h"  #include "llvm/ADT/SmallPtrSet.h"  #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h"  #include "llvm/Analysis/ValueTracking.h"  #include "llvm/Transforms/Scalar.h"  #include "llvm/Transforms/Utils/Local.h" -#include "llvm/Transforms/Utils/MemorySSA.h"  using namespace llvm; @@ -60,7 +83,7 @@ static cl::opt<int>                     cl::desc("Maximum length of dependent chains to hoist "                              "(default = 10, unlimited = -1)")); -namespace { +namespace llvm {  // Provides a sorting function based on the execution order of two instructions.  struct SortByDFSIn { @@ -72,13 +95,6 @@ public:    // Returns true when A executes before B.    bool operator()(const Instruction *A, const Instruction *B) const { -    // FIXME: libc++ has a std::sort() algorithm that will call the compare -    // function on the same element.  Once PR20837 is fixed and some more years -    // pass by and all the buildbots have moved to a corrected std::sort(), -    // enable the following assert: -    // -    // assert(A != B); -      const BasicBlock *BA = A->getParent();      const BasicBlock *BB = B->getParent();      unsigned ADFS, BDFS; @@ -202,6 +218,7 @@ public:    GVNHoist(DominatorTree *DT, AliasAnalysis *AA, MemoryDependenceResults *MD,             MemorySSA *MSSA)        : DT(DT), AA(AA), MD(MD), MSSA(MSSA), +        MSSAUpdater(make_unique<MemorySSAUpdater>(MSSA)),          HoistingGeps(false),          HoistedCtr(0)    { } @@ -249,9 +266,11 @@ private:    AliasAnalysis *AA;    MemoryDependenceResults *MD;    MemorySSA *MSSA; +  std::unique_ptr<MemorySSAUpdater> MSSAUpdater;    const bool HoistingGeps;    DenseMap<const Value *, unsigned> DFSNumber;    BBSideEffectsSet BBSideEffects; +  DenseSet<const BasicBlock*> HoistBarrier;    int HoistedCtr;    enum InsKind { Unknown, Scalar, Load, Store }; @@ -307,8 +326,8 @@ private:          continue;        } -      // Check for end of function, calls that do not return, etc. -      if (!isGuaranteedToTransferExecutionToSuccessor(BB->getTerminator())) +      // We reached the leaf Basic Block => not all paths have this instruction. +      if (!BB->getTerminator()->getNumSuccessors())          return false;        // When reaching the back-edge of a loop, there may be a path through the @@ -360,7 +379,7 @@ private:              ReachedNewPt = true;            }          } -        if (defClobbersUseOrDef(Def, MU, *AA)) +        if (MemorySSAUtil::defClobbersUseOrDef(Def, MU, *AA))            return true;        } @@ -387,7 +406,8 @@ private:      // executed between the execution of NewBB and OldBB. Hoisting an expression      // from OldBB into NewBB has to be safe on all execution paths.      for (auto I = idf_begin(OldBB), E = idf_end(OldBB); I != E;) { -      if (*I == NewBB) { +      const BasicBlock *BB = *I; +      if (BB == NewBB) {          // Stop traversal when reaching HoistPt.          I.skipChildren();          continue; @@ -398,11 +418,17 @@ private:          return true;        // Impossible to hoist with exceptions on the path. -      if (hasEH(*I)) +      if (hasEH(BB)) +        return true; + +      // No such instruction after HoistBarrier in a basic block was +      // selected for hoisting so instructions selected within basic block with +      // a hoist barrier can be hoisted. +      if ((BB != OldBB) && HoistBarrier.count(BB))          return true;        // Check that we do not move a store past loads. -      if (hasMemoryUse(NewPt, Def, *I)) +      if (hasMemoryUse(NewPt, Def, BB))          return true;        // -1 is unlimited number of blocks on all paths. @@ -419,17 +445,18 @@ private:    // Decrement by 1 NBBsOnAllPaths for each block between HoistPt and BB, and    // return true when the counter NBBsOnAllPaths reaches 0, except when it is    // initialized to -1 which is unlimited. -  bool hasEHOnPath(const BasicBlock *HoistPt, const BasicBlock *BB, +  bool hasEHOnPath(const BasicBlock *HoistPt, const BasicBlock *SrcBB,                     int &NBBsOnAllPaths) { -    assert(DT->dominates(HoistPt, BB) && "Invalid path"); +    assert(DT->dominates(HoistPt, SrcBB) && "Invalid path");      // Walk all basic blocks reachable in depth-first iteration on      // the inverse CFG from BBInsn to NewHoistPt. These blocks are all the      // blocks that may be executed between the execution of NewHoistPt and      // BBInsn. Hoisting an expression from BBInsn into NewHoistPt has to be safe      // on all execution paths. -    for (auto I = idf_begin(BB), E = idf_end(BB); I != E;) { -      if (*I == HoistPt) { +    for (auto I = idf_begin(SrcBB), E = idf_end(SrcBB); I != E;) { +      const BasicBlock *BB = *I; +      if (BB == HoistPt) {          // Stop traversal when reaching NewHoistPt.          I.skipChildren();          continue; @@ -440,7 +467,13 @@ private:          return true;        // Impossible to hoist with exceptions on the path. -      if (hasEH(*I)) +      if (hasEH(BB)) +        return true; + +      // No such instruction after HoistBarrier in a basic block was +      // selected for hoisting so instructions selected within basic block with +      // a hoist barrier can be hoisted. +      if ((BB != SrcBB) && HoistBarrier.count(BB))          return true;        // -1 is unlimited number of blocks on all paths. @@ -626,6 +659,8 @@ private:        // Compute the insertion point and the list of expressions to be hoisted.        SmallVecInsn InstructionsToHoist;        for (auto I : V) +        // We don't need to check for hoist-barriers here because if +        // I->getParent() is a barrier then I precedes the barrier.          if (!hasEH(I->getParent()))            InstructionsToHoist.push_back(I); @@ -809,9 +844,9 @@ private:            // legal when the ld/st is not moved past its current definition.            MemoryAccess *Def = OldMemAcc->getDefiningAccess();            NewMemAcc = -              MSSA->createMemoryAccessInBB(Repl, Def, HoistPt, MemorySSA::End); +            MSSAUpdater->createMemoryAccessInBB(Repl, Def, HoistPt, MemorySSA::End);            OldMemAcc->replaceAllUsesWith(NewMemAcc); -          MSSA->removeMemoryAccess(OldMemAcc); +          MSSAUpdater->removeMemoryAccess(OldMemAcc);          }        } @@ -850,7 +885,7 @@ private:              // Update the uses of the old MSSA access with NewMemAcc.              MemoryAccess *OldMA = MSSA->getMemoryAccess(I);              OldMA->replaceAllUsesWith(NewMemAcc); -            MSSA->removeMemoryAccess(OldMA); +            MSSAUpdater->removeMemoryAccess(OldMA);            }            Repl->andIRFlags(I); @@ -872,7 +907,7 @@ private:            auto In = Phi->incoming_values();            if (all_of(In, [&](Use &U) { return U == NewMemAcc; })) {              Phi->replaceAllUsesWith(NewMemAcc); -            MSSA->removeMemoryAccess(Phi); +            MSSAUpdater->removeMemoryAccess(Phi);            }          }        } @@ -896,6 +931,12 @@ private:      for (BasicBlock *BB : depth_first(&F.getEntryBlock())) {        int InstructionNb = 0;        for (Instruction &I1 : *BB) { +        // If I1 cannot guarantee progress, subsequent instructions +        // in BB cannot be hoisted anyways. +        if (!isGuaranteedToTransferExecutionToSuccessor(&I1)) { +           HoistBarrier.insert(BB); +           break; +        }          // Only hoist the first instructions in BB up to MaxDepthInBB. Hoisting          // deeper may increase the register pressure and compilation time.          if (MaxDepthInBB != -1 && InstructionNb++ >= MaxDepthInBB) diff --git a/lib/Transforms/Scalar/GuardWidening.cpp b/lib/Transforms/Scalar/GuardWidening.cpp index b05ef002a456..7019287954a1 100644 --- a/lib/Transforms/Scalar/GuardWidening.cpp +++ b/lib/Transforms/Scalar/GuardWidening.cpp @@ -568,8 +568,7 @@ bool GuardWideningImpl::combineRangeChecks(        return RC.getBase() == CurrentBase && RC.getLength() == CurrentLength;      }; -    std::copy_if(Checks.begin(), Checks.end(), -                 std::back_inserter(CurrentChecks), IsCurrentCheck); +    copy_if(Checks, std::back_inserter(CurrentChecks), IsCurrentCheck);      Checks.erase(remove_if(Checks, IsCurrentCheck), Checks.end());      assert(CurrentChecks.size() != 0 && "We know we have at least one!"); @@ -658,8 +657,12 @@ PreservedAnalyses GuardWideningPass::run(Function &F,    auto &DT = AM.getResult<DominatorTreeAnalysis>(F);    auto &LI = AM.getResult<LoopAnalysis>(F);    auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); -  bool Changed = GuardWideningImpl(DT, PDT, LI).run(); -  return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); +  if (!GuardWideningImpl(DT, PDT, LI).run()) +    return PreservedAnalyses::all(); + +  PreservedAnalyses PA; +  PA.preserveSet<CFGAnalyses>(); +  return PA;  }  StringRef GuardWideningImpl::scoreTypeToString(WideningScore WS) { diff --git a/lib/Transforms/Scalar/IndVarSimplify.cpp b/lib/Transforms/Scalar/IndVarSimplify.cpp index 1752fb75eb1b..dcb2a4a0c6e6 100644 --- a/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -231,8 +231,9 @@ static bool ConvertToSInt(const APFloat &APF, int64_t &IntVal) {    bool isExact = false;    // See if we can convert this to an int64_t    uint64_t UIntVal; -  if (APF.convertToInteger(&UIntVal, 64, true, APFloat::rmTowardZero, -                           &isExact) != APFloat::opOK || !isExact) +  if (APF.convertToInteger(makeMutableArrayRef(UIntVal), 64, true, +                           APFloat::rmTowardZero, &isExact) != APFloat::opOK || +      !isExact)      return false;    IntVal = UIntVal;    return true; @@ -906,7 +907,7 @@ class WidenIV {    SmallVector<NarrowIVDefUse, 8> NarrowIVUsers;    enum ExtendKind { ZeroExtended, SignExtended, Unknown }; -  // A map tracking the kind of extension used to widen each narrow IV  +  // A map tracking the kind of extension used to widen each narrow IV    // and narrow IV user.    // Key: pointer to a narrow IV or IV user.    // Value: the kind of extension used to widen this Instruction. @@ -1608,7 +1609,7 @@ void WidenIV::calculatePostIncRange(Instruction *NarrowDef,        return;      CmpInst::Predicate P = -            TrueDest ? Pred : CmpInst::getInversePredicate(Pred);   +            TrueDest ? Pred : CmpInst::getInversePredicate(Pred);      auto CmpRHSRange = SE->getSignedRange(SE->getSCEV(CmpRHS));      auto CmpConstrainedLHSRange = @@ -1634,7 +1635,7 @@ void WidenIV::calculatePostIncRange(Instruction *NarrowDef,    UpdateRangeFromGuards(NarrowUser);    BasicBlock *NarrowUserBB = NarrowUser->getParent(); -  // If NarrowUserBB is statically unreachable asking dominator queries may  +  // If NarrowUserBB is statically unreachable asking dominator queries may    // yield surprising results. (e.g. the block may not have a dom tree node)    if (!DT->isReachableFromEntry(NarrowUserBB))      return; @@ -2152,6 +2153,8 @@ linearFunctionTestReplace(Loop *L,    Value *CmpIndVar = IndVar;    const SCEV *IVCount = BackedgeTakenCount; +  assert(L->getLoopLatch() && "Loop no longer in simplified form?"); +    // If the exiting block is the same as the backedge block, we prefer to    // compare against the post-incremented value, otherwise we must compare    // against the preincremented value. @@ -2376,6 +2379,7 @@ bool IndVarSimplify::run(Loop *L) {    //    Loop::getCanonicalInductionVariable only supports loops with preheaders,    //    and we're in trouble if we can't find the induction variable even when    //    we've manually inserted one. +  //  - LFTR relies on having a single backedge.    if (!L->isLoopSimplifyForm())      return false; @@ -2492,8 +2496,9 @@ PreservedAnalyses IndVarSimplifyPass::run(Loop &L, LoopAnalysisManager &AM,    if (!IVS.run(&L))      return PreservedAnalyses::all(); -  // FIXME: This should also 'preserve the CFG'. -  return getLoopPassPreservedAnalyses(); +  auto PA = getLoopPassPreservedAnalyses(); +  PA.preserveSet<CFGAnalyses>(); +  return PA;  }  namespace { diff --git a/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index 8e81541c2337..85db6e5e1105 100644 --- a/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -446,6 +446,15 @@ struct LoopStructure {    BasicBlock *LatchExit;    unsigned LatchBrExitIdx; +  // The loop represented by this instance of LoopStructure is semantically +  // equivalent to: +  // +  // intN_ty inc = IndVarIncreasing ? 1 : -1; +  // pred_ty predicate = IndVarIncreasing ? ICMP_SLT : ICMP_SGT; +  // +  // for (intN_ty iv = IndVarStart; predicate(iv, LoopExitAt); iv = IndVarNext) +  //   ... body ... +    Value *IndVarNext;    Value *IndVarStart;    Value *LoopExitAt; @@ -789,6 +798,10 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP      return None;    } +  const SCEV *StartNext = IndVarNext->getStart(); +  const SCEV *Addend = SE.getNegativeSCEV(IndVarNext->getStepRecurrence(SE)); +  const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend); +    ConstantInt *One = ConstantInt::get(IndVarTy, 1);    // TODO: generalize the predicates here to also match their unsigned variants.    if (IsIncreasing) { @@ -809,10 +822,22 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP          return None;        } +      if (!SE.isLoopEntryGuardedByCond( +              &L, CmpInst::ICMP_SLT, IndVarStart, +              SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())))) { +        FailureReason = "Induction variable start not bounded by upper limit"; +        return None; +      } +        IRBuilder<> B(Preheader->getTerminator());        RightValue = B.CreateAdd(RightValue, One); +    } else { +      if (!SE.isLoopEntryGuardedByCond(&L, CmpInst::ICMP_SLT, IndVarStart, +                                       RightSCEV)) { +        FailureReason = "Induction variable start not bounded by upper limit"; +        return None; +      }      } -    } else {      bool FoundExpectedPred =          (Pred == ICmpInst::ICMP_SGT && LatchBrExitIdx == 1) || @@ -831,15 +856,24 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP          return None;        } +      if (!SE.isLoopEntryGuardedByCond( +              &L, CmpInst::ICMP_SGT, IndVarStart, +              SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())))) { +        FailureReason = "Induction variable start not bounded by lower limit"; +        return None; +      } +        IRBuilder<> B(Preheader->getTerminator());        RightValue = B.CreateSub(RightValue, One); +    } else { +      if (!SE.isLoopEntryGuardedByCond(&L, CmpInst::ICMP_SGT, IndVarStart, +                                       RightSCEV)) { +        FailureReason = "Induction variable start not bounded by lower limit"; +        return None; +      }      }    } -  const SCEV *StartNext = IndVarNext->getStart(); -  const SCEV *Addend = SE.getNegativeSCEV(IndVarNext->getStepRecurrence(SE)); -  const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend); -    BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx);    assert(SE.getLoopDisposition(LatchCount, &L) == diff --git a/lib/Transforms/Scalar/InferAddressSpaces.cpp b/lib/Transforms/Scalar/InferAddressSpaces.cpp new file mode 100644 index 000000000000..5d8701431a2c --- /dev/null +++ b/lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -0,0 +1,903 @@ +//===-- NVPTXInferAddressSpace.cpp - ---------------------*- C++ -*-===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// CUDA C/C++ includes memory space designation as variable type qualifers (such +// as __global__ and __shared__). Knowing the space of a memory access allows +// CUDA compilers to emit faster PTX loads and stores. For example, a load from +// shared memory can be translated to `ld.shared` which is roughly 10% faster +// than a generic `ld` on an NVIDIA Tesla K40c. +// +// Unfortunately, type qualifiers only apply to variable declarations, so CUDA +// compilers must infer the memory space of an address expression from +// type-qualified variables. +// +// LLVM IR uses non-zero (so-called) specific address spaces to represent memory +// spaces (e.g. addrspace(3) means shared memory). The Clang frontend +// places only type-qualified variables in specific address spaces, and then +// conservatively `addrspacecast`s each type-qualified variable to addrspace(0) +// (so-called the generic address space) for other instructions to use. +// +// For example, the Clang translates the following CUDA code +//   __shared__ float a[10]; +//   float v = a[i]; +// to +//   %0 = addrspacecast [10 x float] addrspace(3)* @a to [10 x float]* +//   %1 = gep [10 x float], [10 x float]* %0, i64 0, i64 %i +//   %v = load float, float* %1 ; emits ld.f32 +// @a is in addrspace(3) since it's type-qualified, but its use from %1 is +// redirected to %0 (the generic version of @a). +// +// The optimization implemented in this file propagates specific address spaces +// from type-qualified variable declarations to its users. For example, it +// optimizes the above IR to +//   %1 = gep [10 x float] addrspace(3)* @a, i64 0, i64 %i +//   %v = load float addrspace(3)* %1 ; emits ld.shared.f32 +// propagating the addrspace(3) from @a to %1. As the result, the NVPTX +// codegen is able to emit ld.shared.f32 for %v. +// +// Address space inference works in two steps. First, it uses a data-flow +// analysis to infer as many generic pointers as possible to point to only one +// specific address space. In the above example, it can prove that %1 only +// points to addrspace(3). This algorithm was published in +//   CUDA: Compiling and optimizing for a GPU platform +//   Chakrabarti, Grover, Aarts, Kong, Kudlur, Lin, Marathe, Murphy, Wang +//   ICCS 2012 +// +// Then, address space inference replaces all refinable generic pointers with +// equivalent specific pointers. +// +// The major challenge of implementing this optimization is handling PHINodes, +// which may create loops in the data flow graph. This brings two complications. +// +// First, the data flow analysis in Step 1 needs to be circular. For example, +//     %generic.input = addrspacecast float addrspace(3)* %input to float* +//   loop: +//     %y = phi [ %generic.input, %y2 ] +//     %y2 = getelementptr %y, 1 +//     %v = load %y2 +//     br ..., label %loop, ... +// proving %y specific requires proving both %generic.input and %y2 specific, +// but proving %y2 specific circles back to %y. To address this complication, +// the data flow analysis operates on a lattice: +//   uninitialized > specific address spaces > generic. +// All address expressions (our implementation only considers phi, bitcast, +// addrspacecast, and getelementptr) start with the uninitialized address space. +// The monotone transfer function moves the address space of a pointer down a +// lattice path from uninitialized to specific and then to generic. A join +// operation of two different specific address spaces pushes the expression down +// to the generic address space. The analysis completes once it reaches a fixed +// point. +// +// Second, IR rewriting in Step 2 also needs to be circular. For example, +// converting %y to addrspace(3) requires the compiler to know the converted +// %y2, but converting %y2 needs the converted %y. To address this complication, +// we break these cycles using "undef" placeholders. When converting an +// instruction `I` to a new address space, if its operand `Op` is not converted +// yet, we let `I` temporarily use `undef` and fix all the uses of undef later. +// For instance, our algorithm first converts %y to +//   %y' = phi float addrspace(3)* [ %input, undef ] +// Then, it converts %y2 to +//   %y2' = getelementptr %y', 1 +// Finally, it fixes the undef in %y' so that +//   %y' = phi float addrspace(3)* [ %input, %y2' ] +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Operator.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/ValueMapper.h" + +#define DEBUG_TYPE "infer-address-spaces" + +using namespace llvm; + +namespace { +static const unsigned UninitializedAddressSpace = ~0u; + +using ValueToAddrSpaceMapTy = DenseMap<const Value *, unsigned>; + +/// \brief InferAddressSpaces +class InferAddressSpaces : public FunctionPass { +  /// Target specific address space which uses of should be replaced if +  /// possible. +  unsigned FlatAddrSpace; + +public: +  static char ID; + +  InferAddressSpaces() : FunctionPass(ID) {} + +  void getAnalysisUsage(AnalysisUsage &AU) const override { +    AU.setPreservesCFG(); +    AU.addRequired<TargetTransformInfoWrapperPass>(); +  } + +  bool runOnFunction(Function &F) override; + +private: +  // Returns the new address space of V if updated; otherwise, returns None. +  Optional<unsigned> +  updateAddressSpace(const Value &V, +                     const ValueToAddrSpaceMapTy &InferredAddrSpace) const; + +  // Tries to infer the specific address space of each address expression in +  // Postorder. +  void inferAddressSpaces(const std::vector<Value *> &Postorder, +                          ValueToAddrSpaceMapTy *InferredAddrSpace) const; + +  bool isSafeToCastConstAddrSpace(Constant *C, unsigned NewAS) const; + +  // Changes the flat address expressions in function F to point to specific +  // address spaces if InferredAddrSpace says so. Postorder is the postorder of +  // all flat expressions in the use-def graph of function F. +  bool +  rewriteWithNewAddressSpaces(const std::vector<Value *> &Postorder, +                              const ValueToAddrSpaceMapTy &InferredAddrSpace, +                              Function *F) const; + +  void appendsFlatAddressExpressionToPostorderStack( +    Value *V, std::vector<std::pair<Value *, bool>> *PostorderStack, +    DenseSet<Value *> *Visited) const; + +  bool rewriteIntrinsicOperands(IntrinsicInst *II, +                                Value *OldV, Value *NewV) const; +  void collectRewritableIntrinsicOperands( +    IntrinsicInst *II, +    std::vector<std::pair<Value *, bool>> *PostorderStack, +    DenseSet<Value *> *Visited) const; + +  std::vector<Value *> collectFlatAddressExpressions(Function &F) const; + +  Value *cloneValueWithNewAddressSpace( +    Value *V, unsigned NewAddrSpace, +    const ValueToValueMapTy &ValueWithNewAddrSpace, +    SmallVectorImpl<const Use *> *UndefUsesToFix) const; +  unsigned joinAddressSpaces(unsigned AS1, unsigned AS2) const; +}; +} // end anonymous namespace + +char InferAddressSpaces::ID = 0; + +namespace llvm { +void initializeInferAddressSpacesPass(PassRegistry &); +} + +INITIALIZE_PASS(InferAddressSpaces, DEBUG_TYPE, "Infer address spaces", +                false, false) + +// Returns true if V is an address expression. +// TODO: Currently, we consider only phi, bitcast, addrspacecast, and +// getelementptr operators. +static bool isAddressExpression(const Value &V) { +  if (!isa<Operator>(V)) +    return false; + +  switch (cast<Operator>(V).getOpcode()) { +  case Instruction::PHI: +  case Instruction::BitCast: +  case Instruction::AddrSpaceCast: +  case Instruction::GetElementPtr: +  case Instruction::Select: +    return true; +  default: +    return false; +  } +} + +// Returns the pointer operands of V. +// +// Precondition: V is an address expression. +static SmallVector<Value *, 2> getPointerOperands(const Value &V) { +  assert(isAddressExpression(V)); +  const Operator &Op = cast<Operator>(V); +  switch (Op.getOpcode()) { +  case Instruction::PHI: { +    auto IncomingValues = cast<PHINode>(Op).incoming_values(); +    return SmallVector<Value *, 2>(IncomingValues.begin(), +                                   IncomingValues.end()); +  } +  case Instruction::BitCast: +  case Instruction::AddrSpaceCast: +  case Instruction::GetElementPtr: +    return {Op.getOperand(0)}; +  case Instruction::Select: +    return {Op.getOperand(1), Op.getOperand(2)}; +  default: +    llvm_unreachable("Unexpected instruction type."); +  } +} + +// TODO: Move logic to TTI? +bool InferAddressSpaces::rewriteIntrinsicOperands(IntrinsicInst *II, +                                                  Value *OldV, +                                                  Value *NewV) const { +  Module *M = II->getParent()->getParent()->getParent(); + +  switch (II->getIntrinsicID()) { +  case Intrinsic::amdgcn_atomic_inc: +  case Intrinsic::amdgcn_atomic_dec:{ +    const ConstantInt *IsVolatile = dyn_cast<ConstantInt>(II->getArgOperand(4)); +    if (!IsVolatile || !IsVolatile->isNullValue()) +      return false; + +    LLVM_FALLTHROUGH; +  } +  case Intrinsic::objectsize: { +    Type *DestTy = II->getType(); +    Type *SrcTy = NewV->getType(); +    Function *NewDecl = +        Intrinsic::getDeclaration(M, II->getIntrinsicID(), {DestTy, SrcTy}); +    II->setArgOperand(0, NewV); +    II->setCalledFunction(NewDecl); +    return true; +  } +  default: +    return false; +  } +} + +// TODO: Move logic to TTI? +void InferAddressSpaces::collectRewritableIntrinsicOperands( +    IntrinsicInst *II, std::vector<std::pair<Value *, bool>> *PostorderStack, +    DenseSet<Value *> *Visited) const { +  switch (II->getIntrinsicID()) { +  case Intrinsic::objectsize: +  case Intrinsic::amdgcn_atomic_inc: +  case Intrinsic::amdgcn_atomic_dec: +    appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0), +                                                 PostorderStack, Visited); +    break; +  default: +    break; +  } +} + +// Returns all flat address expressions in function F. The elements are +// If V is an unvisited flat address expression, appends V to PostorderStack +// and marks it as visited. +void InferAddressSpaces::appendsFlatAddressExpressionToPostorderStack( +    Value *V, std::vector<std::pair<Value *, bool>> *PostorderStack, +    DenseSet<Value *> *Visited) const { +  assert(V->getType()->isPointerTy()); +  if (isAddressExpression(*V) && +      V->getType()->getPointerAddressSpace() == FlatAddrSpace) { +    if (Visited->insert(V).second) +      PostorderStack->push_back(std::make_pair(V, false)); +  } +} + +// Returns all flat address expressions in function F. The elements are ordered +// ordered in postorder. +std::vector<Value *> +InferAddressSpaces::collectFlatAddressExpressions(Function &F) const { +  // This function implements a non-recursive postorder traversal of a partial +  // use-def graph of function F. +  std::vector<std::pair<Value *, bool>> PostorderStack; +  // The set of visited expressions. +  DenseSet<Value *> Visited; + +  auto PushPtrOperand = [&](Value *Ptr) { +    appendsFlatAddressExpressionToPostorderStack(Ptr, &PostorderStack, +                                                 &Visited); +  }; + +  // We only explore address expressions that are reachable from loads and +  // stores for now because we aim at generating faster loads and stores. +  for (Instruction &I : instructions(F)) { +    if (auto *LI = dyn_cast<LoadInst>(&I)) +      PushPtrOperand(LI->getPointerOperand()); +    else if (auto *SI = dyn_cast<StoreInst>(&I)) +      PushPtrOperand(SI->getPointerOperand()); +    else if (auto *RMW = dyn_cast<AtomicRMWInst>(&I)) +      PushPtrOperand(RMW->getPointerOperand()); +    else if (auto *CmpX = dyn_cast<AtomicCmpXchgInst>(&I)) +      PushPtrOperand(CmpX->getPointerOperand()); +    else if (auto *MI = dyn_cast<MemIntrinsic>(&I)) { +      // For memset/memcpy/memmove, any pointer operand can be replaced. +      PushPtrOperand(MI->getRawDest()); + +      // Handle 2nd operand for memcpy/memmove. +      if (auto *MTI = dyn_cast<MemTransferInst>(MI)) +        PushPtrOperand(MTI->getRawSource()); +    } else if (auto *II = dyn_cast<IntrinsicInst>(&I)) +      collectRewritableIntrinsicOperands(II, &PostorderStack, &Visited); +    else if (ICmpInst *Cmp = dyn_cast<ICmpInst>(&I)) { +      // FIXME: Handle vectors of pointers +      if (Cmp->getOperand(0)->getType()->isPointerTy()) { +        PushPtrOperand(Cmp->getOperand(0)); +        PushPtrOperand(Cmp->getOperand(1)); +      } +    } +  } + +  std::vector<Value *> Postorder; // The resultant postorder. +  while (!PostorderStack.empty()) { +    // If the operands of the expression on the top are already explored, +    // adds that expression to the resultant postorder. +    if (PostorderStack.back().second) { +      Postorder.push_back(PostorderStack.back().first); +      PostorderStack.pop_back(); +      continue; +    } +    // Otherwise, adds its operands to the stack and explores them. +    PostorderStack.back().second = true; +    for (Value *PtrOperand : getPointerOperands(*PostorderStack.back().first)) { +      appendsFlatAddressExpressionToPostorderStack(PtrOperand, &PostorderStack, +                                                   &Visited); +    } +  } +  return Postorder; +} + +// A helper function for cloneInstructionWithNewAddressSpace. Returns the clone +// of OperandUse.get() in the new address space. If the clone is not ready yet, +// returns an undef in the new address space as a placeholder. +static Value *operandWithNewAddressSpaceOrCreateUndef( +    const Use &OperandUse, unsigned NewAddrSpace, +    const ValueToValueMapTy &ValueWithNewAddrSpace, +    SmallVectorImpl<const Use *> *UndefUsesToFix) { +  Value *Operand = OperandUse.get(); + +  Type *NewPtrTy = +      Operand->getType()->getPointerElementType()->getPointerTo(NewAddrSpace); + +  if (Constant *C = dyn_cast<Constant>(Operand)) +    return ConstantExpr::getAddrSpaceCast(C, NewPtrTy); + +  if (Value *NewOperand = ValueWithNewAddrSpace.lookup(Operand)) +    return NewOperand; + +  UndefUsesToFix->push_back(&OperandUse); +  return UndefValue::get(NewPtrTy); +} + +// Returns a clone of `I` with its operands converted to those specified in +// ValueWithNewAddrSpace. Due to potential cycles in the data flow graph, an +// operand whose address space needs to be modified might not exist in +// ValueWithNewAddrSpace. In that case, uses undef as a placeholder operand and +// adds that operand use to UndefUsesToFix so that caller can fix them later. +// +// Note that we do not necessarily clone `I`, e.g., if it is an addrspacecast +// from a pointer whose type already matches. Therefore, this function returns a +// Value* instead of an Instruction*. +static Value *cloneInstructionWithNewAddressSpace( +    Instruction *I, unsigned NewAddrSpace, +    const ValueToValueMapTy &ValueWithNewAddrSpace, +    SmallVectorImpl<const Use *> *UndefUsesToFix) { +  Type *NewPtrType = +      I->getType()->getPointerElementType()->getPointerTo(NewAddrSpace); + +  if (I->getOpcode() == Instruction::AddrSpaceCast) { +    Value *Src = I->getOperand(0); +    // Because `I` is flat, the source address space must be specific. +    // Therefore, the inferred address space must be the source space, according +    // to our algorithm. +    assert(Src->getType()->getPointerAddressSpace() == NewAddrSpace); +    if (Src->getType() != NewPtrType) +      return new BitCastInst(Src, NewPtrType); +    return Src; +  } + +  // Computes the converted pointer operands. +  SmallVector<Value *, 4> NewPointerOperands; +  for (const Use &OperandUse : I->operands()) { +    if (!OperandUse.get()->getType()->isPointerTy()) +      NewPointerOperands.push_back(nullptr); +    else +      NewPointerOperands.push_back(operandWithNewAddressSpaceOrCreateUndef( +                                     OperandUse, NewAddrSpace, ValueWithNewAddrSpace, UndefUsesToFix)); +  } + +  switch (I->getOpcode()) { +  case Instruction::BitCast: +    return new BitCastInst(NewPointerOperands[0], NewPtrType); +  case Instruction::PHI: { +    assert(I->getType()->isPointerTy()); +    PHINode *PHI = cast<PHINode>(I); +    PHINode *NewPHI = PHINode::Create(NewPtrType, PHI->getNumIncomingValues()); +    for (unsigned Index = 0; Index < PHI->getNumIncomingValues(); ++Index) { +      unsigned OperandNo = PHINode::getOperandNumForIncomingValue(Index); +      NewPHI->addIncoming(NewPointerOperands[OperandNo], +                          PHI->getIncomingBlock(Index)); +    } +    return NewPHI; +  } +  case Instruction::GetElementPtr: { +    GetElementPtrInst *GEP = cast<GetElementPtrInst>(I); +    GetElementPtrInst *NewGEP = GetElementPtrInst::Create( +        GEP->getSourceElementType(), NewPointerOperands[0], +        SmallVector<Value *, 4>(GEP->idx_begin(), GEP->idx_end())); +    NewGEP->setIsInBounds(GEP->isInBounds()); +    return NewGEP; +  } +  case Instruction::Select: { +    assert(I->getType()->isPointerTy()); +    return SelectInst::Create(I->getOperand(0), NewPointerOperands[1], +                              NewPointerOperands[2], "", nullptr, I); +  } +  default: +    llvm_unreachable("Unexpected opcode"); +  } +} + +// Similar to cloneInstructionWithNewAddressSpace, returns a clone of the +// constant expression `CE` with its operands replaced as specified in +// ValueWithNewAddrSpace. +static Value *cloneConstantExprWithNewAddressSpace( +  ConstantExpr *CE, unsigned NewAddrSpace, +  const ValueToValueMapTy &ValueWithNewAddrSpace) { +  Type *TargetType = +    CE->getType()->getPointerElementType()->getPointerTo(NewAddrSpace); + +  if (CE->getOpcode() == Instruction::AddrSpaceCast) { +    // Because CE is flat, the source address space must be specific. +    // Therefore, the inferred address space must be the source space according +    // to our algorithm. +    assert(CE->getOperand(0)->getType()->getPointerAddressSpace() == +           NewAddrSpace); +    return ConstantExpr::getBitCast(CE->getOperand(0), TargetType); +  } + +  if (CE->getOpcode() == Instruction::BitCast) { +    if (Value *NewOperand = ValueWithNewAddrSpace.lookup(CE->getOperand(0))) +      return ConstantExpr::getBitCast(cast<Constant>(NewOperand), TargetType); +    return ConstantExpr::getAddrSpaceCast(CE, TargetType); +  } + +  if (CE->getOpcode() == Instruction::Select) { +    Constant *Src0 = CE->getOperand(1); +    Constant *Src1 = CE->getOperand(2); +    if (Src0->getType()->getPointerAddressSpace() == +        Src1->getType()->getPointerAddressSpace()) { + +      return ConstantExpr::getSelect( +          CE->getOperand(0), ConstantExpr::getAddrSpaceCast(Src0, TargetType), +          ConstantExpr::getAddrSpaceCast(Src1, TargetType)); +    } +  } + +  // Computes the operands of the new constant expression. +  SmallVector<Constant *, 4> NewOperands; +  for (unsigned Index = 0; Index < CE->getNumOperands(); ++Index) { +    Constant *Operand = CE->getOperand(Index); +    // If the address space of `Operand` needs to be modified, the new operand +    // with the new address space should already be in ValueWithNewAddrSpace +    // because (1) the constant expressions we consider (i.e. addrspacecast, +    // bitcast, and getelementptr) do not incur cycles in the data flow graph +    // and (2) this function is called on constant expressions in postorder. +    if (Value *NewOperand = ValueWithNewAddrSpace.lookup(Operand)) { +      NewOperands.push_back(cast<Constant>(NewOperand)); +    } else { +      // Otherwise, reuses the old operand. +      NewOperands.push_back(Operand); +    } +  } + +  if (CE->getOpcode() == Instruction::GetElementPtr) { +    // Needs to specify the source type while constructing a getelementptr +    // constant expression. +    return CE->getWithOperands( +      NewOperands, TargetType, /*OnlyIfReduced=*/false, +      NewOperands[0]->getType()->getPointerElementType()); +  } + +  return CE->getWithOperands(NewOperands, TargetType); +} + +// Returns a clone of the value `V`, with its operands replaced as specified in +// ValueWithNewAddrSpace. This function is called on every flat address +// expression whose address space needs to be modified, in postorder. +// +// See cloneInstructionWithNewAddressSpace for the meaning of UndefUsesToFix. +Value *InferAddressSpaces::cloneValueWithNewAddressSpace( +  Value *V, unsigned NewAddrSpace, +  const ValueToValueMapTy &ValueWithNewAddrSpace, +  SmallVectorImpl<const Use *> *UndefUsesToFix) const { +  // All values in Postorder are flat address expressions. +  assert(isAddressExpression(*V) && +         V->getType()->getPointerAddressSpace() == FlatAddrSpace); + +  if (Instruction *I = dyn_cast<Instruction>(V)) { +    Value *NewV = cloneInstructionWithNewAddressSpace( +      I, NewAddrSpace, ValueWithNewAddrSpace, UndefUsesToFix); +    if (Instruction *NewI = dyn_cast<Instruction>(NewV)) { +      if (NewI->getParent() == nullptr) { +        NewI->insertBefore(I); +        NewI->takeName(I); +      } +    } +    return NewV; +  } + +  return cloneConstantExprWithNewAddressSpace( +    cast<ConstantExpr>(V), NewAddrSpace, ValueWithNewAddrSpace); +} + +// Defines the join operation on the address space lattice (see the file header +// comments). +unsigned InferAddressSpaces::joinAddressSpaces(unsigned AS1, +                                               unsigned AS2) const { +  if (AS1 == FlatAddrSpace || AS2 == FlatAddrSpace) +    return FlatAddrSpace; + +  if (AS1 == UninitializedAddressSpace) +    return AS2; +  if (AS2 == UninitializedAddressSpace) +    return AS1; + +  // The join of two different specific address spaces is flat. +  return (AS1 == AS2) ? AS1 : FlatAddrSpace; +} + +bool InferAddressSpaces::runOnFunction(Function &F) { +  if (skipFunction(F)) +    return false; + +  const TargetTransformInfo &TTI = +      getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); +  FlatAddrSpace = TTI.getFlatAddressSpace(); +  if (FlatAddrSpace == UninitializedAddressSpace) +    return false; + +  // Collects all flat address expressions in postorder. +  std::vector<Value *> Postorder = collectFlatAddressExpressions(F); + +  // Runs a data-flow analysis to refine the address spaces of every expression +  // in Postorder. +  ValueToAddrSpaceMapTy InferredAddrSpace; +  inferAddressSpaces(Postorder, &InferredAddrSpace); + +  // Changes the address spaces of the flat address expressions who are inferred +  // to point to a specific address space. +  return rewriteWithNewAddressSpaces(Postorder, InferredAddrSpace, &F); +} + +void InferAddressSpaces::inferAddressSpaces( +    const std::vector<Value *> &Postorder, +    ValueToAddrSpaceMapTy *InferredAddrSpace) const { +  SetVector<Value *> Worklist(Postorder.begin(), Postorder.end()); +  // Initially, all expressions are in the uninitialized address space. +  for (Value *V : Postorder) +    (*InferredAddrSpace)[V] = UninitializedAddressSpace; + +  while (!Worklist.empty()) { +    Value *V = Worklist.pop_back_val(); + +    // Tries to update the address space of the stack top according to the +    // address spaces of its operands. +    DEBUG(dbgs() << "Updating the address space of\n  " << *V << '\n'); +    Optional<unsigned> NewAS = updateAddressSpace(*V, *InferredAddrSpace); +    if (!NewAS.hasValue()) +      continue; +    // If any updates are made, grabs its users to the worklist because +    // their address spaces can also be possibly updated. +    DEBUG(dbgs() << "  to " << NewAS.getValue() << '\n'); +    (*InferredAddrSpace)[V] = NewAS.getValue(); + +    for (Value *User : V->users()) { +      // Skip if User is already in the worklist. +      if (Worklist.count(User)) +        continue; + +      auto Pos = InferredAddrSpace->find(User); +      // Our algorithm only updates the address spaces of flat address +      // expressions, which are those in InferredAddrSpace. +      if (Pos == InferredAddrSpace->end()) +        continue; + +      // Function updateAddressSpace moves the address space down a lattice +      // path. Therefore, nothing to do if User is already inferred as flat (the +      // bottom element in the lattice). +      if (Pos->second == FlatAddrSpace) +        continue; + +      Worklist.insert(User); +    } +  } +} + +Optional<unsigned> InferAddressSpaces::updateAddressSpace( +    const Value &V, const ValueToAddrSpaceMapTy &InferredAddrSpace) const { +  assert(InferredAddrSpace.count(&V)); + +  // The new inferred address space equals the join of the address spaces +  // of all its pointer operands. +  unsigned NewAS = UninitializedAddressSpace; + +  const Operator &Op = cast<Operator>(V); +  if (Op.getOpcode() == Instruction::Select) { +    Value *Src0 = Op.getOperand(1); +    Value *Src1 = Op.getOperand(2); + +    auto I = InferredAddrSpace.find(Src0); +    unsigned Src0AS = (I != InferredAddrSpace.end()) ? +      I->second : Src0->getType()->getPointerAddressSpace(); + +    auto J = InferredAddrSpace.find(Src1); +    unsigned Src1AS = (J != InferredAddrSpace.end()) ? +      J->second : Src1->getType()->getPointerAddressSpace(); + +    auto *C0 = dyn_cast<Constant>(Src0); +    auto *C1 = dyn_cast<Constant>(Src1); + +    // If one of the inputs is a constant, we may be able to do a constant +    // addrspacecast of it. Defer inferring the address space until the input +    // address space is known. +    if ((C1 && Src0AS == UninitializedAddressSpace) || +        (C0 && Src1AS == UninitializedAddressSpace)) +      return None; + +    if (C0 && isSafeToCastConstAddrSpace(C0, Src1AS)) +      NewAS = Src1AS; +    else if (C1 && isSafeToCastConstAddrSpace(C1, Src0AS)) +      NewAS = Src0AS; +    else +      NewAS = joinAddressSpaces(Src0AS, Src1AS); +  } else { +    for (Value *PtrOperand : getPointerOperands(V)) { +      auto I = InferredAddrSpace.find(PtrOperand); +      unsigned OperandAS = I != InferredAddrSpace.end() ? +        I->second : PtrOperand->getType()->getPointerAddressSpace(); + +      // join(flat, *) = flat. So we can break if NewAS is already flat. +      NewAS = joinAddressSpaces(NewAS, OperandAS); +      if (NewAS == FlatAddrSpace) +        break; +    } +  } + +  unsigned OldAS = InferredAddrSpace.lookup(&V); +  assert(OldAS != FlatAddrSpace); +  if (OldAS == NewAS) +    return None; +  return NewAS; +} + +/// \p returns true if \p U is the pointer operand of a memory instruction with +/// a single pointer operand that can have its address space changed by simply +/// mutating the use to a new value. +static bool isSimplePointerUseValidToReplace(Use &U) { +  User *Inst = U.getUser(); +  unsigned OpNo = U.getOperandNo(); + +  if (auto *LI = dyn_cast<LoadInst>(Inst)) +    return OpNo == LoadInst::getPointerOperandIndex() && !LI->isVolatile(); + +  if (auto *SI = dyn_cast<StoreInst>(Inst)) +    return OpNo == StoreInst::getPointerOperandIndex() && !SI->isVolatile(); + +  if (auto *RMW = dyn_cast<AtomicRMWInst>(Inst)) +    return OpNo == AtomicRMWInst::getPointerOperandIndex() && !RMW->isVolatile(); + +  if (auto *CmpX = dyn_cast<AtomicCmpXchgInst>(Inst)) { +    return OpNo == AtomicCmpXchgInst::getPointerOperandIndex() && +           !CmpX->isVolatile(); +  } + +  return false; +} + +/// Update memory intrinsic uses that require more complex processing than +/// simple memory instructions. Thse require re-mangling and may have multiple +/// pointer operands. +static bool handleMemIntrinsicPtrUse(MemIntrinsic *MI, Value *OldV, +                                     Value *NewV) { +  IRBuilder<> B(MI); +  MDNode *TBAA = MI->getMetadata(LLVMContext::MD_tbaa); +  MDNode *ScopeMD = MI->getMetadata(LLVMContext::MD_alias_scope); +  MDNode *NoAliasMD = MI->getMetadata(LLVMContext::MD_noalias); + +  if (auto *MSI = dyn_cast<MemSetInst>(MI)) { +    B.CreateMemSet(NewV, MSI->getValue(), +                   MSI->getLength(), MSI->getAlignment(), +                   false, // isVolatile +                   TBAA, ScopeMD, NoAliasMD); +  } else if (auto *MTI = dyn_cast<MemTransferInst>(MI)) { +    Value *Src = MTI->getRawSource(); +    Value *Dest = MTI->getRawDest(); + +    // Be careful in case this is a self-to-self copy. +    if (Src == OldV) +      Src = NewV; + +    if (Dest == OldV) +      Dest = NewV; + +    if (isa<MemCpyInst>(MTI)) { +      MDNode *TBAAStruct = MTI->getMetadata(LLVMContext::MD_tbaa_struct); +      B.CreateMemCpy(Dest, Src, MTI->getLength(), +                     MTI->getAlignment(), +                     false, // isVolatile +                     TBAA, TBAAStruct, ScopeMD, NoAliasMD); +    } else { +      assert(isa<MemMoveInst>(MTI)); +      B.CreateMemMove(Dest, Src, MTI->getLength(), +                      MTI->getAlignment(), +                      false, // isVolatile +                      TBAA, ScopeMD, NoAliasMD); +    } +  } else +    llvm_unreachable("unhandled MemIntrinsic"); + +  MI->eraseFromParent(); +  return true; +} + +// \p returns true if it is OK to change the address space of constant \p C with +// a ConstantExpr addrspacecast. +bool InferAddressSpaces::isSafeToCastConstAddrSpace(Constant *C, unsigned NewAS) const { +  assert(NewAS != UninitializedAddressSpace); + +  unsigned SrcAS = C->getType()->getPointerAddressSpace(); +  if (SrcAS == NewAS || isa<UndefValue>(C)) +    return true; + +  // Prevent illegal casts between different non-flat address spaces. +  if (SrcAS != FlatAddrSpace && NewAS != FlatAddrSpace) +    return false; + +  if (isa<ConstantPointerNull>(C)) +    return true; + +  if (auto *Op = dyn_cast<Operator>(C)) { +    // If we already have a constant addrspacecast, it should be safe to cast it +    // off. +    if (Op->getOpcode() == Instruction::AddrSpaceCast) +      return isSafeToCastConstAddrSpace(cast<Constant>(Op->getOperand(0)), NewAS); + +    if (Op->getOpcode() == Instruction::IntToPtr && +        Op->getType()->getPointerAddressSpace() == FlatAddrSpace) +      return true; +  } + +  return false; +} + +static Value::use_iterator skipToNextUser(Value::use_iterator I, +                                          Value::use_iterator End) { +  User *CurUser = I->getUser(); +  ++I; + +  while (I != End && I->getUser() == CurUser) +    ++I; + +  return I; +} + +bool InferAddressSpaces::rewriteWithNewAddressSpaces( +  const std::vector<Value *> &Postorder, +  const ValueToAddrSpaceMapTy &InferredAddrSpace, Function *F) const { +  // For each address expression to be modified, creates a clone of it with its +  // pointer operands converted to the new address space. Since the pointer +  // operands are converted, the clone is naturally in the new address space by +  // construction. +  ValueToValueMapTy ValueWithNewAddrSpace; +  SmallVector<const Use *, 32> UndefUsesToFix; +  for (Value* V : Postorder) { +    unsigned NewAddrSpace = InferredAddrSpace.lookup(V); +    if (V->getType()->getPointerAddressSpace() != NewAddrSpace) { +      ValueWithNewAddrSpace[V] = cloneValueWithNewAddressSpace( +        V, NewAddrSpace, ValueWithNewAddrSpace, &UndefUsesToFix); +    } +  } + +  if (ValueWithNewAddrSpace.empty()) +    return false; + +  // Fixes all the undef uses generated by cloneInstructionWithNewAddressSpace. +  for (const Use *UndefUse : UndefUsesToFix) { +    User *V = UndefUse->getUser(); +    User *NewV = cast<User>(ValueWithNewAddrSpace.lookup(V)); +    unsigned OperandNo = UndefUse->getOperandNo(); +    assert(isa<UndefValue>(NewV->getOperand(OperandNo))); +    NewV->setOperand(OperandNo, ValueWithNewAddrSpace.lookup(UndefUse->get())); +  } + +  // Replaces the uses of the old address expressions with the new ones. +  for (Value *V : Postorder) { +    Value *NewV = ValueWithNewAddrSpace.lookup(V); +    if (NewV == nullptr) +      continue; + +    DEBUG(dbgs() << "Replacing the uses of " << *V +                 << "\n  with\n  " << *NewV << '\n'); + +    Value::use_iterator I, E, Next; +    for (I = V->use_begin(), E = V->use_end(); I != E; ) { +      Use &U = *I; + +      // Some users may see the same pointer operand in multiple operands. Skip +      // to the next instruction. +      I = skipToNextUser(I, E); + +      if (isSimplePointerUseValidToReplace(U)) { +        // If V is used as the pointer operand of a compatible memory operation, +        // sets the pointer operand to NewV. This replacement does not change +        // the element type, so the resultant load/store is still valid. +        U.set(NewV); +        continue; +      } + +      User *CurUser = U.getUser(); +      // Handle more complex cases like intrinsic that need to be remangled. +      if (auto *MI = dyn_cast<MemIntrinsic>(CurUser)) { +        if (!MI->isVolatile() && handleMemIntrinsicPtrUse(MI, V, NewV)) +          continue; +      } + +      if (auto *II = dyn_cast<IntrinsicInst>(CurUser)) { +        if (rewriteIntrinsicOperands(II, V, NewV)) +          continue; +      } + +      if (isa<Instruction>(CurUser)) { +        if (ICmpInst *Cmp = dyn_cast<ICmpInst>(CurUser)) { +          // If we can infer that both pointers are in the same addrspace, +          // transform e.g. +          //   %cmp = icmp eq float* %p, %q +          // into +          //   %cmp = icmp eq float addrspace(3)* %new_p, %new_q + +          unsigned NewAS = NewV->getType()->getPointerAddressSpace(); +          int SrcIdx = U.getOperandNo(); +          int OtherIdx = (SrcIdx == 0) ? 1 : 0; +          Value *OtherSrc = Cmp->getOperand(OtherIdx); + +          if (Value *OtherNewV = ValueWithNewAddrSpace.lookup(OtherSrc)) { +            if (OtherNewV->getType()->getPointerAddressSpace() == NewAS) { +              Cmp->setOperand(OtherIdx, OtherNewV); +              Cmp->setOperand(SrcIdx, NewV); +              continue; +            } +          } + +          // Even if the type mismatches, we can cast the constant. +          if (auto *KOtherSrc = dyn_cast<Constant>(OtherSrc)) { +            if (isSafeToCastConstAddrSpace(KOtherSrc, NewAS)) { +              Cmp->setOperand(SrcIdx, NewV); +              Cmp->setOperand(OtherIdx, +                ConstantExpr::getAddrSpaceCast(KOtherSrc, NewV->getType())); +              continue; +            } +          } +        } + +        // Otherwise, replaces the use with flat(NewV). +        if (Instruction *I = dyn_cast<Instruction>(V)) { +          BasicBlock::iterator InsertPos = std::next(I->getIterator()); +          while (isa<PHINode>(InsertPos)) +            ++InsertPos; +          U.set(new AddrSpaceCastInst(NewV, V->getType(), "", &*InsertPos)); +        } else { +          U.set(ConstantExpr::getAddrSpaceCast(cast<Constant>(NewV), +                                               V->getType())); +        } +      } +    } + +    if (V->use_empty()) +      RecursivelyDeleteTriviallyDeadInstructions(V); +  } + +  return true; +} + +FunctionPass *llvm::createInferAddressSpacesPass() { +  return new InferAddressSpaces(); +} diff --git a/lib/Transforms/Scalar/JumpThreading.cpp b/lib/Transforms/Scalar/JumpThreading.cpp index 1870c3deb4f3..08eb95a1a3d3 100644 --- a/lib/Transforms/Scalar/JumpThreading.cpp +++ b/lib/Transforms/Scalar/JumpThreading.cpp @@ -17,6 +17,7 @@  #include "llvm/ADT/DenseSet.h"  #include "llvm/ADT/STLExtras.h"  #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h"  #include "llvm/Analysis/GlobalsModRef.h"  #include "llvm/Analysis/CFG.h"  #include "llvm/Analysis/BlockFrequencyInfoImpl.h" @@ -30,11 +31,13 @@  #include "llvm/IR/LLVMContext.h"  #include "llvm/IR/MDBuilder.h"  #include "llvm/IR/Metadata.h" +#include "llvm/IR/PatternMatch.h"  #include "llvm/Pass.h"  #include "llvm/Support/CommandLine.h"  #include "llvm/Support/Debug.h"  #include "llvm/Support/raw_ostream.h"  #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h"  #include "llvm/Transforms/Utils/Local.h"  #include "llvm/Transforms/Utils/SSAUpdater.h"  #include <algorithm> @@ -89,6 +92,7 @@ namespace {      bool runOnFunction(Function &F) override;      void getAnalysisUsage(AnalysisUsage &AU) const override { +      AU.addRequired<AAResultsWrapperPass>();        AU.addRequired<LazyValueInfoWrapperPass>();        AU.addPreserved<LazyValueInfoWrapperPass>();        AU.addPreserved<GlobalsAAWrapperPass>(); @@ -104,6 +108,7 @@ INITIALIZE_PASS_BEGIN(JumpThreading, "jump-threading",                  "Jump Threading", false, false)  INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass)  INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)  INITIALIZE_PASS_END(JumpThreading, "jump-threading",                  "Jump Threading", false, false) @@ -121,6 +126,7 @@ bool JumpThreading::runOnFunction(Function &F) {      return false;    auto TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();    auto LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI(); +  auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();    std::unique_ptr<BlockFrequencyInfo> BFI;    std::unique_ptr<BranchProbabilityInfo> BPI;    bool HasProfileData = F.getEntryCount().hasValue(); @@ -129,7 +135,8 @@ bool JumpThreading::runOnFunction(Function &F) {      BPI.reset(new BranchProbabilityInfo(F, LI));      BFI.reset(new BlockFrequencyInfo(F, *BPI, LI));    } -  return Impl.runImpl(F, TLI, LVI, HasProfileData, std::move(BFI), + +  return Impl.runImpl(F, TLI, LVI, AA, HasProfileData, std::move(BFI),                        std::move(BPI));  } @@ -138,6 +145,8 @@ PreservedAnalyses JumpThreadingPass::run(Function &F,    auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);    auto &LVI = AM.getResult<LazyValueAnalysis>(F); +  auto &AA = AM.getResult<AAManager>(F); +    std::unique_ptr<BlockFrequencyInfo> BFI;    std::unique_ptr<BranchProbabilityInfo> BPI;    bool HasProfileData = F.getEntryCount().hasValue(); @@ -146,12 +155,9 @@ PreservedAnalyses JumpThreadingPass::run(Function &F,      BPI.reset(new BranchProbabilityInfo(F, LI));      BFI.reset(new BlockFrequencyInfo(F, *BPI, LI));    } -  bool Changed = -      runImpl(F, &TLI, &LVI, HasProfileData, std::move(BFI), std::move(BPI)); -  // FIXME: We need to invalidate LVI to avoid PR28400. Is there a better -  // solution? -  AM.invalidate<LazyValueAnalysis>(F); +  bool Changed = runImpl(F, &TLI, &LVI, &AA, HasProfileData, std::move(BFI), +                         std::move(BPI));    if (!Changed)      return PreservedAnalyses::all(); @@ -161,18 +167,23 @@ PreservedAnalyses JumpThreadingPass::run(Function &F,  }  bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, -                                LazyValueInfo *LVI_, bool HasProfileData_, +                                LazyValueInfo *LVI_, AliasAnalysis *AA_, +                                bool HasProfileData_,                                  std::unique_ptr<BlockFrequencyInfo> BFI_,                                  std::unique_ptr<BranchProbabilityInfo> BPI_) {    DEBUG(dbgs() << "Jump threading on function '" << F.getName() << "'\n");    TLI = TLI_;    LVI = LVI_; +  AA = AA_;    BFI.reset();    BPI.reset();    // When profile data is available, we need to update edge weights after    // successful jump threading, which requires both BPI and BFI being available.    HasProfileData = HasProfileData_; +  auto *GuardDecl = F.getParent()->getFunction( +      Intrinsic::getName(Intrinsic::experimental_guard)); +  HasGuards = GuardDecl && !GuardDecl->use_empty();    if (HasProfileData) {      BPI = std::move(BPI_);      BFI = std::move(BFI_); @@ -226,26 +237,13 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_,            BB != &BB->getParent()->getEntryBlock() &&            // If the terminator is the only non-phi instruction, try to nuke it.            BB->getFirstNonPHIOrDbg()->isTerminator() && !LoopHeaders.count(BB)) { -        // Since TryToSimplifyUncondBranchFromEmptyBlock may delete the -        // block, we have to make sure it isn't in the LoopHeaders set.  We -        // reinsert afterward if needed. -        bool ErasedFromLoopHeaders = LoopHeaders.erase(BB); -        BasicBlock *Succ = BI->getSuccessor(0); -          // FIXME: It is always conservatively correct to drop the info          // for a block even if it doesn't get erased.  This isn't totally          // awesome, but it allows us to use AssertingVH to prevent nasty          // dangling pointer issues within LazyValueInfo.          LVI->eraseBlock(BB); -        if (TryToSimplifyUncondBranchFromEmptyBlock(BB)) { +        if (TryToSimplifyUncondBranchFromEmptyBlock(BB))            Changed = true; -          // If we deleted BB and BB was the header of a loop, then the -          // successor is now the header of the loop. -          BB = Succ; -        } - -        if (ErasedFromLoopHeaders) -          LoopHeaders.insert(BB);        }      }      EverChanged |= Changed; @@ -255,10 +253,13 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_,    return EverChanged;  } -/// getJumpThreadDuplicationCost - Return the cost of duplicating this block to -/// thread across it. Stop scanning the block when passing the threshold. -static unsigned getJumpThreadDuplicationCost(const BasicBlock *BB, +/// Return the cost of duplicating a piece of this block from first non-phi +/// and before StopAt instruction to thread across it. Stop scanning the block +/// when exceeding the threshold. If duplication is impossible, returns ~0U. +static unsigned getJumpThreadDuplicationCost(BasicBlock *BB, +                                             Instruction *StopAt,                                               unsigned Threshold) { +  assert(StopAt->getParent() == BB && "Not an instruction from proper BB?");    /// Ignore PHI nodes, these will be flattened when duplication happens.    BasicBlock::const_iterator I(BB->getFirstNonPHI()); @@ -266,15 +267,17 @@ static unsigned getJumpThreadDuplicationCost(const BasicBlock *BB,    // branch, so they shouldn't count against the duplication cost.    unsigned Bonus = 0; -  const TerminatorInst *BBTerm = BB->getTerminator(); -  // Threading through a switch statement is particularly profitable.  If this -  // block ends in a switch, decrease its cost to make it more likely to happen. -  if (isa<SwitchInst>(BBTerm)) -    Bonus = 6; - -  // The same holds for indirect branches, but slightly more so. -  if (isa<IndirectBrInst>(BBTerm)) -    Bonus = 8; +  if (BB->getTerminator() == StopAt) { +    // Threading through a switch statement is particularly profitable.  If this +    // block ends in a switch, decrease its cost to make it more likely to +    // happen. +    if (isa<SwitchInst>(StopAt)) +      Bonus = 6; + +    // The same holds for indirect branches, but slightly more so. +    if (isa<IndirectBrInst>(StopAt)) +      Bonus = 8; +  }    // Bump the threshold up so the early exit from the loop doesn't skip the    // terminator-based Size adjustment at the end. @@ -283,7 +286,7 @@ static unsigned getJumpThreadDuplicationCost(const BasicBlock *BB,    // Sum up the cost of each instruction until we get to the terminator.  Don't    // include the terminator because the copy won't include it.    unsigned Size = 0; -  for (; !isa<TerminatorInst>(I); ++I) { +  for (; &*I != StopAt; ++I) {      // Stop scanning the block if we've reached the threshold.      if (Size > Threshold) @@ -729,6 +732,10 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) {    if (TryToUnfoldSelectInCurrBB(BB))      return true; +  // Look if we can propagate guards to predecessors. +  if (HasGuards && ProcessGuards(BB)) +    return true; +    // What kind of constant we're looking for.    ConstantPreference Preference = WantInteger; @@ -804,7 +811,6 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) {      return false;    } -    if (CmpInst *CondCmp = dyn_cast<CmpInst>(CondInst)) {      // If we're branching on a conditional, LVI might be able to determine      // it's value at the branch instruction.  We only handle comparisons @@ -812,7 +818,12 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) {      // TODO: This should be extended to handle switches as well.      BranchInst *CondBr = dyn_cast<BranchInst>(BB->getTerminator());      Constant *CondConst = dyn_cast<Constant>(CondCmp->getOperand(1)); -    if (CondBr && CondConst && CondBr->isConditional()) { +    if (CondBr && CondConst) { +      // We should have returned as soon as we turn a conditional branch to +      // unconditional. Because its no longer interesting as far as jump +      // threading is concerned. +      assert(CondBr->isConditional() && "Threading on unconditional terminator"); +        LazyValueInfo::Tristate Ret =          LVI->getPredicateAt(CondCmp->getPredicate(), CondCmp->getOperand(0),                              CondConst, CondBr); @@ -835,10 +846,12 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) {          }          return true;        } -    } -    if (CondBr && CondConst && TryToUnfoldSelect(CondCmp, BB)) -      return true; +      // We did not manage to simplify this branch, try to see whether +      // CondCmp depends on a known phi-select pattern. +      if (TryToUnfoldSelect(CondCmp, BB)) +        return true; +    }    }    // Check for some cases that are worth simplifying.  Right now we want to look @@ -857,7 +870,6 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) {      if (SimplifyPartiallyRedundantLoad(LI))        return true; -    // Handle a variety of cases where we are branching on something derived from    // a PHI node in the current block.  If we can prove that any predecessors    // compute a predictable value based on a PHI node, thread those predecessors. @@ -871,7 +883,6 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) {      if (PN->getParent() == BB && isa<BranchInst>(BB->getTerminator()))        return ProcessBranchOnPHI(PN); -    // If this is an otherwise-unfoldable branch on a XOR, see if we can simplify.    if (CondInst->getOpcode() == Instruction::Xor &&        CondInst->getParent() == BB && isa<BranchInst>(BB->getTerminator())) @@ -920,6 +931,14 @@ bool JumpThreadingPass::ProcessImpliedCondition(BasicBlock *BB) {    return false;  } +/// Return true if Op is an instruction defined in the given block. +static bool isOpDefinedInBlock(Value *Op, BasicBlock *BB) { +  if (Instruction *OpInst = dyn_cast<Instruction>(Op)) +    if (OpInst->getParent() == BB) +      return true; +  return false; +} +  /// SimplifyPartiallyRedundantLoad - If LI is an obviously partially redundant  /// load instruction, eliminate it by replacing it with a PHI node.  This is an  /// important optimization that encourages jump threading, and needs to be run @@ -942,18 +961,17 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) {    Value *LoadedPtr = LI->getOperand(0); -  // If the loaded operand is defined in the LoadBB, it can't be available. -  // TODO: Could do simple PHI translation, that would be fun :) -  if (Instruction *PtrOp = dyn_cast<Instruction>(LoadedPtr)) -    if (PtrOp->getParent() == LoadBB) -      return false; +  // If the loaded operand is defined in the LoadBB and its not a phi, +  // it can't be available in predecessors. +  if (isOpDefinedInBlock(LoadedPtr, LoadBB) && !isa<PHINode>(LoadedPtr)) +    return false;    // Scan a few instructions up from the load, to see if it is obviously live at    // the entry to its block.    BasicBlock::iterator BBIt(LI);    bool IsLoadCSE; -  if (Value *AvailableVal = -        FindAvailableLoadedValue(LI, LoadBB, BBIt, DefMaxInstsToScan, nullptr, &IsLoadCSE)) { +  if (Value *AvailableVal = FindAvailableLoadedValue( +          LI, LoadBB, BBIt, DefMaxInstsToScan, AA, &IsLoadCSE)) {      // If the value of the load is locally available within the block, just use      // it.  This frequently occurs for reg2mem'd allocas. @@ -997,12 +1015,34 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) {      if (!PredsScanned.insert(PredBB).second)        continue; -    // Scan the predecessor to see if the value is available in the pred.      BBIt = PredBB->end(); -    Value *PredAvailable = FindAvailableLoadedValue(LI, PredBB, BBIt, -                                                    DefMaxInstsToScan, -                                                    nullptr, -                                                    &IsLoadCSE); +    unsigned NumScanedInst = 0; +    Value *PredAvailable = nullptr; +    // NOTE: We don't CSE load that is volatile or anything stronger than +    // unordered, that should have been checked when we entered the function. +    assert(LI->isUnordered() && "Attempting to CSE volatile or atomic loads"); +    // If this is a load on a phi pointer, phi-translate it and search +    // for available load/store to the pointer in predecessors. +    Value *Ptr = LoadedPtr->DoPHITranslation(LoadBB, PredBB); +    PredAvailable = FindAvailablePtrLoadStore( +        Ptr, LI->getType(), LI->isAtomic(), PredBB, BBIt, DefMaxInstsToScan, +        AA, &IsLoadCSE, &NumScanedInst); + +    // If PredBB has a single predecessor, continue scanning through the +    // single precessor. +    BasicBlock *SinglePredBB = PredBB; +    while (!PredAvailable && SinglePredBB && BBIt == SinglePredBB->begin() && +           NumScanedInst < DefMaxInstsToScan) { +      SinglePredBB = SinglePredBB->getSinglePredecessor(); +      if (SinglePredBB) { +        BBIt = SinglePredBB->end(); +        PredAvailable = FindAvailablePtrLoadStore( +            Ptr, LI->getType(), LI->isAtomic(), SinglePredBB, BBIt, +            (DefMaxInstsToScan - NumScanedInst), AA, &IsLoadCSE, +            &NumScanedInst); +      } +    } +      if (!PredAvailable) {        OneUnavailablePred = PredBB;        continue; @@ -1062,10 +1102,10 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) {    if (UnavailablePred) {      assert(UnavailablePred->getTerminator()->getNumSuccessors() == 1 &&             "Can't handle critical edge here!"); -    LoadInst *NewVal = -        new LoadInst(LoadedPtr, LI->getName() + ".pr", false, -                     LI->getAlignment(), LI->getOrdering(), LI->getSynchScope(), -                     UnavailablePred->getTerminator()); +    LoadInst *NewVal = new LoadInst( +        LoadedPtr->DoPHITranslation(LoadBB, UnavailablePred), +        LI->getName() + ".pr", false, LI->getAlignment(), LI->getOrdering(), +        LI->getSynchScope(), UnavailablePred->getTerminator());      NewVal->setDebugLoc(LI->getDebugLoc());      if (AATags)        NewVal->setAAMetadata(AATags); @@ -1229,7 +1269,7 @@ bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB,      else if (BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator()))        DestBB = BI->getSuccessor(cast<ConstantInt>(Val)->isZero());      else if (SwitchInst *SI = dyn_cast<SwitchInst>(BB->getTerminator())) { -      DestBB = SI->findCaseValue(cast<ConstantInt>(Val)).getCaseSuccessor(); +      DestBB = SI->findCaseValue(cast<ConstantInt>(Val))->getCaseSuccessor();      } else {        assert(isa<IndirectBrInst>(BB->getTerminator())                && "Unexpected terminator"); @@ -1468,7 +1508,8 @@ bool JumpThreadingPass::ThreadEdge(BasicBlock *BB,      return false;    } -  unsigned JumpThreadCost = getJumpThreadDuplicationCost(BB, BBDupThreshold); +  unsigned JumpThreadCost = +      getJumpThreadDuplicationCost(BB, BB->getTerminator(), BBDupThreshold);    if (JumpThreadCost > BBDupThreshold) {      DEBUG(dbgs() << "  Not threading BB '" << BB->getName()            << "' - Cost is too high: " << JumpThreadCost << "\n"); @@ -1756,7 +1797,8 @@ bool JumpThreadingPass::DuplicateCondBranchOnPHIIntoPred(      return false;    } -  unsigned DuplicationCost = getJumpThreadDuplicationCost(BB, BBDupThreshold); +  unsigned DuplicationCost = +      getJumpThreadDuplicationCost(BB, BB->getTerminator(), BBDupThreshold);    if (DuplicationCost > BBDupThreshold) {      DEBUG(dbgs() << "  Not duplicating BB '" << BB->getName()            << "' - Cost is too high: " << DuplicationCost << "\n"); @@ -1888,10 +1930,10 @@ bool JumpThreadingPass::DuplicateCondBranchOnPHIIntoPred(  /// TryToUnfoldSelect - Look for blocks of the form  /// bb1:  ///   %a = select -///   br bb +///   br bb2  ///  /// bb2: -///   %p = phi [%a, %bb] ... +///   %p = phi [%a, %bb1] ...  ///   %c = icmp %p  ///   br i1 %c  /// @@ -2021,3 +2063,130 @@ bool JumpThreadingPass::TryToUnfoldSelectInCurrBB(BasicBlock *BB) {    return false;  } + +/// Try to propagate a guard from the current BB into one of its predecessors +/// in case if another branch of execution implies that the condition of this +/// guard is always true. Currently we only process the simplest case that +/// looks like: +/// +/// Start: +///   %cond = ... +///   br i1 %cond, label %T1, label %F1 +/// T1: +///   br label %Merge +/// F1: +///   br label %Merge +/// Merge: +///   %condGuard = ... +///   call void(i1, ...) @llvm.experimental.guard( i1 %condGuard )[ "deopt"() ] +/// +/// And cond either implies condGuard or !condGuard. In this case all the +/// instructions before the guard can be duplicated in both branches, and the +/// guard is then threaded to one of them. +bool JumpThreadingPass::ProcessGuards(BasicBlock *BB) { +  using namespace PatternMatch; +  // We only want to deal with two predecessors. +  BasicBlock *Pred1, *Pred2; +  auto PI = pred_begin(BB), PE = pred_end(BB); +  if (PI == PE) +    return false; +  Pred1 = *PI++; +  if (PI == PE) +    return false; +  Pred2 = *PI++; +  if (PI != PE) +    return false; +  if (Pred1 == Pred2) +    return false; + +  // Try to thread one of the guards of the block. +  // TODO: Look up deeper than to immediate predecessor? +  auto *Parent = Pred1->getSinglePredecessor(); +  if (!Parent || Parent != Pred2->getSinglePredecessor()) +    return false; + +  if (auto *BI = dyn_cast<BranchInst>(Parent->getTerminator())) +    for (auto &I : *BB) +      if (match(&I, m_Intrinsic<Intrinsic::experimental_guard>())) +        if (ThreadGuard(BB, cast<IntrinsicInst>(&I), BI)) +          return true; + +  return false; +} + +/// Try to propagate the guard from BB which is the lower block of a diamond +/// to one of its branches, in case if diamond's condition implies guard's +/// condition. +bool JumpThreadingPass::ThreadGuard(BasicBlock *BB, IntrinsicInst *Guard, +                                    BranchInst *BI) { +  assert(BI->getNumSuccessors() == 2 && "Wrong number of successors?"); +  assert(BI->isConditional() && "Unconditional branch has 2 successors?"); +  Value *GuardCond = Guard->getArgOperand(0); +  Value *BranchCond = BI->getCondition(); +  BasicBlock *TrueDest = BI->getSuccessor(0); +  BasicBlock *FalseDest = BI->getSuccessor(1); + +  auto &DL = BB->getModule()->getDataLayout(); +  bool TrueDestIsSafe = false; +  bool FalseDestIsSafe = false; + +  // True dest is safe if BranchCond => GuardCond. +  auto Impl = isImpliedCondition(BranchCond, GuardCond, DL); +  if (Impl && *Impl) +    TrueDestIsSafe = true; +  else { +    // False dest is safe if !BranchCond => GuardCond. +    Impl = +        isImpliedCondition(BranchCond, GuardCond, DL, /* InvertAPred */ true); +    if (Impl && *Impl) +      FalseDestIsSafe = true; +  } + +  if (!TrueDestIsSafe && !FalseDestIsSafe) +    return false; + +  BasicBlock *UnguardedBlock = TrueDestIsSafe ? TrueDest : FalseDest; +  BasicBlock *GuardedBlock = FalseDestIsSafe ? TrueDest : FalseDest; + +  ValueToValueMapTy UnguardedMapping, GuardedMapping; +  Instruction *AfterGuard = Guard->getNextNode(); +  unsigned Cost = getJumpThreadDuplicationCost(BB, AfterGuard, BBDupThreshold); +  if (Cost > BBDupThreshold) +    return false; +  // Duplicate all instructions before the guard and the guard itself to the +  // branch where implication is not proved. +  GuardedBlock = DuplicateInstructionsInSplitBetween( +      BB, GuardedBlock, AfterGuard, GuardedMapping); +  assert(GuardedBlock && "Could not create the guarded block?"); +  // Duplicate all instructions before the guard in the unguarded branch. +  // Since we have successfully duplicated the guarded block and this block +  // has fewer instructions, we expect it to succeed. +  UnguardedBlock = DuplicateInstructionsInSplitBetween(BB, UnguardedBlock, +                                                       Guard, UnguardedMapping); +  assert(UnguardedBlock && "Could not create the unguarded block?"); +  DEBUG(dbgs() << "Moved guard " << *Guard << " to block " +               << GuardedBlock->getName() << "\n"); + +  // Some instructions before the guard may still have uses. For them, we need +  // to create Phi nodes merging their copies in both guarded and unguarded +  // branches. Those instructions that have no uses can be just removed. +  SmallVector<Instruction *, 4> ToRemove; +  for (auto BI = BB->begin(); &*BI != AfterGuard; ++BI) +    if (!isa<PHINode>(&*BI)) +      ToRemove.push_back(&*BI); + +  Instruction *InsertionPoint = &*BB->getFirstInsertionPt(); +  assert(InsertionPoint && "Empty block?"); +  // Substitute with Phis & remove. +  for (auto *Inst : reverse(ToRemove)) { +    if (!Inst->use_empty()) { +      PHINode *NewPN = PHINode::Create(Inst->getType(), 2); +      NewPN->addIncoming(UnguardedMapping[Inst], UnguardedBlock); +      NewPN->addIncoming(GuardedMapping[Inst], GuardedBlock); +      NewPN->insertBefore(InsertionPoint); +      Inst->replaceAllUsesWith(NewPN); +    } +    Inst->eraseFromParent(); +  } +  return true; +} diff --git a/lib/Transforms/Scalar/LICM.cpp b/lib/Transforms/Scalar/LICM.cpp index f51d11c04cb2..340c81fed0fd 100644 --- a/lib/Transforms/Scalar/LICM.cpp +++ b/lib/Transforms/Scalar/LICM.cpp @@ -77,10 +77,16 @@ STATISTIC(NumMovedLoads, "Number of load insts hoisted or sunk");  STATISTIC(NumMovedCalls, "Number of call insts hoisted or sunk");  STATISTIC(NumPromoted, "Number of memory locations promoted to registers"); +/// Memory promotion is enabled by default.  static cl::opt<bool> -    DisablePromotion("disable-licm-promotion", cl::Hidden, +    DisablePromotion("disable-licm-promotion", cl::Hidden, cl::init(false),                       cl::desc("Disable memory promotion in LICM pass")); +static cl::opt<uint32_t> MaxNumUsesTraversed( +    "licm-max-num-uses-traversed", cl::Hidden, cl::init(8), +    cl::desc("Max num uses visited for identifying load " +             "invariance in loop using invariant start (default = 8)")); +  static bool inSubLoop(BasicBlock *BB, Loop *CurLoop, LoopInfo *LI);  static bool isNotUsedInLoop(const Instruction &I, const Loop *CurLoop,                              const LoopSafetyInfo *SafetyInfo); @@ -201,9 +207,9 @@ PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM,    if (!LICM.runOnLoop(&L, &AR.AA, &AR.LI, &AR.DT, &AR.TLI, &AR.SE, ORE, true))      return PreservedAnalyses::all(); -  // FIXME: There is no setPreservesCFG in the new PM. When that becomes -  // available, it should be used here. -  return getLoopPassPreservedAnalyses(); +  auto PA = getLoopPassPreservedAnalyses(); +  PA.preserveSet<CFGAnalyses>(); +  return PA;  }  char LegacyLICMPass::ID = 0; @@ -425,6 +431,29 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI,          continue;        } +      // Attempt to remove floating point division out of the loop by converting +      // it to a reciprocal multiplication. +      if (I.getOpcode() == Instruction::FDiv && +          CurLoop->isLoopInvariant(I.getOperand(1)) && +          I.hasAllowReciprocal()) { +        auto Divisor = I.getOperand(1); +        auto One = llvm::ConstantFP::get(Divisor->getType(), 1.0); +        auto ReciprocalDivisor = BinaryOperator::CreateFDiv(One, Divisor); +        ReciprocalDivisor->setFastMathFlags(I.getFastMathFlags()); +        ReciprocalDivisor->insertBefore(&I); + +        auto Product = BinaryOperator::CreateFMul(I.getOperand(0), +                                                  ReciprocalDivisor); +        Product->setFastMathFlags(I.getFastMathFlags()); +        Product->insertAfter(&I); +        I.replaceAllUsesWith(Product); +        I.eraseFromParent(); + +        hoist(*ReciprocalDivisor, DT, CurLoop, SafetyInfo, ORE); +        Changed = true; +        continue; +      } +        // Try hoisting the instruction out to the preheader.  We can only do this        // if all of the operands of the instruction are loop invariant and if it        // is safe to hoist the instruction. @@ -461,7 +490,10 @@ void llvm::computeLoopSafetyInfo(LoopSafetyInfo *SafetyInfo, Loop *CurLoop) {    SafetyInfo->MayThrow = SafetyInfo->HeaderMayThrow;    // Iterate over loop instructions and compute safety info. -  for (Loop::block_iterator BB = CurLoop->block_begin(), +  // Skip header as it has been computed and stored in HeaderMayThrow. +  // The first block in loopinfo.Blocks is guaranteed to be the header. +  assert(Header == *CurLoop->getBlocks().begin() && "First block must be header"); +  for (Loop::block_iterator BB = std::next(CurLoop->block_begin()),                              BBE = CurLoop->block_end();         (BB != BBE) && !SafetyInfo->MayThrow; ++BB)      for (BasicBlock::iterator I = (*BB)->begin(), E = (*BB)->end(); @@ -477,6 +509,59 @@ void llvm::computeLoopSafetyInfo(LoopSafetyInfo *SafetyInfo, Loop *CurLoop) {          SafetyInfo->BlockColors = colorEHFunclets(*Fn);  } +// Return true if LI is invariant within scope of the loop. LI is invariant if +// CurLoop is dominated by an invariant.start representing the same memory location +// and size as the memory location LI loads from, and also the invariant.start +// has no uses. +static bool isLoadInvariantInLoop(LoadInst *LI, DominatorTree *DT, +                                  Loop *CurLoop) { +  Value *Addr = LI->getOperand(0); +  const DataLayout &DL = LI->getModule()->getDataLayout(); +  const uint32_t LocSizeInBits = DL.getTypeSizeInBits( +      cast<PointerType>(Addr->getType())->getElementType()); + +  // if the type is i8 addrspace(x)*, we know this is the type of +  // llvm.invariant.start operand +  auto *PtrInt8Ty = PointerType::get(Type::getInt8Ty(LI->getContext()), +                                     LI->getPointerAddressSpace()); +  unsigned BitcastsVisited = 0; +  // Look through bitcasts until we reach the i8* type (this is invariant.start +  // operand type). +  while (Addr->getType() != PtrInt8Ty) { +    auto *BC = dyn_cast<BitCastInst>(Addr); +    // Avoid traversing high number of bitcast uses. +    if (++BitcastsVisited > MaxNumUsesTraversed || !BC) +      return false; +    Addr = BC->getOperand(0); +  } + +  unsigned UsesVisited = 0; +  // Traverse all uses of the load operand value, to see if invariant.start is +  // one of the uses, and whether it dominates the load instruction. +  for (auto *U : Addr->users()) { +    // Avoid traversing for Load operand with high number of users. +    if (++UsesVisited > MaxNumUsesTraversed) +      return false; +    IntrinsicInst *II = dyn_cast<IntrinsicInst>(U); +    // If there are escaping uses of invariant.start instruction, the load maybe +    // non-invariant. +    if (!II || II->getIntrinsicID() != Intrinsic::invariant_start || +        II->hasNUsesOrMore(1)) +      continue; +    unsigned InvariantSizeInBits = +        cast<ConstantInt>(II->getArgOperand(0))->getSExtValue() * 8; +    // Confirm the invariant.start location size contains the load operand size +    // in bits. Also, the invariant.start should dominate the load, and we +    // should not hoist the load out of a loop that contains this dominating +    // invariant.start. +    if (LocSizeInBits <= InvariantSizeInBits && +        DT->properlyDominates(II->getParent(), CurLoop->getHeader())) +      return true; +  } + +  return false; +} +  bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT,                                Loop *CurLoop, AliasSetTracker *CurAST,                                LoopSafetyInfo *SafetyInfo, @@ -493,6 +578,10 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT,      if (LI->getMetadata(LLVMContext::MD_invariant_load))        return true; +    // This checks for an invariant.start dominating the load. +    if (isLoadInvariantInLoop(LI, DT, CurLoop)) +      return true; +      // Don't hoist loads which have may-aliased stores in loop.      uint64_t Size = 0;      if (LI->getType()->isSized()) @@ -782,7 +871,7 @@ static bool hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop,    DEBUG(dbgs() << "LICM hoisting to " << Preheader->getName() << ": " << I                 << "\n");    ORE->emit(OptimizationRemark(DEBUG_TYPE, "Hoisted", &I) -            << "hosting " << ore::NV("Inst", &I)); +            << "hoisting " << ore::NV("Inst", &I));    // Metadata can be dependent on conditions we are hoisting above.    // Conservatively strip all metadata on the instruction unless we were @@ -852,6 +941,7 @@ class LoopPromoter : public LoadAndStorePromoter {    LoopInfo &LI;    DebugLoc DL;    int Alignment; +  bool UnorderedAtomic;    AAMDNodes AATags;    Value *maybeInsertLCSSAPHI(Value *V, BasicBlock *BB) const { @@ -875,10 +965,11 @@ public:                 SmallVectorImpl<BasicBlock *> &LEB,                 SmallVectorImpl<Instruction *> &LIP, PredIteratorCache &PIC,                 AliasSetTracker &ast, LoopInfo &li, DebugLoc dl, int alignment, -               const AAMDNodes &AATags) +               bool UnorderedAtomic, const AAMDNodes &AATags)        : LoadAndStorePromoter(Insts, S), SomePtr(SP), PointerMustAliases(PMA),          LoopExitBlocks(LEB), LoopInsertPts(LIP), PredCache(PIC), AST(ast), -        LI(li), DL(std::move(dl)), Alignment(alignment), AATags(AATags) {} +        LI(li), DL(std::move(dl)), Alignment(alignment), +        UnorderedAtomic(UnorderedAtomic),AATags(AATags) {}    bool isInstInList(Instruction *I,                      const SmallVectorImpl<Instruction *> &) const override { @@ -902,6 +993,8 @@ public:        Value *Ptr = maybeInsertLCSSAPHI(SomePtr, ExitBlock);        Instruction *InsertPos = LoopInsertPts[i];        StoreInst *NewSI = new StoreInst(LiveInValue, Ptr, InsertPos); +      if (UnorderedAtomic) +        NewSI->setOrdering(AtomicOrdering::Unordered);        NewSI->setAlignment(Alignment);        NewSI->setDebugLoc(DL);        if (AATags) @@ -992,18 +1085,41 @@ bool llvm::promoteLoopAccessesToScalars(    // We start with an alignment of one and try to find instructions that allow    // us to prove better alignment.    unsigned Alignment = 1; +  // Keep track of which types of access we see +  bool SawUnorderedAtomic = false;  +  bool SawNotAtomic = false;    AAMDNodes AATags;    const DataLayout &MDL = Preheader->getModule()->getDataLayout(); +  // Do we know this object does not escape ? +  bool IsKnownNonEscapingObject = false;    if (SafetyInfo->MayThrow) {      // If a loop can throw, we have to insert a store along each unwind edge.      // That said, we can't actually make the unwind edge explicit. Therefore,      // we have to prove that the store is dead along the unwind edge.      // -    // Currently, this code just special-cases alloca instructions. -    if (!isa<AllocaInst>(GetUnderlyingObject(SomePtr, MDL))) -      return false; +    // If the underlying object is not an alloca, nor a pointer that does not +    // escape, then we can not effectively prove that the store is dead along +    // the unwind edge. i.e. the caller of this function could have ways to +    // access the pointed object. +    Value *Object = GetUnderlyingObject(SomePtr, MDL); +    // If this is a base pointer we do not understand, simply bail. +    // We only handle alloca and return value from alloc-like fn right now. +    if (!isa<AllocaInst>(Object)) { +        if (!isAllocLikeFn(Object, TLI)) +          return false; +      // If this is an alloc like fn. There are more constraints we need to verify. +      // More specifically, we must make sure that the pointer can not escape. +      // +      // NOTE: PointerMayBeCaptured is not enough as the pointer may have escaped +      // even though its not captured by the enclosing function. Standard allocation +      // functions like malloc, calloc, and operator new return values which can +      // be assumed not to have previously escaped. +      if (PointerMayBeCaptured(Object, true, true)) +        return false; +      IsKnownNonEscapingObject = true; +    }    }    // Check that all of the pointers in the alias set have the same type.  We @@ -1029,8 +1145,11 @@ bool llvm::promoteLoopAccessesToScalars(        // it.        if (LoadInst *Load = dyn_cast<LoadInst>(UI)) {          assert(!Load->isVolatile() && "AST broken"); -        if (!Load->isSimple()) +        if (!Load->isUnordered())            return false; +         +        SawUnorderedAtomic |= Load->isAtomic(); +        SawNotAtomic |= !Load->isAtomic();          if (!DereferenceableInPH)            DereferenceableInPH = isSafeToExecuteUnconditionally( @@ -1041,9 +1160,12 @@ bool llvm::promoteLoopAccessesToScalars(          if (UI->getOperand(1) != ASIV)            continue;          assert(!Store->isVolatile() && "AST broken"); -        if (!Store->isSimple()) +        if (!Store->isUnordered())            return false; +        SawUnorderedAtomic |= Store->isAtomic(); +        SawNotAtomic |= !Store->isAtomic(); +          // If the store is guaranteed to execute, both properties are satisfied.          // We may want to check if a store is guaranteed to execute even if we          // already know that promotion is safe, since it may have higher @@ -1096,6 +1218,12 @@ bool llvm::promoteLoopAccessesToScalars(      }    } +  // If we found both an unordered atomic instruction and a non-atomic memory +  // access, bail.  We can't blindly promote non-atomic to atomic since we +  // might not be able to lower the result.  We can't downgrade since that +  // would violate memory model.  Also, align 0 is an error for atomics. +  if (SawUnorderedAtomic && SawNotAtomic) +    return false;    // If we couldn't prove we can hoist the load, bail.    if (!DereferenceableInPH) @@ -1106,10 +1234,15 @@ bool llvm::promoteLoopAccessesToScalars(    // stores along paths which originally didn't have them without violating the    // memory model.    if (!SafeToInsertStore) { -    Value *Object = GetUnderlyingObject(SomePtr, MDL); -    SafeToInsertStore = -        (isAllocLikeFn(Object, TLI) || isa<AllocaInst>(Object)) && +    // If this is a known non-escaping object, it is safe to insert the stores. +    if (IsKnownNonEscapingObject) +      SafeToInsertStore = true; +    else { +      Value *Object = GetUnderlyingObject(SomePtr, MDL); +      SafeToInsertStore = +        (isAllocLikeFn(Object, TLI) || isa<AllocaInst>(Object)) &&           !PointerMayBeCaptured(Object, true, true); +    }    }    // If we've still failed to prove we can sink the store, give up. @@ -1134,12 +1267,15 @@ bool llvm::promoteLoopAccessesToScalars(    SmallVector<PHINode *, 16> NewPHIs;    SSAUpdater SSA(&NewPHIs);    LoopPromoter Promoter(SomePtr, LoopUses, SSA, PointerMustAliases, ExitBlocks, -                        InsertPts, PIC, *CurAST, *LI, DL, Alignment, AATags); +                        InsertPts, PIC, *CurAST, *LI, DL, Alignment, +                        SawUnorderedAtomic, AATags);    // Set up the preheader to have a definition of the value.  It is the live-out    // value from the preheader that uses in the loop will use.    LoadInst *PreheaderLoad = new LoadInst(        SomePtr, SomePtr->getName() + ".promoted", Preheader->getTerminator()); +  if (SawUnorderedAtomic) +    PreheaderLoad->setOrdering(AtomicOrdering::Unordered);    PreheaderLoad->setAlignment(Alignment);    PreheaderLoad->setDebugLoc(DL);    if (AATags) diff --git a/lib/Transforms/Scalar/LoadCombine.cpp b/lib/Transforms/Scalar/LoadCombine.cpp index 389f1c595aa4..02215d3450c2 100644 --- a/lib/Transforms/Scalar/LoadCombine.cpp +++ b/lib/Transforms/Scalar/LoadCombine.cpp @@ -19,6 +19,7 @@  #include "llvm/Analysis/GlobalsModRef.h"  #include "llvm/Analysis/TargetFolder.h"  #include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h"  #include "llvm/IR/Function.h"  #include "llvm/IR/IRBuilder.h"  #include "llvm/IR/Instructions.h" @@ -53,18 +54,20 @@ struct LoadPOPPair {  class LoadCombine : public BasicBlockPass {    LLVMContext *C;    AliasAnalysis *AA; +  DominatorTree *DT;  public:    LoadCombine() : BasicBlockPass(ID), C(nullptr), AA(nullptr) {      initializeLoadCombinePass(*PassRegistry::getPassRegistry());    } -   +    using llvm::Pass::doInitialization;    bool doInitialization(Function &) override;    bool runOnBasicBlock(BasicBlock &BB) override;    void getAnalysisUsage(AnalysisUsage &AU) const override {      AU.setPreservesCFG();      AU.addRequired<AAResultsWrapperPass>(); +    AU.addRequired<DominatorTreeWrapperPass>();      AU.addPreserved<GlobalsAAWrapperPass>();    } @@ -234,6 +237,14 @@ bool LoadCombine::runOnBasicBlock(BasicBlock &BB) {      return false;    AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); +  DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + +  // Skip analysing dead blocks (not forward reachable from function entry). +  if (!DT->isReachableFromEntry(&BB)) { +    DEBUG(dbgs() << "LC: skipping unreachable " << BB.getName() << +          " in " << BB.getParent()->getName() << "\n"); +    return false; +  }    IRBuilder<TargetFolder> TheBuilder(        BB.getContext(), TargetFolder(BB.getModule()->getDataLayout())); @@ -245,13 +256,17 @@ bool LoadCombine::runOnBasicBlock(BasicBlock &BB) {    bool Combined = false;    unsigned Index = 0;    for (auto &I : BB) { -    if (I.mayThrow() || (I.mayWriteToMemory() && AST.containsUnknown(&I))) { +    if (I.mayThrow() || AST.containsUnknown(&I)) {        if (combineLoads(LoadMap))          Combined = true;        LoadMap.clear();        AST.clear();        continue;      } +    if (I.mayWriteToMemory()) { +      AST.add(&I); +      continue; +    }      LoadInst *LI = dyn_cast<LoadInst>(&I);      if (!LI)        continue; diff --git a/lib/Transforms/Scalar/LoopDeletion.cpp b/lib/Transforms/Scalar/LoopDeletion.cpp index cca75a365024..73e8ce0e1d93 100644 --- a/lib/Transforms/Scalar/LoopDeletion.cpp +++ b/lib/Transforms/Scalar/LoopDeletion.cpp @@ -29,32 +29,31 @@ using namespace llvm;  STATISTIC(NumDeleted, "Number of loops deleted"); -/// isLoopDead - Determined if a loop is dead.  This assumes that we've already -/// checked for unique exit and exiting blocks, and that the code is in LCSSA -/// form. -bool LoopDeletionPass::isLoopDead(Loop *L, ScalarEvolution &SE, -                                  SmallVectorImpl<BasicBlock *> &exitingBlocks, -                                  SmallVectorImpl<BasicBlock *> &exitBlocks, -                                  bool &Changed, BasicBlock *Preheader) { -  BasicBlock *exitBlock = exitBlocks[0]; - +/// Determines if a loop is dead. +/// +/// This assumes that we've already checked for unique exit and exiting blocks, +/// and that the code is in LCSSA form. +static bool isLoopDead(Loop *L, ScalarEvolution &SE, +                       SmallVectorImpl<BasicBlock *> &ExitingBlocks, +                       BasicBlock *ExitBlock, bool &Changed, +                       BasicBlock *Preheader) {    // Make sure that all PHI entries coming from the loop are loop invariant.    // Because the code is in LCSSA form, any values used outside of the loop    // must pass through a PHI in the exit block, meaning that this check is    // sufficient to guarantee that no loop-variant values are used outside    // of the loop. -  BasicBlock::iterator BI = exitBlock->begin(); +  BasicBlock::iterator BI = ExitBlock->begin();    bool AllEntriesInvariant = true;    bool AllOutgoingValuesSame = true;    while (PHINode *P = dyn_cast<PHINode>(BI)) { -    Value *incoming = P->getIncomingValueForBlock(exitingBlocks[0]); +    Value *incoming = P->getIncomingValueForBlock(ExitingBlocks[0]);      // Make sure all exiting blocks produce the same incoming value for the exit      // block.  If there are different incoming values for different exiting      // blocks, then it is impossible to statically determine which value should      // be used.      AllOutgoingValuesSame = -        all_of(makeArrayRef(exitingBlocks).slice(1), [&](BasicBlock *BB) { +        all_of(makeArrayRef(ExitingBlocks).slice(1), [&](BasicBlock *BB) {            return incoming == P->getIncomingValueForBlock(BB);          }); @@ -78,33 +77,37 @@ bool LoopDeletionPass::isLoopDead(Loop *L, ScalarEvolution &SE,    // Make sure that no instructions in the block have potential side-effects.    // This includes instructions that could write to memory, and loads that are -  // marked volatile.  This could be made more aggressive by using aliasing -  // information to identify readonly and readnone calls. -  for (Loop::block_iterator LI = L->block_begin(), LE = L->block_end(); -       LI != LE; ++LI) { -    for (Instruction &I : **LI) { -      if (I.mayHaveSideEffects()) -        return false; -    } -  } - +  // marked volatile. +  for (auto &I : L->blocks()) +    if (any_of(*I, [](Instruction &I) { return I.mayHaveSideEffects(); })) +      return false;    return true;  } -/// Remove dead loops, by which we mean loops that do not impact the observable -/// behavior of the program other than finite running time.  Note we do ensure -/// that this never remove a loop that might be infinite, as doing so could -/// change the halting/non-halting nature of a program. NOTE: This entire -/// process relies pretty heavily on LoopSimplify and LCSSA in order to make -/// various safety checks work. -bool LoopDeletionPass::runImpl(Loop *L, DominatorTree &DT, ScalarEvolution &SE, -                               LoopInfo &loopInfo) { +/// Remove a loop if it is dead. +/// +/// A loop is considered dead if it does not impact the observable behavior of +/// the program other than finite running time. This never removes a loop that +/// might be infinite, as doing so could change the halting/non-halting nature +/// of a program. +/// +/// This entire process relies pretty heavily on LoopSimplify form and LCSSA in +/// order to make various safety checks work. +/// +/// \returns true if any changes were made. This may mutate the loop even if it +/// is unable to delete it due to hoisting trivially loop invariant +/// instructions out of the loop. +/// +/// This also updates the relevant analysis information in \p DT, \p SE, and \p +/// LI. It also updates the loop PM if an updater struct is provided. +static bool deleteLoopIfDead(Loop *L, DominatorTree &DT, ScalarEvolution &SE, +                             LoopInfo &LI, LPMUpdater *Updater = nullptr) {    assert(L->isLCSSAForm(DT) && "Expected LCSSA!");    // We can only remove the loop if there is a preheader that we can    // branch from after removing it. -  BasicBlock *preheader = L->getLoopPreheader(); -  if (!preheader) +  BasicBlock *Preheader = L->getLoopPreheader(); +  if (!Preheader)      return false;    // If LoopSimplify form is not available, stay out of trouble. @@ -116,22 +119,20 @@ bool LoopDeletionPass::runImpl(Loop *L, DominatorTree &DT, ScalarEvolution &SE,    if (L->begin() != L->end())      return false; -  SmallVector<BasicBlock *, 4> exitingBlocks; -  L->getExitingBlocks(exitingBlocks); - -  SmallVector<BasicBlock *, 4> exitBlocks; -  L->getUniqueExitBlocks(exitBlocks); +  SmallVector<BasicBlock *, 4> ExitingBlocks; +  L->getExitingBlocks(ExitingBlocks);    // We require that the loop only have a single exit block.  Otherwise, we'd    // be in the situation of needing to be able to solve statically which exit    // block will be branched to, or trying to preserve the branching logic in    // a loop invariant manner. -  if (exitBlocks.size() != 1) +  BasicBlock *ExitBlock = L->getUniqueExitBlock(); +  if (!ExitBlock)      return false;    // Finally, we have to check that the loop really is dead.    bool Changed = false; -  if (!isLoopDead(L, SE, exitingBlocks, exitBlocks, Changed, preheader)) +  if (!isLoopDead(L, SE, ExitingBlocks, ExitBlock, Changed, Preheader))      return Changed;    // Don't remove loops for which we can't solve the trip count. @@ -142,11 +143,13 @@ bool LoopDeletionPass::runImpl(Loop *L, DominatorTree &DT, ScalarEvolution &SE,    // Now that we know the removal is safe, remove the loop by changing the    // branch from the preheader to go to the single exit block. -  BasicBlock *exitBlock = exitBlocks[0]; - +  //    // Because we're deleting a large chunk of code at once, the sequence in which -  // we remove things is very important to avoid invalidation issues.  Don't -  // mess with this unless you have good reason and know what you're doing. +  // we remove things is very important to avoid invalidation issues. + +  // If we have an LPM updater, tell it about the loop being removed. +  if (Updater) +    Updater->markLoopAsDeleted(*L);    // Tell ScalarEvolution that the loop is deleted. Do this before    // deleting the loop so that ScalarEvolution can look at the loop @@ -154,19 +157,19 @@ bool LoopDeletionPass::runImpl(Loop *L, DominatorTree &DT, ScalarEvolution &SE,    SE.forgetLoop(L);    // Connect the preheader directly to the exit block. -  TerminatorInst *TI = preheader->getTerminator(); -  TI->replaceUsesOfWith(L->getHeader(), exitBlock); +  TerminatorInst *TI = Preheader->getTerminator(); +  TI->replaceUsesOfWith(L->getHeader(), ExitBlock);    // Rewrite phis in the exit block to get their inputs from    // the preheader instead of the exiting block. -  BasicBlock *exitingBlock = exitingBlocks[0]; -  BasicBlock::iterator BI = exitBlock->begin(); +  BasicBlock *ExitingBlock = ExitingBlocks[0]; +  BasicBlock::iterator BI = ExitBlock->begin();    while (PHINode *P = dyn_cast<PHINode>(BI)) { -    int j = P->getBasicBlockIndex(exitingBlock); +    int j = P->getBasicBlockIndex(ExitingBlock);      assert(j >= 0 && "Can't find exiting block in exit block's phi node!"); -    P->setIncomingBlock(j, preheader); -    for (unsigned i = 1; i < exitingBlocks.size(); ++i) -      P->removeIncomingValue(exitingBlocks[i]); +    P->setIncomingBlock(j, Preheader); +    for (unsigned i = 1; i < ExitingBlocks.size(); ++i) +      P->removeIncomingValue(ExitingBlocks[i]);      ++BI;    } @@ -175,11 +178,11 @@ bool LoopDeletionPass::runImpl(Loop *L, DominatorTree &DT, ScalarEvolution &SE,    SmallVector<DomTreeNode*, 8> ChildNodes;    for (Loop::block_iterator LI = L->block_begin(), LE = L->block_end();         LI != LE; ++LI) { -    // Move all of the block's children to be children of the preheader, which +    // Move all of the block's children to be children of the Preheader, which      // allows us to remove the domtree entry for the block.      ChildNodes.insert(ChildNodes.begin(), DT[*LI]->begin(), DT[*LI]->end());      for (DomTreeNode *ChildNode : ChildNodes) { -      DT.changeImmediateDominator(ChildNode, DT[preheader]); +      DT.changeImmediateDominator(ChildNode, DT[Preheader]);      }      ChildNodes.clear(); @@ -204,22 +207,19 @@ bool LoopDeletionPass::runImpl(Loop *L, DominatorTree &DT, ScalarEvolution &SE,    SmallPtrSet<BasicBlock *, 8> blocks;    blocks.insert(L->block_begin(), L->block_end());    for (BasicBlock *BB : blocks) -    loopInfo.removeBlock(BB); +    LI.removeBlock(BB);    // The last step is to update LoopInfo now that we've eliminated this loop. -  loopInfo.markAsRemoved(L); -  Changed = true; - +  LI.markAsRemoved(L);    ++NumDeleted; -  return Changed; +  return true;  }  PreservedAnalyses LoopDeletionPass::run(Loop &L, LoopAnalysisManager &AM,                                          LoopStandardAnalysisResults &AR, -                                        LPMUpdater &) { -  bool Changed = runImpl(&L, AR.DT, AR.SE, AR.LI); -  if (!Changed) +                                        LPMUpdater &Updater) { +  if (!deleteLoopIfDead(&L, AR.DT, AR.SE, AR.LI, &Updater))      return PreservedAnalyses::all();    return getLoopPassPreservedAnalyses(); @@ -257,8 +257,7 @@ bool LoopDeletionLegacyPass::runOnLoop(Loop *L, LPPassManager &) {    DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();    ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); -  LoopInfo &loopInfo = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); +  LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); -  LoopDeletionPass Impl; -  return Impl.runImpl(L, DT, SE, loopInfo); +  return deleteLoopIfDead(L, DT, SE, LI);  } diff --git a/lib/Transforms/Scalar/LoopDistribute.cpp b/lib/Transforms/Scalar/LoopDistribute.cpp index 19716b28ad66..3624bba10345 100644 --- a/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/lib/Transforms/Scalar/LoopDistribute.cpp @@ -812,29 +812,29 @@ private:        const RuntimePointerChecking *RtPtrChecking) {      SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks; -    std::copy_if(AllChecks.begin(), AllChecks.end(), std::back_inserter(Checks), -                 [&](const RuntimePointerChecking::PointerCheck &Check) { -                   for (unsigned PtrIdx1 : Check.first->Members) -                     for (unsigned PtrIdx2 : Check.second->Members) -                       // Only include this check if there is a pair of pointers -                       // that require checking and the pointers fall into -                       // separate partitions. -                       // -                       // (Note that we already know at this point that the two -                       // pointer groups need checking but it doesn't follow -                       // that each pair of pointers within the two groups need -                       // checking as well. -                       // -                       // In other words we don't want to include a check just -                       // because there is a pair of pointers between the two -                       // pointer groups that require checks and a different -                       // pair whose pointers fall into different partitions.) -                       if (RtPtrChecking->needsChecking(PtrIdx1, PtrIdx2) && -                           !RuntimePointerChecking::arePointersInSamePartition( -                               PtrToPartition, PtrIdx1, PtrIdx2)) -                         return true; -                   return false; -                 }); +    copy_if(AllChecks, std::back_inserter(Checks), +            [&](const RuntimePointerChecking::PointerCheck &Check) { +              for (unsigned PtrIdx1 : Check.first->Members) +                for (unsigned PtrIdx2 : Check.second->Members) +                  // Only include this check if there is a pair of pointers +                  // that require checking and the pointers fall into +                  // separate partitions. +                  // +                  // (Note that we already know at this point that the two +                  // pointer groups need checking but it doesn't follow +                  // that each pair of pointers within the two groups need +                  // checking as well. +                  // +                  // In other words we don't want to include a check just +                  // because there is a pair of pointers between the two +                  // pointer groups that require checks and a different +                  // pair whose pointers fall into different partitions.) +                  if (RtPtrChecking->needsChecking(PtrIdx1, PtrIdx2) && +                      !RuntimePointerChecking::arePointersInSamePartition( +                          PtrToPartition, PtrIdx1, PtrIdx2)) +                    return true; +              return false; +            });      return Checks;    } diff --git a/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 5fec51c095d0..946d85d7360f 100644 --- a/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -236,9 +236,9 @@ bool LoopIdiomRecognize::runOnLoop(Loop *L) {    ApplyCodeSizeHeuristics =        L->getHeader()->getParent()->optForSize() && UseLIRCodeSizeHeurs; -  HasMemset = TLI->has(LibFunc::memset); -  HasMemsetPattern = TLI->has(LibFunc::memset_pattern16); -  HasMemcpy = TLI->has(LibFunc::memcpy); +  HasMemset = TLI->has(LibFunc_memset); +  HasMemsetPattern = TLI->has(LibFunc_memset_pattern16); +  HasMemcpy = TLI->has(LibFunc_memcpy);    if (HasMemset || HasMemsetPattern || HasMemcpy)      if (SE->hasLoopInvariantBackedgeTakenCount(L)) @@ -823,7 +823,7 @@ bool LoopIdiomRecognize::processLoopStridedStore(      Module *M = TheStore->getModule();      Value *MSP =          M->getOrInsertFunction("memset_pattern16", Builder.getVoidTy(), -                               Int8PtrTy, Int8PtrTy, IntPtr, (void *)nullptr); +                               Int8PtrTy, Int8PtrTy, IntPtr);      inferLibFuncAttributes(*M->getFunction("memset_pattern16"), *TLI);      // Otherwise we should form a memset_pattern16.  PatternValue is known to be diff --git a/lib/Transforms/Scalar/LoopInstSimplify.cpp b/lib/Transforms/Scalar/LoopInstSimplify.cpp index 69102d10ff60..28e71ca05436 100644 --- a/lib/Transforms/Scalar/LoopInstSimplify.cpp +++ b/lib/Transforms/Scalar/LoopInstSimplify.cpp @@ -189,7 +189,9 @@ PreservedAnalyses LoopInstSimplifyPass::run(Loop &L, LoopAnalysisManager &AM,    if (!SimplifyLoopInst(&L, &AR.DT, &AR.LI, &AR.AC, &AR.TLI))      return PreservedAnalyses::all(); -  return getLoopPassPreservedAnalyses(); +  auto PA = getLoopPassPreservedAnalyses(); +  PA.preserveSet<CFGAnalyses>(); +  return PA;  }  char LoopInstSimplifyLegacyPass::ID = 0; diff --git a/lib/Transforms/Scalar/LoopInterchange.cpp b/lib/Transforms/Scalar/LoopInterchange.cpp index e9f84edd1cbf..9f3875a3027f 100644 --- a/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/lib/Transforms/Scalar/LoopInterchange.cpp @@ -39,7 +39,7 @@  #include "llvm/Transforms/Scalar.h"  #include "llvm/Transforms/Utils/BasicBlockUtils.h"  #include "llvm/Transforms/Utils/LoopUtils.h" -#include "llvm/Transforms/Utils/SSAUpdater.h" +  using namespace llvm;  #define DEBUG_TYPE "loop-interchange" diff --git a/lib/Transforms/Scalar/LoopLoadElimination.cpp b/lib/Transforms/Scalar/LoopLoadElimination.cpp index 8fb580183e30..cf63cb660db8 100644 --- a/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -20,13 +20,14 @@  //  //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Scalar/LoopLoadElimination.h"  #include "llvm/ADT/APInt.h"  #include "llvm/ADT/DenseMap.h"  #include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/STLExtras.h"  #include "llvm/ADT/SmallSet.h"  #include "llvm/ADT/SmallVector.h"  #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/STLExtras.h"  #include "llvm/Analysis/GlobalsModRef.h"  #include "llvm/Analysis/LoopAccessAnalysis.h"  #include "llvm/Analysis/LoopInfo.h" @@ -45,9 +46,9 @@  #include "llvm/Support/Debug.h"  #include "llvm/Transforms/Scalar.h"  #include "llvm/Transforms/Utils/LoopVersioning.h" -#include <forward_list> -#include <cassert>  #include <algorithm> +#include <cassert> +#include <forward_list>  #include <set>  #include <tuple>  #include <utility> @@ -373,15 +374,15 @@ public:      const auto &AllChecks = LAI.getRuntimePointerChecking()->getChecks();      SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks; -    std::copy_if(AllChecks.begin(), AllChecks.end(), std::back_inserter(Checks), -                 [&](const RuntimePointerChecking::PointerCheck &Check) { -                   for (auto PtrIdx1 : Check.first->Members) -                     for (auto PtrIdx2 : Check.second->Members) -                       if (needsChecking(PtrIdx1, PtrIdx2, -                                         PtrsWrittenOnFwdingPath, CandLoadPtrs)) -                         return true; -                   return false; -                 }); +    copy_if(AllChecks, std::back_inserter(Checks), +            [&](const RuntimePointerChecking::PointerCheck &Check) { +              for (auto PtrIdx1 : Check.first->Members) +                for (auto PtrIdx2 : Check.second->Members) +                  if (needsChecking(PtrIdx1, PtrIdx2, PtrsWrittenOnFwdingPath, +                                    CandLoadPtrs)) +                    return true; +              return false; +            });      DEBUG(dbgs() << "\nPointer Checks (count: " << Checks.size() << "):\n");      DEBUG(LAI.getRuntimePointerChecking()->printChecks(dbgs(), Checks)); @@ -558,6 +559,32 @@ private:    PredicatedScalarEvolution PSE;  }; +static bool +eliminateLoadsAcrossLoops(Function &F, LoopInfo &LI, DominatorTree &DT, +                          function_ref<const LoopAccessInfo &(Loop &)> GetLAI) { +  // Build up a worklist of inner-loops to transform to avoid iterator +  // invalidation. +  // FIXME: This logic comes from other passes that actually change the loop +  // nest structure. It isn't clear this is necessary (or useful) for a pass +  // which merely optimizes the use of loads in a loop. +  SmallVector<Loop *, 8> Worklist; + +  for (Loop *TopLevelLoop : LI) +    for (Loop *L : depth_first(TopLevelLoop)) +      // We only handle inner-most loops. +      if (L->empty()) +        Worklist.push_back(L); + +  // Now walk the identified inner loops. +  bool Changed = false; +  for (Loop *L : Worklist) { +    // The actual work is performed by LoadEliminationForLoop. +    LoadEliminationForLoop LEL(L, &LI, GetLAI(*L), &DT); +    Changed |= LEL.processLoop(); +  } +  return Changed; +} +  /// \brief The pass.  Most of the work is delegated to the per-loop  /// LoadEliminationForLoop class.  class LoopLoadElimination : public FunctionPass { @@ -570,32 +597,14 @@ public:      if (skipFunction(F))        return false; -    auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); -    auto *LAA = &getAnalysis<LoopAccessLegacyAnalysis>(); -    auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - -    // Build up a worklist of inner-loops to vectorize. This is necessary as the -    // act of distributing a loop creates new loops and can invalidate iterators -    // across the loops. -    SmallVector<Loop *, 8> Worklist; - -    for (Loop *TopLevelLoop : *LI) -      for (Loop *L : depth_first(TopLevelLoop)) -        // We only handle inner-most loops. -        if (L->empty()) -          Worklist.push_back(L); - -    // Now walk the identified inner loops. -    bool Changed = false; -    for (Loop *L : Worklist) { -      const LoopAccessInfo &LAI = LAA->getInfo(L); -      // The actual work is performed by LoadEliminationForLoop. -      LoadEliminationForLoop LEL(L, LI, LAI, DT); -      Changed |= LEL.processLoop(); -    } +    auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); +    auto &LAA = getAnalysis<LoopAccessLegacyAnalysis>(); +    auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();      // Process each loop nest in the function. -    return Changed; +    return eliminateLoadsAcrossLoops( +        F, LI, DT, +        [&LAA](Loop &L) -> const LoopAccessInfo & { return LAA.getInfo(&L); });    }    void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -631,4 +640,28 @@ FunctionPass *createLoopLoadEliminationPass() {    return new LoopLoadElimination();  } +PreservedAnalyses LoopLoadEliminationPass::run(Function &F, +                                               FunctionAnalysisManager &AM) { +  auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); +  auto &LI = AM.getResult<LoopAnalysis>(F); +  auto &TTI = AM.getResult<TargetIRAnalysis>(F); +  auto &DT = AM.getResult<DominatorTreeAnalysis>(F); +  auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); +  auto &AA = AM.getResult<AAManager>(F); +  auto &AC = AM.getResult<AssumptionAnalysis>(F); + +  auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager(); +  bool Changed = eliminateLoadsAcrossLoops( +      F, LI, DT, [&](Loop &L) -> const LoopAccessInfo & { +        LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, TLI, TTI}; +        return LAM.getResult<LoopAccessAnalysis>(L, AR); +      }); + +  if (!Changed) +    return PreservedAnalyses::all(); + +  PreservedAnalyses PA; +  return PA; +} +  } // end namespace llvm diff --git a/lib/Transforms/Scalar/LoopPassManager.cpp b/lib/Transforms/Scalar/LoopPassManager.cpp index 028f4bba8b1d..10f6fcdcfdb7 100644 --- a/lib/Transforms/Scalar/LoopPassManager.cpp +++ b/lib/Transforms/Scalar/LoopPassManager.cpp @@ -42,6 +42,13 @@ PassManager<Loop, LoopAnalysisManager, LoopStandardAnalysisResults &,        break;      } +#ifndef NDEBUG +    // Verify the loop structure and LCSSA form before visiting the loop. +    L.verifyLoop(); +    assert(L.isRecursivelyLCSSAForm(AR.DT, AR.LI) && +           "Loops must remain in LCSSA form!"); +#endif +      // Update the analysis manager as each pass runs and potentially      // invalidates analyses.      AM.invalidate(L, PassPA); diff --git a/lib/Transforms/Scalar/LoopPredication.cpp b/lib/Transforms/Scalar/LoopPredication.cpp new file mode 100644 index 000000000000..0ce604429326 --- /dev/null +++ b/lib/Transforms/Scalar/LoopPredication.cpp @@ -0,0 +1,282 @@ +//===-- LoopPredication.cpp - Guard based loop predication pass -----------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// The LoopPredication pass tries to convert loop variant range checks to loop +// invariant by widening checks across loop iterations. For example, it will +// convert +// +//   for (i = 0; i < n; i++) { +//     guard(i < len); +//     ... +//   } +// +// to +// +//   for (i = 0; i < n; i++) { +//     guard(n - 1 < len); +//     ... +//   } +// +// After this transformation the condition of the guard is loop invariant, so +// loop-unswitch can later unswitch the loop by this condition which basically +// predicates the loop by the widened condition: +// +//   if (n - 1 < len) +//     for (i = 0; i < n; i++) { +//       ... +//     } +//   else +//     deoptimize +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LoopPredication.h" +#include "llvm/Pass.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/LoopUtils.h" + +#define DEBUG_TYPE "loop-predication" + +using namespace llvm; + +namespace { +class LoopPredication { +  ScalarEvolution *SE; + +  Loop *L; +  const DataLayout *DL; +  BasicBlock *Preheader; + +  Optional<Value *> widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander, +                                        IRBuilder<> &Builder); +  bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander); + +public: +  LoopPredication(ScalarEvolution *SE) : SE(SE){}; +  bool runOnLoop(Loop *L); +}; + +class LoopPredicationLegacyPass : public LoopPass { +public: +  static char ID; +  LoopPredicationLegacyPass() : LoopPass(ID) { +    initializeLoopPredicationLegacyPassPass(*PassRegistry::getPassRegistry()); +  } + +  void getAnalysisUsage(AnalysisUsage &AU) const override { +    getLoopAnalysisUsage(AU); +  } + +  bool runOnLoop(Loop *L, LPPassManager &LPM) override { +    if (skipLoop(L)) +      return false; +    auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); +    LoopPredication LP(SE); +    return LP.runOnLoop(L); +  } +}; + +char LoopPredicationLegacyPass::ID = 0; +} // end namespace llvm + +INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication", +                      "Loop predication", false, false) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_END(LoopPredicationLegacyPass, "loop-predication", +                    "Loop predication", false, false) + +Pass *llvm::createLoopPredicationPass() { +  return new LoopPredicationLegacyPass(); +} + +PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM, +                                           LoopStandardAnalysisResults &AR, +                                           LPMUpdater &U) { +  LoopPredication LP(&AR.SE); +  if (!LP.runOnLoop(&L)) +    return PreservedAnalyses::all(); + +  return getLoopPassPreservedAnalyses(); +} + +/// If ICI can be widened to a loop invariant condition emits the loop +/// invariant condition in the loop preheader and return it, otherwise +/// returns None. +Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, +                                                       SCEVExpander &Expander, +                                                       IRBuilder<> &Builder) { +  DEBUG(dbgs() << "Analyzing ICmpInst condition:\n"); +  DEBUG(ICI->dump()); + +  ICmpInst::Predicate Pred = ICI->getPredicate(); +  Value *LHS = ICI->getOperand(0); +  Value *RHS = ICI->getOperand(1); +  const SCEV *LHSS = SE->getSCEV(LHS); +  if (isa<SCEVCouldNotCompute>(LHSS)) +    return None; +  const SCEV *RHSS = SE->getSCEV(RHS); +  if (isa<SCEVCouldNotCompute>(RHSS)) +    return None; + +  // Canonicalize RHS to be loop invariant bound, LHS - a loop computable index +  if (SE->isLoopInvariant(LHSS, L)) { +    std::swap(LHS, RHS); +    std::swap(LHSS, RHSS); +    Pred = ICmpInst::getSwappedPredicate(Pred); +  } +  if (!SE->isLoopInvariant(RHSS, L) || !isSafeToExpand(RHSS, *SE)) +    return None; + +  const SCEVAddRecExpr *IndexAR = dyn_cast<SCEVAddRecExpr>(LHSS); +  if (!IndexAR || IndexAR->getLoop() != L) +    return None; + +  DEBUG(dbgs() << "IndexAR: "); +  DEBUG(IndexAR->dump()); + +  bool IsIncreasing = false; +  if (!SE->isMonotonicPredicate(IndexAR, Pred, IsIncreasing)) +    return None; + +  // If the predicate is increasing the condition can change from false to true +  // as the loop progresses, in this case take the value on the first iteration +  // for the widened check. Otherwise the condition can change from true to +  // false as the loop progresses, so take the value on the last iteration. +  const SCEV *NewLHSS = IsIncreasing +                            ? IndexAR->getStart() +                            : SE->getSCEVAtScope(IndexAR, L->getParentLoop()); +  if (NewLHSS == IndexAR) { +    DEBUG(dbgs() << "Can't compute NewLHSS!\n"); +    return None; +  } + +  DEBUG(dbgs() << "NewLHSS: "); +  DEBUG(NewLHSS->dump()); + +  if (!SE->isLoopInvariant(NewLHSS, L) || !isSafeToExpand(NewLHSS, *SE)) +    return None; + +  DEBUG(dbgs() << "NewLHSS is loop invariant and safe to expand. Expand!\n"); + +  Type *Ty = LHS->getType(); +  Instruction *InsertAt = Preheader->getTerminator(); +  assert(Ty == RHS->getType() && "icmp operands have different types?"); +  Value *NewLHS = Expander.expandCodeFor(NewLHSS, Ty, InsertAt); +  Value *NewRHS = Expander.expandCodeFor(RHSS, Ty, InsertAt); +  return Builder.CreateICmp(Pred, NewLHS, NewRHS); +} + +bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, +                                           SCEVExpander &Expander) { +  DEBUG(dbgs() << "Processing guard:\n"); +  DEBUG(Guard->dump()); + +  IRBuilder<> Builder(cast<Instruction>(Preheader->getTerminator())); + +  // The guard condition is expected to be in form of: +  //   cond1 && cond2 && cond3 ... +  // Iterate over subconditions looking for for icmp conditions which can be +  // widened across loop iterations. Widening these conditions remember the +  // resulting list of subconditions in Checks vector. +  SmallVector<Value *, 4> Worklist(1, Guard->getOperand(0)); +  SmallPtrSet<Value *, 4> Visited; + +  SmallVector<Value *, 4> Checks; + +  unsigned NumWidened = 0; +  do { +    Value *Condition = Worklist.pop_back_val(); +    if (!Visited.insert(Condition).second) +      continue; + +    Value *LHS, *RHS; +    using namespace llvm::PatternMatch; +    if (match(Condition, m_And(m_Value(LHS), m_Value(RHS)))) { +      Worklist.push_back(LHS); +      Worklist.push_back(RHS); +      continue; +    } + +    if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) { +      if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, Builder)) { +        Checks.push_back(NewRangeCheck.getValue()); +        NumWidened++; +        continue; +      } +    } + +    // Save the condition as is if we can't widen it +    Checks.push_back(Condition); +  } while (Worklist.size() != 0); + +  if (NumWidened == 0) +    return false; + +  // Emit the new guard condition +  Builder.SetInsertPoint(Guard); +  Value *LastCheck = nullptr; +  for (auto *Check : Checks) +    if (!LastCheck) +      LastCheck = Check; +    else +      LastCheck = Builder.CreateAnd(LastCheck, Check); +  Guard->setOperand(0, LastCheck); + +  DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n"); +  return true; +} + +bool LoopPredication::runOnLoop(Loop *Loop) { +  L = Loop; + +  DEBUG(dbgs() << "Analyzing "); +  DEBUG(L->dump()); + +  Module *M = L->getHeader()->getModule(); + +  // There is nothing to do if the module doesn't use guards +  auto *GuardDecl = +      M->getFunction(Intrinsic::getName(Intrinsic::experimental_guard)); +  if (!GuardDecl || GuardDecl->use_empty()) +    return false; + +  DL = &M->getDataLayout(); + +  Preheader = L->getLoopPreheader(); +  if (!Preheader) +    return false; + +  // Collect all the guards into a vector and process later, so as not +  // to invalidate the instruction iterator. +  SmallVector<IntrinsicInst *, 4> Guards; +  for (const auto BB : L->blocks()) +    for (auto &I : *BB) +      if (auto *II = dyn_cast<IntrinsicInst>(&I)) +        if (II->getIntrinsicID() == Intrinsic::experimental_guard) +          Guards.push_back(II); + +  SCEVExpander Expander(*SE, *DL, "loop-predication"); + +  bool Changed = false; +  for (auto *Guard : Guards) +    Changed |= widenGuardConditions(Guard, Expander); + +  return Changed; +} diff --git a/lib/Transforms/Scalar/LoopRotation.cpp b/lib/Transforms/Scalar/LoopRotation.cpp index cc83069d5f52..e5689368de80 100644 --- a/lib/Transforms/Scalar/LoopRotation.cpp +++ b/lib/Transforms/Scalar/LoopRotation.cpp @@ -79,7 +79,8 @@ private:  /// to merge the two values.  Do this now.  static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader,                                              BasicBlock *OrigPreheader, -                                            ValueToValueMapTy &ValueMap) { +                                            ValueToValueMapTy &ValueMap, +                                SmallVectorImpl<PHINode*> *InsertedPHIs) {    // Remove PHI node entries that are no longer live.    BasicBlock::iterator I, E = OrigHeader->end();    for (I = OrigHeader->begin(); PHINode *PN = dyn_cast<PHINode>(I); ++I) @@ -87,7 +88,7 @@ static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader,    // Now fix up users of the instructions in OrigHeader, inserting PHI nodes    // as necessary. -  SSAUpdater SSA; +  SSAUpdater SSA(InsertedPHIs);    for (I = OrigHeader->begin(); I != E; ++I) {      Value *OrigHeaderVal = &*I; @@ -174,6 +175,38 @@ static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader,    }  } +/// Propagate dbg.value intrinsics through the newly inserted Phis. +static void insertDebugValues(BasicBlock *OrigHeader, +                              SmallVectorImpl<PHINode*> &InsertedPHIs) { +  ValueToValueMapTy DbgValueMap; + +  // Map existing PHI nodes to their dbg.values. +  for (auto &I : *OrigHeader) { +    if (auto DbgII = dyn_cast<DbgInfoIntrinsic>(&I)) { +      if (auto *Loc = dyn_cast_or_null<PHINode>(DbgII->getVariableLocation())) +        DbgValueMap.insert({Loc, DbgII}); +    } +  } + +  // Then iterate through the new PHIs and look to see if they use one of the +  // previously mapped PHIs. If so, insert a new dbg.value intrinsic that will +  // propagate the info through the new PHI. +  LLVMContext &C = OrigHeader->getContext(); +  for (auto PHI : InsertedPHIs) { +    for (auto VI : PHI->operand_values()) { +      auto V = DbgValueMap.find(VI); +      if (V != DbgValueMap.end()) { +        auto *DbgII = cast<DbgInfoIntrinsic>(V->second); +        Instruction *NewDbgII = DbgII->clone(); +        auto PhiMAV = MetadataAsValue::get(C, ValueAsMetadata::get(PHI)); +        NewDbgII->setOperand(0, PhiMAV); +        BasicBlock *Parent = PHI->getParent(); +        NewDbgII->insertBefore(Parent->getFirstNonPHIOrDbgOrLifetime()); +      } +    } +  } +} +  /// Rotate loop LP. Return true if the loop is rotated.  ///  /// \param SimplifiedLatch is true if the latch was just folded into the final @@ -347,9 +380,18 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {    // remove the corresponding incoming values from the PHI nodes in OrigHeader.    LoopEntryBranch->eraseFromParent(); + +  SmallVector<PHINode*, 2> InsertedPHIs;    // If there were any uses of instructions in the duplicated block outside the    // loop, update them, inserting PHI nodes as required -  RewriteUsesOfClonedInstructions(OrigHeader, OrigPreheader, ValueMap); +  RewriteUsesOfClonedInstructions(OrigHeader, OrigPreheader, ValueMap, +                                  &InsertedPHIs); + +  // Attach dbg.value intrinsics to the new phis if that phi uses a value that +  // previously had debug metadata attached. This keeps the debug info +  // up-to-date in the loop body. +  if (!InsertedPHIs.empty()) +    insertDebugValues(OrigHeader, InsertedPHIs);    // NewHeader is now the header of the loop.    L->moveToHeader(NewHeader); @@ -634,6 +676,7 @@ PreservedAnalyses LoopRotatePass::run(Loop &L, LoopAnalysisManager &AM,    bool Changed = LR.processLoop(&L);    if (!Changed)      return PreservedAnalyses::all(); +    return getLoopPassPreservedAnalyses();  } diff --git a/lib/Transforms/Scalar/LoopSimplifyCFG.cpp b/lib/Transforms/Scalar/LoopSimplifyCFG.cpp index 16061212ba38..a5a81c33a8eb 100644 --- a/lib/Transforms/Scalar/LoopSimplifyCFG.cpp +++ b/lib/Transforms/Scalar/LoopSimplifyCFG.cpp @@ -69,6 +69,7 @@ PreservedAnalyses LoopSimplifyCFGPass::run(Loop &L, LoopAnalysisManager &AM,                                             LPMUpdater &) {    if (!simplifyLoopCFG(L, AR.DT, AR.LI))      return PreservedAnalyses::all(); +    return getLoopPassPreservedAnalyses();  } diff --git a/lib/Transforms/Scalar/LoopSink.cpp b/lib/Transforms/Scalar/LoopSink.cpp index f3f415275c0e..c9d55b4594fe 100644 --- a/lib/Transforms/Scalar/LoopSink.cpp +++ b/lib/Transforms/Scalar/LoopSink.cpp @@ -1,4 +1,4 @@ -//===-- LoopSink.cpp - Loop Sink Pass ------------------------===// +//===-- LoopSink.cpp - Loop Sink Pass -------------------------------------===//  //  //                     The LLVM Compiler Infrastructure  // @@ -28,8 +28,10 @@  //       InsertBBs = UseBBs - DomBBs + BB  //   For BB in InsertBBs:  //     Insert I at BB's beginning +//  //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Scalar/LoopSink.h"  #include "llvm/ADT/Statistic.h"  #include "llvm/Analysis/AliasAnalysis.h"  #include "llvm/Analysis/AliasSetTracker.h" @@ -297,6 +299,42 @@ static bool sinkLoopInvariantInstructions(Loop &L, AAResults &AA, LoopInfo &LI,    return Changed;  } +PreservedAnalyses LoopSinkPass::run(Function &F, FunctionAnalysisManager &FAM) { +  LoopInfo &LI = FAM.getResult<LoopAnalysis>(F); +  // Nothing to do if there are no loops. +  if (LI.empty()) +    return PreservedAnalyses::all(); + +  AAResults &AA = FAM.getResult<AAManager>(F); +  DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); +  BlockFrequencyInfo &BFI = FAM.getResult<BlockFrequencyAnalysis>(F); + +  // We want to do a postorder walk over the loops. Since loops are a tree this +  // is equivalent to a reversed preorder walk and preorder is easy to compute +  // without recursion. Since we reverse the preorder, we will visit siblings +  // in reverse program order. This isn't expected to matter at all but is more +  // consistent with sinking algorithms which generally work bottom-up. +  SmallVector<Loop *, 4> PreorderLoops = LI.getLoopsInPreorder(); + +  bool Changed = false; +  do { +    Loop &L = *PreorderLoops.pop_back_val(); + +    // Note that we don't pass SCEV here because it is only used to invalidate +    // loops in SCEV and we don't preserve (or request) SCEV at all making that +    // unnecessary. +    Changed |= sinkLoopInvariantInstructions(L, AA, LI, DT, BFI, +                                             /*ScalarEvolution*/ nullptr); +  } while (!PreorderLoops.empty()); + +  if (!Changed) +    return PreservedAnalyses::all(); + +  PreservedAnalyses PA; +  PA.preserveSet<CFGAnalyses>(); +  return PA; +} +  namespace {  struct LegacyLoopSinkPass : public LoopPass {    static char ID; diff --git a/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 194587a85e7c..af137f6faa63 100644 --- a/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -129,6 +129,17 @@ static cl::opt<bool> EnablePhiElim(    "enable-lsr-phielim", cl::Hidden, cl::init(true),    cl::desc("Enable LSR phi elimination")); +// The flag adds instruction count to solutions cost comparision. +static cl::opt<bool> InsnsCost( +  "lsr-insns-cost", cl::Hidden, cl::init(false), +  cl::desc("Add instruction count to a LSR cost model")); + +// Flag to choose how to narrow complex lsr solution +static cl::opt<bool> LSRExpNarrow( +  "lsr-exp-narrow", cl::Hidden, cl::init(false), +  cl::desc("Narrow LSR complex solution using" +           " expectation of registers number")); +  #ifndef NDEBUG  // Stress test IV chain generation.  static cl::opt<bool> StressIVChain( @@ -181,10 +192,11 @@ void RegSortData::print(raw_ostream &OS) const {    OS << "[NumUses=" << UsedByIndices.count() << ']';  } -LLVM_DUMP_METHOD -void RegSortData::dump() const { +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void RegSortData::dump() const {    print(errs()); errs() << '\n';  } +#endif  namespace { @@ -295,9 +307,13 @@ struct Formula {    /// canonical representation of a formula is    /// 1. BaseRegs.size > 1 implies ScaledReg != NULL and    /// 2. ScaledReg != NULL implies Scale != 1 || !BaseRegs.empty(). +  /// 3. The reg containing recurrent expr related with currect loop in the +  /// formula should be put in the ScaledReg.    /// #1 enforces that the scaled register is always used when at least two    /// registers are needed by the formula: e.g., reg1 + reg2 is reg1 + 1 * reg2.    /// #2 enforces that 1 * reg is reg. +  /// #3 ensures invariant regs with respect to current loop can be combined +  /// together in LSR codegen.    /// This invariant can be temporarly broken while building a formula.    /// However, every formula inserted into the LSRInstance must be in canonical    /// form. @@ -318,12 +334,14 @@ struct Formula {    void initialMatch(const SCEV *S, Loop *L, ScalarEvolution &SE); -  bool isCanonical() const; +  bool isCanonical(const Loop &L) const; -  void canonicalize(); +  void canonicalize(const Loop &L);    bool unscale(); +  bool hasZeroEnd() const; +    size_t getNumRegs() const;    Type *getType() const; @@ -410,16 +428,35 @@ void Formula::initialMatch(const SCEV *S, Loop *L, ScalarEvolution &SE) {        BaseRegs.push_back(Sum);      HasBaseReg = true;    } -  canonicalize(); +  canonicalize(*L);  }  /// \brief Check whether or not this formula statisfies the canonical  /// representation.  /// \see Formula::BaseRegs. -bool Formula::isCanonical() const { -  if (ScaledReg) -    return Scale != 1 || !BaseRegs.empty(); -  return BaseRegs.size() <= 1; +bool Formula::isCanonical(const Loop &L) const { +  if (!ScaledReg) +    return BaseRegs.size() <= 1; + +  if (Scale != 1) +    return true; + +  if (Scale == 1 && BaseRegs.empty()) +    return false; + +  const SCEVAddRecExpr *SAR = dyn_cast<const SCEVAddRecExpr>(ScaledReg); +  if (SAR && SAR->getLoop() == &L) +    return true; + +  // If ScaledReg is not a recurrent expr, or it is but its loop is not current +  // loop, meanwhile BaseRegs contains a recurrent expr reg related with current +  // loop, we want to swap the reg in BaseRegs with ScaledReg. +  auto I = +      find_if(make_range(BaseRegs.begin(), BaseRegs.end()), [&](const SCEV *S) { +        return isa<const SCEVAddRecExpr>(S) && +               (cast<SCEVAddRecExpr>(S)->getLoop() == &L); +      }); +  return I == BaseRegs.end();  }  /// \brief Helper method to morph a formula into its canonical representation. @@ -428,21 +465,33 @@ bool Formula::isCanonical() const {  /// field. Otherwise, we would have to do special cases everywhere in LSR  /// to treat reg1 + reg2 + ... the same way as reg1 + 1*reg2 + ...  /// On the other hand, 1*reg should be canonicalized into reg. -void Formula::canonicalize() { -  if (isCanonical()) +void Formula::canonicalize(const Loop &L) { +  if (isCanonical(L))      return;    // So far we did not need this case. This is easy to implement but it is    // useless to maintain dead code. Beside it could hurt compile time.    assert(!BaseRegs.empty() && "1*reg => reg, should not be needed."); +    // Keep the invariant sum in BaseRegs and one of the variant sum in ScaledReg. -  ScaledReg = BaseRegs.back(); -  BaseRegs.pop_back(); -  Scale = 1; -  size_t BaseRegsSize = BaseRegs.size(); -  size_t Try = 0; -  // If ScaledReg is an invariant, try to find a variant expression. -  while (Try < BaseRegsSize && !isa<SCEVAddRecExpr>(ScaledReg)) -    std::swap(ScaledReg, BaseRegs[Try++]); +  if (!ScaledReg) { +    ScaledReg = BaseRegs.back(); +    BaseRegs.pop_back(); +    Scale = 1; +  } + +  // If ScaledReg is an invariant with respect to L, find the reg from +  // BaseRegs containing the recurrent expr related with Loop L. Swap the +  // reg with ScaledReg. +  const SCEVAddRecExpr *SAR = dyn_cast<const SCEVAddRecExpr>(ScaledReg); +  if (!SAR || SAR->getLoop() != &L) { +    auto I = find_if(make_range(BaseRegs.begin(), BaseRegs.end()), +                     [&](const SCEV *S) { +                       return isa<const SCEVAddRecExpr>(S) && +                              (cast<SCEVAddRecExpr>(S)->getLoop() == &L); +                     }); +    if (I != BaseRegs.end()) +      std::swap(ScaledReg, *I); +  }  }  /// \brief Get rid of the scale in the formula. @@ -458,6 +507,14 @@ bool Formula::unscale() {    return true;  } +bool Formula::hasZeroEnd() const { +  if (UnfoldedOffset || BaseOffset) +    return false; +  if (BaseRegs.size() != 1 || ScaledReg) +    return false; +  return true; +} +  /// Return the total number of register operands used by this formula. This does  /// not include register uses implied by non-constant addrec strides.  size_t Formula::getNumRegs() const { @@ -534,10 +591,11 @@ void Formula::print(raw_ostream &OS) const {    }  } -LLVM_DUMP_METHOD -void Formula::dump() const { +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void Formula::dump() const {    print(errs()); errs() << '\n';  } +#endif  /// Return true if the given addrec can be sign-extended without changing its  /// value. @@ -711,7 +769,7 @@ static GlobalValue *ExtractSymbol(const SCEV *&S, ScalarEvolution &SE) {  static bool isAddressUse(Instruction *Inst, Value *OperandVal) {    bool isAddress = isa<LoadInst>(Inst);    if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { -    if (SI->getOperand(1) == OperandVal) +    if (SI->getPointerOperand() == OperandVal)        isAddress = true;    } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) {      // Addressing modes can also be folded into prefetches and a variety @@ -723,6 +781,12 @@ static bool isAddressUse(Instruction *Inst, Value *OperandVal) {            isAddress = true;          break;      } +  } else if (AtomicRMWInst *RMW = dyn_cast<AtomicRMWInst>(Inst)) { +    if (RMW->getPointerOperand() == OperandVal) +      isAddress = true; +  } else if (AtomicCmpXchgInst *CmpX = dyn_cast<AtomicCmpXchgInst>(Inst)) { +    if (CmpX->getPointerOperand() == OperandVal) +      isAddress = true;    }    return isAddress;  } @@ -735,6 +799,10 @@ static MemAccessTy getAccessType(const Instruction *Inst) {      AccessTy.AddrSpace = SI->getPointerAddressSpace();    } else if (const LoadInst *LI = dyn_cast<LoadInst>(Inst)) {      AccessTy.AddrSpace = LI->getPointerAddressSpace(); +  } else if (const AtomicRMWInst *RMW = dyn_cast<AtomicRMWInst>(Inst)) { +    AccessTy.AddrSpace = RMW->getPointerAddressSpace(); +  } else if (const AtomicCmpXchgInst *CmpX = dyn_cast<AtomicCmpXchgInst>(Inst)) { +    AccessTy.AddrSpace = CmpX->getPointerAddressSpace();    }    // All pointers have the same requirements, so canonicalize them to an @@ -875,7 +943,8 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,                                   const LSRUse &LU, const Formula &F);  // Get the cost of the scaling factor used in F for LU.  static unsigned getScalingFactorCost(const TargetTransformInfo &TTI, -                                     const LSRUse &LU, const Formula &F); +                                     const LSRUse &LU, const Formula &F, +                                     const Loop &L);  namespace { @@ -883,6 +952,7 @@ namespace {  class Cost {    /// TODO: Some of these could be merged. Also, a lexical ordering    /// isn't always optimal. +  unsigned Insns;    unsigned NumRegs;    unsigned AddRecCost;    unsigned NumIVMuls; @@ -893,8 +963,8 @@ class Cost {  public:    Cost() -    : NumRegs(0), AddRecCost(0), NumIVMuls(0), NumBaseAdds(0), ImmCost(0), -      SetupCost(0), ScaleCost(0) {} +    : Insns(0), NumRegs(0), AddRecCost(0), NumIVMuls(0), NumBaseAdds(0), +      ImmCost(0), SetupCost(0), ScaleCost(0) {}    bool operator<(const Cost &Other) const; @@ -903,9 +973,9 @@ public:  #ifndef NDEBUG    // Once any of the metrics loses, they must all remain losers.    bool isValid() { -    return ((NumRegs | AddRecCost | NumIVMuls | NumBaseAdds +    return ((Insns | NumRegs | AddRecCost | NumIVMuls | NumBaseAdds               | ImmCost | SetupCost | ScaleCost) != ~0u) -      || ((NumRegs & AddRecCost & NumIVMuls & NumBaseAdds +      || ((Insns & NumRegs & AddRecCost & NumIVMuls & NumBaseAdds             & ImmCost & SetupCost & ScaleCost) == ~0u);    }  #endif @@ -1067,7 +1137,8 @@ public:    }    bool HasFormulaWithSameRegs(const Formula &F) const; -  bool InsertFormula(const Formula &F); +  float getNotSelectedProbability(const SCEV *Reg) const; +  bool InsertFormula(const Formula &F, const Loop &L);    void DeleteFormula(Formula &F);    void RecomputeRegs(size_t LUIdx, RegUseTracker &Reguses); @@ -1083,17 +1154,23 @@ void Cost::RateRegister(const SCEV *Reg,                          const Loop *L,                          ScalarEvolution &SE, DominatorTree &DT) {    if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Reg)) { -    // If this is an addrec for another loop, don't second-guess its addrec phi -    // nodes. LSR isn't currently smart enough to reason about more than one -    // loop at a time. LSR has already run on inner loops, will not run on outer -    // loops, and cannot be expected to change sibling loops. +    // If this is an addrec for another loop, it should be an invariant +    // with respect to L since L is the innermost loop (at least +    // for now LSR only handles innermost loops).      if (AR->getLoop() != L) {        // If the AddRec exists, consider it's register free and leave it alone.        if (isExistingPhi(AR, SE))          return; -      // Otherwise, do not consider this formula at all. -      Lose(); +      // It is bad to allow LSR for current loop to add induction variables +      // for its sibling loops. +      if (!AR->getLoop()->contains(L)) { +        Lose(); +        return; +      } + +      // Otherwise, it will be an invariant with respect to Loop L. +      ++NumRegs;        return;      }      AddRecCost += 1; /// TODO: This should be a function of the stride. @@ -1150,8 +1227,11 @@ void Cost::RateFormula(const TargetTransformInfo &TTI,                         ScalarEvolution &SE, DominatorTree &DT,                         const LSRUse &LU,                         SmallPtrSetImpl<const SCEV *> *LoserRegs) { -  assert(F.isCanonical() && "Cost is accurate only for canonical formula"); +  assert(F.isCanonical(*L) && "Cost is accurate only for canonical formula");    // Tally up the registers. +  unsigned PrevAddRecCost = AddRecCost; +  unsigned PrevNumRegs = NumRegs; +  unsigned PrevNumBaseAdds = NumBaseAdds;    if (const SCEV *ScaledReg = F.ScaledReg) {      if (VisitedRegs.count(ScaledReg)) {        Lose(); @@ -1171,6 +1251,18 @@ void Cost::RateFormula(const TargetTransformInfo &TTI,        return;    } +  // Treat every new register that exceeds TTI.getNumberOfRegisters() - 1 as +  // additional instruction (at least fill). +  unsigned TTIRegNum = TTI.getNumberOfRegisters(false) - 1; +  if (NumRegs > TTIRegNum) { +    // Cost already exceeded TTIRegNum, then only newly added register can add +    // new instructions. +    if (PrevNumRegs > TTIRegNum) +      Insns += (NumRegs - PrevNumRegs); +    else +      Insns += (NumRegs - TTIRegNum); +  } +    // Determine how many (unfolded) adds we'll need inside the loop.    size_t NumBaseParts = F.getNumRegs();    if (NumBaseParts > 1) @@ -1181,7 +1273,7 @@ void Cost::RateFormula(const TargetTransformInfo &TTI,    NumBaseAdds += (F.UnfoldedOffset != 0);    // Accumulate non-free scaling amounts. -  ScaleCost += getScalingFactorCost(TTI, LU, F); +  ScaleCost += getScalingFactorCost(TTI, LU, F, *L);    // Tally up the non-zero immediates.    for (const LSRFixup &Fixup : LU.Fixups) { @@ -1199,11 +1291,30 @@ void Cost::RateFormula(const TargetTransformInfo &TTI,          !TTI.isFoldableMemAccessOffset(Fixup.UserInst, Offset))        NumBaseAdds++;    } + +  // If ICmpZero formula ends with not 0, it could not be replaced by +  // just add or sub. We'll need to compare final result of AddRec. +  // That means we'll need an additional instruction. +  // For -10 + {0, +, 1}: +  // i = i + 1; +  // cmp i, 10 +  // +  // For {-10, +, 1}: +  // i = i + 1; +  if (LU.Kind == LSRUse::ICmpZero && !F.hasZeroEnd()) +    Insns++; +  // Each new AddRec adds 1 instruction to calculation. +  Insns += (AddRecCost - PrevAddRecCost); + +  // BaseAdds adds instructions for unfolded registers. +  if (LU.Kind != LSRUse::ICmpZero) +    Insns += NumBaseAdds - PrevNumBaseAdds;    assert(isValid() && "invalid cost");  }  /// Set this cost to a losing value.  void Cost::Lose() { +  Insns = ~0u;    NumRegs = ~0u;    AddRecCost = ~0u;    NumIVMuls = ~0u; @@ -1215,6 +1326,8 @@ void Cost::Lose() {  /// Choose the lower cost.  bool Cost::operator<(const Cost &Other) const { +  if (InsnsCost && Insns != Other.Insns) +    return Insns < Other.Insns;    return std::tie(NumRegs, AddRecCost, NumIVMuls, NumBaseAdds, ScaleCost,                    ImmCost, SetupCost) <           std::tie(Other.NumRegs, Other.AddRecCost, Other.NumIVMuls, @@ -1223,6 +1336,7 @@ bool Cost::operator<(const Cost &Other) const {  }  void Cost::print(raw_ostream &OS) const { +  OS << Insns << " instruction" << (Insns == 1 ? " " : "s ");    OS << NumRegs << " reg" << (NumRegs == 1 ? "" : "s");    if (AddRecCost != 0)      OS << ", with addrec cost " << AddRecCost; @@ -1239,10 +1353,11 @@ void Cost::print(raw_ostream &OS) const {      OS << ", plus " << SetupCost << " setup cost";  } -LLVM_DUMP_METHOD -void Cost::dump() const { +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void Cost::dump() const {    print(errs()); errs() << '\n';  } +#endif  LSRFixup::LSRFixup()    : UserInst(nullptr), OperandValToReplace(nullptr), @@ -1285,10 +1400,11 @@ void LSRFixup::print(raw_ostream &OS) const {      OS << ", Offset=" << Offset;  } -LLVM_DUMP_METHOD -void LSRFixup::dump() const { +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void LSRFixup::dump() const {    print(errs()); errs() << '\n';  } +#endif  /// Test whether this use as a formula which has the same registers as the given  /// formula. @@ -1300,10 +1416,19 @@ bool LSRUse::HasFormulaWithSameRegs(const Formula &F) const {    return Uniquifier.count(Key);  } +/// The function returns a probability of selecting formula without Reg. +float LSRUse::getNotSelectedProbability(const SCEV *Reg) const { +  unsigned FNum = 0; +  for (const Formula &F : Formulae) +    if (F.referencesReg(Reg)) +      FNum++; +  return ((float)(Formulae.size() - FNum)) / Formulae.size(); +} +  /// If the given formula has not yet been inserted, add it to the list, and  /// return true. Return false otherwise.  The formula must be in canonical form. -bool LSRUse::InsertFormula(const Formula &F) { -  assert(F.isCanonical() && "Invalid canonical representation"); +bool LSRUse::InsertFormula(const Formula &F, const Loop &L) { +  assert(F.isCanonical(L) && "Invalid canonical representation");    if (!Formulae.empty() && RigidFormula)      return false; @@ -1391,10 +1516,11 @@ void LSRUse::print(raw_ostream &OS) const {      OS << ", widest fixup type: " << *WidestFixupType;  } -LLVM_DUMP_METHOD -void LSRUse::dump() const { +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void LSRUse::dump() const {    print(errs()); errs() << '\n';  } +#endif  static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,                                   LSRUse::KindType Kind, MemAccessTy AccessTy, @@ -1472,7 +1598,7 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,  static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,                                   int64_t MinOffset, int64_t MaxOffset,                                   LSRUse::KindType Kind, MemAccessTy AccessTy, -                                 const Formula &F) { +                                 const Formula &F, const Loop &L) {    // For the purpose of isAMCompletelyFolded either having a canonical formula    // or a scale not equal to zero is correct.    // Problems may arise from non canonical formulae having a scale == 0. @@ -1480,7 +1606,7 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,    // However, when we generate the scaled formulae, we first check that the    // scaling factor is profitable before computing the actual ScaledReg for    // compile time sake. -  assert((F.isCanonical() || F.Scale != 0)); +  assert((F.isCanonical(L) || F.Scale != 0));    return isAMCompletelyFolded(TTI, MinOffset, MaxOffset, Kind, AccessTy,                                F.BaseGV, F.BaseOffset, F.HasBaseReg, F.Scale);  } @@ -1515,14 +1641,15 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,  }  static unsigned getScalingFactorCost(const TargetTransformInfo &TTI, -                                     const LSRUse &LU, const Formula &F) { +                                     const LSRUse &LU, const Formula &F, +                                     const Loop &L) {    if (!F.Scale)      return 0;    // If the use is not completely folded in that instruction, we will have to    // pay an extra cost only for scale != 1.    if (!isAMCompletelyFolded(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, -                            LU.AccessTy, F)) +                            LU.AccessTy, F, L))      return F.Scale != 1;    switch (LU.Kind) { @@ -1772,6 +1899,7 @@ class LSRInstance {    void NarrowSearchSpaceByDetectingSupersets();    void NarrowSearchSpaceByCollapsingUnrolledCode();    void NarrowSearchSpaceByRefilteringUndesirableDedicatedRegisters(); +  void NarrowSearchSpaceByDeletingCostlyFormulas();    void NarrowSearchSpaceByPickingWinnerRegs();    void NarrowSearchSpaceUsingHeuristics(); @@ -2492,7 +2620,12 @@ static Value *getWideOperand(Value *Oper) {  static bool isCompatibleIVType(Value *LVal, Value *RVal) {    Type *LType = LVal->getType();    Type *RType = RVal->getType(); -  return (LType == RType) || (LType->isPointerTy() && RType->isPointerTy()); +  return (LType == RType) || (LType->isPointerTy() && RType->isPointerTy() && +                              // Different address spaces means (possibly) +                              // different types of the pointer implementation, +                              // e.g. i16 vs i32 so disallow that. +                              (LType->getPointerAddressSpace() == +                               RType->getPointerAddressSpace()));  }  /// Return an approximation of this SCEV expression's "base", or NULL for any @@ -2989,8 +3122,10 @@ void LSRInstance::CollectFixupsAndInitialFormulae() {      User::op_iterator UseI =          find(UserInst->operands(), U.getOperandValToReplace());      assert(UseI != UserInst->op_end() && "cannot find IV operand"); -    if (IVIncSet.count(UseI)) +    if (IVIncSet.count(UseI)) { +      DEBUG(dbgs() << "Use is in profitable chain: " << **UseI << '\n');        continue; +    }      LSRUse::KindType Kind = LSRUse::Basic;      MemAccessTy AccessTy; @@ -3025,8 +3160,7 @@ void LSRInstance::CollectFixupsAndInitialFormulae() {          if (SE.isLoopInvariant(N, L) && isSafeToExpand(N, SE)) {            // S is normalized, so normalize N before folding it into S            // to keep the result normalized. -          N = TransformForPostIncUse(Normalize, N, CI, nullptr, -                                     TmpPostIncLoops, SE, DT); +          N = normalizeForPostIncUse(N, TmpPostIncLoops, SE);            Kind = LSRUse::ICmpZero;            S = SE.getMinusSCEV(N, S);          } @@ -3108,7 +3242,8 @@ bool LSRInstance::InsertFormula(LSRUse &LU, unsigned LUIdx, const Formula &F) {    // Do not insert formula that we will not be able to expand.    assert(isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, LU.AccessTy, F) &&           "Formula is illegal"); -  if (!LU.InsertFormula(F)) + +  if (!LU.InsertFormula(F, *L))      return false;    CountRegisters(F, LUIdx); @@ -3347,7 +3482,7 @@ void LSRInstance::GenerateReassociationsImpl(LSRUse &LU, unsigned LUIdx,        F.BaseRegs.push_back(*J);      // We may have changed the number of register in base regs, adjust the      // formula accordingly. -    F.canonicalize(); +    F.canonicalize(*L);      if (InsertFormula(LU, LUIdx, F))        // If that formula hadn't been seen before, recurse to find more like @@ -3359,7 +3494,7 @@ void LSRInstance::GenerateReassociationsImpl(LSRUse &LU, unsigned LUIdx,  /// Split out subexpressions from adds and the bases of addrecs.  void LSRInstance::GenerateReassociations(LSRUse &LU, unsigned LUIdx,                                           Formula Base, unsigned Depth) { -  assert(Base.isCanonical() && "Input must be in the canonical form"); +  assert(Base.isCanonical(*L) && "Input must be in the canonical form");    // Arbitrarily cap recursion to protect compile time.    if (Depth >= 3)      return; @@ -3400,7 +3535,7 @@ void LSRInstance::GenerateCombinations(LSRUse &LU, unsigned LUIdx,      // rather than proceed with zero in a register.      if (!Sum->isZero()) {        F.BaseRegs.push_back(Sum); -      F.canonicalize(); +      F.canonicalize(*L);        (void)InsertFormula(LU, LUIdx, F);      }    } @@ -3457,7 +3592,7 @@ void LSRInstance::GenerateConstantOffsetsImpl(            F.ScaledReg = nullptr;          } else            F.deleteBaseReg(F.BaseRegs[Idx]); -        F.canonicalize(); +        F.canonicalize(*L);        } else if (IsScaledReg)          F.ScaledReg = NewG;        else @@ -3620,10 +3755,10 @@ void LSRInstance::GenerateScales(LSRUse &LU, unsigned LUIdx, Formula Base) {      if (LU.Kind == LSRUse::ICmpZero &&          !Base.HasBaseReg && Base.BaseOffset == 0 && !Base.BaseGV)        continue; -    // For each addrec base reg, apply the scale, if possible. -    for (size_t i = 0, e = Base.BaseRegs.size(); i != e; ++i) -      if (const SCEVAddRecExpr *AR = -            dyn_cast<SCEVAddRecExpr>(Base.BaseRegs[i])) { +    // For each addrec base reg, if its loop is current loop, apply the scale. +    for (size_t i = 0, e = Base.BaseRegs.size(); i != e; ++i) { +      const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Base.BaseRegs[i]); +      if (AR && (AR->getLoop() == L || LU.AllFixupsOutsideLoop)) {          const SCEV *FactorS = SE.getConstant(IntTy, Factor);          if (FactorS->isZero())            continue; @@ -3637,11 +3772,17 @@ void LSRInstance::GenerateScales(LSRUse &LU, unsigned LUIdx, Formula Base) {            // The canonical representation of 1*reg is reg, which is already in            // Base. In that case, do not try to insert the formula, it will be            // rejected anyway. -          if (F.Scale == 1 && F.BaseRegs.empty()) +          if (F.Scale == 1 && (F.BaseRegs.empty() || +                               (AR->getLoop() != L && LU.AllFixupsOutsideLoop)))              continue; +          // If AllFixupsOutsideLoop is true and F.Scale is 1, we may generate +          // non canonical Formula with ScaledReg's loop not being L. +          if (F.Scale == 1 && LU.AllFixupsOutsideLoop) +            F.canonicalize(*L);            (void)InsertFormula(LU, LUIdx, F);          }        } +    }    }  } @@ -3697,10 +3838,11 @@ void WorkItem::print(raw_ostream &OS) const {       << " , add offset " << Imm;  } -LLVM_DUMP_METHOD -void WorkItem::dump() const { +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void WorkItem::dump() const {    print(errs()); errs() << '\n';  } +#endif  /// Look for registers which are a constant distance apart and try to form reuse  /// opportunities between them. @@ -3821,7 +3963,7 @@ void LSRInstance::GenerateCrossUseConstantOffsets() {              continue;          // OK, looks good. -        NewF.canonicalize(); +        NewF.canonicalize(*this->L);          (void)InsertFormula(LU, LUIdx, NewF);        } else {          // Use the immediate in a base register. @@ -3853,7 +3995,7 @@ void LSRInstance::GenerateCrossUseConstantOffsets() {                  goto skip_formula;            // Ok, looks good. -          NewF.canonicalize(); +          NewF.canonicalize(*this->L);            (void)InsertFormula(LU, LUIdx, NewF);            break;          skip_formula:; @@ -4165,6 +4307,144 @@ void LSRInstance::NarrowSearchSpaceByRefilteringUndesirableDedicatedRegisters(){    }  } +/// The function delete formulas with high registers number expectation. +/// Assuming we don't know the value of each formula (already delete +/// all inefficient), generate probability of not selecting for each +/// register. +/// For example, +/// Use1: +///  reg(a) + reg({0,+,1}) +///  reg(a) + reg({-1,+,1}) + 1 +///  reg({a,+,1}) +/// Use2: +///  reg(b) + reg({0,+,1}) +///  reg(b) + reg({-1,+,1}) + 1 +///  reg({b,+,1}) +/// Use3: +///  reg(c) + reg(b) + reg({0,+,1}) +///  reg(c) + reg({b,+,1}) +/// +/// Probability of not selecting +///                 Use1   Use2    Use3 +/// reg(a)         (1/3) *   1   *   1 +/// reg(b)           1   * (1/3) * (1/2) +/// reg({0,+,1})   (2/3) * (2/3) * (1/2) +/// reg({-1,+,1})  (2/3) * (2/3) *   1 +/// reg({a,+,1})   (2/3) *   1   *   1 +/// reg({b,+,1})     1   * (2/3) * (2/3) +/// reg(c)           1   *   1   *   0 +/// +/// Now count registers number mathematical expectation for each formula: +/// Note that for each use we exclude probability if not selecting for the use. +/// For example for Use1 probability for reg(a) would be just 1 * 1 (excluding +/// probabilty 1/3 of not selecting for Use1). +/// Use1: +///  reg(a) + reg({0,+,1})          1 + 1/3       -- to be deleted +///  reg(a) + reg({-1,+,1}) + 1     1 + 4/9       -- to be deleted +///  reg({a,+,1})                   1 +/// Use2: +///  reg(b) + reg({0,+,1})          1/2 + 1/3     -- to be deleted +///  reg(b) + reg({-1,+,1}) + 1     1/2 + 2/3     -- to be deleted +///  reg({b,+,1})                   2/3 +/// Use3: +///  reg(c) + reg(b) + reg({0,+,1}) 1 + 1/3 + 4/9 -- to be deleted +///  reg(c) + reg({b,+,1})          1 + 2/3 + +void LSRInstance::NarrowSearchSpaceByDeletingCostlyFormulas() { +  if (EstimateSearchSpaceComplexity() < ComplexityLimit) +    return; +  // Ok, we have too many of formulae on our hands to conveniently handle. +  // Use a rough heuristic to thin out the list. + +  // Set of Regs wich will be 100% used in final solution. +  // Used in each formula of a solution (in example above this is reg(c)). +  // We can skip them in calculations. +  SmallPtrSet<const SCEV *, 4> UniqRegs; +  DEBUG(dbgs() << "The search space is too complex.\n"); + +  // Map each register to probability of not selecting +  DenseMap <const SCEV *, float> RegNumMap; +  for (const SCEV *Reg : RegUses) { +    if (UniqRegs.count(Reg)) +      continue; +    float PNotSel = 1; +    for (const LSRUse &LU : Uses) { +      if (!LU.Regs.count(Reg)) +        continue; +      float P = LU.getNotSelectedProbability(Reg); +      if (P != 0.0) +        PNotSel *= P; +      else +        UniqRegs.insert(Reg); +    } +    RegNumMap.insert(std::make_pair(Reg, PNotSel)); +  } + +  DEBUG(dbgs() << "Narrowing the search space by deleting costly formulas\n"); + +  // Delete formulas where registers number expectation is high. +  for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { +    LSRUse &LU = Uses[LUIdx]; +    // If nothing to delete - continue. +    if (LU.Formulae.size() < 2) +      continue; +    // This is temporary solution to test performance. Float should be +    // replaced with round independent type (based on integers) to avoid +    // different results for different target builds. +    float FMinRegNum = LU.Formulae[0].getNumRegs(); +    float FMinARegNum = LU.Formulae[0].getNumRegs(); +    size_t MinIdx = 0; +    for (size_t i = 0, e = LU.Formulae.size(); i != e; ++i) { +      Formula &F = LU.Formulae[i]; +      float FRegNum = 0; +      float FARegNum = 0; +      for (const SCEV *BaseReg : F.BaseRegs) { +        if (UniqRegs.count(BaseReg)) +          continue; +        FRegNum += RegNumMap[BaseReg] / LU.getNotSelectedProbability(BaseReg); +        if (isa<SCEVAddRecExpr>(BaseReg)) +          FARegNum += +              RegNumMap[BaseReg] / LU.getNotSelectedProbability(BaseReg); +      } +      if (const SCEV *ScaledReg = F.ScaledReg) { +        if (!UniqRegs.count(ScaledReg)) { +          FRegNum += +              RegNumMap[ScaledReg] / LU.getNotSelectedProbability(ScaledReg); +          if (isa<SCEVAddRecExpr>(ScaledReg)) +            FARegNum += +                RegNumMap[ScaledReg] / LU.getNotSelectedProbability(ScaledReg); +        } +      } +      if (FMinRegNum > FRegNum || +          (FMinRegNum == FRegNum && FMinARegNum > FARegNum)) { +        FMinRegNum = FRegNum; +        FMinARegNum = FARegNum; +        MinIdx = i; +      } +    } +    DEBUG(dbgs() << "  The formula "; LU.Formulae[MinIdx].print(dbgs()); +          dbgs() << " with min reg num " << FMinRegNum << '\n'); +    if (MinIdx != 0) +      std::swap(LU.Formulae[MinIdx], LU.Formulae[0]); +    while (LU.Formulae.size() != 1) { +      DEBUG(dbgs() << "  Deleting "; LU.Formulae.back().print(dbgs()); +            dbgs() << '\n'); +      LU.Formulae.pop_back(); +    } +    LU.RecomputeRegs(LUIdx, RegUses); +    assert(LU.Formulae.size() == 1 && "Should be exactly 1 min regs formula"); +    Formula &F = LU.Formulae[0]; +    DEBUG(dbgs() << "  Leaving only "; F.print(dbgs()); dbgs() << '\n'); +    // When we choose the formula, the regs become unique. +    UniqRegs.insert(F.BaseRegs.begin(), F.BaseRegs.end()); +    if (F.ScaledReg) +      UniqRegs.insert(F.ScaledReg); +  } +  DEBUG(dbgs() << "After pre-selection:\n"; +  print_uses(dbgs())); +} + +  /// Pick a register which seems likely to be profitable, and then in any use  /// which has any reference to that register, delete all formulae which do not  /// reference that register. @@ -4237,7 +4517,10 @@ void LSRInstance::NarrowSearchSpaceUsingHeuristics() {    NarrowSearchSpaceByDetectingSupersets();    NarrowSearchSpaceByCollapsingUnrolledCode();    NarrowSearchSpaceByRefilteringUndesirableDedicatedRegisters(); -  NarrowSearchSpaceByPickingWinnerRegs(); +  if (LSRExpNarrow) +    NarrowSearchSpaceByDeletingCostlyFormulas(); +  else +    NarrowSearchSpaceByPickingWinnerRegs();  }  /// This is the recursive solver. @@ -4515,11 +4798,7 @@ Value *LSRInstance::Expand(const LSRUse &LU,      assert(!Reg->isZero() && "Zero allocated in a base register!");      // If we're expanding for a post-inc user, make the post-inc adjustment. -    PostIncLoopSet &Loops = const_cast<PostIncLoopSet &>(LF.PostIncLoops); -    Reg = TransformForPostIncUse(Denormalize, Reg, -                                 LF.UserInst, LF.OperandValToReplace, -                                 Loops, SE, DT); - +    Reg = denormalizeForPostIncUse(Reg, LF.PostIncLoops, SE);      Ops.push_back(SE.getUnknown(Rewriter.expandCodeFor(Reg, nullptr)));    } @@ -4530,9 +4809,7 @@ Value *LSRInstance::Expand(const LSRUse &LU,      // If we're expanding for a post-inc user, make the post-inc adjustment.      PostIncLoopSet &Loops = const_cast<PostIncLoopSet &>(LF.PostIncLoops); -    ScaledS = TransformForPostIncUse(Denormalize, ScaledS, -                                     LF.UserInst, LF.OperandValToReplace, -                                     Loops, SE, DT); +    ScaledS = denormalizeForPostIncUse(ScaledS, Loops, SE);      if (LU.Kind == LSRUse::ICmpZero) {        // Expand ScaleReg as if it was part of the base regs. @@ -4975,10 +5252,11 @@ void LSRInstance::print(raw_ostream &OS) const {    print_uses(OS);  } -LLVM_DUMP_METHOD -void LSRInstance::dump() const { +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void LSRInstance::dump() const {    print(errs()); errs() << '\n';  } +#endif  namespace { diff --git a/lib/Transforms/Scalar/LoopUnrollPass.cpp b/lib/Transforms/Scalar/LoopUnrollPass.cpp index c7f91226d222..62aa6ee48069 100644 --- a/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -44,7 +44,11 @@ using namespace llvm;  static cl::opt<unsigned>      UnrollThreshold("unroll-threshold", cl::Hidden, -                    cl::desc("The baseline cost threshold for loop unrolling")); +                    cl::desc("The cost threshold for loop unrolling")); + +static cl::opt<unsigned> UnrollPartialThreshold( +    "unroll-partial-threshold", cl::Hidden, +    cl::desc("The cost threshold for partial loop unrolling"));  static cl::opt<unsigned> UnrollMaxPercentThresholdBoost(      "unroll-max-percent-threshold-boost", cl::init(400), cl::Hidden, @@ -106,10 +110,19 @@ static cl::opt<unsigned> FlatLoopTripCountThreshold(               "aggressively unrolled."));  static cl::opt<bool> -    UnrollAllowPeeling("unroll-allow-peeling", cl::Hidden, +    UnrollAllowPeeling("unroll-allow-peeling", cl::init(true), cl::Hidden,                         cl::desc("Allows loops to be peeled when the dynamic "                                  "trip count is known to be low.")); +// This option isn't ever intended to be enabled, it serves to allow +// experiments to check the assumptions about when this kind of revisit is +// necessary. +static cl::opt<bool> UnrollRevisitChildLoops( +    "unroll-revisit-child-loops", cl::Hidden, +    cl::desc("Enqueue and re-visit child loops in the loop PM after unrolling. " +             "This shouldn't typically be needed as child loops (or their " +             "clones) were already visited.")); +  /// A magic value for use with the Threshold parameter to indicate  /// that the loop unroll should be performed regardless of how much  /// code expansion would result. @@ -118,16 +131,17 @@ static const unsigned NoThreshold = UINT_MAX;  /// Gather the various unrolling parameters based on the defaults, compiler  /// flags, TTI overrides and user specified parameters.  static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences( -    Loop *L, const TargetTransformInfo &TTI, Optional<unsigned> UserThreshold, -    Optional<unsigned> UserCount, Optional<bool> UserAllowPartial, -    Optional<bool> UserRuntime, Optional<bool> UserUpperBound) { +    Loop *L, const TargetTransformInfo &TTI, int OptLevel, +    Optional<unsigned> UserThreshold, Optional<unsigned> UserCount, +    Optional<bool> UserAllowPartial, Optional<bool> UserRuntime, +    Optional<bool> UserUpperBound) {    TargetTransformInfo::UnrollingPreferences UP;    // Set up the defaults -  UP.Threshold = 150; +  UP.Threshold = OptLevel > 2 ? 300 : 150;    UP.MaxPercentThresholdBoost = 400;    UP.OptSizeThreshold = 0; -  UP.PartialThreshold = UP.Threshold; +  UP.PartialThreshold = 150;    UP.PartialOptSizeThreshold = 0;    UP.Count = 0;    UP.PeelCount = 0; @@ -141,7 +155,7 @@ static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences(    UP.AllowExpensiveTripCount = false;    UP.Force = false;    UP.UpperBound = false; -  UP.AllowPeeling = false; +  UP.AllowPeeling = true;    // Override with any target specific settings    TTI.getUnrollingPreferences(L, UP); @@ -153,10 +167,10 @@ static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences(    }    // Apply any user values specified by cl::opt -  if (UnrollThreshold.getNumOccurrences() > 0) { +  if (UnrollThreshold.getNumOccurrences() > 0)      UP.Threshold = UnrollThreshold; -    UP.PartialThreshold = UnrollThreshold; -  } +  if (UnrollPartialThreshold.getNumOccurrences() > 0) +    UP.PartialThreshold = UnrollPartialThreshold;    if (UnrollMaxPercentThresholdBoost.getNumOccurrences() > 0)      UP.MaxPercentThresholdBoost = UnrollMaxPercentThresholdBoost;    if (UnrollMaxCount.getNumOccurrences() > 0) @@ -495,7 +509,7 @@ analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT,              KnownSucc = SI->getSuccessor(0);            else if (ConstantInt *SimpleCondVal =                         dyn_cast<ConstantInt>(SimpleCond)) -            KnownSucc = SI->findCaseValue(SimpleCondVal).getCaseSuccessor(); +            KnownSucc = SI->findCaseValue(SimpleCondVal)->getCaseSuccessor();          }        }        if (KnownSucc) { @@ -770,7 +784,15 @@ static bool computeUnrollCount(      }    } -  // 4rd priority is partial unrolling. +  // 4th priority is loop peeling +  computePeelCount(L, LoopSize, UP, TripCount); +  if (UP.PeelCount) { +    UP.Runtime = false; +    UP.Count = 1; +    return ExplicitUnroll; +  } + +  // 5th priority is partial unrolling.    // Try partial unroll only when TripCount could be staticaly calculated.    if (TripCount) {      UP.Partial |= ExplicitUnroll; @@ -833,14 +855,6 @@ static bool computeUnrollCount(          << "Unable to fully unroll loop as directed by unroll(full) pragma "             "because loop has a runtime trip count."); -  // 5th priority is loop peeling -  computePeelCount(L, LoopSize, UP); -  if (UP.PeelCount) { -    UP.Runtime = false; -    UP.Count = 1; -    return ExplicitUnroll; -  } -    // 6th priority is runtime unrolling.    // Don't unroll a runtime trip count loop when it is disabled.    if (HasRuntimeUnrollDisablePragma(L)) { @@ -914,7 +928,7 @@ static bool computeUnrollCount(  static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,                              ScalarEvolution *SE, const TargetTransformInfo &TTI,                              AssumptionCache &AC, OptimizationRemarkEmitter &ORE, -                            bool PreserveLCSSA, +                            bool PreserveLCSSA, int OptLevel,                              Optional<unsigned> ProvidedCount,                              Optional<unsigned> ProvidedThreshold,                              Optional<bool> ProvidedAllowPartial, @@ -934,7 +948,7 @@ static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,    bool NotDuplicatable;    bool Convergent;    TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences( -      L, TTI, ProvidedThreshold, ProvidedCount, ProvidedAllowPartial, +      L, TTI, OptLevel, ProvidedThreshold, ProvidedCount, ProvidedAllowPartial,        ProvidedRuntime, ProvidedUpperBound);    // Exit early if unrolling is disabled.    if (UP.Threshold == 0 && (!UP.Partial || UP.PartialThreshold == 0)) @@ -1034,16 +1048,17 @@ namespace {  class LoopUnroll : public LoopPass {  public:    static char ID; // Pass ID, replacement for typeid -  LoopUnroll(Optional<unsigned> Threshold = None, +  LoopUnroll(int OptLevel = 2, Optional<unsigned> Threshold = None,               Optional<unsigned> Count = None,               Optional<bool> AllowPartial = None, Optional<bool> Runtime = None,               Optional<bool> UpperBound = None) -      : LoopPass(ID), ProvidedCount(std::move(Count)), +      : LoopPass(ID), OptLevel(OptLevel), ProvidedCount(std::move(Count)),          ProvidedThreshold(Threshold), ProvidedAllowPartial(AllowPartial),          ProvidedRuntime(Runtime), ProvidedUpperBound(UpperBound) {      initializeLoopUnrollPass(*PassRegistry::getPassRegistry());    } +  int OptLevel;    Optional<unsigned> ProvidedCount;    Optional<unsigned> ProvidedThreshold;    Optional<bool> ProvidedAllowPartial; @@ -1068,7 +1083,7 @@ public:      OptimizationRemarkEmitter ORE(&F);      bool PreserveLCSSA = mustPreserveAnalysisID(LCSSAID); -    return tryToUnrollLoop(L, DT, LI, SE, TTI, AC, ORE, PreserveLCSSA, +    return tryToUnrollLoop(L, DT, LI, SE, TTI, AC, ORE, PreserveLCSSA, OptLevel,                             ProvidedCount, ProvidedThreshold,                             ProvidedAllowPartial, ProvidedRuntime,                             ProvidedUpperBound); @@ -1094,26 +1109,27 @@ INITIALIZE_PASS_DEPENDENCY(LoopPass)  INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)  INITIALIZE_PASS_END(LoopUnroll, "loop-unroll", "Unroll loops", false, false) -Pass *llvm::createLoopUnrollPass(int Threshold, int Count, int AllowPartial, -                                 int Runtime, int UpperBound) { +Pass *llvm::createLoopUnrollPass(int OptLevel, int Threshold, int Count, +                                 int AllowPartial, int Runtime, +                                 int UpperBound) {    // TODO: It would make more sense for this function to take the optionals    // directly, but that's dangerous since it would silently break out of tree    // callers. -  return new LoopUnroll(Threshold == -1 ? None : Optional<unsigned>(Threshold), -                        Count == -1 ? None : Optional<unsigned>(Count), -                        AllowPartial == -1 ? None -                                           : Optional<bool>(AllowPartial), -                        Runtime == -1 ? None : Optional<bool>(Runtime), -                        UpperBound == -1 ? None : Optional<bool>(UpperBound)); +  return new LoopUnroll( +      OptLevel, Threshold == -1 ? None : Optional<unsigned>(Threshold), +      Count == -1 ? None : Optional<unsigned>(Count), +      AllowPartial == -1 ? None : Optional<bool>(AllowPartial), +      Runtime == -1 ? None : Optional<bool>(Runtime), +      UpperBound == -1 ? None : Optional<bool>(UpperBound));  } -Pass *llvm::createSimpleLoopUnrollPass() { -  return llvm::createLoopUnrollPass(-1, -1, 0, 0, 0); +Pass *llvm::createSimpleLoopUnrollPass(int OptLevel) { +  return llvm::createLoopUnrollPass(OptLevel, -1, -1, 0, 0, 0);  }  PreservedAnalyses LoopUnrollPass::run(Loop &L, LoopAnalysisManager &AM,                                        LoopStandardAnalysisResults &AR, -                                      LPMUpdater &) { +                                      LPMUpdater &Updater) {    const auto &FAM =        AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager();    Function *F = L.getHeader()->getParent(); @@ -1124,12 +1140,84 @@ PreservedAnalyses LoopUnrollPass::run(Loop &L, LoopAnalysisManager &AM,      report_fatal_error("LoopUnrollPass: OptimizationRemarkEmitterAnalysis not "                         "cached at a higher level"); -  bool Changed = tryToUnrollLoop(&L, AR.DT, &AR.LI, &AR.SE, AR.TTI, AR.AC, *ORE, -                                 /*PreserveLCSSA*/ true, ProvidedCount, -                                 ProvidedThreshold, ProvidedAllowPartial, -                                 ProvidedRuntime, ProvidedUpperBound); - +  // Keep track of the previous loop structure so we can identify new loops +  // created by unrolling. +  Loop *ParentL = L.getParentLoop(); +  SmallPtrSet<Loop *, 4> OldLoops; +  if (ParentL) +    OldLoops.insert(ParentL->begin(), ParentL->end()); +  else +    OldLoops.insert(AR.LI.begin(), AR.LI.end()); + +  // The API here is quite complex to call, but there are only two interesting +  // states we support: partial and full (or "simple") unrolling. However, to +  // enable these things we actually pass "None" in for the optional to avoid +  // providing an explicit choice. +  Optional<bool> AllowPartialParam, RuntimeParam, UpperBoundParam; +  if (!AllowPartialUnrolling) +    AllowPartialParam = RuntimeParam = UpperBoundParam = false; +  bool Changed = tryToUnrollLoop( +      &L, AR.DT, &AR.LI, &AR.SE, AR.TTI, AR.AC, *ORE, +      /*PreserveLCSSA*/ true, OptLevel, /*Count*/ None, +      /*Threshold*/ None, AllowPartialParam, RuntimeParam, UpperBoundParam);    if (!Changed)      return PreservedAnalyses::all(); + +  // The parent must not be damaged by unrolling! +#ifndef NDEBUG +  if (ParentL) +    ParentL->verifyLoop(); +#endif + +  // Unrolling can do several things to introduce new loops into a loop nest: +  // - Partial unrolling clones child loops within the current loop. If it +  //   uses a remainder, then it can also create any number of sibling loops. +  // - Full unrolling clones child loops within the current loop but then +  //   removes the current loop making all of the children appear to be new +  //   sibling loops. +  // - Loop peeling can directly introduce new sibling loops by peeling one +  //   iteration. +  // +  // When a new loop appears as a sibling loop, either from peeling an +  // iteration or fully unrolling, its nesting structure has fundamentally +  // changed and we want to revisit it to reflect that. +  // +  // When unrolling has removed the current loop, we need to tell the +  // infrastructure that it is gone. +  // +  // Finally, we support a debugging/testing mode where we revisit child loops +  // as well. These are not expected to require further optimizations as either +  // they or the loop they were cloned from have been directly visited already. +  // But the debugging mode allows us to check this assumption. +  bool IsCurrentLoopValid = false; +  SmallVector<Loop *, 4> SibLoops; +  if (ParentL) +    SibLoops.append(ParentL->begin(), ParentL->end()); +  else +    SibLoops.append(AR.LI.begin(), AR.LI.end()); +  erase_if(SibLoops, [&](Loop *SibLoop) { +    if (SibLoop == &L) { +      IsCurrentLoopValid = true; +      return true; +    } + +    // Otherwise erase the loop from the list if it was in the old loops. +    return OldLoops.count(SibLoop) != 0; +  }); +  Updater.addSiblingLoops(SibLoops); + +  if (!IsCurrentLoopValid) { +    Updater.markLoopAsDeleted(L); +  } else { +    // We can only walk child loops if the current loop remained valid. +    if (UnrollRevisitChildLoops) { +      // Walk *all* of the child loops. This is a highly speculative mode +      // anyways so look for any simplifications that arose from partial +      // unrolling or peeling off of iterations. +      SmallVector<Loop *, 4> ChildLoops(L.begin(), L.end()); +      Updater.addChildLoops(ChildLoops); +    } +  } +    return getLoopPassPreservedAnalyses();  } diff --git a/lib/Transforms/Scalar/LoopUnswitch.cpp b/lib/Transforms/Scalar/LoopUnswitch.cpp index 76fe91884c7b..a99c9999c619 100644 --- a/lib/Transforms/Scalar/LoopUnswitch.cpp +++ b/lib/Transforms/Scalar/LoopUnswitch.cpp @@ -33,6 +33,7 @@  #include "llvm/Analysis/GlobalsModRef.h"  #include "llvm/Analysis/AssumptionCache.h"  #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/DivergenceAnalysis.h"  #include "llvm/Analysis/InstructionSimplify.h"  #include "llvm/Analysis/LoopInfo.h"  #include "llvm/Analysis/LoopPass.h" @@ -47,6 +48,7 @@  #include "llvm/IR/Dominators.h"  #include "llvm/IR/Function.h"  #include "llvm/IR/Instructions.h" +#include "llvm/IR/InstrTypes.h"  #include "llvm/IR/Module.h"  #include "llvm/IR/MDBuilder.h"  #include "llvm/Support/CommandLine.h" @@ -77,19 +79,6 @@ static cl::opt<unsigned>  Threshold("loop-unswitch-threshold", cl::desc("Max loop size to unswitch"),            cl::init(100), cl::Hidden); -static cl::opt<bool> -LoopUnswitchWithBlockFrequency("loop-unswitch-with-block-frequency", -    cl::init(false), cl::Hidden, -    cl::desc("Enable the use of the block frequency analysis to access PGO " -             "heuristics to minimize code growth in cold regions.")); - -static cl::opt<unsigned> -ColdnessThreshold("loop-unswitch-coldness-threshold", cl::init(1), cl::Hidden, -    cl::desc("Coldness threshold in percentage. The loop header frequency " -             "(relative to the entry frequency) is compared with this " -             "threshold to determine if non-trivial unswitching should be " -             "enabled.")); -  namespace {    class LUAnalysisCache { @@ -174,13 +163,6 @@ namespace {      LUAnalysisCache BranchesInfo; -    bool EnabledPGO; - -    // BFI and ColdEntryFreq are only used when PGO and -    // LoopUnswitchWithBlockFrequency are enabled. -    BlockFrequencyInfo BFI; -    BlockFrequency ColdEntryFreq; -      bool OptimizeForSize;      bool redoLoop; @@ -199,12 +181,14 @@ namespace {      // NewBlocks contained cloned copy of basic blocks from LoopBlocks.      std::vector<BasicBlock*> NewBlocks; +    bool hasBranchDivergence; +    public:      static char ID; // Pass ID, replacement for typeid -    explicit LoopUnswitch(bool Os = false) : +    explicit LoopUnswitch(bool Os = false, bool hasBranchDivergence = false) :        LoopPass(ID), OptimizeForSize(Os), redoLoop(false),        currentLoop(nullptr), DT(nullptr), loopHeader(nullptr), -      loopPreheader(nullptr) { +      loopPreheader(nullptr), hasBranchDivergence(hasBranchDivergence) {          initializeLoopUnswitchPass(*PassRegistry::getPassRegistry());        } @@ -217,6 +201,8 @@ namespace {      void getAnalysisUsage(AnalysisUsage &AU) const override {        AU.addRequired<AssumptionCacheTracker>();        AU.addRequired<TargetTransformInfoWrapperPass>(); +      if (hasBranchDivergence) +        AU.addRequired<DivergenceAnalysis>();        getLoopAnalysisUsage(AU);      } @@ -255,6 +241,11 @@ namespace {                                          TerminatorInst *TI);      void SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L); + +    /// Given that the Invariant is not equal to Val. Simplify instructions +    /// in the loop. +    Value *SimplifyInstructionWithNotEqual(Instruction *Inst, Value *Invariant, +                                           Constant *Val);    };  } @@ -381,16 +372,35 @@ INITIALIZE_PASS_BEGIN(LoopUnswitch, "loop-unswitch", "Unswitch loops",  INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)  INITIALIZE_PASS_DEPENDENCY(LoopPass)  INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DivergenceAnalysis)  INITIALIZE_PASS_END(LoopUnswitch, "loop-unswitch", "Unswitch loops",                        false, false) -Pass *llvm::createLoopUnswitchPass(bool Os) { -  return new LoopUnswitch(Os); +Pass *llvm::createLoopUnswitchPass(bool Os, bool hasBranchDivergence) { +  return new LoopUnswitch(Os, hasBranchDivergence);  } +/// Operator chain lattice. +enum OperatorChain { +  OC_OpChainNone,    ///< There is no operator. +  OC_OpChainOr,      ///< There are only ORs. +  OC_OpChainAnd,     ///< There are only ANDs. +  OC_OpChainMixed    ///< There are ANDs and ORs. +}; +  /// Cond is a condition that occurs in L. If it is invariant in the loop, or has  /// an invariant piece, return the invariant. Otherwise, return null. +// +/// NOTE: FindLIVLoopCondition will not return a partial LIV by walking up a +/// mixed operator chain, as we can not reliably find a value which will simplify +/// the operator chain. If the chain is AND-only or OR-only, we can use 0 or ~0 +/// to simplify the chain. +/// +/// NOTE: In case a partial LIV and a mixed operator chain, we may be able to +/// simplify the condition itself to a loop variant condition, but at the +/// cost of creating an entirely new loop.  static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, +                                   OperatorChain &ParentChain,                                     DenseMap<Value *, Value *> &Cache) {    auto CacheIt = Cache.find(Cond);    if (CacheIt != Cache.end()) @@ -414,21 +424,53 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,      return Cond;    } +  // Walk up the operator chain to find partial invariant conditions.    if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Cond))      if (BO->getOpcode() == Instruction::And ||          BO->getOpcode() == Instruction::Or) { -      // If either the left or right side is invariant, we can unswitch on this, -      // which will cause the branch to go away in one loop and the condition to -      // simplify in the other one. -      if (Value *LHS = -              FindLIVLoopCondition(BO->getOperand(0), L, Changed, Cache)) { -        Cache[Cond] = LHS; -        return LHS; +      // Given the previous operator, compute the current operator chain status. +      OperatorChain NewChain; +      switch (ParentChain) { +      case OC_OpChainNone: +        NewChain = BO->getOpcode() == Instruction::And ? OC_OpChainAnd : +                                      OC_OpChainOr; +        break; +      case OC_OpChainOr: +        NewChain = BO->getOpcode() == Instruction::Or ? OC_OpChainOr : +                                      OC_OpChainMixed; +        break; +      case OC_OpChainAnd: +        NewChain = BO->getOpcode() == Instruction::And ? OC_OpChainAnd : +                                      OC_OpChainMixed; +        break; +      case OC_OpChainMixed: +        NewChain = OC_OpChainMixed; +        break;        } -      if (Value *RHS = -              FindLIVLoopCondition(BO->getOperand(1), L, Changed, Cache)) { -        Cache[Cond] = RHS; -        return RHS; + +      // If we reach a Mixed state, we do not want to keep walking up as we can not +      // reliably find a value that will simplify the chain. With this check, we +      // will return null on the first sight of mixed chain and the caller will +      // either backtrack to find partial LIV in other operand or return null. +      if (NewChain != OC_OpChainMixed) { +        // Update the current operator chain type before we search up the chain. +        ParentChain = NewChain; +        // If either the left or right side is invariant, we can unswitch on this, +        // which will cause the branch to go away in one loop and the condition to +        // simplify in the other one. +        if (Value *LHS = FindLIVLoopCondition(BO->getOperand(0), L, Changed, +                                              ParentChain, Cache)) { +          Cache[Cond] = LHS; +          return LHS; +        } +        // We did not manage to find a partial LIV in operand(0). Backtrack and try +        // operand(1). +        ParentChain = NewChain; +        if (Value *RHS = FindLIVLoopCondition(BO->getOperand(1), L, Changed, +                                              ParentChain, Cache)) { +          Cache[Cond] = RHS; +          return RHS; +        }        }      } @@ -436,9 +478,21 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,    return nullptr;  } -static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed) { +/// Cond is a condition that occurs in L. If it is invariant in the loop, or has +/// an invariant piece, return the invariant along with the operator chain type. +/// Otherwise, return null. +static std::pair<Value *, OperatorChain> FindLIVLoopCondition(Value *Cond, +                                                              Loop *L, +                                                              bool &Changed) {    DenseMap<Value *, Value *> Cache; -  return FindLIVLoopCondition(Cond, L, Changed, Cache); +  OperatorChain OpChain = OC_OpChainNone; +  Value *FCond = FindLIVLoopCondition(Cond, L, Changed, OpChain, Cache); + +  // In case we do find a LIV, it can not be obtained by walking up a mixed +  // operator chain. +  assert((!FCond || OpChain != OC_OpChainMixed) && +        "Do not expect a partial LIV with mixed operator chain"); +  return {FCond, OpChain};  }  bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) { @@ -457,19 +511,6 @@ bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) {    if (SanitizeMemory)      computeLoopSafetyInfo(&SafetyInfo, L); -  EnabledPGO = F->getEntryCount().hasValue(); - -  if (LoopUnswitchWithBlockFrequency && EnabledPGO) { -    BranchProbabilityInfo BPI(*F, *LI); -    BFI.calculate(*L->getHeader()->getParent(), BPI, *LI); - -    // Use BranchProbability to compute a minimum frequency based on -    // function entry baseline frequency. Loops with headers below this -    // frequency are considered as cold. -    const BranchProbability ColdProb(ColdnessThreshold, 100); -    ColdEntryFreq = BlockFrequency(BFI.getEntryFreq()) * ColdProb; -  } -    bool Changed = false;    do {      assert(currentLoop->isLCSSAForm(*DT)); @@ -581,19 +622,9 @@ bool LoopUnswitch::processCurrentLoop() {        loopHeader->getParent()->hasFnAttribute(Attribute::OptimizeForSize))      return false; -  if (LoopUnswitchWithBlockFrequency && EnabledPGO) { -    // Compute the weighted frequency of the hottest block in the -    // loop (loopHeader in this case since inner loops should be -    // processed before outer loop). If it is less than ColdFrequency, -    // we should not unswitch. -    BlockFrequency LoopEntryFreq = BFI.getBlockFreq(loopHeader); -    if (LoopEntryFreq < ColdEntryFreq) -      return false; -  } -    for (IntrinsicInst *Guard : Guards) {      Value *LoopCond = -        FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed); +        FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed).first;      if (LoopCond &&          UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) {        // NB! Unswitching (if successful) could have erased some of the @@ -634,7 +665,7 @@ bool LoopUnswitch::processCurrentLoop() {          // See if this, or some part of it, is loop invariant.  If so, we can          // unswitch on it if we desire.          Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), -                                               currentLoop, Changed); +                                               currentLoop, Changed).first;          if (LoopCond &&              UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context), TI)) {            ++NumBranches; @@ -642,24 +673,48 @@ bool LoopUnswitch::processCurrentLoop() {          }        }      } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { -      Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), -                                             currentLoop, Changed); +      Value *SC = SI->getCondition(); +      Value *LoopCond; +      OperatorChain OpChain; +      std::tie(LoopCond, OpChain) = +        FindLIVLoopCondition(SC, currentLoop, Changed); +        unsigned NumCases = SI->getNumCases();        if (LoopCond && NumCases) {          // Find a value to unswitch on:          // FIXME: this should chose the most expensive case!          // FIXME: scan for a case with a non-critical edge?          Constant *UnswitchVal = nullptr; - -        // Do not process same value again and again. -        // At this point we have some cases already unswitched and -        // some not yet unswitched. Let's find the first not yet unswitched one. -        for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); -             i != e; ++i) { -          Constant *UnswitchValCandidate = i.getCaseValue(); -          if (!BranchesInfo.isUnswitched(SI, UnswitchValCandidate)) { -            UnswitchVal = UnswitchValCandidate; -            break; +        // Find a case value such that at least one case value is unswitched +        // out. +        if (OpChain == OC_OpChainAnd) { +          // If the chain only has ANDs and the switch has a case value of 0. +          // Dropping in a 0 to the chain will unswitch out the 0-casevalue. +          auto *AllZero = cast<ConstantInt>(Constant::getNullValue(SC->getType())); +          if (BranchesInfo.isUnswitched(SI, AllZero)) +            continue; +          // We are unswitching 0 out. +          UnswitchVal = AllZero; +        } else if (OpChain == OC_OpChainOr) { +          // If the chain only has ORs and the switch has a case value of ~0. +          // Dropping in a ~0 to the chain will unswitch out the ~0-casevalue. +          auto *AllOne = cast<ConstantInt>(Constant::getAllOnesValue(SC->getType())); +          if (BranchesInfo.isUnswitched(SI, AllOne)) +            continue; +          // We are unswitching ~0 out. +          UnswitchVal = AllOne; +        } else { +          assert(OpChain == OC_OpChainNone &&  +                 "Expect to unswitch on trivial chain"); +          // Do not process same value again and again. +          // At this point we have some cases already unswitched and +          // some not yet unswitched. Let's find the first not yet unswitched one. +          for (auto Case : SI->cases()) { +            Constant *UnswitchValCandidate = Case.getCaseValue(); +            if (!BranchesInfo.isUnswitched(SI, UnswitchValCandidate)) { +              UnswitchVal = UnswitchValCandidate; +              break; +            }            }          } @@ -668,6 +723,11 @@ bool LoopUnswitch::processCurrentLoop() {          if (UnswitchIfProfitable(LoopCond, UnswitchVal)) {            ++NumSwitches; +          // In case of a full LIV, UnswitchVal is the value we unswitched out. +          // In case of a partial LIV, we only unswitch when its an AND-chain +          // or OR-chain. In both cases switch input value simplifies to +          // UnswitchVal. +          BranchesInfo.setUnswitched(SI, UnswitchVal);            return true;          }        } @@ -678,7 +738,7 @@ bool LoopUnswitch::processCurrentLoop() {           BBI != E; ++BBI)        if (SelectInst *SI = dyn_cast<SelectInst>(BBI)) {          Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), -                                               currentLoop, Changed); +                                               currentLoop, Changed).first;          if (LoopCond && UnswitchIfProfitable(LoopCond,                                               ConstantInt::getTrue(Context))) {            ++NumSelects; @@ -753,6 +813,15 @@ bool LoopUnswitch::UnswitchIfProfitable(Value *LoopCond, Constant *Val,                   << ". Cost too high.\n");      return false;    } +  if (hasBranchDivergence && +      getAnalysis<DivergenceAnalysis>().isDivergent(LoopCond)) { +    DEBUG(dbgs() << "NOT unswitching loop %" +                 << currentLoop->getHeader()->getName() +                 << " at non-trivial condition '" << *Val +                 << "' == " << *LoopCond << "\n" +                 << ". Condition is divergent.\n"); +    return false; +  }    UnswitchNontrivialCondition(LoopCond, Val, currentLoop, TI);    return true; @@ -899,7 +968,6 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {        if (I.mayHaveSideEffects())          return false; -    // FIXME: add check for constant foldable switch instructions.      if (BranchInst *BI = dyn_cast<BranchInst>(CurrentTerm)) {        if (BI->isUnconditional()) {          CurrentBB = BI->getSuccessor(0); @@ -911,7 +979,16 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {          // Found a trivial condition candidate: non-foldable conditional branch.          break;        } +    } else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurrentTerm)) { +      // At this point, any constant-foldable instructions should have probably +      // been folded. +      ConstantInt *Cond = dyn_cast<ConstantInt>(SI->getCondition()); +      if (!Cond) +        break; +      // Find the target block we are definitely going to. +      CurrentBB = SI->findCaseValue(Cond)->getCaseSuccessor();      } else { +      // We do not understand these terminator instructions.        break;      } @@ -929,7 +1006,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {        return false;      Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), -                                           currentLoop, Changed); +                                           currentLoop, Changed).first;      // Unswitch only if the trivial condition itself is an LIV (not      // partial LIV which could occur in and/or) @@ -960,7 +1037,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {    } else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurrentTerm)) {      // If this isn't switching on an invariant condition, we can't unswitch it.      Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), -                                           currentLoop, Changed); +                                           currentLoop, Changed).first;      // Unswitch only if the trivial condition itself is an LIV (not      // partial LIV which could occur in and/or) @@ -973,13 +1050,12 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {      // this.      // Note that we can't trivially unswitch on the default case or      // on already unswitched cases. -    for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); -         i != e; ++i) { +    for (auto Case : SI->cases()) {        BasicBlock *LoopExitCandidate; -      if ((LoopExitCandidate = isTrivialLoopExitBlock(currentLoop, -                                               i.getCaseSuccessor()))) { +      if ((LoopExitCandidate = +               isTrivialLoopExitBlock(currentLoop, Case.getCaseSuccessor()))) {          // Okay, we found a trivial case, remember the value that is trivial. -        ConstantInt *CaseVal = i.getCaseValue(); +        ConstantInt *CaseVal = Case.getCaseValue();          // Check that it was not unswitched before, since already unswitched          // trivial vals are looks trivial too. @@ -998,6 +1074,9 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {      UnswitchTrivialCondition(currentLoop, LoopCond, CondVal, LoopExitBB,                               nullptr); + +    // We are only unswitching full LIV. +    BranchesInfo.setUnswitched(SI, CondVal);      ++NumSwitches;      return true;    } @@ -1253,18 +1332,38 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,      if (!UI || !L->contains(UI))        continue; -    Worklist.push_back(UI); +    // At this point, we know LIC is definitely not Val. Try to use some simple +    // logic to simplify the user w.r.t. to the context. +    if (Value *Replacement = SimplifyInstructionWithNotEqual(UI, LIC, Val)) { +      if (LI->replacementPreservesLCSSAForm(UI, Replacement)) { +        // This in-loop instruction has been simplified w.r.t. its context, +        // i.e. LIC != Val, make sure we propagate its replacement value to +        // all its users. +        //   +        // We can not yet delete UI, the LIC user, yet, because that would invalidate +        // the LIC->users() iterator !. However, we can make this instruction +        // dead by replacing all its users and push it onto the worklist so that +        // it can be properly deleted and its operands simplified.  +        UI->replaceAllUsesWith(Replacement); +      } +    } -    // TODO: We could do other simplifications, for example, turning -    // 'icmp eq LIC, Val' -> false. +    // This is a LIC user, push it into the worklist so that SimplifyCode can +    // attempt to simplify it. +    Worklist.push_back(UI);      // If we know that LIC is not Val, use this info to simplify code.      SwitchInst *SI = dyn_cast<SwitchInst>(UI);      if (!SI || !isa<ConstantInt>(Val)) continue; -    SwitchInst::CaseIt DeadCase = SI->findCaseValue(cast<ConstantInt>(Val)); +    // NOTE: if a case value for the switch is unswitched out, we record it +    // after the unswitch finishes. We can not record it here as the switch +    // is not a direct user of the partial LIV. +    SwitchInst::CaseHandle DeadCase = +        *SI->findCaseValue(cast<ConstantInt>(Val));      // Default case is live for multiple values. -    if (DeadCase == SI->case_default()) continue; +    if (DeadCase == *SI->case_default()) +      continue;      // Found a dead case value.  Don't remove PHI nodes in the      // successor if they become single-entry, those PHI nodes may @@ -1274,8 +1373,6 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,      BasicBlock *SISucc = DeadCase.getCaseSuccessor();      BasicBlock *Latch = L->getLoopLatch(); -    BranchesInfo.setUnswitched(SI, Val); -      if (!SI->findCaseDest(SISucc)) continue;  // Edge is critical.      // If the DeadCase successor dominates the loop latch, then the      // transformation isn't safe since it will delete the sole predecessor edge @@ -1397,3 +1494,27 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) {      }    }  } + +/// Simple simplifications we can do given the information that Cond is +/// definitely not equal to Val. +Value *LoopUnswitch::SimplifyInstructionWithNotEqual(Instruction *Inst, +                                                     Value *Invariant, +                                                     Constant *Val) { +  // icmp eq cond, val -> false +  ICmpInst *CI = dyn_cast<ICmpInst>(Inst); +  if (CI && CI->isEquality()) { +    Value *Op0 = CI->getOperand(0); +    Value *Op1 = CI->getOperand(1); +    if ((Op0 == Invariant && Op1 == Val) || (Op0 == Val && Op1 == Invariant)) { +      LLVMContext &Ctx = Inst->getContext(); +      if (CI->getPredicate() == CmpInst::ICMP_EQ) +        return ConstantInt::getFalse(Ctx); +      else  +        return ConstantInt::getTrue(Ctx); +     } +  } + +  // FIXME: there may be other opportunities, e.g. comparison with floating +  // point, or Invariant - Val != 0, etc. +  return nullptr; +} diff --git a/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp index 52975ef35153..a143b9a3c645 100644 --- a/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp +++ b/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp @@ -67,11 +67,11 @@ static bool handleSwitchExpect(SwitchInst &SI) {    if (!ExpectedValue)      return false; -  SwitchInst::CaseIt Case = SI.findCaseValue(ExpectedValue); +  SwitchInst::CaseHandle Case = *SI.findCaseValue(ExpectedValue);    unsigned n = SI.getNumCases(); // +1 for default case.    SmallVector<uint32_t, 16> Weights(n + 1, UnlikelyBranchWeight); -  if (Case == SI.case_default()) +  if (Case == *SI.case_default())      Weights[0] = LikelyBranchWeight;    else      Weights[Case.getCaseIndex() + 1] = LikelyBranchWeight; diff --git a/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/lib/Transforms/Scalar/MemCpyOptimizer.cpp index 1b590140f70a..a3f3f25c1e0f 100644 --- a/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -12,20 +12,49 @@  //  //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar/MemCpyOptimizer.h" -#include "llvm/Transforms/Scalar.h"  #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/iterator_range.h"  #include "llvm/ADT/SmallVector.h"  #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/MemoryDependenceAnalysis.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/TargetLibraryInfo.h"  #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Constants.h"  #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h"  #include "llvm/IR/GetElementPtrTypeIterator.h"  #include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h"  #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h"  #include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h"  #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/MemCpyOptimizer.h"  #include "llvm/Transforms/Utils/Local.h"  #include <algorithm> +#include <cassert> +#include <cstdint> +  using namespace llvm;  #define DEBUG_TYPE "memcpyopt" @@ -119,6 +148,7 @@ static bool IsPointerOffset(Value *Ptr1, Value *Ptr2, int64_t &Offset,    return true;  } +namespace {  /// Represents a range of memset'd bytes with the ByteVal value.  /// This allows us to analyze stores like: @@ -130,7 +160,6 @@ static bool IsPointerOffset(Value *Ptr1, Value *Ptr2, int64_t &Offset,  /// the first store, we make a range [1, 2).  The second store extends the range  /// to [0, 2).  The third makes a new range [2, 3).  The fourth store joins the  /// two ranges into [0, 3) which is memset'able. -namespace {  struct MemsetRange {    // Start/End - A semi range that describes the span that this range covers.    // The range is closed at the start and open at the end: [Start, End). @@ -148,7 +177,8 @@ struct MemsetRange {    bool isProfitableToUseMemset(const DataLayout &DL) const;  }; -} // end anon namespace + +} // end anonymous namespace  bool MemsetRange::isProfitableToUseMemset(const DataLayout &DL) const {    // If we found more than 4 stores to merge or 16 bytes, use memset. @@ -192,13 +222,14 @@ bool MemsetRange::isProfitableToUseMemset(const DataLayout &DL) const {    return TheStores.size() > NumPointerStores+NumByteStores;  } -  namespace { +  class MemsetRanges {    /// A sorted list of the memset ranges.    SmallVector<MemsetRange, 8> Ranges;    typedef SmallVectorImpl<MemsetRange>::iterator range_iterator;    const DataLayout &DL; +  public:    MemsetRanges(const DataLayout &DL) : DL(DL) {} @@ -231,8 +262,7 @@ public:  }; -} // end anon namespace - +} // end anonymous namespace  /// Add a new store to the MemsetRanges data structure.  This adds a  /// new range for the specified store at the specified offset, merging into @@ -299,48 +329,36 @@ void MemsetRanges::addRange(int64_t Start, int64_t Size, Value *Ptr,  //===----------------------------------------------------------------------===//  namespace { -  class MemCpyOptLegacyPass : public FunctionPass { -    MemCpyOptPass Impl; -  public: -    static char ID; // Pass identification, replacement for typeid -    MemCpyOptLegacyPass() : FunctionPass(ID) { -      initializeMemCpyOptLegacyPassPass(*PassRegistry::getPassRegistry()); -    } -    bool runOnFunction(Function &F) override; - -  private: -    // This transformation requires dominator postdominator info -    void getAnalysisUsage(AnalysisUsage &AU) const override { -      AU.setPreservesCFG(); -      AU.addRequired<AssumptionCacheTracker>(); -      AU.addRequired<DominatorTreeWrapperPass>(); -      AU.addRequired<MemoryDependenceWrapperPass>(); -      AU.addRequired<AAResultsWrapperPass>(); -      AU.addRequired<TargetLibraryInfoWrapperPass>(); -      AU.addPreserved<GlobalsAAWrapperPass>(); -      AU.addPreserved<MemoryDependenceWrapperPass>(); -    } +class MemCpyOptLegacyPass : public FunctionPass { +  MemCpyOptPass Impl; -    // Helper functions -    bool processStore(StoreInst *SI, BasicBlock::iterator &BBI); -    bool processMemSet(MemSetInst *SI, BasicBlock::iterator &BBI); -    bool processMemCpy(MemCpyInst *M); -    bool processMemMove(MemMoveInst *M); -    bool performCallSlotOptzn(Instruction *cpy, Value *cpyDst, Value *cpySrc, -                              uint64_t cpyLen, unsigned cpyAlign, CallInst *C); -    bool processMemCpyMemCpyDependence(MemCpyInst *M, MemCpyInst *MDep); -    bool processMemSetMemCpyDependence(MemCpyInst *M, MemSetInst *MDep); -    bool performMemCpyToMemSetOptzn(MemCpyInst *M, MemSetInst *MDep); -    bool processByValArgument(CallSite CS, unsigned ArgNo); -    Instruction *tryMergingIntoMemset(Instruction *I, Value *StartPtr, -                                      Value *ByteVal); - -    bool iterateOnFunction(Function &F); -  }; +public: +  static char ID; // Pass identification, replacement for typeid -  char MemCpyOptLegacyPass::ID = 0; -} +  MemCpyOptLegacyPass() : FunctionPass(ID) { +    initializeMemCpyOptLegacyPassPass(*PassRegistry::getPassRegistry()); +  } + +  bool runOnFunction(Function &F) override; + +private: +  // This transformation requires dominator postdominator info +  void getAnalysisUsage(AnalysisUsage &AU) const override { +    AU.setPreservesCFG(); +    AU.addRequired<AssumptionCacheTracker>(); +    AU.addRequired<DominatorTreeWrapperPass>(); +    AU.addRequired<MemoryDependenceWrapperPass>(); +    AU.addRequired<AAResultsWrapperPass>(); +    AU.addRequired<TargetLibraryInfoWrapperPass>(); +    AU.addPreserved<GlobalsAAWrapperPass>(); +    AU.addPreserved<MemoryDependenceWrapperPass>(); +  } +}; + +char MemCpyOptLegacyPass::ID = 0; + +} // end anonymous namespace  /// The public interface to this file...  FunctionPass *llvm::createMemCpyOptPass() { return new MemCpyOptLegacyPass(); } @@ -523,14 +541,15 @@ static bool moveUp(AliasAnalysis &AA, StoreInst *SI, Instruction *P,      if (Args.erase(C))        NeedLift = true;      else if (MayAlias) { -      NeedLift = any_of(MemLocs, [C, &AA](const MemoryLocation &ML) { +      NeedLift = llvm::any_of(MemLocs, [C, &AA](const MemoryLocation &ML) {          return AA.getModRefInfo(C, ML);        });        if (!NeedLift) -        NeedLift = any_of(CallSites, [C, &AA](const ImmutableCallSite &CS) { -          return AA.getModRefInfo(C, CS); -        }); +        NeedLift = +            llvm::any_of(CallSites, [C, &AA](const ImmutableCallSite &CS) { +              return AA.getModRefInfo(C, CS); +            });      }      if (!NeedLift) @@ -567,7 +586,7 @@ static bool moveUp(AliasAnalysis &AA, StoreInst *SI, Instruction *P,    }    // We made it, we need to lift -  for (auto *I : reverse(ToLift)) { +  for (auto *I : llvm::reverse(ToLift)) {      DEBUG(dbgs() << "Lifting " << *I << " before " << *P << "\n");      I->moveBefore(P);    } @@ -761,7 +780,6 @@ bool MemCpyOptPass::processMemSet(MemSetInst *MSI, BasicBlock::iterator &BBI) {    return false;  } -  /// Takes a memcpy and a call that it depends on,  /// and checks for the possibility of a call slot optimization by having  /// the call write its result directly into the destination of the memcpy. @@ -914,6 +932,17 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest,    if (MR != MRI_NoModRef)      return false; +  // We can't create address space casts here because we don't know if they're +  // safe for the target. +  if (cpySrc->getType()->getPointerAddressSpace() != +      cpyDest->getType()->getPointerAddressSpace()) +    return false; +  for (unsigned i = 0; i < CS.arg_size(); ++i) +    if (CS.getArgument(i)->stripPointerCasts() == cpySrc && +        cpySrc->getType()->getPointerAddressSpace() != +        CS.getArgument(i)->getType()->getPointerAddressSpace()) +      return false; +    // All the checks have passed, so do the transformation.    bool changedArgument = false;    for (unsigned i = 0; i < CS.arg_size(); ++i) @@ -1240,7 +1269,7 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M) {  bool MemCpyOptPass::processMemMove(MemMoveInst *M) {    AliasAnalysis &AA = LookupAliasAnalysis(); -  if (!TLI->has(LibFunc::memmove)) +  if (!TLI->has(LibFunc_memmove))      return false;    // See if the pointers alias. @@ -1306,6 +1335,11 @@ bool MemCpyOptPass::processByValArgument(CallSite CS, unsigned ArgNo) {                                   CS.getInstruction(), &AC, &DT) < ByValAlign)      return false; +  // The address space of the memcpy source must match the byval argument +  if (MDep->getSource()->getType()->getPointerAddressSpace() != +      ByValArg->getType()->getPointerAddressSpace()) +    return false; +    // Verify that the copied-from memory doesn't change in between the memcpy and    // the byval call.    //    memcpy(a <- b) @@ -1375,7 +1409,6 @@ bool MemCpyOptPass::iterateOnFunction(Function &F) {  }  PreservedAnalyses MemCpyOptPass::run(Function &F, FunctionAnalysisManager &AM) { -    auto &MD = AM.getResult<MemoryDependenceAnalysis>(F);    auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); @@ -1393,7 +1426,9 @@ PreservedAnalyses MemCpyOptPass::run(Function &F, FunctionAnalysisManager &AM) {                              LookupAssumptionCache, LookupDomTree);    if (!MadeChange)      return PreservedAnalyses::all(); +    PreservedAnalyses PA; +  PA.preserveSet<CFGAnalyses>();    PA.preserve<GlobalsAA>();    PA.preserve<MemoryDependenceAnalysis>();    return PA; @@ -1414,10 +1449,10 @@ bool MemCpyOptPass::runImpl(    // If we don't have at least memset and memcpy, there is little point of doing    // anything here.  These are required by a freestanding implementation, so if    // even they are disabled, there is no point in trying hard. -  if (!TLI->has(LibFunc::memset) || !TLI->has(LibFunc::memcpy)) +  if (!TLI->has(LibFunc_memset) || !TLI->has(LibFunc_memcpy))      return false; -  while (1) { +  while (true) {      if (!iterateOnFunction(F))        break;      MadeChange = true; diff --git a/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp index 6a64c6b3619c..acd3ef6791be 100644 --- a/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp +++ b/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp @@ -19,6 +19,8 @@  // thinks it safe to do so.  This optimization helps with eg. hiding load  // latencies, triggering if-conversion, and reducing static code size.  // +// NOTE: This code no longer performs load hoisting, it is subsumed by GVNHoist. +//  //===----------------------------------------------------------------------===//  //  // @@ -87,7 +89,6 @@  #include "llvm/Support/raw_ostream.h"  #include "llvm/Transforms/Scalar.h"  #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/SSAUpdater.h"  using namespace llvm; @@ -118,16 +119,6 @@ private:    void removeInstruction(Instruction *Inst);    BasicBlock *getDiamondTail(BasicBlock *BB);    bool isDiamondHead(BasicBlock *BB); -  // Routines for hoisting loads -  bool isLoadHoistBarrierInRange(const Instruction &Start, -                                 const Instruction &End, LoadInst *LI, -                                 bool SafeToLoadUnconditionally); -  LoadInst *canHoistFromBlock(BasicBlock *BB, LoadInst *LI); -  void hoistInstruction(BasicBlock *BB, Instruction *HoistCand, -                        Instruction *ElseInst); -  bool isSafeToHoist(Instruction *I) const; -  bool hoistLoad(BasicBlock *BB, LoadInst *HoistCand, LoadInst *ElseInst); -  bool mergeLoads(BasicBlock *BB);    // Routines for sinking stores    StoreInst *canSinkFromBlock(BasicBlock *BB, StoreInst *SI);    PHINode *getPHIOperand(BasicBlock *BB, StoreInst *S0, StoreInst *S1); @@ -188,169 +179,6 @@ bool MergedLoadStoreMotion::isDiamondHead(BasicBlock *BB) {    return true;  } -/// -/// \brief True when instruction is a hoist barrier for a load -/// -/// Whenever an instruction could possibly modify the value -/// being loaded or protect against the load from happening -/// it is considered a hoist barrier. -/// -bool MergedLoadStoreMotion::isLoadHoistBarrierInRange( -    const Instruction &Start, const Instruction &End, LoadInst *LI, -    bool SafeToLoadUnconditionally) { -  if (!SafeToLoadUnconditionally) -    for (const Instruction &Inst : -         make_range(Start.getIterator(), End.getIterator())) -      if (!isGuaranteedToTransferExecutionToSuccessor(&Inst)) -        return true; -  MemoryLocation Loc = MemoryLocation::get(LI); -  return AA->canInstructionRangeModRef(Start, End, Loc, MRI_Mod); -} - -/// -/// \brief Decide if a load can be hoisted -/// -/// When there is a load in \p BB to the same address as \p LI -/// and it can be hoisted from \p BB, return that load. -/// Otherwise return Null. -/// -LoadInst *MergedLoadStoreMotion::canHoistFromBlock(BasicBlock *BB1, -                                                   LoadInst *Load0) { -  BasicBlock *BB0 = Load0->getParent(); -  BasicBlock *Head = BB0->getSinglePredecessor(); -  bool SafeToLoadUnconditionally = isSafeToLoadUnconditionally( -      Load0->getPointerOperand(), Load0->getAlignment(), -      Load0->getModule()->getDataLayout(), -      /*ScanFrom=*/Head->getTerminator()); -  for (BasicBlock::iterator BBI = BB1->begin(), BBE = BB1->end(); BBI != BBE; -       ++BBI) { -    Instruction *Inst = &*BBI; - -    // Only merge and hoist loads when their result in used only in BB -    auto *Load1 = dyn_cast<LoadInst>(Inst); -    if (!Load1 || Inst->isUsedOutsideOfBlock(BB1)) -      continue; - -    MemoryLocation Loc0 = MemoryLocation::get(Load0); -    MemoryLocation Loc1 = MemoryLocation::get(Load1); -    if (Load0->isSameOperationAs(Load1) && AA->isMustAlias(Loc0, Loc1) && -        !isLoadHoistBarrierInRange(BB1->front(), *Load1, Load1, -                                   SafeToLoadUnconditionally) && -        !isLoadHoistBarrierInRange(BB0->front(), *Load0, Load0, -                                   SafeToLoadUnconditionally)) { -      return Load1; -    } -  } -  return nullptr; -} - -/// -/// \brief Merge two equivalent instructions \p HoistCand and \p ElseInst into -/// \p BB -/// -/// BB is the head of a diamond -/// -void MergedLoadStoreMotion::hoistInstruction(BasicBlock *BB, -                                             Instruction *HoistCand, -                                             Instruction *ElseInst) { -  DEBUG(dbgs() << " Hoist Instruction into BB \n"; BB->dump(); -        dbgs() << "Instruction Left\n"; HoistCand->dump(); dbgs() << "\n"; -        dbgs() << "Instruction Right\n"; ElseInst->dump(); dbgs() << "\n"); -  // Hoist the instruction. -  assert(HoistCand->getParent() != BB); - -  // Intersect optional metadata. -  HoistCand->andIRFlags(ElseInst); -  HoistCand->dropUnknownNonDebugMetadata(); - -  // Prepend point for instruction insert -  Instruction *HoistPt = BB->getTerminator(); - -  // Merged instruction -  Instruction *HoistedInst = HoistCand->clone(); - -  // Hoist instruction. -  HoistedInst->insertBefore(HoistPt); - -  HoistCand->replaceAllUsesWith(HoistedInst); -  removeInstruction(HoistCand); -  // Replace the else block instruction. -  ElseInst->replaceAllUsesWith(HoistedInst); -  removeInstruction(ElseInst); -} - -/// -/// \brief Return true if no operand of \p I is defined in I's parent block -/// -bool MergedLoadStoreMotion::isSafeToHoist(Instruction *I) const { -  BasicBlock *Parent = I->getParent(); -  for (Use &U : I->operands()) -    if (auto *Instr = dyn_cast<Instruction>(&U)) -      if (Instr->getParent() == Parent) -        return false; -  return true; -} - -/// -/// \brief Merge two equivalent loads and GEPs and hoist into diamond head -/// -bool MergedLoadStoreMotion::hoistLoad(BasicBlock *BB, LoadInst *L0, -                                      LoadInst *L1) { -  // Only one definition? -  auto *A0 = dyn_cast<Instruction>(L0->getPointerOperand()); -  auto *A1 = dyn_cast<Instruction>(L1->getPointerOperand()); -  if (A0 && A1 && A0->isIdenticalTo(A1) && isSafeToHoist(A0) && -      A0->hasOneUse() && (A0->getParent() == L0->getParent()) && -      A1->hasOneUse() && (A1->getParent() == L1->getParent()) && -      isa<GetElementPtrInst>(A0)) { -    DEBUG(dbgs() << "Hoist Instruction into BB \n"; BB->dump(); -          dbgs() << "Instruction Left\n"; L0->dump(); dbgs() << "\n"; -          dbgs() << "Instruction Right\n"; L1->dump(); dbgs() << "\n"); -    hoistInstruction(BB, A0, A1); -    hoistInstruction(BB, L0, L1); -    return true; -  } -  return false; -} - -/// -/// \brief Try to hoist two loads to same address into diamond header -/// -/// Starting from a diamond head block, iterate over the instructions in one -/// successor block and try to match a load in the second successor. -/// -bool MergedLoadStoreMotion::mergeLoads(BasicBlock *BB) { -  bool MergedLoads = false; -  assert(isDiamondHead(BB)); -  BranchInst *BI = cast<BranchInst>(BB->getTerminator()); -  BasicBlock *Succ0 = BI->getSuccessor(0); -  BasicBlock *Succ1 = BI->getSuccessor(1); -  // #Instructions in Succ1 for Compile Time Control -  int Size1 = Succ1->size(); -  int NLoads = 0; -  for (BasicBlock::iterator BBI = Succ0->begin(), BBE = Succ0->end(); -       BBI != BBE;) { -    Instruction *I = &*BBI; -    ++BBI; - -    // Don't move non-simple (atomic, volatile) loads. -    auto *L0 = dyn_cast<LoadInst>(I); -    if (!L0 || !L0->isSimple() || L0->isUsedOutsideOfBlock(Succ0)) -      continue; - -    ++NLoads; -    if (NLoads * Size1 >= MagicCompileTimeControl) -      break; -    if (LoadInst *L1 = canHoistFromBlock(Succ1, L0)) { -      bool Res = hoistLoad(BB, L0, L1); -      MergedLoads |= Res; -      // Don't attempt to hoist above loads that had not been hoisted. -      if (!Res) -        break; -    } -  } -  return MergedLoads; -}  ///  /// \brief True when instruction is a sink barrier for a store @@ -534,7 +362,6 @@ bool MergedLoadStoreMotion::run(Function &F, MemoryDependenceResults *MD,      // Hoist equivalent loads and sink stores      // outside diamonds when possible      if (isDiamondHead(BB)) { -      Changed |= mergeLoads(BB);        Changed |= mergeStores(getDiamondTail(BB));      }    } @@ -596,8 +423,8 @@ MergedLoadStoreMotionPass::run(Function &F, FunctionAnalysisManager &AM) {    if (!Impl.run(F, MD, AA))      return PreservedAnalyses::all(); -  // FIXME: This should also 'preserve the CFG'.    PreservedAnalyses PA; +  PA.preserveSet<CFGAnalyses>();    PA.preserve<GlobalsAA>();    PA.preserve<MemoryDependenceAnalysis>();    return PA; diff --git a/lib/Transforms/Scalar/NaryReassociate.cpp b/lib/Transforms/Scalar/NaryReassociate.cpp index 0a3bf7b4c31b..c5bf2f28d185 100644 --- a/lib/Transforms/Scalar/NaryReassociate.cpp +++ b/lib/Transforms/Scalar/NaryReassociate.cpp @@ -156,20 +156,12 @@ PreservedAnalyses NaryReassociatePass::run(Function &F,    auto *TLI = &AM.getResult<TargetLibraryAnalysis>(F);    auto *TTI = &AM.getResult<TargetIRAnalysis>(F); -  bool Changed = runImpl(F, AC, DT, SE, TLI, TTI); - -  // FIXME: We need to invalidate this to avoid PR28400. Is there a better -  // solution? -  AM.invalidate<ScalarEvolutionAnalysis>(F); - -  if (!Changed) +  if (!runImpl(F, AC, DT, SE, TLI, TTI))      return PreservedAnalyses::all(); -  // FIXME: This should also 'preserve the CFG'.    PreservedAnalyses PA; -  PA.preserve<DominatorTreeAnalysis>(); +  PA.preserveSet<CFGAnalyses>();    PA.preserve<ScalarEvolutionAnalysis>(); -  PA.preserve<TargetLibraryAnalysis>();    return PA;  } diff --git a/lib/Transforms/Scalar/NewGVN.cpp b/lib/Transforms/Scalar/NewGVN.cpp index 57e6e3ddad94..3d8ce888867e 100644 --- a/lib/Transforms/Scalar/NewGVN.cpp +++ b/lib/Transforms/Scalar/NewGVN.cpp @@ -17,6 +17,27 @@  /// "A Sparse Algorithm for Predicated Global Value Numbering" from  /// Karthik Gargi.  /// +/// A brief overview of the algorithm: The algorithm is essentially the same as +/// the standard RPO value numbering algorithm (a good reference is the paper +/// "SCC based value numbering" by L. Taylor Simpson) with one major difference: +/// The RPO algorithm proceeds, on every iteration, to process every reachable +/// block and every instruction in that block.  This is because the standard RPO +/// algorithm does not track what things have the same value number, it only +/// tracks what the value number of a given operation is (the mapping is +/// operation -> value number).  Thus, when a value number of an operation +/// changes, it must reprocess everything to ensure all uses of a value number +/// get updated properly.  In constrast, the sparse algorithm we use *also* +/// tracks what operations have a given value number (IE it also tracks the +/// reverse mapping from value number -> operations with that value number), so +/// that it only needs to reprocess the instructions that are affected when +/// something's value number changes.  The rest of the algorithm is devoted to +/// performing symbolic evaluation, forward propagation, and simplification of +/// operations based on the value numbers deduced so far. +/// +/// We also do not perform elimination by using any published algorithm.  All +/// published algorithms are O(Instructions). Instead, we use a technique that +/// is O(number of operations with the same value number), enabling us to skip +/// trying to eliminate things that have unique value numbers.  //===----------------------------------------------------------------------===//  #include "llvm/Transforms/Scalar/NewGVN.h" @@ -40,13 +61,10 @@  #include "llvm/Analysis/ConstantFolding.h"  #include "llvm/Analysis/GlobalsModRef.h"  #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/Analysis/Loads.h"  #include "llvm/Analysis/MemoryBuiltins.h" -#include "llvm/Analysis/MemoryDependenceAnalysis.h"  #include "llvm/Analysis/MemoryLocation.h" -#include "llvm/Analysis/PHITransAddr.h" +#include "llvm/Analysis/MemorySSA.h"  #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Analysis/ValueTracking.h"  #include "llvm/IR/DataLayout.h"  #include "llvm/IR/Dominators.h"  #include "llvm/IR/GlobalVariable.h" @@ -55,24 +73,25 @@  #include "llvm/IR/LLVMContext.h"  #include "llvm/IR/Metadata.h"  #include "llvm/IR/PatternMatch.h" -#include "llvm/IR/PredIteratorCache.h"  #include "llvm/IR/Type.h"  #include "llvm/Support/Allocator.h"  #include "llvm/Support/CommandLine.h"  #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugCounter.h"  #include "llvm/Transforms/Scalar.h"  #include "llvm/Transforms/Scalar/GVNExpression.h"  #include "llvm/Transforms/Utils/BasicBlockUtils.h"  #include "llvm/Transforms/Utils/Local.h" -#include "llvm/Transforms/Utils/MemorySSA.h" -#include "llvm/Transforms/Utils/SSAUpdater.h" +#include "llvm/Transforms/Utils/PredicateInfo.h" +#include "llvm/Transforms/Utils/VNCoercion.h" +#include <numeric>  #include <unordered_map>  #include <utility>  #include <vector>  using namespace llvm;  using namespace PatternMatch;  using namespace llvm::GVNExpression; - +using namespace llvm::VNCoercion;  #define DEBUG_TYPE "newgvn"  STATISTIC(NumGVNInstrDeleted, "Number of instructions deleted"); @@ -87,6 +106,15 @@ STATISTIC(NumGVNAvoidedSortedLeaderChanges,            "Number of avoided sorted leader changes");  STATISTIC(NumGVNNotMostDominatingLeader,            "Number of times a member dominated it's new classes' leader"); +STATISTIC(NumGVNDeadStores, "Number of redundant/dead stores eliminated"); +DEBUG_COUNTER(VNCounter, "newgvn-vn", +              "Controls which instructions are value numbered") + +// Currently store defining access refinement is too slow due to basicaa being +// egregiously slow.  This flag lets us keep it working while we work on this +// issue. +static cl::opt<bool> EnableStoreRefinement("enable-store-refinement", +                                           cl::init(false), cl::Hidden);  //===----------------------------------------------------------------------===//  //                                GVN Pass @@ -105,6 +133,77 @@ PHIExpression::~PHIExpression() = default;  }  } +// Tarjan's SCC finding algorithm with Nuutila's improvements +// SCCIterator is actually fairly complex for the simple thing we want. +// It also wants to hand us SCC's that are unrelated to the phi node we ask +// about, and have us process them there or risk redoing work. +// Graph traits over a filter iterator also doesn't work that well here. +// This SCC finder is specialized to walk use-def chains, and only follows instructions, +// not generic values (arguments, etc). +struct TarjanSCC { + +  TarjanSCC() : Components(1) {} + +  void Start(const Instruction *Start) { +    if (Root.lookup(Start) == 0) +      FindSCC(Start); +  } + +  const SmallPtrSetImpl<const Value *> &getComponentFor(const Value *V) const { +    unsigned ComponentID = ValueToComponent.lookup(V); + +    assert(ComponentID > 0 && +           "Asking for a component for a value we never processed"); +    return Components[ComponentID]; +  } + +private: +  void FindSCC(const Instruction *I) { +    Root[I] = ++DFSNum; +    // Store the DFS Number we had before it possibly gets incremented. +    unsigned int OurDFS = DFSNum; +    for (auto &Op : I->operands()) { +      if (auto *InstOp = dyn_cast<Instruction>(Op)) { +        if (Root.lookup(Op) == 0) +          FindSCC(InstOp); +        if (!InComponent.count(Op)) +          Root[I] = std::min(Root.lookup(I), Root.lookup(Op)); +      } +    } +    // See if we really were the root of a component, by seeing if we still have our DFSNumber. +    // If we do, we are the root of the component, and we have completed a component. If we do not, +    // we are not the root of a component, and belong on the component stack. +    if (Root.lookup(I) == OurDFS) { +      unsigned ComponentID = Components.size(); +      Components.resize(Components.size() + 1); +      auto &Component = Components.back(); +      Component.insert(I); +      DEBUG(dbgs() << "Component root is " << *I << "\n"); +      InComponent.insert(I); +      ValueToComponent[I] = ComponentID; +      // Pop a component off the stack and label it. +      while (!Stack.empty() && Root.lookup(Stack.back()) >= OurDFS) { +        auto *Member = Stack.back(); +        DEBUG(dbgs() << "Component member is " << *Member << "\n"); +        Component.insert(Member); +        InComponent.insert(Member); +        ValueToComponent[Member] = ComponentID; +        Stack.pop_back(); +      } +    } else { +      // Part of a component, push to stack +      Stack.push_back(I); +    } +  } +  unsigned int DFSNum = 1; +  SmallPtrSet<const Value *, 8> InComponent; +  DenseMap<const Value *, unsigned int> Root; +  SmallVector<const Value *, 8> Stack; +  // Store the components as vector of ptr sets, because we need the topo order +  // of SCC's, but not individual member order +  SmallVector<SmallPtrSet<const Value *, 8>, 8> Components; +  DenseMap<const Value *, unsigned> ValueToComponent; +};  // Congruence classes represent the set of expressions/instructions  // that are all the same *during some scope in the function*.  // That is, because of the way we perform equality propagation, and @@ -115,43 +214,152 @@ PHIExpression::~PHIExpression() = default;  // For any Value in the Member set, it is valid to replace any dominated member  // with that Value.  // -// Every congruence class has a leader, and the leader is used to -// symbolize instructions in a canonical way (IE every operand of an -// instruction that is a member of the same congruence class will -// always be replaced with leader during symbolization). -// To simplify symbolization, we keep the leader as a constant if class can be -// proved to be a constant value. -// Otherwise, the leader is a randomly chosen member of the value set, it does -// not matter which one is chosen. -// Each congruence class also has a defining expression, -// though the expression may be null.  If it exists, it can be used for forward -// propagation and reassociation of values. -// -struct CongruenceClass { -  using MemberSet = SmallPtrSet<Value *, 4>; +// Every congruence class has a leader, and the leader is used to symbolize +// instructions in a canonical way (IE every operand of an instruction that is a +// member of the same congruence class will always be replaced with leader +// during symbolization).  To simplify symbolization, we keep the leader as a +// constant if class can be proved to be a constant value.  Otherwise, the +// leader is the member of the value set with the smallest DFS number.  Each +// congruence class also has a defining expression, though the expression may be +// null.  If it exists, it can be used for forward propagation and reassociation +// of values. + +// For memory, we also track a representative MemoryAccess, and a set of memory +// members for MemoryPhis (which have no real instructions). Note that for +// memory, it seems tempting to try to split the memory members into a +// MemoryCongruenceClass or something.  Unfortunately, this does not work +// easily.  The value numbering of a given memory expression depends on the +// leader of the memory congruence class, and the leader of memory congruence +// class depends on the value numbering of a given memory expression.  This +// leads to wasted propagation, and in some cases, missed optimization.  For +// example: If we had value numbered two stores together before, but now do not, +// we move them to a new value congruence class.  This in turn will move at one +// of the memorydefs to a new memory congruence class.  Which in turn, affects +// the value numbering of the stores we just value numbered (because the memory +// congruence class is part of the value number).  So while theoretically +// possible to split them up, it turns out to be *incredibly* complicated to get +// it to work right, because of the interdependency.  While structurally +// slightly messier, it is algorithmically much simpler and faster to do what we +// do here, and track them both at once in the same class. +// Note: The default iterators for this class iterate over values +class CongruenceClass { +public: +  using MemberType = Value; +  using MemberSet = SmallPtrSet<MemberType *, 4>; +  using MemoryMemberType = MemoryPhi; +  using MemoryMemberSet = SmallPtrSet<const MemoryMemberType *, 2>; + +  explicit CongruenceClass(unsigned ID) : ID(ID) {} +  CongruenceClass(unsigned ID, Value *Leader, const Expression *E) +      : ID(ID), RepLeader(Leader), DefiningExpr(E) {} +  unsigned getID() const { return ID; } +  // True if this class has no members left.  This is mainly used for assertion +  // purposes, and for skipping empty classes. +  bool isDead() const { +    // If it's both dead from a value perspective, and dead from a memory +    // perspective, it's really dead. +    return empty() && memory_empty(); +  } +  // Leader functions +  Value *getLeader() const { return RepLeader; } +  void setLeader(Value *Leader) { RepLeader = Leader; } +  const std::pair<Value *, unsigned int> &getNextLeader() const { +    return NextLeader; +  } +  void resetNextLeader() { NextLeader = {nullptr, ~0}; } + +  void addPossibleNextLeader(std::pair<Value *, unsigned int> LeaderPair) { +    if (LeaderPair.second < NextLeader.second) +      NextLeader = LeaderPair; +  } + +  Value *getStoredValue() const { return RepStoredValue; } +  void setStoredValue(Value *Leader) { RepStoredValue = Leader; } +  const MemoryAccess *getMemoryLeader() const { return RepMemoryAccess; } +  void setMemoryLeader(const MemoryAccess *Leader) { RepMemoryAccess = Leader; } + +  // Forward propagation info +  const Expression *getDefiningExpr() const { return DefiningExpr; } +  void setDefiningExpr(const Expression *E) { DefiningExpr = E; } + +  // Value member set +  bool empty() const { return Members.empty(); } +  unsigned size() const { return Members.size(); } +  MemberSet::const_iterator begin() const { return Members.begin(); } +  MemberSet::const_iterator end() const { return Members.end(); } +  void insert(MemberType *M) { Members.insert(M); } +  void erase(MemberType *M) { Members.erase(M); } +  void swap(MemberSet &Other) { Members.swap(Other); } + +  // Memory member set +  bool memory_empty() const { return MemoryMembers.empty(); } +  unsigned memory_size() const { return MemoryMembers.size(); } +  MemoryMemberSet::const_iterator memory_begin() const { +    return MemoryMembers.begin(); +  } +  MemoryMemberSet::const_iterator memory_end() const { +    return MemoryMembers.end(); +  } +  iterator_range<MemoryMemberSet::const_iterator> memory() const { +    return make_range(memory_begin(), memory_end()); +  } +  void memory_insert(const MemoryMemberType *M) { MemoryMembers.insert(M); } +  void memory_erase(const MemoryMemberType *M) { MemoryMembers.erase(M); } + +  // Store count +  unsigned getStoreCount() const { return StoreCount; } +  void incStoreCount() { ++StoreCount; } +  void decStoreCount() { +    assert(StoreCount != 0 && "Store count went negative"); +    --StoreCount; +  } + +  // Return true if two congruence classes are equivalent to each other.  This +  // means +  // that every field but the ID number and the dead field are equivalent. +  bool isEquivalentTo(const CongruenceClass *Other) const { +    if (!Other) +      return false; +    if (this == Other) +      return true; + +    if (std::tie(StoreCount, RepLeader, RepStoredValue, RepMemoryAccess) != +        std::tie(Other->StoreCount, Other->RepLeader, Other->RepStoredValue, +                 Other->RepMemoryAccess)) +      return false; +    if (DefiningExpr != Other->DefiningExpr) +      if (!DefiningExpr || !Other->DefiningExpr || +          *DefiningExpr != *Other->DefiningExpr) +        return false; +    // We need some ordered set +    std::set<Value *> AMembers(Members.begin(), Members.end()); +    std::set<Value *> BMembers(Members.begin(), Members.end()); +    return AMembers == BMembers; +  } + +private:    unsigned ID;    // Representative leader.    Value *RepLeader = nullptr; +  // The most dominating leader after our current leader, because the member set +  // is not sorted and is expensive to keep sorted all the time. +  std::pair<Value *, unsigned int> NextLeader = {nullptr, ~0U}; +  // If this is represented by a store, the value of the store. +  Value *RepStoredValue = nullptr; +  // If this class contains MemoryDefs or MemoryPhis, this is the leading memory +  // access. +  const MemoryAccess *RepMemoryAccess = nullptr;    // Defining Expression.    const Expression *DefiningExpr = nullptr;    // Actual members of this class.    MemberSet Members; - -  // True if this class has no members left.  This is mainly used for assertion -  // purposes, and for skipping empty classes. -  bool Dead = false; - +  // This is the set of MemoryPhis that exist in the class. MemoryDefs and +  // MemoryUses have real instructions representing them, so we only need to +  // track MemoryPhis here. +  MemoryMemberSet MemoryMembers;    // Number of stores in this congruence class.    // This is used so we can detect store equivalence changes properly.    int StoreCount = 0; - -  // The most dominating leader after our current leader, because the member set -  // is not sorted and is expensive to keep sorted all the time. -  std::pair<Value *, unsigned int> NextLeader = {nullptr, ~0U}; - -  explicit CongruenceClass(unsigned ID) : ID(ID) {} -  CongruenceClass(unsigned ID, Value *Leader, const Expression *E) -      : ID(ID), RepLeader(Leader), DefiningExpr(E) {}  };  namespace llvm { @@ -180,19 +388,34 @@ template <> struct DenseMapInfo<const Expression *> {  };  } // end namespace llvm -class NewGVN : public FunctionPass { +namespace { +class NewGVN { +  Function &F;    DominatorTree *DT; -  const DataLayout *DL; -  const TargetLibraryInfo *TLI;    AssumptionCache *AC; +  const TargetLibraryInfo *TLI;    AliasAnalysis *AA;    MemorySSA *MSSA;    MemorySSAWalker *MSSAWalker; +  const DataLayout &DL; +  std::unique_ptr<PredicateInfo> PredInfo;    BumpPtrAllocator ExpressionAllocator;    ArrayRecycler<Value *> ArgRecycler; +  TarjanSCC SCCFinder; + +  // Number of function arguments, used by ranking +  unsigned int NumFuncArgs; + +  // RPOOrdering of basic blocks +  DenseMap<const DomTreeNode *, unsigned> RPOOrdering;    // Congruence class info. -  CongruenceClass *InitialClass; + +  // This class is called INITIAL in the paper. It is the class everything +  // startsout in, and represents any value. Being an optimistic analysis, +  // anything in the TOP class has the value TOP, which is indeterminate and +  // equivalent to everything. +  CongruenceClass *TOPClass;    std::vector<CongruenceClass *> CongruenceClasses;    unsigned NextCongruenceNum; @@ -200,13 +423,38 @@ class NewGVN : public FunctionPass {    DenseMap<Value *, CongruenceClass *> ValueToClass;    DenseMap<Value *, const Expression *> ValueToExpression; +  // Mapping from predicate info we used to the instructions we used it with. +  // In order to correctly ensure propagation, we must keep track of what +  // comparisons we used, so that when the values of the comparisons change, we +  // propagate the information to the places we used the comparison. +  DenseMap<const Value *, SmallPtrSet<Instruction *, 2>> PredicateToUsers; +  // Mapping from MemoryAccess we used to the MemoryAccess we used it with.  Has +  // the same reasoning as PredicateToUsers.  When we skip MemoryAccesses for +  // stores, we no longer can rely solely on the def-use chains of MemorySSA. +  DenseMap<const MemoryAccess *, SmallPtrSet<MemoryAccess *, 2>> MemoryToUsers; +    // A table storing which memorydefs/phis represent a memory state provably    // equivalent to another memory state.    // We could use the congruence class machinery, but the MemoryAccess's are    // abstract memory states, so they can only ever be equivalent to each other,    // and not to constants, etc. -  DenseMap<const MemoryAccess *, MemoryAccess *> MemoryAccessEquiv; - +  DenseMap<const MemoryAccess *, CongruenceClass *> MemoryAccessToClass; + +  // We could, if we wanted, build MemoryPhiExpressions and +  // MemoryVariableExpressions, etc, and value number them the same way we value +  // number phi expressions.  For the moment, this seems like overkill.  They +  // can only exist in one of three states: they can be TOP (equal to +  // everything), Equivalent to something else, or unique.  Because we do not +  // create expressions for them, we need to simulate leader change not just +  // when they change class, but when they change state.  Note: We can do the +  // same thing for phis, and avoid having phi expressions if we wanted, We +  // should eventually unify in one direction or the other, so this is a little +  // bit of an experiment in which turns out easier to maintain. +  enum MemoryPhiState { MPS_Invalid, MPS_TOP, MPS_Equivalent, MPS_Unique }; +  DenseMap<const MemoryPhi *, MemoryPhiState> MemoryPhiState; + +  enum PhiCycleState { PCS_Unknown, PCS_CycleFree, PCS_Cycle }; +  DenseMap<const PHINode *, PhiCycleState> PhiCycleState;    // Expression to class mapping.    using ExpressionClassMap = DenseMap<const Expression *, CongruenceClass *>;    ExpressionClassMap ExpressionToClass; @@ -231,8 +479,6 @@ class NewGVN : public FunctionPass {    BitVector TouchedInstructions;    DenseMap<const BasicBlock *, std::pair<unsigned, unsigned>> BlockInstRange; -  DenseMap<const DomTreeNode *, std::pair<unsigned, unsigned>> -      DominatedInstRange;  #ifndef NDEBUG    // Debugging for how many times each block and instruction got processed. @@ -240,56 +486,42 @@ class NewGVN : public FunctionPass {  #endif    // DFS info. -  DenseMap<const BasicBlock *, std::pair<int, int>> DFSDomMap; +  // This contains a mapping from Instructions to DFS numbers. +  // The numbering starts at 1. An instruction with DFS number zero +  // means that the instruction is dead.    DenseMap<const Value *, unsigned> InstrDFS; + +  // This contains the mapping DFS numbers to instructions.    SmallVector<Value *, 32> DFSToInstr;    // Deletion info.    SmallPtrSet<Instruction *, 8> InstructionsToErase;  public: -  static char ID; // Pass identification, replacement for typeid. -  NewGVN() : FunctionPass(ID) { -    initializeNewGVNPass(*PassRegistry::getPassRegistry()); -  } - -  bool runOnFunction(Function &F) override; -  bool runGVN(Function &F, DominatorTree *DT, AssumptionCache *AC, -              TargetLibraryInfo *TLI, AliasAnalysis *AA, MemorySSA *MSSA); +  NewGVN(Function &F, DominatorTree *DT, AssumptionCache *AC, +         TargetLibraryInfo *TLI, AliasAnalysis *AA, MemorySSA *MSSA, +         const DataLayout &DL) +      : F(F), DT(DT), AC(AC), TLI(TLI), AA(AA), MSSA(MSSA), DL(DL), +        PredInfo(make_unique<PredicateInfo>(F, *DT, *AC)) {} +  bool runGVN();  private: -  // This transformation requires dominator postdominator info. -  void getAnalysisUsage(AnalysisUsage &AU) const override { -    AU.addRequired<AssumptionCacheTracker>(); -    AU.addRequired<DominatorTreeWrapperPass>(); -    AU.addRequired<TargetLibraryInfoWrapperPass>(); -    AU.addRequired<MemorySSAWrapperPass>(); -    AU.addRequired<AAResultsWrapperPass>(); - -    AU.addPreserved<DominatorTreeWrapperPass>(); -    AU.addPreserved<GlobalsAAWrapperPass>(); -  } -    // Expression handling. -  const Expression *createExpression(Instruction *, const BasicBlock *); -  const Expression *createBinaryExpression(unsigned, Type *, Value *, Value *, -                                           const BasicBlock *); -  PHIExpression *createPHIExpression(Instruction *); +  const Expression *createExpression(Instruction *); +  const Expression *createBinaryExpression(unsigned, Type *, Value *, Value *); +  PHIExpression *createPHIExpression(Instruction *, bool &HasBackedge, +                                     bool &AllConstant);    const VariableExpression *createVariableExpression(Value *);    const ConstantExpression *createConstantExpression(Constant *); -  const Expression *createVariableOrConstant(Value *V, const BasicBlock *B); +  const Expression *createVariableOrConstant(Value *V);    const UnknownExpression *createUnknownExpression(Instruction *); -  const StoreExpression *createStoreExpression(StoreInst *, MemoryAccess *, -                                               const BasicBlock *); +  const StoreExpression *createStoreExpression(StoreInst *, +                                               const MemoryAccess *);    LoadExpression *createLoadExpression(Type *, Value *, LoadInst *, -                                       MemoryAccess *, const BasicBlock *); - -  const CallExpression *createCallExpression(CallInst *, MemoryAccess *, -                                             const BasicBlock *); -  const AggregateValueExpression * -  createAggregateValueExpression(Instruction *, const BasicBlock *); -  bool setBasicExpressionInfo(Instruction *, BasicExpression *, -                              const BasicBlock *); +                                       const MemoryAccess *); +  const CallExpression *createCallExpression(CallInst *, const MemoryAccess *); +  const AggregateValueExpression *createAggregateValueExpression(Instruction *); +  bool setBasicExpressionInfo(Instruction *, BasicExpression *);    // Congruence class handling.    CongruenceClass *createCongruenceClass(Value *Leader, const Expression *E) { @@ -298,9 +530,21 @@ private:      return result;    } +  CongruenceClass *createMemoryClass(MemoryAccess *MA) { +    auto *CC = createCongruenceClass(nullptr, nullptr); +    CC->setMemoryLeader(MA); +    return CC; +  } +  CongruenceClass *ensureLeaderOfMemoryClass(MemoryAccess *MA) { +    auto *CC = getMemoryClass(MA); +    if (CC->getMemoryLeader() != MA) +      CC = createMemoryClass(MA); +    return CC; +  } +    CongruenceClass *createSingletonCongruenceClass(Value *Member) {      CongruenceClass *CClass = createCongruenceClass(Member, nullptr); -    CClass->Members.insert(Member); +    CClass->insert(Member);      ValueToClass[Member] = CClass;      return CClass;    } @@ -313,37 +557,49 @@ private:    // Symbolic evaluation.    const Expression *checkSimplificationResults(Expression *, Instruction *,                                                 Value *); -  const Expression *performSymbolicEvaluation(Value *, const BasicBlock *); -  const Expression *performSymbolicLoadEvaluation(Instruction *, -                                                  const BasicBlock *); -  const Expression *performSymbolicStoreEvaluation(Instruction *, -                                                   const BasicBlock *); -  const Expression *performSymbolicCallEvaluation(Instruction *, -                                                  const BasicBlock *); -  const Expression *performSymbolicPHIEvaluation(Instruction *, -                                                 const BasicBlock *); -  bool setMemoryAccessEquivTo(MemoryAccess *From, MemoryAccess *To); -  const Expression *performSymbolicAggrValueEvaluation(Instruction *, -                                                       const BasicBlock *); +  const Expression *performSymbolicEvaluation(Value *); +  const Expression *performSymbolicLoadCoercion(Type *, Value *, LoadInst *, +                                                Instruction *, MemoryAccess *); +  const Expression *performSymbolicLoadEvaluation(Instruction *); +  const Expression *performSymbolicStoreEvaluation(Instruction *); +  const Expression *performSymbolicCallEvaluation(Instruction *); +  const Expression *performSymbolicPHIEvaluation(Instruction *); +  const Expression *performSymbolicAggrValueEvaluation(Instruction *); +  const Expression *performSymbolicCmpEvaluation(Instruction *); +  const Expression *performSymbolicPredicateInfoEvaluation(Instruction *);    // Congruence finding. -  // Templated to allow them to work both on BB's and BB-edges. -  template <class T> -  Value *lookupOperandLeader(Value *, const User *, const T &) const; +  bool someEquivalentDominates(const Instruction *, const Instruction *) const; +  Value *lookupOperandLeader(Value *) const;    void performCongruenceFinding(Instruction *, const Expression *); -  void moveValueToNewCongruenceClass(Instruction *, CongruenceClass *, -                                     CongruenceClass *); +  void moveValueToNewCongruenceClass(Instruction *, const Expression *, +                                     CongruenceClass *, CongruenceClass *); +  void moveMemoryToNewCongruenceClass(Instruction *, MemoryAccess *, +                                      CongruenceClass *, CongruenceClass *); +  Value *getNextValueLeader(CongruenceClass *) const; +  const MemoryAccess *getNextMemoryLeader(CongruenceClass *) const; +  bool setMemoryClass(const MemoryAccess *From, CongruenceClass *To); +  CongruenceClass *getMemoryClass(const MemoryAccess *MA) const; +  const MemoryAccess *lookupMemoryLeader(const MemoryAccess *) const; +  bool isMemoryAccessTop(const MemoryAccess *) const; + +  // Ranking +  unsigned int getRank(const Value *) const; +  bool shouldSwapOperands(const Value *, const Value *) const; +    // Reachability handling.    void updateReachableEdge(BasicBlock *, BasicBlock *);    void processOutgoingEdges(TerminatorInst *, BasicBlock *); -  bool isOnlyReachableViaThisEdge(const BasicBlockEdge &) const; -  Value *findConditionEquivalence(Value *, BasicBlock *) const; -  MemoryAccess *lookupMemoryAccessEquiv(MemoryAccess *) const; +  Value *findConditionEquivalence(Value *) const;    // Elimination.    struct ValueDFS; -  void convertDenseToDFSOrdered(CongruenceClass::MemberSet &, -                                SmallVectorImpl<ValueDFS> &); +  void convertClassToDFSOrdered(const CongruenceClass &, +                                SmallVectorImpl<ValueDFS> &, +                                DenseMap<const Value *, unsigned int> &, +                                SmallPtrSetImpl<Instruction *> &) const; +  void convertClassToLoadsAndStores(const CongruenceClass &, +                                    SmallVectorImpl<ValueDFS> &) const;    bool eliminateInstructions(Function &);    void replaceInstruction(Instruction *, Value *); @@ -355,35 +611,58 @@ private:    // Various instruction touch utilities    void markUsersTouched(Value *); -  void markMemoryUsersTouched(MemoryAccess *); -  void markLeaderChangeTouched(CongruenceClass *CC); +  void markMemoryUsersTouched(const MemoryAccess *); +  void markMemoryDefTouched(const MemoryAccess *); +  void markPredicateUsersTouched(Instruction *); +  void markValueLeaderChangeTouched(CongruenceClass *CC); +  void markMemoryLeaderChangeTouched(CongruenceClass *CC); +  void addPredicateUsers(const PredicateBase *, Instruction *); +  void addMemoryUsers(const MemoryAccess *To, MemoryAccess *U); + +  // Main loop of value numbering +  void iterateTouchedInstructions();    // Utilities.    void cleanupTables();    std::pair<unsigned, unsigned> assignDFSNumbers(BasicBlock *, unsigned);    void updateProcessedCount(Value *V);    void verifyMemoryCongruency() const; +  void verifyIterationSettled(Function &F);    bool singleReachablePHIPath(const MemoryAccess *, const MemoryAccess *) const; -}; - -char NewGVN::ID = 0; +  BasicBlock *getBlockForValue(Value *V) const; +  void deleteExpression(const Expression *E); +  unsigned InstrToDFSNum(const Value *V) const { +    assert(isa<Instruction>(V) && "This should not be used for MemoryAccesses"); +    return InstrDFS.lookup(V); +  } -// createGVNPass - The public interface to this file. -FunctionPass *llvm::createNewGVNPass() { return new NewGVN(); } +  unsigned InstrToDFSNum(const MemoryAccess *MA) const { +    return MemoryToDFSNum(MA); +  } +  Value *InstrFromDFSNum(unsigned DFSNum) { return DFSToInstr[DFSNum]; } +  // Given a MemoryAccess, return the relevant instruction DFS number.  Note: +  // This deliberately takes a value so it can be used with Use's, which will +  // auto-convert to Value's but not to MemoryAccess's. +  unsigned MemoryToDFSNum(const Value *MA) const { +    assert(isa<MemoryAccess>(MA) && +           "This should not be used with instructions"); +    return isa<MemoryUseOrDef>(MA) +               ? InstrToDFSNum(cast<MemoryUseOrDef>(MA)->getMemoryInst()) +               : InstrDFS.lookup(MA); +  } +  bool isCycleFree(const PHINode *PN); +  template <class T, class Range> T *getMinDFSOfRange(const Range &) const; +  // Debug counter info.  When verifying, we have to reset the value numbering +  // debug counter to the same state it started in to get the same results. +  std::pair<int, int> StartingVNCounter; +}; +} // end anonymous namespace  template <typename T>  static bool equalsLoadStoreHelper(const T &LHS, const Expression &RHS) { -  if ((!isa<LoadExpression>(RHS) && !isa<StoreExpression>(RHS)) || -      !LHS.BasicExpression::equals(RHS)) { +  if (!isa<LoadExpression>(RHS) && !isa<StoreExpression>(RHS))      return false; -  } else if (const auto *L = dyn_cast<LoadExpression>(&RHS)) { -    if (LHS.getDefiningAccess() != L->getDefiningAccess()) -      return false; -  } else if (const auto *S = dyn_cast<StoreExpression>(&RHS)) { -    if (LHS.getDefiningAccess() != S->getDefiningAccess()) -      return false; -  } -  return true; +  return LHS.MemoryExpression::equals(RHS);  }  bool LoadExpression::equals(const Expression &Other) const { @@ -391,7 +670,13 @@ bool LoadExpression::equals(const Expression &Other) const {  }  bool StoreExpression::equals(const Expression &Other) const { -  return equalsLoadStoreHelper(*this, Other); +  if (!equalsLoadStoreHelper(*this, Other)) +    return false; +  // Make sure that store vs store includes the value operand. +  if (const auto *S = dyn_cast<StoreExpression>(&Other)) +    if (getStoredValue() != S->getStoredValue()) +      return false; +  return true;  }  #ifndef NDEBUG @@ -400,16 +685,28 @@ static std::string getBlockName(const BasicBlock *B) {  }  #endif -INITIALIZE_PASS_BEGIN(NewGVN, "newgvn", "Global Value Numbering", false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) -INITIALIZE_PASS_END(NewGVN, "newgvn", "Global Value Numbering", false, false) +// Get the basic block from an instruction/memory value. +BasicBlock *NewGVN::getBlockForValue(Value *V) const { +  if (auto *I = dyn_cast<Instruction>(V)) +    return I->getParent(); +  else if (auto *MP = dyn_cast<MemoryPhi>(V)) +    return MP->getBlock(); +  llvm_unreachable("Should have been able to figure out a block for our value"); +  return nullptr; +} -PHIExpression *NewGVN::createPHIExpression(Instruction *I) { +// Delete a definitely dead expression, so it can be reused by the expression +// allocator.  Some of these are not in creation functions, so we have to accept +// const versions. +void NewGVN::deleteExpression(const Expression *E) { +  assert(isa<BasicExpression>(E)); +  auto *BE = cast<BasicExpression>(E); +  const_cast<BasicExpression *>(BE)->deallocateOperands(ArgRecycler); +  ExpressionAllocator.Deallocate(E); +} + +PHIExpression *NewGVN::createPHIExpression(Instruction *I, bool &HasBackedge, +                                           bool &AllConstant) {    BasicBlock *PHIBlock = I->getParent();    auto *PN = cast<PHINode>(I);    auto *E = @@ -419,28 +716,32 @@ PHIExpression *NewGVN::createPHIExpression(Instruction *I) {    E->setType(I->getType());    E->setOpcode(I->getOpcode()); -  auto ReachablePhiArg = [&](const Use &U) { -    return ReachableBlocks.count(PN->getIncomingBlock(U)); -  }; +  unsigned PHIRPO = RPOOrdering.lookup(DT->getNode(PHIBlock)); -  // Filter out unreachable operands -  auto Filtered = make_filter_range(PN->operands(), ReachablePhiArg); +  // Filter out unreachable phi operands. +  auto Filtered = make_filter_range(PN->operands(), [&](const Use &U) { +    return ReachableEdges.count({PN->getIncomingBlock(U), PHIBlock}); +  });    std::transform(Filtered.begin(), Filtered.end(), op_inserter(E),                   [&](const Use &U) -> Value * { +                   auto *BB = PN->getIncomingBlock(U); +                   auto *DTN = DT->getNode(BB); +                   if (RPOOrdering.lookup(DTN) >= PHIRPO) +                     HasBackedge = true; +                   AllConstant &= isa<UndefValue>(U) || isa<Constant>(U); +                     // Don't try to transform self-defined phis.                     if (U == PN)                       return PN; -                   const BasicBlockEdge BBE(PN->getIncomingBlock(U), PHIBlock); -                   return lookupOperandLeader(U, I, BBE); +                   return lookupOperandLeader(U);                   });    return E;  }  // Set basic expression info (Arguments, type, opcode) for Expression  // E from Instruction I in block B. -bool NewGVN::setBasicExpressionInfo(Instruction *I, BasicExpression *E, -                                    const BasicBlock *B) { +bool NewGVN::setBasicExpressionInfo(Instruction *I, BasicExpression *E) {    bool AllConstant = true;    if (auto *GEP = dyn_cast<GetElementPtrInst>(I))      E->setType(GEP->getSourceElementType()); @@ -452,7 +753,7 @@ bool NewGVN::setBasicExpressionInfo(Instruction *I, BasicExpression *E,    // Transform the operand array into an operand leader array, and keep track of    // whether all members are constant.    std::transform(I->op_begin(), I->op_end(), op_inserter(E), [&](Value *O) { -    auto Operand = lookupOperandLeader(O, I, B); +    auto Operand = lookupOperandLeader(O);      AllConstant &= isa<Constant>(Operand);      return Operand;    }); @@ -461,8 +762,7 @@ bool NewGVN::setBasicExpressionInfo(Instruction *I, BasicExpression *E,  }  const Expression *NewGVN::createBinaryExpression(unsigned Opcode, Type *T, -                                                 Value *Arg1, Value *Arg2, -                                                 const BasicBlock *B) { +                                                 Value *Arg1, Value *Arg2) {    auto *E = new (ExpressionAllocator) BasicExpression(2);    E->setType(T); @@ -473,13 +773,13 @@ const Expression *NewGVN::createBinaryExpression(unsigned Opcode, Type *T,      // of their operands get the same value number by sorting the operand value      // numbers.  Since all commutative instructions have two operands it is more      // efficient to sort by hand rather than using, say, std::sort. -    if (Arg1 > Arg2) +    if (shouldSwapOperands(Arg1, Arg2))        std::swap(Arg1, Arg2);    } -  E->op_push_back(lookupOperandLeader(Arg1, nullptr, B)); -  E->op_push_back(lookupOperandLeader(Arg2, nullptr, B)); +  E->op_push_back(lookupOperandLeader(Arg1)); +  E->op_push_back(lookupOperandLeader(Arg2)); -  Value *V = SimplifyBinOp(Opcode, E->getOperand(0), E->getOperand(1), *DL, TLI, +  Value *V = SimplifyBinOp(Opcode, E->getOperand(0), E->getOperand(1), DL, TLI,                             DT, AC);    if (const Expression *SimplifiedE = checkSimplificationResults(E, nullptr, V))      return SimplifiedE; @@ -502,40 +802,32 @@ const Expression *NewGVN::checkSimplificationResults(Expression *E,      NumGVNOpsSimplified++;      assert(isa<BasicExpression>(E) &&             "We should always have had a basic expression here"); - -    cast<BasicExpression>(E)->deallocateOperands(ArgRecycler); -    ExpressionAllocator.Deallocate(E); +    deleteExpression(E);      return createConstantExpression(C);    } else if (isa<Argument>(V) || isa<GlobalVariable>(V)) {      if (I)        DEBUG(dbgs() << "Simplified " << *I << " to "                     << " variable " << *V << "\n"); -    cast<BasicExpression>(E)->deallocateOperands(ArgRecycler); -    ExpressionAllocator.Deallocate(E); +    deleteExpression(E);      return createVariableExpression(V);    }    CongruenceClass *CC = ValueToClass.lookup(V); -  if (CC && CC->DefiningExpr) { +  if (CC && CC->getDefiningExpr()) {      if (I)        DEBUG(dbgs() << "Simplified " << *I << " to "                     << " expression " << *V << "\n");      NumGVNOpsSimplified++; -    assert(isa<BasicExpression>(E) && -           "We should always have had a basic expression here"); -    cast<BasicExpression>(E)->deallocateOperands(ArgRecycler); -    ExpressionAllocator.Deallocate(E); -    return CC->DefiningExpr; +    deleteExpression(E); +    return CC->getDefiningExpr();    }    return nullptr;  } -const Expression *NewGVN::createExpression(Instruction *I, -                                           const BasicBlock *B) { - +const Expression *NewGVN::createExpression(Instruction *I) {    auto *E = new (ExpressionAllocator) BasicExpression(I->getNumOperands()); -  bool AllConstant = setBasicExpressionInfo(I, E, B); +  bool AllConstant = setBasicExpressionInfo(I, E);    if (I->isCommutative()) {      // Ensure that commutative instructions that only differ by a permutation @@ -543,7 +835,7 @@ const Expression *NewGVN::createExpression(Instruction *I,      // numbers.  Since all commutative instructions have two operands it is more      // efficient to sort by hand rather than using, say, std::sort.      assert(I->getNumOperands() == 2 && "Unsupported commutative instruction!"); -    if (E->getOperand(0) > E->getOperand(1)) +    if (shouldSwapOperands(E->getOperand(0), E->getOperand(1)))        E->swapOperands(0, 1);    } @@ -559,48 +851,43 @@ const Expression *NewGVN::createExpression(Instruction *I,      // Sort the operand value numbers so x<y and y>x get the same value      // number.      CmpInst::Predicate Predicate = CI->getPredicate(); -    if (E->getOperand(0) > E->getOperand(1)) { +    if (shouldSwapOperands(E->getOperand(0), E->getOperand(1))) {        E->swapOperands(0, 1);        Predicate = CmpInst::getSwappedPredicate(Predicate);      }      E->setOpcode((CI->getOpcode() << 8) | Predicate);      // TODO: 25% of our time is spent in SimplifyCmpInst with pointer operands -    // TODO: Since we noop bitcasts, we may need to check types before -    // simplifying, so that we don't end up simplifying based on a wrong -    // type assumption. We should clean this up so we can use constants of the -    // wrong type -      assert(I->getOperand(0)->getType() == I->getOperand(1)->getType() &&             "Wrong types on cmp instruction"); -    if ((E->getOperand(0)->getType() == I->getOperand(0)->getType() && -         E->getOperand(1)->getType() == I->getOperand(1)->getType())) { -      Value *V = SimplifyCmpInst(Predicate, E->getOperand(0), E->getOperand(1), -                                 *DL, TLI, DT, AC); -      if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) -        return SimplifiedE; -    } +    assert((E->getOperand(0)->getType() == I->getOperand(0)->getType() && +            E->getOperand(1)->getType() == I->getOperand(1)->getType())); +    Value *V = SimplifyCmpInst(Predicate, E->getOperand(0), E->getOperand(1), +                               DL, TLI, DT, AC); +    if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) +      return SimplifiedE;    } else if (isa<SelectInst>(I)) {      if (isa<Constant>(E->getOperand(0)) || -        (E->getOperand(1)->getType() == I->getOperand(1)->getType() && -         E->getOperand(2)->getType() == I->getOperand(2)->getType())) { +        E->getOperand(0) == E->getOperand(1)) { +      assert(E->getOperand(1)->getType() == I->getOperand(1)->getType() && +             E->getOperand(2)->getType() == I->getOperand(2)->getType());        Value *V = SimplifySelectInst(E->getOperand(0), E->getOperand(1), -                                    E->getOperand(2), *DL, TLI, DT, AC); +                                    E->getOperand(2), DL, TLI, DT, AC);        if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V))          return SimplifiedE;      }    } else if (I->isBinaryOp()) {      Value *V = SimplifyBinOp(E->getOpcode(), E->getOperand(0), E->getOperand(1), -                             *DL, TLI, DT, AC); +                             DL, TLI, DT, AC);      if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V))        return SimplifiedE;    } else if (auto *BI = dyn_cast<BitCastInst>(I)) { -    Value *V = SimplifyInstruction(BI, *DL, TLI, DT, AC); +    Value *V = SimplifyInstruction(BI, DL, TLI, DT, AC);      if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V))        return SimplifiedE;    } else if (isa<GetElementPtrInst>(I)) {      Value *V = SimplifyGEPInst(E->getType(),                                 ArrayRef<Value *>(E->op_begin(), E->op_end()), -                               *DL, TLI, DT, AC); +                               DL, TLI, DT, AC);      if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V))        return SimplifiedE;    } else if (AllConstant) { @@ -615,7 +902,7 @@ const Expression *NewGVN::createExpression(Instruction *I,      for (Value *Arg : E->operands())        C.emplace_back(cast<Constant>(Arg)); -    if (Value *V = ConstantFoldInstOperands(I, C, *DL, TLI)) +    if (Value *V = ConstantFoldInstOperands(I, C, DL, TLI))        if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V))          return SimplifiedE;    } @@ -623,18 +910,18 @@ const Expression *NewGVN::createExpression(Instruction *I,  }  const AggregateValueExpression * -NewGVN::createAggregateValueExpression(Instruction *I, const BasicBlock *B) { +NewGVN::createAggregateValueExpression(Instruction *I) {    if (auto *II = dyn_cast<InsertValueInst>(I)) {      auto *E = new (ExpressionAllocator)          AggregateValueExpression(I->getNumOperands(), II->getNumIndices()); -    setBasicExpressionInfo(I, E, B); +    setBasicExpressionInfo(I, E);      E->allocateIntOperands(ExpressionAllocator);      std::copy(II->idx_begin(), II->idx_end(), int_op_inserter(E));      return E;    } else if (auto *EI = dyn_cast<ExtractValueInst>(I)) {      auto *E = new (ExpressionAllocator)          AggregateValueExpression(I->getNumOperands(), EI->getNumIndices()); -    setBasicExpressionInfo(EI, E, B); +    setBasicExpressionInfo(EI, E);      E->allocateIntOperands(ExpressionAllocator);      std::copy(EI->idx_begin(), EI->idx_end(), int_op_inserter(E));      return E; @@ -648,12 +935,10 @@ const VariableExpression *NewGVN::createVariableExpression(Value *V) {    return E;  } -const Expression *NewGVN::createVariableOrConstant(Value *V, -                                                   const BasicBlock *B) { -  auto Leader = lookupOperandLeader(V, nullptr, B); -  if (auto *C = dyn_cast<Constant>(Leader)) +const Expression *NewGVN::createVariableOrConstant(Value *V) { +  if (auto *C = dyn_cast<Constant>(V))      return createConstantExpression(C); -  return createVariableExpression(Leader); +  return createVariableExpression(V);  }  const ConstantExpression *NewGVN::createConstantExpression(Constant *C) { @@ -669,40 +954,90 @@ const UnknownExpression *NewGVN::createUnknownExpression(Instruction *I) {  }  const CallExpression *NewGVN::createCallExpression(CallInst *CI, -                                                   MemoryAccess *HV, -                                                   const BasicBlock *B) { +                                                   const MemoryAccess *MA) {    // FIXME: Add operand bundles for calls.    auto *E = -      new (ExpressionAllocator) CallExpression(CI->getNumOperands(), CI, HV); -  setBasicExpressionInfo(CI, E, B); +      new (ExpressionAllocator) CallExpression(CI->getNumOperands(), CI, MA); +  setBasicExpressionInfo(CI, E);    return E;  } +// Return true if some equivalent of instruction Inst dominates instruction U. +bool NewGVN::someEquivalentDominates(const Instruction *Inst, +                                     const Instruction *U) const { +  auto *CC = ValueToClass.lookup(Inst); +  // This must be an instruction because we are only called from phi nodes +  // in the case that the value it needs to check against is an instruction. + +  // The most likely candiates for dominance are the leader and the next leader. +  // The leader or nextleader will dominate in all cases where there is an +  // equivalent that is higher up in the dom tree. +  // We can't *only* check them, however, because the +  // dominator tree could have an infinite number of non-dominating siblings +  // with instructions that are in the right congruence class. +  //       A +  // B C D E F G +  // | +  // H +  // Instruction U could be in H,  with equivalents in every other sibling. +  // Depending on the rpo order picked, the leader could be the equivalent in +  // any of these siblings. +  if (!CC) +    return false; +  if (DT->dominates(cast<Instruction>(CC->getLeader()), U)) +    return true; +  if (CC->getNextLeader().first && +      DT->dominates(cast<Instruction>(CC->getNextLeader().first), U)) +    return true; +  return llvm::any_of(*CC, [&](const Value *Member) { +    return Member != CC->getLeader() && +           DT->dominates(cast<Instruction>(Member), U); +  }); +} +  // See if we have a congruence class and leader for this operand, and if so,  // return it. Otherwise, return the operand itself. -template <class T> -Value *NewGVN::lookupOperandLeader(Value *V, const User *U, const T &B) const { +Value *NewGVN::lookupOperandLeader(Value *V) const {    CongruenceClass *CC = ValueToClass.lookup(V); -  if (CC && (CC != InitialClass)) -    return CC->RepLeader; +  if (CC) { +    // Everything in TOP is represneted by undef, as it can be any value. +    // We do have to make sure we get the type right though, so we can't set the +    // RepLeader to undef. +    if (CC == TOPClass) +      return UndefValue::get(V->getType()); +    return CC->getStoredValue() ? CC->getStoredValue() : CC->getLeader(); +  } +    return V;  } -MemoryAccess *NewGVN::lookupMemoryAccessEquiv(MemoryAccess *MA) const { -  MemoryAccess *Result = MemoryAccessEquiv.lookup(MA); -  return Result ? Result : MA; +const MemoryAccess *NewGVN::lookupMemoryLeader(const MemoryAccess *MA) const { +  auto *CC = getMemoryClass(MA); +  assert(CC->getMemoryLeader() && +         "Every MemoryAccess should be mapped to a " +         "congruence class with a represenative memory " +         "access"); +  return CC->getMemoryLeader(); +} + +// Return true if the MemoryAccess is really equivalent to everything. This is +// equivalent to the lattice value "TOP" in most lattices.  This is the initial +// state of all MemoryAccesses. +bool NewGVN::isMemoryAccessTop(const MemoryAccess *MA) const { +  return getMemoryClass(MA) == TOPClass;  }  LoadExpression *NewGVN::createLoadExpression(Type *LoadType, Value *PointerOp, -                                             LoadInst *LI, MemoryAccess *DA, -                                             const BasicBlock *B) { -  auto *E = new (ExpressionAllocator) LoadExpression(1, LI, DA); +                                             LoadInst *LI, +                                             const MemoryAccess *MA) { +  auto *E = +      new (ExpressionAllocator) LoadExpression(1, LI, lookupMemoryLeader(MA));    E->allocateOperands(ArgRecycler, ExpressionAllocator);    E->setType(LoadType);    // Give store and loads same opcode so they value number together.    E->setOpcode(0); -  E->op_push_back(lookupOperandLeader(PointerOp, LI, B)); +  E->op_push_back(PointerOp);    if (LI)      E->setAlignment(LI->getAlignment()); @@ -713,16 +1048,16 @@ LoadExpression *NewGVN::createLoadExpression(Type *LoadType, Value *PointerOp,  }  const StoreExpression *NewGVN::createStoreExpression(StoreInst *SI, -                                                     MemoryAccess *DA, -                                                     const BasicBlock *B) { -  auto *E = -      new (ExpressionAllocator) StoreExpression(SI->getNumOperands(), SI, DA); +                                                     const MemoryAccess *MA) { +  auto *StoredValueLeader = lookupOperandLeader(SI->getValueOperand()); +  auto *E = new (ExpressionAllocator) +      StoreExpression(SI->getNumOperands(), SI, StoredValueLeader, MA);    E->allocateOperands(ArgRecycler, ExpressionAllocator);    E->setType(SI->getValueOperand()->getType());    // Give store and loads same opcode so they value number together.    E->setOpcode(0); -  E->op_push_back(lookupOperandLeader(SI->getPointerOperand(), SI, B)); +  E->op_push_back(lookupOperandLeader(SI->getPointerOperand()));    // TODO: Value number heap versions. We may be able to discover    // things alias analysis can't on it's own (IE that a store and a @@ -730,44 +1065,140 @@ const StoreExpression *NewGVN::createStoreExpression(StoreInst *SI,    return E;  } -// Utility function to check whether the congruence class has a member other -// than the given instruction. -bool hasMemberOtherThanUs(const CongruenceClass *CC, Instruction *I) { -  // Either it has more than one store, in which case it must contain something -  // other than us (because it's indexed by value), or if it only has one store -  // right now, that member should not be us. -  return CC->StoreCount > 1 || CC->Members.count(I) == 0; -} - -const Expression *NewGVN::performSymbolicStoreEvaluation(Instruction *I, -                                                         const BasicBlock *B) { +const Expression *NewGVN::performSymbolicStoreEvaluation(Instruction *I) {    // Unlike loads, we never try to eliminate stores, so we do not check if they    // are simple and avoid value numbering them.    auto *SI = cast<StoreInst>(I); -  MemoryAccess *StoreAccess = MSSA->getMemoryAccess(SI); -  // See if we are defined by a previous store expression, it already has a -  // value, and it's the same value as our current store. FIXME: Right now, we -  // only do this for simple stores, we should expand to cover memcpys, etc. +  auto *StoreAccess = MSSA->getMemoryAccess(SI); +  // Get the expression, if any, for the RHS of the MemoryDef. +  const MemoryAccess *StoreRHS = StoreAccess->getDefiningAccess(); +  if (EnableStoreRefinement) +    StoreRHS = MSSAWalker->getClobberingMemoryAccess(StoreAccess); +  // If we bypassed the use-def chains, make sure we add a use. +  if (StoreRHS != StoreAccess->getDefiningAccess()) +    addMemoryUsers(StoreRHS, StoreAccess); + +  StoreRHS = lookupMemoryLeader(StoreRHS); +  // If we are defined by ourselves, use the live on entry def. +  if (StoreRHS == StoreAccess) +    StoreRHS = MSSA->getLiveOnEntryDef(); +    if (SI->isSimple()) { -    // Get the expression, if any, for the RHS of the MemoryDef. -    MemoryAccess *StoreRHS = lookupMemoryAccessEquiv( -        cast<MemoryDef>(StoreAccess)->getDefiningAccess()); -    const Expression *OldStore = createStoreExpression(SI, StoreRHS, B); -    CongruenceClass *CC = ExpressionToClass.lookup(OldStore); +    // See if we are defined by a previous store expression, it already has a +    // value, and it's the same value as our current store. FIXME: Right now, we +    // only do this for simple stores, we should expand to cover memcpys, etc. +    const auto *LastStore = createStoreExpression(SI, StoreRHS); +    const auto *LastCC = ExpressionToClass.lookup(LastStore);      // Basically, check if the congruence class the store is in is defined by a      // store that isn't us, and has the same value.  MemorySSA takes care of      // ensuring the store has the same memory state as us already. -    if (CC && CC->DefiningExpr && isa<StoreExpression>(CC->DefiningExpr) && -        CC->RepLeader == lookupOperandLeader(SI->getValueOperand(), SI, B) && -        hasMemberOtherThanUs(CC, I)) -      return createStoreExpression(SI, StoreRHS, B); +    // The RepStoredValue gets nulled if all the stores disappear in a class, so +    // we don't need to check if the class contains a store besides us. +    if (LastCC && +        LastCC->getStoredValue() == lookupOperandLeader(SI->getValueOperand())) +      return LastStore; +    deleteExpression(LastStore); +    // Also check if our value operand is defined by a load of the same memory +    // location, and the memory state is the same as it was then (otherwise, it +    // could have been overwritten later. See test32 in +    // transforms/DeadStoreElimination/simple.ll). +    if (auto *LI = +            dyn_cast<LoadInst>(lookupOperandLeader(SI->getValueOperand()))) { +      if ((lookupOperandLeader(LI->getPointerOperand()) == +           lookupOperandLeader(SI->getPointerOperand())) && +          (lookupMemoryLeader(MSSA->getMemoryAccess(LI)->getDefiningAccess()) == +           StoreRHS)) +        return createVariableExpression(LI); +    }    } -  return createStoreExpression(SI, StoreAccess, B); +  // If the store is not equivalent to anything, value number it as a store that +  // produces a unique memory state (instead of using it's MemoryUse, we use +  // it's MemoryDef). +  return createStoreExpression(SI, StoreAccess);  } -const Expression *NewGVN::performSymbolicLoadEvaluation(Instruction *I, -                                                        const BasicBlock *B) { +// See if we can extract the value of a loaded pointer from a load, a store, or +// a memory instruction. +const Expression * +NewGVN::performSymbolicLoadCoercion(Type *LoadType, Value *LoadPtr, +                                    LoadInst *LI, Instruction *DepInst, +                                    MemoryAccess *DefiningAccess) { +  assert((!LI || LI->isSimple()) && "Not a simple load"); +  if (auto *DepSI = dyn_cast<StoreInst>(DepInst)) { +    // Can't forward from non-atomic to atomic without violating memory model. +    // Also don't need to coerce if they are the same type, we will just +    // propogate.. +    if (LI->isAtomic() > DepSI->isAtomic() || +        LoadType == DepSI->getValueOperand()->getType()) +      return nullptr; +    int Offset = analyzeLoadFromClobberingStore(LoadType, LoadPtr, DepSI, DL); +    if (Offset >= 0) { +      if (auto *C = dyn_cast<Constant>( +              lookupOperandLeader(DepSI->getValueOperand()))) { +        DEBUG(dbgs() << "Coercing load from store " << *DepSI << " to constant " +                     << *C << "\n"); +        return createConstantExpression( +            getConstantStoreValueForLoad(C, Offset, LoadType, DL)); +      } +    } + +  } else if (LoadInst *DepLI = dyn_cast<LoadInst>(DepInst)) { +    // Can't forward from non-atomic to atomic without violating memory model. +    if (LI->isAtomic() > DepLI->isAtomic()) +      return nullptr; +    int Offset = analyzeLoadFromClobberingLoad(LoadType, LoadPtr, DepLI, DL); +    if (Offset >= 0) { +      // We can coerce a constant load into a load +      if (auto *C = dyn_cast<Constant>(lookupOperandLeader(DepLI))) +        if (auto *PossibleConstant = +                getConstantLoadValueForLoad(C, Offset, LoadType, DL)) { +          DEBUG(dbgs() << "Coercing load from load " << *LI << " to constant " +                       << *PossibleConstant << "\n"); +          return createConstantExpression(PossibleConstant); +        } +    } + +  } else if (MemIntrinsic *DepMI = dyn_cast<MemIntrinsic>(DepInst)) { +    int Offset = analyzeLoadFromClobberingMemInst(LoadType, LoadPtr, DepMI, DL); +    if (Offset >= 0) { +      if (auto *PossibleConstant = +              getConstantMemInstValueForLoad(DepMI, Offset, LoadType, DL)) { +        DEBUG(dbgs() << "Coercing load from meminst " << *DepMI +                     << " to constant " << *PossibleConstant << "\n"); +        return createConstantExpression(PossibleConstant); +      } +    } +  } + +  // All of the below are only true if the loaded pointer is produced +  // by the dependent instruction. +  if (LoadPtr != lookupOperandLeader(DepInst) && +      !AA->isMustAlias(LoadPtr, DepInst)) +    return nullptr; +  // If this load really doesn't depend on anything, then we must be loading an +  // undef value.  This can happen when loading for a fresh allocation with no +  // intervening stores, for example.  Note that this is only true in the case +  // that the result of the allocation is pointer equal to the load ptr. +  if (isa<AllocaInst>(DepInst) || isMallocLikeFn(DepInst, TLI)) { +    return createConstantExpression(UndefValue::get(LoadType)); +  } +  // If this load occurs either right after a lifetime begin, +  // then the loaded value is undefined. +  else if (auto *II = dyn_cast<IntrinsicInst>(DepInst)) { +    if (II->getIntrinsicID() == Intrinsic::lifetime_start) +      return createConstantExpression(UndefValue::get(LoadType)); +  } +  // If this load follows a calloc (which zero initializes memory), +  // then the loaded value is zero +  else if (isCallocLikeFn(DepInst, TLI)) { +    return createConstantExpression(Constant::getNullValue(LoadType)); +  } + +  return nullptr; +} + +const Expression *NewGVN::performSymbolicLoadEvaluation(Instruction *I) {    auto *LI = cast<LoadInst>(I);    // We can eliminate in favor of non-simple loads, but we won't be able to @@ -775,7 +1206,7 @@ const Expression *NewGVN::performSymbolicLoadEvaluation(Instruction *I,    if (!LI->isSimple())      return nullptr; -  Value *LoadAddressLeader = lookupOperandLeader(LI->getPointerOperand(), I, B); +  Value *LoadAddressLeader = lookupOperandLeader(LI->getPointerOperand());    // Load of undef is undef.    if (isa<UndefValue>(LoadAddressLeader))      return createConstantExpression(UndefValue::get(LI->getType())); @@ -788,61 +1219,233 @@ const Expression *NewGVN::performSymbolicLoadEvaluation(Instruction *I,        // If the defining instruction is not reachable, replace with undef.        if (!ReachableBlocks.count(DefiningInst->getParent()))          return createConstantExpression(UndefValue::get(LI->getType())); +      // This will handle stores and memory insts.  We only do if it the +      // defining access has a different type, or it is a pointer produced by +      // certain memory operations that cause the memory to have a fixed value +      // (IE things like calloc). +      if (const auto *CoercionResult = +              performSymbolicLoadCoercion(LI->getType(), LoadAddressLeader, LI, +                                          DefiningInst, DefiningAccess)) +        return CoercionResult;      }    } -  const Expression *E = -      createLoadExpression(LI->getType(), LI->getPointerOperand(), LI, -                           lookupMemoryAccessEquiv(DefiningAccess), B); +  const Expression *E = createLoadExpression(LI->getType(), LoadAddressLeader, +                                             LI, DefiningAccess);    return E;  } +const Expression * +NewGVN::performSymbolicPredicateInfoEvaluation(Instruction *I) { +  auto *PI = PredInfo->getPredicateInfoFor(I); +  if (!PI) +    return nullptr; + +  DEBUG(dbgs() << "Found predicate info from instruction !\n"); + +  auto *PWC = dyn_cast<PredicateWithCondition>(PI); +  if (!PWC) +    return nullptr; + +  auto *CopyOf = I->getOperand(0); +  auto *Cond = PWC->Condition; + +  // If this a copy of the condition, it must be either true or false depending +  // on the predicate info type and edge +  if (CopyOf == Cond) { +    // We should not need to add predicate users because the predicate info is +    // already a use of this operand. +    if (isa<PredicateAssume>(PI)) +      return createConstantExpression(ConstantInt::getTrue(Cond->getType())); +    if (auto *PBranch = dyn_cast<PredicateBranch>(PI)) { +      if (PBranch->TrueEdge) +        return createConstantExpression(ConstantInt::getTrue(Cond->getType())); +      return createConstantExpression(ConstantInt::getFalse(Cond->getType())); +    } +    if (auto *PSwitch = dyn_cast<PredicateSwitch>(PI)) +      return createConstantExpression(cast<Constant>(PSwitch->CaseValue)); +  } + +  // Not a copy of the condition, so see what the predicates tell us about this +  // value.  First, though, we check to make sure the value is actually a copy +  // of one of the condition operands. It's possible, in certain cases, for it +  // to be a copy of a predicateinfo copy. In particular, if two branch +  // operations use the same condition, and one branch dominates the other, we +  // will end up with a copy of a copy.  This is currently a small deficiency in +  // predicateinfo.  What will end up happening here is that we will value +  // number both copies the same anyway. + +  // Everything below relies on the condition being a comparison. +  auto *Cmp = dyn_cast<CmpInst>(Cond); +  if (!Cmp) +    return nullptr; + +  if (CopyOf != Cmp->getOperand(0) && CopyOf != Cmp->getOperand(1)) { +    DEBUG(dbgs() << "Copy is not of any condition operands!"); +    return nullptr; +  } +  Value *FirstOp = lookupOperandLeader(Cmp->getOperand(0)); +  Value *SecondOp = lookupOperandLeader(Cmp->getOperand(1)); +  bool SwappedOps = false; +  // Sort the ops +  if (shouldSwapOperands(FirstOp, SecondOp)) { +    std::swap(FirstOp, SecondOp); +    SwappedOps = true; +  } +  CmpInst::Predicate Predicate = +      SwappedOps ? Cmp->getSwappedPredicate() : Cmp->getPredicate(); + +  if (isa<PredicateAssume>(PI)) { +    // If the comparison is true when the operands are equal, then we know the +    // operands are equal, because assumes must always be true. +    if (CmpInst::isTrueWhenEqual(Predicate)) { +      addPredicateUsers(PI, I); +      return createVariableOrConstant(FirstOp); +    } +  } +  if (const auto *PBranch = dyn_cast<PredicateBranch>(PI)) { +    // If we are *not* a copy of the comparison, we may equal to the other +    // operand when the predicate implies something about equality of +    // operations.  In particular, if the comparison is true/false when the +    // operands are equal, and we are on the right edge, we know this operation +    // is equal to something. +    if ((PBranch->TrueEdge && Predicate == CmpInst::ICMP_EQ) || +        (!PBranch->TrueEdge && Predicate == CmpInst::ICMP_NE)) { +      addPredicateUsers(PI, I); +      return createVariableOrConstant(FirstOp); +    } +    // Handle the special case of floating point. +    if (((PBranch->TrueEdge && Predicate == CmpInst::FCMP_OEQ) || +         (!PBranch->TrueEdge && Predicate == CmpInst::FCMP_UNE)) && +        isa<ConstantFP>(FirstOp) && !cast<ConstantFP>(FirstOp)->isZero()) { +      addPredicateUsers(PI, I); +      return createConstantExpression(cast<Constant>(FirstOp)); +    } +  } +  return nullptr; +} +  // Evaluate read only and pure calls, and create an expression result. -const Expression *NewGVN::performSymbolicCallEvaluation(Instruction *I, -                                                        const BasicBlock *B) { +const Expression *NewGVN::performSymbolicCallEvaluation(Instruction *I) {    auto *CI = cast<CallInst>(I); -  if (AA->doesNotAccessMemory(CI)) -    return createCallExpression(CI, nullptr, B); -  if (AA->onlyReadsMemory(CI)) { +  if (auto *II = dyn_cast<IntrinsicInst>(I)) { +    // Instrinsics with the returned attribute are copies of arguments. +    if (auto *ReturnedValue = II->getReturnedArgOperand()) { +      if (II->getIntrinsicID() == Intrinsic::ssa_copy) +        if (const auto *Result = performSymbolicPredicateInfoEvaluation(I)) +          return Result; +      return createVariableOrConstant(ReturnedValue); +    } +  } +  if (AA->doesNotAccessMemory(CI)) { +    return createCallExpression(CI, TOPClass->getMemoryLeader()); +  } else if (AA->onlyReadsMemory(CI)) {      MemoryAccess *DefiningAccess = MSSAWalker->getClobberingMemoryAccess(CI); -    return createCallExpression(CI, lookupMemoryAccessEquiv(DefiningAccess), B); +    return createCallExpression(CI, DefiningAccess);    }    return nullptr;  } -// Update the memory access equivalence table to say that From is equal to To, +// Retrieve the memory class for a given MemoryAccess. +CongruenceClass *NewGVN::getMemoryClass(const MemoryAccess *MA) const { + +  auto *Result = MemoryAccessToClass.lookup(MA); +  assert(Result && "Should have found memory class"); +  return Result; +} + +// Update the MemoryAccess equivalence table to say that From is equal to To,  // and return true if this is different from what already existed in the table. -bool NewGVN::setMemoryAccessEquivTo(MemoryAccess *From, MemoryAccess *To) { -  DEBUG(dbgs() << "Setting " << *From << " equivalent to "); -  if (!To) -    DEBUG(dbgs() << "itself"); -  else -    DEBUG(dbgs() << *To); +bool NewGVN::setMemoryClass(const MemoryAccess *From, +                            CongruenceClass *NewClass) { +  assert(NewClass && +         "Every MemoryAccess should be getting mapped to a non-null class"); +  DEBUG(dbgs() << "Setting " << *From); +  DEBUG(dbgs() << " equivalent to congruence class "); +  DEBUG(dbgs() << NewClass->getID() << " with current MemoryAccess leader "); +  DEBUG(dbgs() << *NewClass->getMemoryLeader());    DEBUG(dbgs() << "\n"); -  auto LookupResult = MemoryAccessEquiv.find(From); + +  auto LookupResult = MemoryAccessToClass.find(From);    bool Changed = false;    // If it's already in the table, see if the value changed. -  if (LookupResult != MemoryAccessEquiv.end()) { -    if (To && LookupResult->second != To) { +  if (LookupResult != MemoryAccessToClass.end()) { +    auto *OldClass = LookupResult->second; +    if (OldClass != NewClass) { +      // If this is a phi, we have to handle memory member updates. +      if (auto *MP = dyn_cast<MemoryPhi>(From)) { +        OldClass->memory_erase(MP); +        NewClass->memory_insert(MP); +        // This may have killed the class if it had no non-memory members +        if (OldClass->getMemoryLeader() == From) { +          if (OldClass->memory_empty()) { +            OldClass->setMemoryLeader(nullptr); +          } else { +            OldClass->setMemoryLeader(getNextMemoryLeader(OldClass)); +            DEBUG(dbgs() << "Memory class leader change for class " +                         << OldClass->getID() << " to " +                         << *OldClass->getMemoryLeader() +                         << " due to removal of a memory member " << *From +                         << "\n"); +            markMemoryLeaderChangeTouched(OldClass); +          } +        } +      }        // It wasn't equivalent before, and now it is. -      LookupResult->second = To; -      Changed = true; -    } else if (!To) { -      // It used to be equivalent to something, and now it's not. -      MemoryAccessEquiv.erase(LookupResult); +      LookupResult->second = NewClass;        Changed = true;      } -  } else { -    assert(!To && -           "Memory equivalence should never change from nothing to something");    }    return Changed;  } + +// Determine if a phi is cycle-free.  That means the values in the phi don't +// depend on any expressions that can change value as a result of the phi. +// For example, a non-cycle free phi would be  v = phi(0, v+1). +bool NewGVN::isCycleFree(const PHINode *PN) { +  // In order to compute cycle-freeness, we do SCC finding on the phi, and see +  // what kind of SCC it ends up in.  If it is a singleton, it is cycle-free. +  // If it is not in a singleton, it is only cycle free if the other members are +  // all phi nodes (as they do not compute anything, they are copies).  TODO: +  // There are likely a few other intrinsics or expressions that could be +  // included here, but this happens so infrequently already that it is not +  // likely to be worth it. +  auto PCS = PhiCycleState.lookup(PN); +  if (PCS == PCS_Unknown) { +    SCCFinder.Start(PN); +    auto &SCC = SCCFinder.getComponentFor(PN); +    // It's cycle free if it's size 1 or or the SCC is *only* phi nodes. +    if (SCC.size() == 1) +      PhiCycleState.insert({PN, PCS_CycleFree}); +    else { +      bool AllPhis = +          llvm::all_of(SCC, [](const Value *V) { return isa<PHINode>(V); }); +      PCS = AllPhis ? PCS_CycleFree : PCS_Cycle; +      for (auto *Member : SCC) +        if (auto *MemberPhi = dyn_cast<PHINode>(Member)) +          PhiCycleState.insert({MemberPhi, PCS}); +    } +  } +  if (PCS == PCS_Cycle) +    return false; +  return true; +} +  // Evaluate PHI nodes symbolically, and create an expression result. -const Expression *NewGVN::performSymbolicPHIEvaluation(Instruction *I, -                                                       const BasicBlock *B) { -  auto *E = cast<PHIExpression>(createPHIExpression(I)); +const Expression *NewGVN::performSymbolicPHIEvaluation(Instruction *I) { +  // True if one of the incoming phi edges is a backedge. +  bool HasBackedge = false; +  // All constant tracks the state of whether all the *original* phi operands +  // were constant. +  // This is really shorthand for "this phi cannot cycle due to forward +  // propagation", as any +  // change in value of the phi is guaranteed not to later change the value of +  // the phi. +  // IE it can't be v = phi(undef, v+1) +  bool AllConstant = true; +  auto *E = +      cast<PHIExpression>(createPHIExpression(I, HasBackedge, AllConstant));    // We match the semantics of SimplifyPhiNode from InstructionSimplify here.    // See if all arguaments are the same. @@ -861,14 +1464,15 @@ const Expression *NewGVN::performSymbolicPHIEvaluation(Instruction *I,    if (Filtered.begin() == Filtered.end()) {      DEBUG(dbgs() << "Simplified PHI node " << *I << " to undef"                   << "\n"); -    E->deallocateOperands(ArgRecycler); -    ExpressionAllocator.Deallocate(E); +    deleteExpression(E);      return createConstantExpression(UndefValue::get(I->getType()));    } +  unsigned NumOps = 0;    Value *AllSameValue = *(Filtered.begin());    ++Filtered.begin();    // Can't use std::equal here, sadly, because filter.begin moves. -  if (llvm::all_of(Filtered, [AllSameValue](const Value *V) { +  if (llvm::all_of(Filtered, [AllSameValue, &NumOps](const Value *V) { +        ++NumOps;          return V == AllSameValue;        })) {      // In LLVM's non-standard representation of phi nodes, it's possible to have @@ -881,27 +1485,32 @@ const Expression *NewGVN::performSymbolicPHIEvaluation(Instruction *I,      // We also special case undef, so that if we have an undef, we can't use the      // common value unless it dominates the phi block.      if (HasUndef) { +      // If we have undef and at least one other value, this is really a +      // multivalued phi, and we need to know if it's cycle free in order to +      // evaluate whether we can ignore the undef.  The other parts of this are +      // just shortcuts.  If there is no backedge, or all operands are +      // constants, or all operands are ignored but the undef, it also must be +      // cycle free. +      if (!AllConstant && HasBackedge && NumOps > 0 && +          !isa<UndefValue>(AllSameValue) && !isCycleFree(cast<PHINode>(I))) +        return E; +        // Only have to check for instructions        if (auto *AllSameInst = dyn_cast<Instruction>(AllSameValue)) -        if (!DT->dominates(AllSameInst, I)) +        if (!someEquivalentDominates(AllSameInst, I))            return E;      }      NumGVNPhisAllSame++;      DEBUG(dbgs() << "Simplified PHI node " << *I << " to " << *AllSameValue                   << "\n"); -    E->deallocateOperands(ArgRecycler); -    ExpressionAllocator.Deallocate(E); -    if (auto *C = dyn_cast<Constant>(AllSameValue)) -      return createConstantExpression(C); -    return createVariableExpression(AllSameValue); +    deleteExpression(E); +    return createVariableOrConstant(AllSameValue);    }    return E;  } -const Expression * -NewGVN::performSymbolicAggrValueEvaluation(Instruction *I, -                                           const BasicBlock *B) { +const Expression *NewGVN::performSymbolicAggrValueEvaluation(Instruction *I) {    if (auto *EI = dyn_cast<ExtractValueInst>(I)) {      auto *II = dyn_cast<IntrinsicInst>(EI->getAggregateOperand());      if (II && EI->getNumIndices() == 1 && *EI->idx_begin() == 0) { @@ -931,19 +1540,130 @@ NewGVN::performSymbolicAggrValueEvaluation(Instruction *I,          // expression.          assert(II->getNumArgOperands() == 2 &&                 "Expect two args for recognised intrinsics."); -        return createBinaryExpression(Opcode, EI->getType(), -                                      II->getArgOperand(0), -                                      II->getArgOperand(1), B); +        return createBinaryExpression( +            Opcode, EI->getType(), II->getArgOperand(0), II->getArgOperand(1));        }      }    } -  return createAggregateValueExpression(I, B); +  return createAggregateValueExpression(I); +} +const Expression *NewGVN::performSymbolicCmpEvaluation(Instruction *I) { +  auto *CI = dyn_cast<CmpInst>(I); +  // See if our operands are equal to those of a previous predicate, and if so, +  // if it implies true or false. +  auto Op0 = lookupOperandLeader(CI->getOperand(0)); +  auto Op1 = lookupOperandLeader(CI->getOperand(1)); +  auto OurPredicate = CI->getPredicate(); +  if (shouldSwapOperands(Op0, Op1)) { +    std::swap(Op0, Op1); +    OurPredicate = CI->getSwappedPredicate(); +  } + +  // Avoid processing the same info twice +  const PredicateBase *LastPredInfo = nullptr; +  // See if we know something about the comparison itself, like it is the target +  // of an assume. +  auto *CmpPI = PredInfo->getPredicateInfoFor(I); +  if (dyn_cast_or_null<PredicateAssume>(CmpPI)) +    return createConstantExpression(ConstantInt::getTrue(CI->getType())); + +  if (Op0 == Op1) { +    // This condition does not depend on predicates, no need to add users +    if (CI->isTrueWhenEqual()) +      return createConstantExpression(ConstantInt::getTrue(CI->getType())); +    else if (CI->isFalseWhenEqual()) +      return createConstantExpression(ConstantInt::getFalse(CI->getType())); +  } + +  // NOTE: Because we are comparing both operands here and below, and using +  // previous comparisons, we rely on fact that predicateinfo knows to mark +  // comparisons that use renamed operands as users of the earlier comparisons. +  // It is *not* enough to just mark predicateinfo renamed operands as users of +  // the earlier comparisons, because the *other* operand may have changed in a +  // previous iteration. +  // Example: +  // icmp slt %a, %b +  // %b.0 = ssa.copy(%b) +  // false branch: +  // icmp slt %c, %b.0 + +  // %c and %a may start out equal, and thus, the code below will say the second +  // %icmp is false.  c may become equal to something else, and in that case the +  // %second icmp *must* be reexamined, but would not if only the renamed +  // %operands are considered users of the icmp. + +  // *Currently* we only check one level of comparisons back, and only mark one +  // level back as touched when changes appen .  If you modify this code to look +  // back farther through comparisons, you *must* mark the appropriate +  // comparisons as users in PredicateInfo.cpp, or you will cause bugs.  See if +  // we know something just from the operands themselves + +  // See if our operands have predicate info, so that we may be able to derive +  // something from a previous comparison. +  for (const auto &Op : CI->operands()) { +    auto *PI = PredInfo->getPredicateInfoFor(Op); +    if (const auto *PBranch = dyn_cast_or_null<PredicateBranch>(PI)) { +      if (PI == LastPredInfo) +        continue; +      LastPredInfo = PI; + +      // TODO: Along the false edge, we may know more things too, like icmp of +      // same operands is false. +      // TODO: We only handle actual comparison conditions below, not and/or. +      auto *BranchCond = dyn_cast<CmpInst>(PBranch->Condition); +      if (!BranchCond) +        continue; +      auto *BranchOp0 = lookupOperandLeader(BranchCond->getOperand(0)); +      auto *BranchOp1 = lookupOperandLeader(BranchCond->getOperand(1)); +      auto BranchPredicate = BranchCond->getPredicate(); +      if (shouldSwapOperands(BranchOp0, BranchOp1)) { +        std::swap(BranchOp0, BranchOp1); +        BranchPredicate = BranchCond->getSwappedPredicate(); +      } +      if (BranchOp0 == Op0 && BranchOp1 == Op1) { +        if (PBranch->TrueEdge) { +          // If we know the previous predicate is true and we are in the true +          // edge then we may be implied true or false. +          if (CmpInst::isImpliedTrueByMatchingCmp(OurPredicate, +                                                  BranchPredicate)) { +            addPredicateUsers(PI, I); +            return createConstantExpression( +                ConstantInt::getTrue(CI->getType())); +          } + +          if (CmpInst::isImpliedFalseByMatchingCmp(OurPredicate, +                                                   BranchPredicate)) { +            addPredicateUsers(PI, I); +            return createConstantExpression( +                ConstantInt::getFalse(CI->getType())); +          } + +        } else { +          // Just handle the ne and eq cases, where if we have the same +          // operands, we may know something. +          if (BranchPredicate == OurPredicate) { +            addPredicateUsers(PI, I); +            // Same predicate, same ops,we know it was false, so this is false. +            return createConstantExpression( +                ConstantInt::getFalse(CI->getType())); +          } else if (BranchPredicate == +                     CmpInst::getInversePredicate(OurPredicate)) { +            addPredicateUsers(PI, I); +            // Inverse predicate, we know the other was false, so this is true. +            return createConstantExpression( +                ConstantInt::getTrue(CI->getType())); +          } +        } +      } +    } +  } +  // Create expression will take care of simplifyCmpInst +  return createExpression(I);  }  // Substitute and symbolize the value before value numbering. -const Expression *NewGVN::performSymbolicEvaluation(Value *V, -                                                    const BasicBlock *B) { +const Expression *NewGVN::performSymbolicEvaluation(Value *V) {    const Expression *E = nullptr;    if (auto *C = dyn_cast<Constant>(V))      E = createConstantExpression(C); @@ -957,24 +1677,27 @@ const Expression *NewGVN::performSymbolicEvaluation(Value *V,      switch (I->getOpcode()) {      case Instruction::ExtractValue:      case Instruction::InsertValue: -      E = performSymbolicAggrValueEvaluation(I, B); +      E = performSymbolicAggrValueEvaluation(I);        break;      case Instruction::PHI: -      E = performSymbolicPHIEvaluation(I, B); +      E = performSymbolicPHIEvaluation(I);        break;      case Instruction::Call: -      E = performSymbolicCallEvaluation(I, B); +      E = performSymbolicCallEvaluation(I);        break;      case Instruction::Store: -      E = performSymbolicStoreEvaluation(I, B); +      E = performSymbolicStoreEvaluation(I);        break;      case Instruction::Load: -      E = performSymbolicLoadEvaluation(I, B); +      E = performSymbolicLoadEvaluation(I);        break;      case Instruction::BitCast: { -      E = createExpression(I, B); +      E = createExpression(I); +    } break; +    case Instruction::ICmp: +    case Instruction::FCmp: { +      E = performSymbolicCmpEvaluation(I);      } break; -      case Instruction::Add:      case Instruction::FAdd:      case Instruction::Sub: @@ -993,8 +1716,6 @@ const Expression *NewGVN::performSymbolicEvaluation(Value *V,      case Instruction::And:      case Instruction::Or:      case Instruction::Xor: -    case Instruction::ICmp: -    case Instruction::FCmp:      case Instruction::Trunc:      case Instruction::ZExt:      case Instruction::SExt: @@ -1011,7 +1732,7 @@ const Expression *NewGVN::performSymbolicEvaluation(Value *V,      case Instruction::InsertElement:      case Instruction::ShuffleVector:      case Instruction::GetElementPtr: -      E = createExpression(I, B); +      E = createExpression(I);        break;      default:        return nullptr; @@ -1020,129 +1741,297 @@ const Expression *NewGVN::performSymbolicEvaluation(Value *V,    return E;  } -// There is an edge from 'Src' to 'Dst'.  Return true if every path from -// the entry block to 'Dst' passes via this edge.  In particular 'Dst' -// must not be reachable via another edge from 'Src'. -bool NewGVN::isOnlyReachableViaThisEdge(const BasicBlockEdge &E) const { - -  // While in theory it is interesting to consider the case in which Dst has -  // more than one predecessor, because Dst might be part of a loop which is -  // only reachable from Src, in practice it is pointless since at the time -  // GVN runs all such loops have preheaders, which means that Dst will have -  // been changed to have only one predecessor, namely Src. -  const BasicBlock *Pred = E.getEnd()->getSinglePredecessor(); -  const BasicBlock *Src = E.getStart(); -  assert((!Pred || Pred == Src) && "No edge between these basic blocks!"); -  (void)Src; -  return Pred != nullptr; -} -  void NewGVN::markUsersTouched(Value *V) {    // Now mark the users as touched.    for (auto *User : V->users()) {      assert(isa<Instruction>(User) && "Use of value not within an instruction?"); -    TouchedInstructions.set(InstrDFS[User]); +    TouchedInstructions.set(InstrToDFSNum(User));    }  } -void NewGVN::markMemoryUsersTouched(MemoryAccess *MA) { -  for (auto U : MA->users()) { -    if (auto *MUD = dyn_cast<MemoryUseOrDef>(U)) -      TouchedInstructions.set(InstrDFS[MUD->getMemoryInst()]); -    else -      TouchedInstructions.set(InstrDFS[U]); +void NewGVN::addMemoryUsers(const MemoryAccess *To, MemoryAccess *U) { +  DEBUG(dbgs() << "Adding memory user " << *U << " to " << *To << "\n"); +  MemoryToUsers[To].insert(U); +} + +void NewGVN::markMemoryDefTouched(const MemoryAccess *MA) { +  TouchedInstructions.set(MemoryToDFSNum(MA)); +} + +void NewGVN::markMemoryUsersTouched(const MemoryAccess *MA) { +  if (isa<MemoryUse>(MA)) +    return; +  for (auto U : MA->users()) +    TouchedInstructions.set(MemoryToDFSNum(U)); +  const auto Result = MemoryToUsers.find(MA); +  if (Result != MemoryToUsers.end()) { +    for (auto *User : Result->second) +      TouchedInstructions.set(MemoryToDFSNum(User)); +    MemoryToUsers.erase(Result); +  } +} + +// Add I to the set of users of a given predicate. +void NewGVN::addPredicateUsers(const PredicateBase *PB, Instruction *I) { +  if (auto *PBranch = dyn_cast<PredicateBranch>(PB)) +    PredicateToUsers[PBranch->Condition].insert(I); +  else if (auto *PAssume = dyn_cast<PredicateBranch>(PB)) +    PredicateToUsers[PAssume->Condition].insert(I); +} + +// Touch all the predicates that depend on this instruction. +void NewGVN::markPredicateUsersTouched(Instruction *I) { +  const auto Result = PredicateToUsers.find(I); +  if (Result != PredicateToUsers.end()) { +    for (auto *User : Result->second) +      TouchedInstructions.set(InstrToDFSNum(User)); +    PredicateToUsers.erase(Result);    }  } +// Mark users affected by a memory leader change. +void NewGVN::markMemoryLeaderChangeTouched(CongruenceClass *CC) { +  for (auto M : CC->memory()) +    markMemoryDefTouched(M); +} +  // Touch the instructions that need to be updated after a congruence class has a  // leader change, and mark changed values. -void NewGVN::markLeaderChangeTouched(CongruenceClass *CC) { -  for (auto M : CC->Members) { +void NewGVN::markValueLeaderChangeTouched(CongruenceClass *CC) { +  for (auto M : *CC) {      if (auto *I = dyn_cast<Instruction>(M)) -      TouchedInstructions.set(InstrDFS[I]); +      TouchedInstructions.set(InstrToDFSNum(I));      LeaderChanges.insert(M);    }  } +// Give a range of things that have instruction DFS numbers, this will return +// the member of the range with the smallest dfs number. +template <class T, class Range> +T *NewGVN::getMinDFSOfRange(const Range &R) const { +  std::pair<T *, unsigned> MinDFS = {nullptr, ~0U}; +  for (const auto X : R) { +    auto DFSNum = InstrToDFSNum(X); +    if (DFSNum < MinDFS.second) +      MinDFS = {X, DFSNum}; +  } +  return MinDFS.first; +} + +// This function returns the MemoryAccess that should be the next leader of +// congruence class CC, under the assumption that the current leader is going to +// disappear. +const MemoryAccess *NewGVN::getNextMemoryLeader(CongruenceClass *CC) const { +  // TODO: If this ends up to slow, we can maintain a next memory leader like we +  // do for regular leaders. +  // Make sure there will be a leader to find +  assert((CC->getStoreCount() > 0 || !CC->memory_empty()) && +         "Can't get next leader if there is none"); +  if (CC->getStoreCount() > 0) { +    if (auto *NL = dyn_cast_or_null<StoreInst>(CC->getNextLeader().first)) +      return MSSA->getMemoryAccess(NL); +    // Find the store with the minimum DFS number. +    auto *V = getMinDFSOfRange<Value>(make_filter_range( +        *CC, [&](const Value *V) { return isa<StoreInst>(V); })); +    return MSSA->getMemoryAccess(cast<StoreInst>(V)); +  } +  assert(CC->getStoreCount() == 0); + +  // Given our assertion, hitting this part must mean +  // !OldClass->memory_empty() +  if (CC->memory_size() == 1) +    return *CC->memory_begin(); +  return getMinDFSOfRange<const MemoryPhi>(CC->memory()); +} + +// This function returns the next value leader of a congruence class, under the +// assumption that the current leader is going away.  This should end up being +// the next most dominating member. +Value *NewGVN::getNextValueLeader(CongruenceClass *CC) const { +  // We don't need to sort members if there is only 1, and we don't care about +  // sorting the TOP class because everything either gets out of it or is +  // unreachable. + +  if (CC->size() == 1 || CC == TOPClass) { +    return *(CC->begin()); +  } else if (CC->getNextLeader().first) { +    ++NumGVNAvoidedSortedLeaderChanges; +    return CC->getNextLeader().first; +  } else { +    ++NumGVNSortedLeaderChanges; +    // NOTE: If this ends up to slow, we can maintain a dual structure for +    // member testing/insertion, or keep things mostly sorted, and sort only +    // here, or use SparseBitVector or .... +    return getMinDFSOfRange<Value>(*CC); +  } +} + +// Move a MemoryAccess, currently in OldClass, to NewClass, including updates to +// the memory members, etc for the move. +// +// The invariants of this function are: +// +// I must be moving to NewClass from OldClass The StoreCount of OldClass and +// NewClass is expected to have been updated for I already if it is is a store. +// The OldClass memory leader has not been updated yet if I was the leader. +void NewGVN::moveMemoryToNewCongruenceClass(Instruction *I, +                                            MemoryAccess *InstMA, +                                            CongruenceClass *OldClass, +                                            CongruenceClass *NewClass) { +  // If the leader is I, and we had a represenative MemoryAccess, it should +  // be the MemoryAccess of OldClass. +  assert((!InstMA || !OldClass->getMemoryLeader() || +          OldClass->getLeader() != I || +          OldClass->getMemoryLeader() == InstMA) && +         "Representative MemoryAccess mismatch"); +  // First, see what happens to the new class +  if (!NewClass->getMemoryLeader()) { +    // Should be a new class, or a store becoming a leader of a new class. +    assert(NewClass->size() == 1 || +           (isa<StoreInst>(I) && NewClass->getStoreCount() == 1)); +    NewClass->setMemoryLeader(InstMA); +    // Mark it touched if we didn't just create a singleton +    DEBUG(dbgs() << "Memory class leader change for class " << NewClass->getID() +                 << " due to new memory instruction becoming leader\n"); +    markMemoryLeaderChangeTouched(NewClass); +  } +  setMemoryClass(InstMA, NewClass); +  // Now, fixup the old class if necessary +  if (OldClass->getMemoryLeader() == InstMA) { +    if (OldClass->getStoreCount() != 0 || !OldClass->memory_empty()) { +      OldClass->setMemoryLeader(getNextMemoryLeader(OldClass)); +      DEBUG(dbgs() << "Memory class leader change for class " +                   << OldClass->getID() << " to " +                   << *OldClass->getMemoryLeader() +                   << " due to removal of old leader " << *InstMA << "\n"); +      markMemoryLeaderChangeTouched(OldClass); +    } else +      OldClass->setMemoryLeader(nullptr); +  } +} +  // Move a value, currently in OldClass, to be part of NewClass -// Update OldClass for the move (including changing leaders, etc) -void NewGVN::moveValueToNewCongruenceClass(Instruction *I, +// Update OldClass and NewClass for the move (including changing leaders, etc). +void NewGVN::moveValueToNewCongruenceClass(Instruction *I, const Expression *E,                                             CongruenceClass *OldClass,                                             CongruenceClass *NewClass) { -  DEBUG(dbgs() << "New congruence class for " << I << " is " << NewClass->ID -               << "\n"); - -  if (I == OldClass->NextLeader.first) -    OldClass->NextLeader = {nullptr, ~0U}; +  if (I == OldClass->getNextLeader().first) +    OldClass->resetNextLeader();    // It's possible, though unlikely, for us to discover equivalences such    // that the current leader does not dominate the old one.    // This statistic tracks how often this happens.    // We assert on phi nodes when this happens, currently, for debugging, because    // we want to make sure we name phi node cycles properly. -  if (isa<Instruction>(NewClass->RepLeader) && NewClass->RepLeader && -      I != NewClass->RepLeader && -      DT->properlyDominates( -          I->getParent(), -          cast<Instruction>(NewClass->RepLeader)->getParent())) { -    ++NumGVNNotMostDominatingLeader; -    assert(!isa<PHINode>(I) && -           "New class for instruction should not be dominated by instruction"); -  } - -  if (NewClass->RepLeader != I) { -    auto DFSNum = InstrDFS.lookup(I); -    if (DFSNum < NewClass->NextLeader.second) -      NewClass->NextLeader = {I, DFSNum}; +  if (isa<Instruction>(NewClass->getLeader()) && NewClass->getLeader() && +      I != NewClass->getLeader()) { +    auto *IBB = I->getParent(); +    auto *NCBB = cast<Instruction>(NewClass->getLeader())->getParent(); +    bool Dominated = +        IBB == NCBB && InstrToDFSNum(I) < InstrToDFSNum(NewClass->getLeader()); +    Dominated = Dominated || DT->properlyDominates(IBB, NCBB); +    if (Dominated) { +      ++NumGVNNotMostDominatingLeader; +      assert( +          !isa<PHINode>(I) && +          "New class for instruction should not be dominated by instruction"); +    }    } -  OldClass->Members.erase(I); -  NewClass->Members.insert(I); -  if (isa<StoreInst>(I)) { -    --OldClass->StoreCount; -    assert(OldClass->StoreCount >= 0); -    ++NewClass->StoreCount; -    assert(NewClass->StoreCount > 0); +  if (NewClass->getLeader() != I) +    NewClass->addPossibleNextLeader({I, InstrToDFSNum(I)}); + +  OldClass->erase(I); +  NewClass->insert(I); +  // Handle our special casing of stores. +  if (auto *SI = dyn_cast<StoreInst>(I)) { +    OldClass->decStoreCount(); +    // Okay, so when do we want to make a store a leader of a class? +    // If we have a store defined by an earlier load, we want the earlier load +    // to lead the class. +    // If we have a store defined by something else, we want the store to lead +    // the class so everything else gets the "something else" as a value. +    // If we have a store as the single member of the class, we want the store +    // as the leader +    if (NewClass->getStoreCount() == 0 && !NewClass->getStoredValue()) { +      // If it's a store expression we are using, it means we are not equivalent +      // to something earlier. +      if (isa<StoreExpression>(E)) { +        assert(lookupOperandLeader(SI->getValueOperand()) != +               NewClass->getLeader()); +        NewClass->setStoredValue(lookupOperandLeader(SI->getValueOperand())); +        markValueLeaderChangeTouched(NewClass); +        // Shift the new class leader to be the store +        DEBUG(dbgs() << "Changing leader of congruence class " +                     << NewClass->getID() << " from " << *NewClass->getLeader() +                     << " to  " << *SI << " because store joined class\n"); +        // If we changed the leader, we have to mark it changed because we don't +        // know what it will do to symbolic evlauation. +        NewClass->setLeader(SI); +      } +      // We rely on the code below handling the MemoryAccess change. +    } +    NewClass->incStoreCount();    } - +  // True if there is no memory instructions left in a class that had memory +  // instructions before. + +  // If it's not a memory use, set the MemoryAccess equivalence +  auto *InstMA = dyn_cast_or_null<MemoryDef>(MSSA->getMemoryAccess(I)); +  bool InstWasMemoryLeader = InstMA && OldClass->getMemoryLeader() == InstMA; +  if (InstMA) +    moveMemoryToNewCongruenceClass(I, InstMA, OldClass, NewClass);    ValueToClass[I] = NewClass;    // See if we destroyed the class or need to swap leaders. -  if (OldClass->Members.empty() && OldClass != InitialClass) { -    if (OldClass->DefiningExpr) { -      OldClass->Dead = true; -      DEBUG(dbgs() << "Erasing expression " << OldClass->DefiningExpr +  if (OldClass->empty() && OldClass != TOPClass) { +    if (OldClass->getDefiningExpr()) { +      DEBUG(dbgs() << "Erasing expression " << OldClass->getDefiningExpr()                     << " from table\n"); -      ExpressionToClass.erase(OldClass->DefiningExpr); +      ExpressionToClass.erase(OldClass->getDefiningExpr());      } -  } else if (OldClass->RepLeader == I) { +  } else if (OldClass->getLeader() == I) {      // When the leader changes, the value numbering of      // everything may change due to symbolization changes, so we need to      // reprocess. -    DEBUG(dbgs() << "Leader change!\n"); +    DEBUG(dbgs() << "Value class leader change for class " << OldClass->getID() +                 << "\n");      ++NumGVNLeaderChanges; -    // We don't need to sort members if there is only 1, and we don't care about -    // sorting the initial class because everything either gets out of it or is -    // unreachable. -    if (OldClass->Members.size() == 1 || OldClass == InitialClass) { -      OldClass->RepLeader = *(OldClass->Members.begin()); -    } else if (OldClass->NextLeader.first) { -      ++NumGVNAvoidedSortedLeaderChanges; -      OldClass->RepLeader = OldClass->NextLeader.first; -      OldClass->NextLeader = {nullptr, ~0U}; -    } else { -      ++NumGVNSortedLeaderChanges; -      // TODO: If this ends up to slow, we can maintain a dual structure for -      // member testing/insertion, or keep things mostly sorted, and sort only -      // here, or .... -      std::pair<Value *, unsigned> MinDFS = {nullptr, ~0U}; -      for (const auto X : OldClass->Members) { -        auto DFSNum = InstrDFS.lookup(X); -        if (DFSNum < MinDFS.second) -          MinDFS = {X, DFSNum}; -      } -      OldClass->RepLeader = MinDFS.first; +    // Destroy the stored value if there are no more stores to represent it. +    // Note that this is basically clean up for the expression removal that +    // happens below.  If we remove stores from a class, we may leave it as a +    // class of equivalent memory phis. +    if (OldClass->getStoreCount() == 0) { +      if (OldClass->getStoredValue()) +        OldClass->setStoredValue(nullptr); +    } +    // If we destroy the old access leader and it's a store, we have to +    // effectively destroy the congruence class.  When it comes to scalars, +    // anything with the same value is as good as any other.  That means that +    // one leader is as good as another, and as long as you have some leader for +    // the value, you are good.. When it comes to *memory states*, only one +    // particular thing really represents the definition of a given memory +    // state.  Once it goes away, we need to re-evaluate which pieces of memory +    // are really still equivalent. The best way to do this is to re-value +    // number things.  The only way to really make that happen is to destroy the +    // rest of the class.  In order to effectively destroy the class, we reset +    // ExpressionToClass for each by using the ValueToExpression mapping.  The +    // members later get marked as touched due to the leader change.  We will +    // create new congruence classes, and the pieces that are still equivalent +    // will end back together in a new class.  If this becomes too expensive, it +    // is possible to use a versioning scheme for the congruence classes to +    // avoid the expressions finding this old class.  Note that the situation is +    // different for memory phis, becuase they are evaluated anew each time, and +    // they become equal not by hashing, but by seeing if all operands are the +    // same (or only one is reachable). +    if (OldClass->getStoreCount() > 0 && InstWasMemoryLeader) { +      DEBUG(dbgs() << "Kicking everything out of class " << OldClass->getID() +                   << " because MemoryAccess leader changed"); +      for (auto Member : *OldClass) +        ExpressionToClass.erase(ValueToExpression.lookup(Member));      } -    markLeaderChangeTouched(OldClass); +    OldClass->setLeader(getNextValueLeader(OldClass)); +    OldClass->resetNextLeader(); +    markValueLeaderChangeTouched(OldClass);    }  } @@ -1150,12 +2039,12 @@ void NewGVN::moveValueToNewCongruenceClass(Instruction *I,  void NewGVN::performCongruenceFinding(Instruction *I, const Expression *E) {    ValueToExpression[I] = E;    // This is guaranteed to return something, since it will at least find -  // INITIAL. +  // TOP.    CongruenceClass *IClass = ValueToClass[I];    assert(IClass && "Should have found a IClass");    // Dead classes should have been eliminated from the mapping. -  assert(!IClass->Dead && "Found a dead class"); +  assert(!IClass->isDead() && "Found a dead class");    CongruenceClass *EClass;    if (const auto *VE = dyn_cast<VariableExpression>(E)) { @@ -1171,79 +2060,52 @@ void NewGVN::performCongruenceFinding(Instruction *I, const Expression *E) {        // Constants and variables should always be made the leader.        if (const auto *CE = dyn_cast<ConstantExpression>(E)) { -        NewClass->RepLeader = CE->getConstantValue(); +        NewClass->setLeader(CE->getConstantValue());        } else if (const auto *SE = dyn_cast<StoreExpression>(E)) {          StoreInst *SI = SE->getStoreInst(); -        NewClass->RepLeader = -            lookupOperandLeader(SI->getValueOperand(), SI, SI->getParent()); +        NewClass->setLeader(SI); +        NewClass->setStoredValue(lookupOperandLeader(SI->getValueOperand())); +        // The RepMemoryAccess field will be filled in properly by the +        // moveValueToNewCongruenceClass call.        } else { -        NewClass->RepLeader = I; +        NewClass->setLeader(I);        }        assert(!isa<VariableExpression>(E) &&               "VariableExpression should have been handled already");        EClass = NewClass;        DEBUG(dbgs() << "Created new congruence class for " << *I -                   << " using expression " << *E << " at " << NewClass->ID -                   << " and leader " << *(NewClass->RepLeader) << "\n"); -      DEBUG(dbgs() << "Hash value was " << E->getHashValue() << "\n"); +                   << " using expression " << *E << " at " << NewClass->getID() +                   << " and leader " << *(NewClass->getLeader())); +      if (NewClass->getStoredValue()) +        DEBUG(dbgs() << " and stored value " << *(NewClass->getStoredValue())); +      DEBUG(dbgs() << "\n");      } else {        EClass = lookupResult.first->second;        if (isa<ConstantExpression>(E)) -        assert(isa<Constant>(EClass->RepLeader) && +        assert((isa<Constant>(EClass->getLeader()) || +                (EClass->getStoredValue() && +                 isa<Constant>(EClass->getStoredValue()))) &&                 "Any class with a constant expression should have a "                 "constant leader");        assert(EClass && "Somehow don't have an eclass"); -      assert(!EClass->Dead && "We accidentally looked up a dead class"); +      assert(!EClass->isDead() && "We accidentally looked up a dead class");      }    }    bool ClassChanged = IClass != EClass;    bool LeaderChanged = LeaderChanges.erase(I);    if (ClassChanged || LeaderChanged) { -    DEBUG(dbgs() << "Found class " << EClass->ID << " for expression " << E +    DEBUG(dbgs() << "New class " << EClass->getID() << " for expression " << *E                   << "\n"); -      if (ClassChanged) -      moveValueToNewCongruenceClass(I, IClass, EClass); +      moveValueToNewCongruenceClass(I, E, IClass, EClass);      markUsersTouched(I); -    if (MemoryAccess *MA = MSSA->getMemoryAccess(I)) { -      // If this is a MemoryDef, we need to update the equivalence table. If -      // we determined the expression is congruent to a different memory -      // state, use that different memory state.  If we determined it didn't, -      // we update that as well.  Right now, we only support store -      // expressions. -      if (!isa<MemoryUse>(MA) && isa<StoreExpression>(E) && -          EClass->Members.size() != 1) { -        auto *DefAccess = cast<StoreExpression>(E)->getDefiningAccess(); -        setMemoryAccessEquivTo(MA, DefAccess != MA ? DefAccess : nullptr); -      } else { -        setMemoryAccessEquivTo(MA, nullptr); -      } +    if (MemoryAccess *MA = MSSA->getMemoryAccess(I))        markMemoryUsersTouched(MA); -    } -  } else if (auto *SI = dyn_cast<StoreInst>(I)) { -    // There is, sadly, one complicating thing for stores.  Stores do not -    // produce values, only consume them.  However, in order to make loads and -    // stores value number the same, we ignore the value operand of the store. -    // But the value operand will still be the leader of our class, and thus, it -    // may change.  Because the store is a use, the store will get reprocessed, -    // but nothing will change about it, and so nothing above will catch it -    // (since the class will not change).  In order to make sure everything ends -    // up okay, we need to recheck the leader of the class.  Since stores of -    // different values value number differently due to different memorydefs, we -    // are guaranteed the leader is always the same between stores in the same -    // class. -    DEBUG(dbgs() << "Checking store leader\n"); -    auto ProperLeader = -        lookupOperandLeader(SI->getValueOperand(), SI, SI->getParent()); -    if (EClass->RepLeader != ProperLeader) { -      DEBUG(dbgs() << "Store leader changed, fixing\n"); -      EClass->RepLeader = ProperLeader; -      markLeaderChangeTouched(EClass); -      markMemoryUsersTouched(MSSA->getMemoryAccess(SI)); -    } +    if (auto *CI = dyn_cast<CmpInst>(I)) +      markPredicateUsersTouched(CI);    }  } @@ -1267,11 +2129,11 @@ void NewGVN::updateReachableEdge(BasicBlock *From, BasicBlock *To) {        // they are the only thing that depend on new edges. Anything using their        // values will get propagated to if necessary.        if (MemoryAccess *MemPhi = MSSA->getMemoryAccess(To)) -        TouchedInstructions.set(InstrDFS[MemPhi]); +        TouchedInstructions.set(InstrToDFSNum(MemPhi));        auto BI = To->begin();        while (isa<PHINode>(BI)) { -        TouchedInstructions.set(InstrDFS[&*BI]); +        TouchedInstructions.set(InstrToDFSNum(&*BI));          ++BI;        }      } @@ -1280,8 +2142,8 @@ void NewGVN::updateReachableEdge(BasicBlock *From, BasicBlock *To) {  // Given a predicate condition (from a switch, cmp, or whatever) and a block,  // see if we know some constant value for it already. -Value *NewGVN::findConditionEquivalence(Value *Cond, BasicBlock *B) const { -  auto Result = lookupOperandLeader(Cond, nullptr, B); +Value *NewGVN::findConditionEquivalence(Value *Cond) const { +  auto Result = lookupOperandLeader(Cond);    if (isa<Constant>(Result))      return Result;    return nullptr; @@ -1293,10 +2155,10 @@ void NewGVN::processOutgoingEdges(TerminatorInst *TI, BasicBlock *B) {    BranchInst *BR;    if ((BR = dyn_cast<BranchInst>(TI)) && BR->isConditional()) {      Value *Cond = BR->getCondition(); -    Value *CondEvaluated = findConditionEquivalence(Cond, B); +    Value *CondEvaluated = findConditionEquivalence(Cond);      if (!CondEvaluated) {        if (auto *I = dyn_cast<Instruction>(Cond)) { -        const Expression *E = createExpression(I, B); +        const Expression *E = createExpression(I);          if (const auto *CE = dyn_cast<ConstantExpression>(E)) {            CondEvaluated = CE->getConstantValue();          } @@ -1329,13 +2191,13 @@ void NewGVN::processOutgoingEdges(TerminatorInst *TI, BasicBlock *B) {      SmallDenseMap<BasicBlock *, unsigned, 16> SwitchEdges;      Value *SwitchCond = SI->getCondition(); -    Value *CondEvaluated = findConditionEquivalence(SwitchCond, B); +    Value *CondEvaluated = findConditionEquivalence(SwitchCond);      // See if we were able to turn this switch statement into a constant.      if (CondEvaluated && isa<ConstantInt>(CondEvaluated)) {        auto *CondVal = cast<ConstantInt>(CondEvaluated);        // We should be able to get case value for this. -      auto CaseVal = SI->findCaseValue(CondVal); -      if (CaseVal.getCaseSuccessor() == SI->getDefaultDest()) { +      auto Case = *SI->findCaseValue(CondVal); +      if (Case.getCaseSuccessor() == SI->getDefaultDest()) {          // We proved the value is outside of the range of the case.          // We can't do anything other than mark the default dest as reachable,          // and go home. @@ -1343,7 +2205,7 @@ void NewGVN::processOutgoingEdges(TerminatorInst *TI, BasicBlock *B) {          return;        }        // Now get where it goes and mark it reachable. -      BasicBlock *TargetBlock = CaseVal.getCaseSuccessor(); +      BasicBlock *TargetBlock = Case.getCaseSuccessor();        updateReachableEdge(B, TargetBlock);      } else {        for (unsigned i = 0, e = SI->getNumSuccessors(); i != e; ++i) { @@ -1361,45 +2223,66 @@ void NewGVN::processOutgoingEdges(TerminatorInst *TI, BasicBlock *B) {      }      // This also may be a memory defining terminator, in which case, set it -    // equivalent to nothing. -    if (MemoryAccess *MA = MSSA->getMemoryAccess(TI)) -      setMemoryAccessEquivTo(MA, nullptr); +    // equivalent only to itself. +    // +    auto *MA = MSSA->getMemoryAccess(TI); +    if (MA && !isa<MemoryUse>(MA)) { +      auto *CC = ensureLeaderOfMemoryClass(MA); +      if (setMemoryClass(MA, CC)) +        markMemoryUsersTouched(MA); +    }    }  } -// The algorithm initially places the values of the routine in the INITIAL -// congruence -// class. The leader of INITIAL is the undetermined value `TOP`. -// When the algorithm has finished, values still in INITIAL are unreachable. +// The algorithm initially places the values of the routine in the TOP +// congruence class. The leader of TOP is the undetermined value `undef`. +// When the algorithm has finished, values still in TOP are unreachable.  void NewGVN::initializeCongruenceClasses(Function &F) { -  // FIXME now i can't remember why this is 2 -  NextCongruenceNum = 2; -  // Initialize all other instructions to be in INITIAL class. -  CongruenceClass::MemberSet InitialValues; -  InitialClass = createCongruenceClass(nullptr, nullptr); +  NextCongruenceNum = 0; + +  // Note that even though we use the live on entry def as a representative +  // MemoryAccess, it is *not* the same as the actual live on entry def. We +  // have no real equivalemnt to undef for MemoryAccesses, and so we really +  // should be checking whether the MemoryAccess is top if we want to know if it +  // is equivalent to everything.  Otherwise, what this really signifies is that +  // the access "it reaches all the way back to the beginning of the function" + +  // Initialize all other instructions to be in TOP class. +  TOPClass = createCongruenceClass(nullptr, nullptr); +  TOPClass->setMemoryLeader(MSSA->getLiveOnEntryDef()); +  //  The live on entry def gets put into it's own class +  MemoryAccessToClass[MSSA->getLiveOnEntryDef()] = +      createMemoryClass(MSSA->getLiveOnEntryDef()); +    for (auto &B : F) { -    if (auto *MP = MSSA->getMemoryAccess(&B)) -      MemoryAccessEquiv.insert({MP, MSSA->getLiveOnEntryDef()}); +    // All MemoryAccesses are equivalent to live on entry to start. They must +    // be initialized to something so that initial changes are noticed. For +    // the maximal answer, we initialize them all to be the same as +    // liveOnEntry. +    auto *MemoryBlockDefs = MSSA->getBlockDefs(&B); +    if (MemoryBlockDefs) +      for (const auto &Def : *MemoryBlockDefs) { +        MemoryAccessToClass[&Def] = TOPClass; +        auto *MD = dyn_cast<MemoryDef>(&Def); +        // Insert the memory phis into the member list. +        if (!MD) { +          const MemoryPhi *MP = cast<MemoryPhi>(&Def); +          TOPClass->memory_insert(MP); +          MemoryPhiState.insert({MP, MPS_TOP}); +        } -    for (auto &I : B) { -      InitialValues.insert(&I); -      ValueToClass[&I] = InitialClass; -      // All memory accesses are equivalent to live on entry to start. They must -      // be initialized to something so that initial changes are noticed. For -      // the maximal answer, we initialize them all to be the same as -      // liveOnEntry.  Note that to save time, we only initialize the -      // MemoryDef's for stores and all MemoryPhis to be equal.  Right now, no -      // other expression can generate a memory equivalence.  If we start -      // handling memcpy/etc, we can expand this. -      if (isa<StoreInst>(&I)) { -        MemoryAccessEquiv.insert( -            {MSSA->getMemoryAccess(&I), MSSA->getLiveOnEntryDef()}); -        ++InitialClass->StoreCount; -        assert(InitialClass->StoreCount > 0); +        if (MD && isa<StoreInst>(MD->getMemoryInst())) +          TOPClass->incStoreCount();        } +    for (auto &I : B) { +      // Don't insert void terminators into the class. We don't value number +      // them, and they just end up sitting in TOP. +      if (isa<TerminatorInst>(I) && I.getType()->isVoidTy()) +        continue; +      TOPClass->insert(&I); +      ValueToClass[&I] = TOPClass;      }    } -  InitialClass->Members.swap(InitialValues);    // Initialize arguments to be in their own unique congruence classes    for (auto &FA : F.args()) @@ -1408,8 +2291,8 @@ void NewGVN::initializeCongruenceClasses(Function &F) {  void NewGVN::cleanupTables() {    for (unsigned i = 0, e = CongruenceClasses.size(); i != e; ++i) { -    DEBUG(dbgs() << "Congruence class " << CongruenceClasses[i]->ID << " has " -                 << CongruenceClasses[i]->Members.size() << " members\n"); +    DEBUG(dbgs() << "Congruence class " << CongruenceClasses[i]->getID() +                 << " has " << CongruenceClasses[i]->size() << " members\n");      // Make sure we delete the congruence class (probably worth switching to      // a unique_ptr at some point.      delete CongruenceClasses[i]; @@ -1427,15 +2310,14 @@ void NewGVN::cleanupTables() {  #ifndef NDEBUG    ProcessedCount.clear();  #endif -  DFSDomMap.clear();    InstrDFS.clear();    InstructionsToErase.clear(); -    DFSToInstr.clear();    BlockInstRange.clear();    TouchedInstructions.clear(); -  DominatedInstRange.clear(); -  MemoryAccessEquiv.clear(); +  MemoryAccessToClass.clear(); +  PredicateToUsers.clear(); +  MemoryToUsers.clear();  }  std::pair<unsigned, unsigned> NewGVN::assignDFSNumbers(BasicBlock *B, @@ -1447,6 +2329,16 @@ std::pair<unsigned, unsigned> NewGVN::assignDFSNumbers(BasicBlock *B,    }    for (auto &I : *B) { +    // There's no need to call isInstructionTriviallyDead more than once on +    // an instruction. Therefore, once we know that an instruction is dead +    // we change its DFS number so that it doesn't get value numbered. +    if (isInstructionTriviallyDead(&I, TLI)) { +      InstrDFS[&I] = 0; +      DEBUG(dbgs() << "Skipping trivially dead instruction " << I << "\n"); +      markInstructionForDeletion(&I); +      continue; +    } +      InstrDFS[&I] = End++;      DFSToInstr.emplace_back(&I);    } @@ -1462,7 +2354,7 @@ void NewGVN::updateProcessedCount(Value *V) {    if (ProcessedCount.count(V) == 0) {      ProcessedCount.insert({V, 1});    } else { -    ProcessedCount[V] += 1; +    ++ProcessedCount[V];      assert(ProcessedCount[V] < 100 &&             "Seem to have processed the same Value a lot");    } @@ -1472,26 +2364,33 @@ void NewGVN::updateProcessedCount(Value *V) {  void NewGVN::valueNumberMemoryPhi(MemoryPhi *MP) {    // If all the arguments are the same, the MemoryPhi has the same value as the    // argument. -  // Filter out unreachable blocks from our operands. +  // Filter out unreachable blocks and self phis from our operands. +  const BasicBlock *PHIBlock = MP->getBlock();    auto Filtered = make_filter_range(MP->operands(), [&](const Use &U) { -    return ReachableBlocks.count(MP->getIncomingBlock(U)); +    return lookupMemoryLeader(cast<MemoryAccess>(U)) != MP && +           !isMemoryAccessTop(cast<MemoryAccess>(U)) && +           ReachableEdges.count({MP->getIncomingBlock(U), PHIBlock});    }); - -  assert(Filtered.begin() != Filtered.end() && -         "We should not be processing a MemoryPhi in a completely " -         "unreachable block"); +  // If all that is left is nothing, our memoryphi is undef. We keep it as +  // InitialClass.  Note: The only case this should happen is if we have at +  // least one self-argument. +  if (Filtered.begin() == Filtered.end()) { +    if (setMemoryClass(MP, TOPClass)) +      markMemoryUsersTouched(MP); +    return; +  }    // Transform the remaining operands into operand leaders.    // FIXME: mapped_iterator should have a range version.    auto LookupFunc = [&](const Use &U) { -    return lookupMemoryAccessEquiv(cast<MemoryAccess>(U)); +    return lookupMemoryLeader(cast<MemoryAccess>(U));    };    auto MappedBegin = map_iterator(Filtered.begin(), LookupFunc);    auto MappedEnd = map_iterator(Filtered.end(), LookupFunc);    // and now check if all the elements are equal.    // Sadly, we can't use std::equals since these are random access iterators. -  MemoryAccess *AllSameValue = *MappedBegin; +  const auto *AllSameValue = *MappedBegin;    ++MappedBegin;    bool AllEqual = std::all_of(        MappedBegin, MappedEnd, @@ -1501,8 +2400,18 @@ void NewGVN::valueNumberMemoryPhi(MemoryPhi *MP) {      DEBUG(dbgs() << "Memory Phi value numbered to " << *AllSameValue << "\n");    else      DEBUG(dbgs() << "Memory Phi value numbered to itself\n"); - -  if (setMemoryAccessEquivTo(MP, AllEqual ? AllSameValue : nullptr)) +  // If it's equal to something, it's in that class. Otherwise, it has to be in +  // a class where it is the leader (other things may be equivalent to it, but +  // it needs to start off in its own class, which means it must have been the +  // leader, and it can't have stopped being the leader because it was never +  // removed). +  CongruenceClass *CC = +      AllEqual ? getMemoryClass(AllSameValue) : ensureLeaderOfMemoryClass(MP); +  auto OldState = MemoryPhiState.lookup(MP); +  assert(OldState != MPS_Invalid && "Invalid memory phi state"); +  auto NewState = AllEqual ? MPS_Equivalent : MPS_Unique; +  MemoryPhiState[MP] = NewState; +  if (setMemoryClass(MP, CC) || OldState != NewState)      markMemoryUsersTouched(MP);  } @@ -1510,21 +2419,25 @@ void NewGVN::valueNumberMemoryPhi(MemoryPhi *MP) {  // congruence finding, and updating mappings.  void NewGVN::valueNumberInstruction(Instruction *I) {    DEBUG(dbgs() << "Processing instruction " << *I << "\n"); -  if (isInstructionTriviallyDead(I, TLI)) { -    DEBUG(dbgs() << "Skipping unused instruction\n"); -    markInstructionForDeletion(I); -    return; -  }    if (!I->isTerminator()) { -    const auto *Symbolized = performSymbolicEvaluation(I, I->getParent()); +    const Expression *Symbolized = nullptr; +    if (DebugCounter::shouldExecute(VNCounter)) { +      Symbolized = performSymbolicEvaluation(I); +    } else { +      // Mark the instruction as unused so we don't value number it again. +      InstrDFS[I] = 0; +    }      // If we couldn't come up with a symbolic expression, use the unknown      // expression -    if (Symbolized == nullptr) +    if (Symbolized == nullptr) {        Symbolized = createUnknownExpression(I); +    } +      performCongruenceFinding(I, Symbolized);    } else {      // Handle terminators that return values. All of them produce values we -    // don't currently understand. +    // don't currently understand.  We don't place non-value producing +    // terminators in a class.      if (!I->getType()->isVoidTy()) {        auto *Symbolized = createUnknownExpression(I);        performCongruenceFinding(I, Symbolized); @@ -1539,72 +2452,102 @@ bool NewGVN::singleReachablePHIPath(const MemoryAccess *First,                                      const MemoryAccess *Second) const {    if (First == Second)      return true; - -  if (auto *FirstDef = dyn_cast<MemoryUseOrDef>(First)) { -    auto *DefAccess = FirstDef->getDefiningAccess(); -    return singleReachablePHIPath(DefAccess, Second); -  } else { -    auto *MP = cast<MemoryPhi>(First); -    auto ReachableOperandPred = [&](const Use &U) { -      return ReachableBlocks.count(MP->getIncomingBlock(U)); -    }; -    auto FilteredPhiArgs = -        make_filter_range(MP->operands(), ReachableOperandPred); -    SmallVector<const Value *, 32> OperandList; -    std::copy(FilteredPhiArgs.begin(), FilteredPhiArgs.end(), -              std::back_inserter(OperandList)); -    bool Okay = OperandList.size() == 1; -    if (!Okay) -      Okay = std::equal(OperandList.begin(), OperandList.end(), -                        OperandList.begin()); -    if (Okay) -      return singleReachablePHIPath(cast<MemoryAccess>(OperandList[0]), Second); +  if (MSSA->isLiveOnEntryDef(First))      return false; + +  const auto *EndDef = First; +  for (auto *ChainDef : optimized_def_chain(First)) { +    if (ChainDef == Second) +      return true; +    if (MSSA->isLiveOnEntryDef(ChainDef)) +      return false; +    EndDef = ChainDef;    } +  auto *MP = cast<MemoryPhi>(EndDef); +  auto ReachableOperandPred = [&](const Use &U) { +    return ReachableEdges.count({MP->getIncomingBlock(U), MP->getBlock()}); +  }; +  auto FilteredPhiArgs = +      make_filter_range(MP->operands(), ReachableOperandPred); +  SmallVector<const Value *, 32> OperandList; +  std::copy(FilteredPhiArgs.begin(), FilteredPhiArgs.end(), +            std::back_inserter(OperandList)); +  bool Okay = OperandList.size() == 1; +  if (!Okay) +    Okay = +        std::equal(OperandList.begin(), OperandList.end(), OperandList.begin()); +  if (Okay) +    return singleReachablePHIPath(cast<MemoryAccess>(OperandList[0]), Second); +  return false;  }  // Verify the that the memory equivalence table makes sense relative to the  // congruence classes.  Note that this checking is not perfect, and is currently -// subject to very rare false negatives. It is only useful for testing/debugging. +// subject to very rare false negatives. It is only useful for +// testing/debugging.  void NewGVN::verifyMemoryCongruency() const { -  // Anything equivalent in the memory access table should be in the same +#ifndef NDEBUG +  // Verify that the memory table equivalence and memory member set match +  for (const auto *CC : CongruenceClasses) { +    if (CC == TOPClass || CC->isDead()) +      continue; +    if (CC->getStoreCount() != 0) { +      assert((CC->getStoredValue() || !isa<StoreInst>(CC->getLeader())) && +             "Any class with a store as a " +             "leader should have a " +             "representative stored value\n"); +      assert(CC->getMemoryLeader() && +             "Any congruence class with a store should " +             "have a representative access\n"); +    } + +    if (CC->getMemoryLeader()) +      assert(MemoryAccessToClass.lookup(CC->getMemoryLeader()) == CC && +             "Representative MemoryAccess does not appear to be reverse " +             "mapped properly"); +    for (auto M : CC->memory()) +      assert(MemoryAccessToClass.lookup(M) == CC && +             "Memory member does not appear to be reverse mapped properly"); +  } + +  // Anything equivalent in the MemoryAccess table should be in the same    // congruence class.    // Filter out the unreachable and trivially dead entries, because they may    // never have been updated if the instructions were not processed.    auto ReachableAccessPred = -      [&](const std::pair<const MemoryAccess *, MemoryAccess *> Pair) { +      [&](const std::pair<const MemoryAccess *, CongruenceClass *> Pair) {          bool Result = ReachableBlocks.count(Pair.first->getBlock());          if (!Result)            return false; +        if (MSSA->isLiveOnEntryDef(Pair.first)) +          return true;          if (auto *MemDef = dyn_cast<MemoryDef>(Pair.first))            return !isInstructionTriviallyDead(MemDef->getMemoryInst()); +        if (MemoryToDFSNum(Pair.first) == 0) +          return false;          return true;        }; -  auto Filtered = make_filter_range(MemoryAccessEquiv, ReachableAccessPred); +  auto Filtered = make_filter_range(MemoryAccessToClass, ReachableAccessPred);    for (auto KV : Filtered) { -    assert(KV.first != KV.second && -           "We added a useless equivalence to the memory equivalence table"); -    // Unreachable instructions may not have changed because we never process -    // them. -    if (!ReachableBlocks.count(KV.first->getBlock())) -      continue; +    assert(KV.second != TOPClass && +           "Memory not unreachable but ended up in TOP");      if (auto *FirstMUD = dyn_cast<MemoryUseOrDef>(KV.first)) { -      auto *SecondMUD = dyn_cast<MemoryUseOrDef>(KV.second); +      auto *SecondMUD = dyn_cast<MemoryUseOrDef>(KV.second->getMemoryLeader());        if (FirstMUD && SecondMUD)          assert((singleReachablePHIPath(FirstMUD, SecondMUD) || -               ValueToClass.lookup(FirstMUD->getMemoryInst()) == -                       ValueToClass.lookup(SecondMUD->getMemoryInst())) && -                   "The instructions for these memory operations should have " -                   "been in the same congruence class or reachable through" -                   "a single argument phi"); +                ValueToClass.lookup(FirstMUD->getMemoryInst()) == +                    ValueToClass.lookup(SecondMUD->getMemoryInst())) && +               "The instructions for these memory operations should have " +               "been in the same congruence class or reachable through" +               "a single argument phi");      } else if (auto *FirstMP = dyn_cast<MemoryPhi>(KV.first)) { -        // We can only sanely verify that MemoryDefs in the operand list all have        // the same class.        auto ReachableOperandPred = [&](const Use &U) { -        return ReachableBlocks.count(FirstMP->getIncomingBlock(U)) && +        return ReachableEdges.count( +                   {FirstMP->getIncomingBlock(U), FirstMP->getBlock()}) &&                 isa<MemoryDef>(U);        }; @@ -1622,19 +2565,127 @@ void NewGVN::verifyMemoryCongruency() const {               "All MemoryPhi arguments should be in the same class");      }    } +#endif +} + +// Verify that the sparse propagation we did actually found the maximal fixpoint +// We do this by storing the value to class mapping, touching all instructions, +// and redoing the iteration to see if anything changed. +void NewGVN::verifyIterationSettled(Function &F) { +#ifndef NDEBUG +  DEBUG(dbgs() << "Beginning iteration verification\n"); +  if (DebugCounter::isCounterSet(VNCounter)) +    DebugCounter::setCounterValue(VNCounter, StartingVNCounter); + +  // Note that we have to store the actual classes, as we may change existing +  // classes during iteration.  This is because our memory iteration propagation +  // is not perfect, and so may waste a little work.  But it should generate +  // exactly the same congruence classes we have now, with different IDs. +  std::map<const Value *, CongruenceClass> BeforeIteration; + +  for (auto &KV : ValueToClass) { +    if (auto *I = dyn_cast<Instruction>(KV.first)) +      // Skip unused/dead instructions. +      if (InstrToDFSNum(I) == 0) +        continue; +    BeforeIteration.insert({KV.first, *KV.second}); +  } + +  TouchedInstructions.set(); +  TouchedInstructions.reset(0); +  iterateTouchedInstructions(); +  DenseSet<std::pair<const CongruenceClass *, const CongruenceClass *>> +      EqualClasses; +  for (const auto &KV : ValueToClass) { +    if (auto *I = dyn_cast<Instruction>(KV.first)) +      // Skip unused/dead instructions. +      if (InstrToDFSNum(I) == 0) +        continue; +    // We could sink these uses, but i think this adds a bit of clarity here as +    // to what we are comparing. +    auto *BeforeCC = &BeforeIteration.find(KV.first)->second; +    auto *AfterCC = KV.second; +    // Note that the classes can't change at this point, so we memoize the set +    // that are equal. +    if (!EqualClasses.count({BeforeCC, AfterCC})) { +      assert(BeforeCC->isEquivalentTo(AfterCC) && +             "Value number changed after main loop completed!"); +      EqualClasses.insert({BeforeCC, AfterCC}); +    } +  } +#endif +} + +// This is the main value numbering loop, it iterates over the initial touched +// instruction set, propagating value numbers, marking things touched, etc, +// until the set of touched instructions is completely empty. +void NewGVN::iterateTouchedInstructions() { +  unsigned int Iterations = 0; +  // Figure out where touchedinstructions starts +  int FirstInstr = TouchedInstructions.find_first(); +  // Nothing set, nothing to iterate, just return. +  if (FirstInstr == -1) +    return; +  BasicBlock *LastBlock = getBlockForValue(InstrFromDFSNum(FirstInstr)); +  while (TouchedInstructions.any()) { +    ++Iterations; +    // Walk through all the instructions in all the blocks in RPO. +    // TODO: As we hit a new block, we should push and pop equalities into a +    // table lookupOperandLeader can use, to catch things PredicateInfo +    // might miss, like edge-only equivalences. +    for (int InstrNum = TouchedInstructions.find_first(); InstrNum != -1; +         InstrNum = TouchedInstructions.find_next(InstrNum)) { + +      // This instruction was found to be dead. We don't bother looking +      // at it again. +      if (InstrNum == 0) { +        TouchedInstructions.reset(InstrNum); +        continue; +      } + +      Value *V = InstrFromDFSNum(InstrNum); +      BasicBlock *CurrBlock = getBlockForValue(V); + +      // If we hit a new block, do reachability processing. +      if (CurrBlock != LastBlock) { +        LastBlock = CurrBlock; +        bool BlockReachable = ReachableBlocks.count(CurrBlock); +        const auto &CurrInstRange = BlockInstRange.lookup(CurrBlock); + +        // If it's not reachable, erase any touched instructions and move on. +        if (!BlockReachable) { +          TouchedInstructions.reset(CurrInstRange.first, CurrInstRange.second); +          DEBUG(dbgs() << "Skipping instructions in block " +                       << getBlockName(CurrBlock) +                       << " because it is unreachable\n"); +          continue; +        } +        updateProcessedCount(CurrBlock); +      } + +      if (auto *MP = dyn_cast<MemoryPhi>(V)) { +        DEBUG(dbgs() << "Processing MemoryPhi " << *MP << "\n"); +        valueNumberMemoryPhi(MP); +      } else if (auto *I = dyn_cast<Instruction>(V)) { +        valueNumberInstruction(I); +      } else { +        llvm_unreachable("Should have been a MemoryPhi or Instruction"); +      } +      updateProcessedCount(V); +      // Reset after processing (because we may mark ourselves as touched when +      // we propagate equalities). +      TouchedInstructions.reset(InstrNum); +    } +  } +  NumGVNMaxIterations = std::max(NumGVNMaxIterations.getValue(), Iterations);  }  // This is the main transformation entry point. -bool NewGVN::runGVN(Function &F, DominatorTree *_DT, AssumptionCache *_AC, -                    TargetLibraryInfo *_TLI, AliasAnalysis *_AA, -                    MemorySSA *_MSSA) { +bool NewGVN::runGVN() { +  if (DebugCounter::isCounterSet(VNCounter)) +    StartingVNCounter = DebugCounter::getCounterValue(VNCounter);    bool Changed = false; -  DT = _DT; -  AC = _AC; -  TLI = _TLI; -  AA = _AA; -  MSSA = _MSSA; -  DL = &F.getParent()->getDataLayout(); +  NumFuncArgs = F.arg_size();    MSSAWalker = MSSA->getWalker();    // Count number of instructions for sizing of hash tables, and come @@ -1642,15 +2693,14 @@ bool NewGVN::runGVN(Function &F, DominatorTree *_DT, AssumptionCache *_AC,    unsigned ICount = 1;    // Add an empty instruction to account for the fact that we start at 1    DFSToInstr.emplace_back(nullptr); -  // Note: We want RPO traversal of the blocks, which is not quite the same as -  // dominator tree order, particularly with regard whether backedges get -  // visited first or second, given a block with multiple successors. +  // Note: We want ideal RPO traversal of the blocks, which is not quite the +  // same as dominator tree order, particularly with regard whether backedges +  // get visited first or second, given a block with multiple successors.    // If we visit in the wrong order, we will end up performing N times as many    // iterations.    // The dominator tree does guarantee that, for a given dom tree node, it's    // parent must occur before it in the RPO ordering. Thus, we only need to sort    // the siblings. -  DenseMap<const DomTreeNode *, unsigned> RPOOrdering;    ReversePostOrderTraversal<Function *> RPOT(&F);    unsigned Counter = 0;    for (auto &B : RPOT) { @@ -1663,7 +2713,7 @@ bool NewGVN::runGVN(Function &F, DominatorTree *_DT, AssumptionCache *_AC,      auto *Node = DT->getNode(B);      if (Node->getChildren().size() > 1)        std::sort(Node->begin(), Node->end(), -                [&RPOOrdering](const DomTreeNode *A, const DomTreeNode *B) { +                [&](const DomTreeNode *A, const DomTreeNode *B) {                    return RPOOrdering[A] < RPOOrdering[B];                  });    } @@ -1689,7 +2739,6 @@ bool NewGVN::runGVN(Function &F, DominatorTree *_DT, AssumptionCache *_AC,    }    TouchedInstructions.resize(ICount); -  DominatedInstRange.reserve(F.size());    // Ensure we don't end up resizing the expressionToClass map, as    // that can be quite expensive. At most, we have one expression per    // instruction. @@ -1701,62 +2750,10 @@ bool NewGVN::runGVN(Function &F, DominatorTree *_DT, AssumptionCache *_AC,    ReachableBlocks.insert(&F.getEntryBlock());    initializeCongruenceClasses(F); - -  unsigned int Iterations = 0; -  // We start out in the entry block. -  BasicBlock *LastBlock = &F.getEntryBlock(); -  while (TouchedInstructions.any()) { -    ++Iterations; -    // Walk through all the instructions in all the blocks in RPO. -    for (int InstrNum = TouchedInstructions.find_first(); InstrNum != -1; -         InstrNum = TouchedInstructions.find_next(InstrNum)) { -      assert(InstrNum != 0 && "Bit 0 should never be set, something touched an " -                              "instruction not in the lookup table"); -      Value *V = DFSToInstr[InstrNum]; -      BasicBlock *CurrBlock = nullptr; - -      if (auto *I = dyn_cast<Instruction>(V)) -        CurrBlock = I->getParent(); -      else if (auto *MP = dyn_cast<MemoryPhi>(V)) -        CurrBlock = MP->getBlock(); -      else -        llvm_unreachable("DFSToInstr gave us an unknown type of instruction"); - -      // If we hit a new block, do reachability processing. -      if (CurrBlock != LastBlock) { -        LastBlock = CurrBlock; -        bool BlockReachable = ReachableBlocks.count(CurrBlock); -        const auto &CurrInstRange = BlockInstRange.lookup(CurrBlock); - -        // If it's not reachable, erase any touched instructions and move on. -        if (!BlockReachable) { -          TouchedInstructions.reset(CurrInstRange.first, CurrInstRange.second); -          DEBUG(dbgs() << "Skipping instructions in block " -                       << getBlockName(CurrBlock) -                       << " because it is unreachable\n"); -          continue; -        } -        updateProcessedCount(CurrBlock); -      } - -      if (auto *MP = dyn_cast<MemoryPhi>(V)) { -        DEBUG(dbgs() << "Processing MemoryPhi " << *MP << "\n"); -        valueNumberMemoryPhi(MP); -      } else if (auto *I = dyn_cast<Instruction>(V)) { -        valueNumberInstruction(I); -      } else { -        llvm_unreachable("Should have been a MemoryPhi or Instruction"); -      } -      updateProcessedCount(V); -      // Reset after processing (because we may mark ourselves as touched when -      // we propagate equalities). -      TouchedInstructions.reset(InstrNum); -    } -  } -  NumGVNMaxIterations = std::max(NumGVNMaxIterations.getValue(), Iterations); -#ifndef NDEBUG +  iterateTouchedInstructions();    verifyMemoryCongruency(); -#endif +  verifyIterationSettled(F); +    Changed |= eliminateInstructions(F);    // Delete all instructions marked for deletion. @@ -1783,36 +2780,6 @@ bool NewGVN::runGVN(Function &F, DominatorTree *_DT, AssumptionCache *_AC,    return Changed;  } -bool NewGVN::runOnFunction(Function &F) { -  if (skipFunction(F)) -    return false; -  return runGVN(F, &getAnalysis<DominatorTreeWrapperPass>().getDomTree(), -                &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F), -                &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(), -                &getAnalysis<AAResultsWrapperPass>().getAAResults(), -                &getAnalysis<MemorySSAWrapperPass>().getMSSA()); -} - -PreservedAnalyses NewGVNPass::run(Function &F, AnalysisManager<Function> &AM) { -  NewGVN Impl; - -  // Apparently the order in which we get these results matter for -  // the old GVN (see Chandler's comment in GVN.cpp). I'll keep -  // the same order here, just in case. -  auto &AC = AM.getResult<AssumptionAnalysis>(F); -  auto &DT = AM.getResult<DominatorTreeAnalysis>(F); -  auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); -  auto &AA = AM.getResult<AAManager>(F); -  auto &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA(); -  bool Changed = Impl.runGVN(F, &DT, &AC, &TLI, &AA, &MSSA); -  if (!Changed) -    return PreservedAnalyses::all(); -  PreservedAnalyses PA; -  PA.preserve<DominatorTreeAnalysis>(); -  PA.preserve<GlobalsAA>(); -  return PA; -} -  // Return true if V is a value that will always be available (IE can  // be placed anywhere) in the function.  We don't do globals here  // because they are often worse to put in place. @@ -1821,21 +2788,15 @@ static bool alwaysAvailable(Value *V) {    return isa<Constant>(V) || isa<Argument>(V);  } -// Get the basic block from an instruction/value. -static BasicBlock *getBlockForValue(Value *V) { -  if (auto *I = dyn_cast<Instruction>(V)) -    return I->getParent(); -  return nullptr; -} -  struct NewGVN::ValueDFS {    int DFSIn = 0;    int DFSOut = 0;    int LocalNum = 0; -  // Only one of these will be set. -  Value *Val = nullptr; +  // Only one of Def and U will be set. +  // The bool in the Def tells us whether the Def is the stored value of a +  // store. +  PointerIntPair<Value *, 1, bool> Def;    Use *U = nullptr; -    bool operator<(const ValueDFS &Other) const {      // It's not enough that any given field be less than - we have sets      // of fields that need to be evaluated together to give a proper ordering. @@ -1875,89 +2836,151 @@ struct NewGVN::ValueDFS {      // but .val  and .u.      // It does not matter what order we replace these operands in.      // You will always end up with the same IR, and this is guaranteed. -    return std::tie(DFSIn, DFSOut, LocalNum, Val, U) < -           std::tie(Other.DFSIn, Other.DFSOut, Other.LocalNum, Other.Val, +    return std::tie(DFSIn, DFSOut, LocalNum, Def, U) < +           std::tie(Other.DFSIn, Other.DFSOut, Other.LocalNum, Other.Def,                      Other.U);    }  }; -void NewGVN::convertDenseToDFSOrdered( -    CongruenceClass::MemberSet &Dense, -    SmallVectorImpl<ValueDFS> &DFSOrderedSet) { +// This function converts the set of members for a congruence class from values, +// to sets of defs and uses with associated DFS info.  The total number of +// reachable uses for each value is stored in UseCount, and instructions that +// seem +// dead (have no non-dead uses) are stored in ProbablyDead. +void NewGVN::convertClassToDFSOrdered( +    const CongruenceClass &Dense, SmallVectorImpl<ValueDFS> &DFSOrderedSet, +    DenseMap<const Value *, unsigned int> &UseCounts, +    SmallPtrSetImpl<Instruction *> &ProbablyDead) const {    for (auto D : Dense) {      // First add the value.      BasicBlock *BB = getBlockForValue(D);      // Constants are handled prior to ever calling this function, so      // we should only be left with instructions as members.      assert(BB && "Should have figured out a basic block for value"); -    ValueDFS VD; - -    std::pair<int, int> DFSPair = DFSDomMap[BB]; -    assert(DFSPair.first != -1 && DFSPair.second != -1 && "Invalid DFS Pair"); -    VD.DFSIn = DFSPair.first; -    VD.DFSOut = DFSPair.second; -    VD.Val = D; -    // If it's an instruction, use the real local dfs number. -    if (auto *I = dyn_cast<Instruction>(D)) -      VD.LocalNum = InstrDFS[I]; -    else -      llvm_unreachable("Should have been an instruction"); - -    DFSOrderedSet.emplace_back(VD); - -    // Now add the users. -    for (auto &U : D->uses()) { +    ValueDFS VDDef; +    DomTreeNode *DomNode = DT->getNode(BB); +    VDDef.DFSIn = DomNode->getDFSNumIn(); +    VDDef.DFSOut = DomNode->getDFSNumOut(); +    // If it's a store, use the leader of the value operand, if it's always +    // available, or the value operand.  TODO: We could do dominance checks to +    // find a dominating leader, but not worth it ATM. +    if (auto *SI = dyn_cast<StoreInst>(D)) { +      auto Leader = lookupOperandLeader(SI->getValueOperand()); +      if (alwaysAvailable(Leader)) { +        VDDef.Def.setPointer(Leader); +      } else { +        VDDef.Def.setPointer(SI->getValueOperand()); +        VDDef.Def.setInt(true); +      } +    } else { +      VDDef.Def.setPointer(D); +    } +    assert(isa<Instruction>(D) && +           "The dense set member should always be an instruction"); +    VDDef.LocalNum = InstrToDFSNum(D); +    DFSOrderedSet.emplace_back(VDDef); +    Instruction *Def = cast<Instruction>(D); +    unsigned int UseCount = 0; +    // Now add the uses. +    for (auto &U : Def->uses()) {        if (auto *I = dyn_cast<Instruction>(U.getUser())) { -        ValueDFS VD; +        // Don't try to replace into dead uses +        if (InstructionsToErase.count(I)) +          continue; +        ValueDFS VDUse;          // Put the phi node uses in the incoming block.          BasicBlock *IBlock;          if (auto *P = dyn_cast<PHINode>(I)) {            IBlock = P->getIncomingBlock(U);            // Make phi node users appear last in the incoming block            // they are from. -          VD.LocalNum = InstrDFS.size() + 1; +          VDUse.LocalNum = InstrDFS.size() + 1;          } else {            IBlock = I->getParent(); -          VD.LocalNum = InstrDFS[I]; +          VDUse.LocalNum = InstrToDFSNum(I);          } -        std::pair<int, int> DFSPair = DFSDomMap[IBlock]; -        VD.DFSIn = DFSPair.first; -        VD.DFSOut = DFSPair.second; -        VD.U = &U; -        DFSOrderedSet.emplace_back(VD); + +        // Skip uses in unreachable blocks, as we're going +        // to delete them. +        if (ReachableBlocks.count(IBlock) == 0) +          continue; + +        DomTreeNode *DomNode = DT->getNode(IBlock); +        VDUse.DFSIn = DomNode->getDFSNumIn(); +        VDUse.DFSOut = DomNode->getDFSNumOut(); +        VDUse.U = &U; +        ++UseCount; +        DFSOrderedSet.emplace_back(VDUse);        }      } + +    // If there are no uses, it's probably dead (but it may have side-effects, +    // so not definitely dead. Otherwise, store the number of uses so we can +    // track if it becomes dead later). +    if (UseCount == 0) +      ProbablyDead.insert(Def); +    else +      UseCounts[Def] = UseCount;    }  } -static void patchReplacementInstruction(Instruction *I, Value *Repl) { -  // Patch the replacement so that it is not more restrictive than the value -  // being replaced. -  auto *Op = dyn_cast<BinaryOperator>(I); -  auto *ReplOp = dyn_cast<BinaryOperator>(Repl); +// This function converts the set of members for a congruence class from values, +// to the set of defs for loads and stores, with associated DFS info. +void NewGVN::convertClassToLoadsAndStores( +    const CongruenceClass &Dense, +    SmallVectorImpl<ValueDFS> &LoadsAndStores) const { +  for (auto D : Dense) { +    if (!isa<LoadInst>(D) && !isa<StoreInst>(D)) +      continue; -  if (Op && ReplOp) -    ReplOp->andIRFlags(Op); +    BasicBlock *BB = getBlockForValue(D); +    ValueDFS VD; +    DomTreeNode *DomNode = DT->getNode(BB); +    VD.DFSIn = DomNode->getDFSNumIn(); +    VD.DFSOut = DomNode->getDFSNumOut(); +    VD.Def.setPointer(D); -  if (auto *ReplInst = dyn_cast<Instruction>(Repl)) { -    // FIXME: If both the original and replacement value are part of the -    // same control-flow region (meaning that the execution of one -    // guarentees the executation of the other), then we can combine the -    // noalias scopes here and do better than the general conservative -    // answer used in combineMetadata(). +    // If it's an instruction, use the real local dfs number. +    if (auto *I = dyn_cast<Instruction>(D)) +      VD.LocalNum = InstrToDFSNum(I); +    else +      llvm_unreachable("Should have been an instruction"); -    // In general, GVN unifies expressions over different control-flow -    // regions, and so we need a conservative combination of the noalias -    // scopes. -    unsigned KnownIDs[] = { -        LLVMContext::MD_tbaa,           LLVMContext::MD_alias_scope, -        LLVMContext::MD_noalias,        LLVMContext::MD_range, -        LLVMContext::MD_fpmath,         LLVMContext::MD_invariant_load, -        LLVMContext::MD_invariant_group}; -    combineMetadata(ReplInst, I, KnownIDs); +    LoadsAndStores.emplace_back(VD);    }  } +static void patchReplacementInstruction(Instruction *I, Value *Repl) { +  auto *ReplInst = dyn_cast<Instruction>(Repl); +  if (!ReplInst) +    return; + +  // Patch the replacement so that it is not more restrictive than the value +  // being replaced. +  // Note that if 'I' is a load being replaced by some operation, +  // for example, by an arithmetic operation, then andIRFlags() +  // would just erase all math flags from the original arithmetic +  // operation, which is clearly not wanted and not needed. +  if (!isa<LoadInst>(I)) +    ReplInst->andIRFlags(I); + +  // FIXME: If both the original and replacement value are part of the +  // same control-flow region (meaning that the execution of one +  // guarantees the execution of the other), then we can combine the +  // noalias scopes here and do better than the general conservative +  // answer used in combineMetadata(). + +  // In general, GVN unifies expressions over different control-flow +  // regions, and so we need a conservative combination of the noalias +  // scopes. +  static const unsigned KnownIDs[] = { +      LLVMContext::MD_tbaa,           LLVMContext::MD_alias_scope, +      LLVMContext::MD_noalias,        LLVMContext::MD_range, +      LLVMContext::MD_fpmath,         LLVMContext::MD_invariant_load, +      LLVMContext::MD_invariant_group}; +  combineMetadata(ReplInst, I, KnownIDs); +} +  static void patchAndReplaceAllUsesWith(Instruction *I, Value *Repl) {    patchReplacementInstruction(I, Repl);    I->replaceAllUsesWith(Repl); @@ -1967,10 +2990,6 @@ void NewGVN::deleteInstructionsInBlock(BasicBlock *BB) {    DEBUG(dbgs() << "  BasicBlock Dead:" << *BB);    ++NumGVNBlocksDeleted; -  // Check to see if there are non-terminating instructions to delete. -  if (isa<TerminatorInst>(BB->begin())) -    return; -    // Delete the instructions backwards, as it has a reduced likelihood of having    // to update as many def-use and use-def chains. Start after the terminator.    auto StartPoint = BB->rbegin(); @@ -1987,6 +3006,11 @@ void NewGVN::deleteInstructionsInBlock(BasicBlock *BB) {      Inst.eraseFromParent();      ++NumGVNInstrDeleted;    } +  // Now insert something that simplifycfg will turn into an unreachable. +  Type *Int8Ty = Type::getInt8Ty(BB->getContext()); +  new StoreInst(UndefValue::get(Int8Ty), +                Constant::getNullValue(Int8Ty->getPointerTo()), +                BB->getTerminator());  }  void NewGVN::markInstructionForDeletion(Instruction *I) { @@ -2086,59 +3110,59 @@ bool NewGVN::eliminateInstructions(Function &F) {          }        }      } -    DomTreeNode *Node = DT->getNode(&B); -    if (Node) -      DFSDomMap[&B] = {Node->getDFSNumIn(), Node->getDFSNumOut()};    } -  for (CongruenceClass *CC : CongruenceClasses) { -    // FIXME: We should eventually be able to replace everything still -    // in the initial class with undef, as they should be unreachable. -    // Right now, initial still contains some things we skip value -    // numbering of (UNREACHABLE's, for example). -    if (CC == InitialClass || CC->Dead) +  // Map to store the use counts +  DenseMap<const Value *, unsigned int> UseCounts; +  for (CongruenceClass *CC : reverse(CongruenceClasses)) { +    // Track the equivalent store info so we can decide whether to try +    // dead store elimination. +    SmallVector<ValueDFS, 8> PossibleDeadStores; +    SmallPtrSet<Instruction *, 8> ProbablyDead; +    if (CC->isDead() || CC->empty())        continue; -    assert(CC->RepLeader && "We should have had a leader"); +    // Everything still in the TOP class is unreachable or dead. +    if (CC == TOPClass) { +#ifndef NDEBUG +      for (auto M : *CC) +        assert((!ReachableBlocks.count(cast<Instruction>(M)->getParent()) || +                InstructionsToErase.count(cast<Instruction>(M))) && +               "Everything in TOP should be unreachable or dead at this " +               "point"); +#endif +      continue; +    } +    assert(CC->getLeader() && "We should have had a leader");      // If this is a leader that is always available, and it's a      // constant or has no equivalences, just replace everything with      // it. We then update the congruence class with whatever members      // are left. -    if (alwaysAvailable(CC->RepLeader)) { -      SmallPtrSet<Value *, 4> MembersLeft; -      for (auto M : CC->Members) { - +    Value *Leader = +        CC->getStoredValue() ? CC->getStoredValue() : CC->getLeader(); +    if (alwaysAvailable(Leader)) { +      CongruenceClass::MemberSet MembersLeft; +      for (auto M : *CC) {          Value *Member = M; -          // Void things have no uses we can replace. -        if (Member == CC->RepLeader || Member->getType()->isVoidTy()) { +        if (Member == Leader || !isa<Instruction>(Member) || +            Member->getType()->isVoidTy()) {            MembersLeft.insert(Member);            continue;          } - -        DEBUG(dbgs() << "Found replacement " << *(CC->RepLeader) << " for " -                     << *Member << "\n"); -        // Due to equality propagation, these may not always be -        // instructions, they may be real values.  We don't really -        // care about trying to replace the non-instructions. -        if (auto *I = dyn_cast<Instruction>(Member)) { -          assert(CC->RepLeader != I && -                 "About to accidentally remove our leader"); -          replaceInstruction(I, CC->RepLeader); -          AnythingReplaced = true; - -          continue; -        } else { -          MembersLeft.insert(I); -        } +        DEBUG(dbgs() << "Found replacement " << *(Leader) << " for " << *Member +                     << "\n"); +        auto *I = cast<Instruction>(Member); +        assert(Leader != I && "About to accidentally remove our leader"); +        replaceInstruction(I, Leader); +        AnythingReplaced = true;        } -      CC->Members.swap(MembersLeft); - +      CC->swap(MembersLeft);      } else { -      DEBUG(dbgs() << "Eliminating in congruence class " << CC->ID << "\n"); +      DEBUG(dbgs() << "Eliminating in congruence class " << CC->getID() +                   << "\n");        // If this is a singleton, we can skip it. -      if (CC->Members.size() != 1) { - +      if (CC->size() != 1) {          // This is a stack because equality replacement/etc may place          // constants in the middle of the member list, and we want to use          // those constant values in preference to the current leader, over @@ -2147,24 +3171,19 @@ bool NewGVN::eliminateInstructions(Function &F) {          // Convert the members to DFS ordered sets and then merge them.          SmallVector<ValueDFS, 8> DFSOrderedSet; -        convertDenseToDFSOrdered(CC->Members, DFSOrderedSet); +        convertClassToDFSOrdered(*CC, DFSOrderedSet, UseCounts, ProbablyDead);          // Sort the whole thing.          std::sort(DFSOrderedSet.begin(), DFSOrderedSet.end()); -          for (auto &VD : DFSOrderedSet) {            int MemberDFSIn = VD.DFSIn;            int MemberDFSOut = VD.DFSOut; -          Value *Member = VD.Val; -          Use *MemberUse = VD.U; - -          if (Member) { -            // We ignore void things because we can't get a value from them. -            // FIXME: We could actually use this to kill dead stores that are -            // dominated by equivalent earlier stores. -            if (Member->getType()->isVoidTy()) -              continue; -          } +          Value *Def = VD.Def.getPointer(); +          bool FromStore = VD.Def.getInt(); +          Use *U = VD.U; +          // We ignore void things because we can't get a value from them. +          if (Def && Def->getType()->isVoidTy()) +            continue;            if (EliminationStack.empty()) {              DEBUG(dbgs() << "Elimination Stack is empty\n"); @@ -2189,69 +3208,240 @@ bool NewGVN::eliminateInstructions(Function &F) {            // start using, we also push.            // Otherwise, we walk along, processing members who are            // dominated by this scope, and eliminate them. -          bool ShouldPush = -              Member && (EliminationStack.empty() || isa<Constant>(Member)); +          bool ShouldPush = Def && EliminationStack.empty();            bool OutOfScope =                !EliminationStack.isInScope(MemberDFSIn, MemberDFSOut);            if (OutOfScope || ShouldPush) {              // Sync to our current scope.              EliminationStack.popUntilDFSScope(MemberDFSIn, MemberDFSOut); -            ShouldPush |= Member && EliminationStack.empty(); +            bool ShouldPush = Def && EliminationStack.empty();              if (ShouldPush) { -              EliminationStack.push_back(Member, MemberDFSIn, MemberDFSOut); +              EliminationStack.push_back(Def, MemberDFSIn, MemberDFSOut); +            } +          } + +          // Skip the Def's, we only want to eliminate on their uses.  But mark +          // dominated defs as dead. +          if (Def) { +            // For anything in this case, what and how we value number +            // guarantees that any side-effets that would have occurred (ie +            // throwing, etc) can be proven to either still occur (because it's +            // dominated by something that has the same side-effects), or never +            // occur.  Otherwise, we would not have been able to prove it value +            // equivalent to something else. For these things, we can just mark +            // it all dead.  Note that this is different from the "ProbablyDead" +            // set, which may not be dominated by anything, and thus, are only +            // easy to prove dead if they are also side-effect free. Note that +            // because stores are put in terms of the stored value, we skip +            // stored values here. If the stored value is really dead, it will +            // still be marked for deletion when we process it in its own class. +            if (!EliminationStack.empty() && Def != EliminationStack.back() && +                isa<Instruction>(Def) && !FromStore) +              markInstructionForDeletion(cast<Instruction>(Def)); +            continue; +          } +          // At this point, we know it is a Use we are trying to possibly +          // replace. + +          assert(isa<Instruction>(U->get()) && +                 "Current def should have been an instruction"); +          assert(isa<Instruction>(U->getUser()) && +                 "Current user should have been an instruction"); + +          // If the thing we are replacing into is already marked to be dead, +          // this use is dead.  Note that this is true regardless of whether +          // we have anything dominating the use or not.  We do this here +          // because we are already walking all the uses anyway. +          Instruction *InstUse = cast<Instruction>(U->getUser()); +          if (InstructionsToErase.count(InstUse)) { +            auto &UseCount = UseCounts[U->get()]; +            if (--UseCount == 0) { +              ProbablyDead.insert(cast<Instruction>(U->get()));              }            }            // If we get to this point, and the stack is empty we must have a use -          // with nothing we can use to eliminate it, just skip it. +          // with nothing we can use to eliminate this use, so just skip it.            if (EliminationStack.empty())              continue; -          // Skip the Value's, we only want to eliminate on their uses. -          if (Member) -            continue; -          Value *Result = EliminationStack.back(); +          Value *DominatingLeader = EliminationStack.back();            // Don't replace our existing users with ourselves. -          if (MemberUse->get() == Result) +          if (U->get() == DominatingLeader)              continue; - -          DEBUG(dbgs() << "Found replacement " << *Result << " for " -                       << *MemberUse->get() << " in " << *(MemberUse->getUser()) -                       << "\n"); +          DEBUG(dbgs() << "Found replacement " << *DominatingLeader << " for " +                       << *U->get() << " in " << *(U->getUser()) << "\n");            // If we replaced something in an instruction, handle the patching of -          // metadata. -          if (auto *ReplacedInst = dyn_cast<Instruction>(MemberUse->get())) -            patchReplacementInstruction(ReplacedInst, Result); - -          assert(isa<Instruction>(MemberUse->getUser())); -          MemberUse->set(Result); +          // metadata.  Skip this if we are replacing predicateinfo with its +          // original operand, as we already know we can just drop it. +          auto *ReplacedInst = cast<Instruction>(U->get()); +          auto *PI = PredInfo->getPredicateInfoFor(ReplacedInst); +          if (!PI || DominatingLeader != PI->OriginalOp) +            patchReplacementInstruction(ReplacedInst, DominatingLeader); +          U->set(DominatingLeader); +          // This is now a use of the dominating leader, which means if the +          // dominating leader was dead, it's now live! +          auto &LeaderUseCount = UseCounts[DominatingLeader]; +          // It's about to be alive again. +          if (LeaderUseCount == 0 && isa<Instruction>(DominatingLeader)) +            ProbablyDead.erase(cast<Instruction>(DominatingLeader)); +          ++LeaderUseCount;            AnythingReplaced = true;          }        }      } +    // At this point, anything still in the ProbablyDead set is actually dead if +    // would be trivially dead. +    for (auto *I : ProbablyDead) +      if (wouldInstructionBeTriviallyDead(I)) +        markInstructionForDeletion(I); +      // Cleanup the congruence class. -    SmallPtrSet<Value *, 4> MembersLeft; -    for (Value *Member : CC->Members) { -      if (Member->getType()->isVoidTy()) { +    CongruenceClass::MemberSet MembersLeft; +    for (auto *Member : *CC) +      if (!isa<Instruction>(Member) || +          !InstructionsToErase.count(cast<Instruction>(Member)))          MembersLeft.insert(Member); -        continue; -      } - -      if (auto *MemberInst = dyn_cast<Instruction>(Member)) { -        if (isInstructionTriviallyDead(MemberInst)) { -          // TODO: Don't mark loads of undefs. -          markInstructionForDeletion(MemberInst); -          continue; +    CC->swap(MembersLeft); + +    // If we have possible dead stores to look at, try to eliminate them. +    if (CC->getStoreCount() > 0) { +      convertClassToLoadsAndStores(*CC, PossibleDeadStores); +      std::sort(PossibleDeadStores.begin(), PossibleDeadStores.end()); +      ValueDFSStack EliminationStack; +      for (auto &VD : PossibleDeadStores) { +        int MemberDFSIn = VD.DFSIn; +        int MemberDFSOut = VD.DFSOut; +        Instruction *Member = cast<Instruction>(VD.Def.getPointer()); +        if (EliminationStack.empty() || +            !EliminationStack.isInScope(MemberDFSIn, MemberDFSOut)) { +          // Sync to our current scope. +          EliminationStack.popUntilDFSScope(MemberDFSIn, MemberDFSOut); +          if (EliminationStack.empty()) { +            EliminationStack.push_back(Member, MemberDFSIn, MemberDFSOut); +            continue; +          }          } +        // We already did load elimination, so nothing to do here. +        if (isa<LoadInst>(Member)) +          continue; +        assert(!EliminationStack.empty()); +        Instruction *Leader = cast<Instruction>(EliminationStack.back()); +        (void)Leader; +        assert(DT->dominates(Leader->getParent(), Member->getParent())); +        // Member is dominater by Leader, and thus dead +        DEBUG(dbgs() << "Marking dead store " << *Member +                     << " that is dominated by " << *Leader << "\n"); +        markInstructionForDeletion(Member); +        CC->erase(Member); +        ++NumGVNDeadStores;        } -      MembersLeft.insert(Member);      } -    CC->Members.swap(MembersLeft);    }    return AnythingReplaced;  } + +// This function provides global ranking of operations so that we can place them +// in a canonical order.  Note that rank alone is not necessarily enough for a +// complete ordering, as constants all have the same rank.  However, generally, +// we will simplify an operation with all constants so that it doesn't matter +// what order they appear in. +unsigned int NewGVN::getRank(const Value *V) const { +  // Prefer undef to anything else +  if (isa<UndefValue>(V)) +    return 0; +  if (isa<Constant>(V)) +    return 1; +  else if (auto *A = dyn_cast<Argument>(V)) +    return 2 + A->getArgNo(); + +  // Need to shift the instruction DFS by number of arguments + 3 to account for +  // the constant and argument ranking above. +  unsigned Result = InstrToDFSNum(V); +  if (Result > 0) +    return 3 + NumFuncArgs + Result; +  // Unreachable or something else, just return a really large number. +  return ~0; +} + +// This is a function that says whether two commutative operations should +// have their order swapped when canonicalizing. +bool NewGVN::shouldSwapOperands(const Value *A, const Value *B) const { +  // Because we only care about a total ordering, and don't rewrite expressions +  // in this order, we order by rank, which will give a strict weak ordering to +  // everything but constants, and then we order by pointer address. +  return std::make_pair(getRank(A), A) > std::make_pair(getRank(B), B); +} + +class NewGVNLegacyPass : public FunctionPass { +public: +  static char ID; // Pass identification, replacement for typeid. +  NewGVNLegacyPass() : FunctionPass(ID) { +    initializeNewGVNLegacyPassPass(*PassRegistry::getPassRegistry()); +  } +  bool runOnFunction(Function &F) override; + +private: +  void getAnalysisUsage(AnalysisUsage &AU) const override { +    AU.addRequired<AssumptionCacheTracker>(); +    AU.addRequired<DominatorTreeWrapperPass>(); +    AU.addRequired<TargetLibraryInfoWrapperPass>(); +    AU.addRequired<MemorySSAWrapperPass>(); +    AU.addRequired<AAResultsWrapperPass>(); +    AU.addPreserved<DominatorTreeWrapperPass>(); +    AU.addPreserved<GlobalsAAWrapperPass>(); +  } +}; + +bool NewGVNLegacyPass::runOnFunction(Function &F) { +  if (skipFunction(F)) +    return false; +  return NewGVN(F, &getAnalysis<DominatorTreeWrapperPass>().getDomTree(), +                &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F), +                &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(), +                &getAnalysis<AAResultsWrapperPass>().getAAResults(), +                &getAnalysis<MemorySSAWrapperPass>().getMSSA(), +                F.getParent()->getDataLayout()) +      .runGVN(); +} + +INITIALIZE_PASS_BEGIN(NewGVNLegacyPass, "newgvn", "Global Value Numbering", +                      false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) +INITIALIZE_PASS_END(NewGVNLegacyPass, "newgvn", "Global Value Numbering", false, +                    false) + +char NewGVNLegacyPass::ID = 0; + +// createGVNPass - The public interface to this file. +FunctionPass *llvm::createNewGVNPass() { return new NewGVNLegacyPass(); } + +PreservedAnalyses NewGVNPass::run(Function &F, AnalysisManager<Function> &AM) { +  // Apparently the order in which we get these results matter for +  // the old GVN (see Chandler's comment in GVN.cpp). I'll keep +  // the same order here, just in case. +  auto &AC = AM.getResult<AssumptionAnalysis>(F); +  auto &DT = AM.getResult<DominatorTreeAnalysis>(F); +  auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); +  auto &AA = AM.getResult<AAManager>(F); +  auto &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA(); +  bool Changed = +      NewGVN(F, &DT, &AC, &TLI, &AA, &MSSA, F.getParent()->getDataLayout()) +          .runGVN(); +  if (!Changed) +    return PreservedAnalyses::all(); +  PreservedAnalyses PA; +  PA.preserve<DominatorTreeAnalysis>(); +  PA.preserve<GlobalsAA>(); +  return PA; +} diff --git a/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp b/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp index 1a7ddc9585ba..1bfecea2f61e 100644 --- a/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp +++ b/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp @@ -66,7 +66,7 @@ static bool optimizeSQRT(CallInst *Call, Function *CalledFunc,    // Add attribute "readnone" so that backend can use a native sqrt instruction    // for this call. Insert a FP compare instruction and a conditional branch    // at the end of CurrBB. -  Call->addAttribute(AttributeSet::FunctionIndex, Attribute::ReadNone); +  Call->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);    CurrBB.getTerminator()->eraseFromParent();    Builder.SetInsertPoint(&CurrBB);    Value *FCmp = Builder.CreateFCmpOEQ(Call, Call); @@ -98,14 +98,14 @@ static bool runPartiallyInlineLibCalls(Function &F, TargetLibraryInfo *TLI,        // Skip if function either has local linkage or is not a known library        // function. -      LibFunc::Func LibFunc; +      LibFunc LF;        if (CalledFunc->hasLocalLinkage() || !CalledFunc->hasName() || -          !TLI->getLibFunc(CalledFunc->getName(), LibFunc)) +          !TLI->getLibFunc(CalledFunc->getName(), LF))          continue; -      switch (LibFunc) { -      case LibFunc::sqrtf: -      case LibFunc::sqrt: +      switch (LF) { +      case LibFunc_sqrtf: +      case LibFunc_sqrt:          if (TTI->haveFastSqrt(Call->getType()) &&              optimizeSQRT(Call, CalledFunc, *CurrBB, BB))            break; diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp index 65c814d7a63b..3dcab6090789 100644 --- a/lib/Transforms/Scalar/Reassociate.cpp +++ b/lib/Transforms/Scalar/Reassociate.cpp @@ -1069,8 +1069,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {  ///  /// Ops is the top-level list of add operands we're trying to factor.  static void FindSingleUseMultiplyFactors(Value *V, -                                         SmallVectorImpl<Value*> &Factors, -                                       const SmallVectorImpl<ValueEntry> &Ops) { +                                         SmallVectorImpl<Value*> &Factors) {    BinaryOperator *BO = isReassociableOp(V, Instruction::Mul, Instruction::FMul);    if (!BO) {      Factors.push_back(V); @@ -1078,8 +1077,8 @@ static void FindSingleUseMultiplyFactors(Value *V,    }    // Otherwise, add the LHS and RHS to the list of factors. -  FindSingleUseMultiplyFactors(BO->getOperand(1), Factors, Ops); -  FindSingleUseMultiplyFactors(BO->getOperand(0), Factors, Ops); +  FindSingleUseMultiplyFactors(BO->getOperand(1), Factors); +  FindSingleUseMultiplyFactors(BO->getOperand(0), Factors);  }  /// Optimize a series of operands to an 'and', 'or', or 'xor' instruction. @@ -1499,7 +1498,7 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I,      // Compute all of the factors of this added value.      SmallVector<Value*, 8> Factors; -    FindSingleUseMultiplyFactors(BOp, Factors, Ops); +    FindSingleUseMultiplyFactors(BOp, Factors);      assert(Factors.size() > 1 && "Bad linearize!");      // Add one to FactorOccurrences for each unique factor in this op. @@ -2236,8 +2235,8 @@ PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) {    ValueRankMap.clear();    if (MadeChange) { -    // FIXME: This should also 'preserve the CFG'. -    auto PA = PreservedAnalyses(); +    PreservedAnalyses PA; +    PA.preserveSet<CFGAnalyses>();      PA.preserve<GlobalsAA>();      return PA;    } diff --git a/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp index 1de742050cb3..f344eb151464 100644 --- a/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -365,6 +365,11 @@ findBaseDefiningValueOfVector(Value *I) {      // for particular sufflevector patterns.      return BaseDefiningValueResult(I, false); +  // The behavior of getelementptr instructions is the same for vector and +  // non-vector data types. +  if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) +    return findBaseDefiningValue(GEP->getPointerOperand()); +    // A PHI or Select is a base defining value.  The outer findBasePointer    // algorithm is responsible for constructing a base value for this BDV.    assert((isa<SelectInst>(I) || isa<PHINode>(I)) && @@ -634,7 +639,7 @@ static BDVState meetBDVStateImpl(const BDVState &LHS, const BDVState &RHS) {  // Values of type BDVState form a lattice, and this function implements the meet  // operation. -static BDVState meetBDVState(BDVState LHS, BDVState RHS) { +static BDVState meetBDVState(const BDVState &LHS, const BDVState &RHS) {    BDVState Result = meetBDVStateImpl(LHS, RHS);    assert(Result == meetBDVStateImpl(RHS, LHS) &&           "Math is wrong: meet does not commute!"); @@ -1123,14 +1128,14 @@ normalizeForInvokeSafepoint(BasicBlock *BB, BasicBlock *InvokeParent,  // Create new attribute set containing only attributes which can be transferred  // from original call to the safepoint. -static AttributeSet legalizeCallAttributes(AttributeSet AS) { -  AttributeSet Ret; +static AttributeList legalizeCallAttributes(AttributeList AS) { +  AttributeList Ret;    for (unsigned Slot = 0; Slot < AS.getNumSlots(); Slot++) {      unsigned Index = AS.getSlotIndex(Slot); -    if (Index == AttributeSet::ReturnIndex || -        Index == AttributeSet::FunctionIndex) { +    if (Index == AttributeList::ReturnIndex || +        Index == AttributeList::FunctionIndex) {        for (Attribute Attr : make_range(AS.begin(Slot), AS.end(Slot))) { @@ -1148,7 +1153,7 @@ static AttributeSet legalizeCallAttributes(AttributeSet AS) {          Ret = Ret.addAttributes(              AS.getContext(), Index, -            AttributeSet::get(AS.getContext(), Index, AttrBuilder(Attr))); +            AttributeList::get(AS.getContext(), Index, AttrBuilder(Attr)));        }      } @@ -1299,12 +1304,11 @@ static StringRef getDeoptLowering(CallSite CS) {    const char *DeoptLowering = "deopt-lowering";    if (CS.hasFnAttr(DeoptLowering)) {      // FIXME: CallSite has a *really* confusing interface around attributes -    // with values.   -    const AttributeSet &CSAS = CS.getAttributes(); -    if (CSAS.hasAttribute(AttributeSet::FunctionIndex, -                          DeoptLowering)) -      return CSAS.getAttribute(AttributeSet::FunctionIndex, -                               DeoptLowering).getValueAsString(); +    // with values. +    const AttributeList &CSAS = CS.getAttributes(); +    if (CSAS.hasAttribute(AttributeList::FunctionIndex, DeoptLowering)) +      return CSAS.getAttribute(AttributeList::FunctionIndex, DeoptLowering) +          .getValueAsString();      Function *F = CS.getCalledFunction();      assert(F && F->hasFnAttribute(DeoptLowering));      return F->getFnAttribute(DeoptLowering).getValueAsString(); @@ -1388,7 +1392,6 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */    // Create the statepoint given all the arguments    Instruction *Token = nullptr; -  AttributeSet ReturnAttrs;    if (CS.isCall()) {      CallInst *ToReplace = cast<CallInst>(CS.getInstruction());      CallInst *Call = Builder.CreateGCStatepointCall( @@ -1400,11 +1403,12 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */      // Currently we will fail on parameter attributes and on certain      // function attributes. -    AttributeSet NewAttrs = legalizeCallAttributes(ToReplace->getAttributes()); +    AttributeList NewAttrs = legalizeCallAttributes(ToReplace->getAttributes());      // In case if we can handle this set of attributes - set up function attrs      // directly on statepoint and return attrs later for gc_result intrinsic. -    Call->setAttributes(NewAttrs.getFnAttributes()); -    ReturnAttrs = NewAttrs.getRetAttributes(); +    Call->setAttributes(AttributeList::get(Call->getContext(), +                                           AttributeList::FunctionIndex, +                                           NewAttrs.getFnAttributes()));      Token = Call; @@ -1428,11 +1432,12 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */      // Currently we will fail on parameter attributes and on certain      // function attributes. -    AttributeSet NewAttrs = legalizeCallAttributes(ToReplace->getAttributes()); +    AttributeList NewAttrs = legalizeCallAttributes(ToReplace->getAttributes());      // In case if we can handle this set of attributes - set up function attrs      // directly on statepoint and return attrs later for gc_result intrinsic. -    Invoke->setAttributes(NewAttrs.getFnAttributes()); -    ReturnAttrs = NewAttrs.getRetAttributes(); +    Invoke->setAttributes(AttributeList::get(Invoke->getContext(), +                                             AttributeList::FunctionIndex, +                                             NewAttrs.getFnAttributes()));      Token = Invoke; @@ -1478,7 +1483,9 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */        StringRef Name =            CS.getInstruction()->hasName() ? CS.getInstruction()->getName() : "";        CallInst *GCResult = Builder.CreateGCResult(Token, CS.getType(), Name); -      GCResult->setAttributes(CS.getAttributes().getRetAttributes()); +      GCResult->setAttributes( +          AttributeList::get(GCResult->getContext(), AttributeList::ReturnIndex, +                             CS.getAttributes().getRetAttributes()));        // We cannot RAUW or delete CS.getInstruction() because it could be in the        // live set of some other safepoint, in which case that safepoint's @@ -1615,8 +1622,10 @@ static void relocationViaAlloca(    // Emit alloca for "LiveValue" and record it in "allocaMap" and    // "PromotableAllocas" +  const DataLayout &DL = F.getParent()->getDataLayout();    auto emitAllocaFor = [&](Value *LiveValue) { -    AllocaInst *Alloca = new AllocaInst(LiveValue->getType(), "", +    AllocaInst *Alloca = new AllocaInst(LiveValue->getType(), +                                        DL.getAllocaAddrSpace(), "",                                          F.getEntryBlock().getFirstNonPHI());      AllocaMap[LiveValue] = Alloca;      PromotableAllocas.push_back(Alloca); @@ -1873,7 +1882,7 @@ chainToBasePointerCost(SmallVectorImpl<Instruction*> &Chain,               "non noop cast is found during rematerialization");        Type *SrcTy = CI->getOperand(0)->getType(); -      Cost += TTI.getCastInstrCost(CI->getOpcode(), CI->getType(), SrcTy); +      Cost += TTI.getCastInstrCost(CI->getOpcode(), CI->getType(), SrcTy, CI);      } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Instr)) {        // Cost of the address calculation @@ -2304,7 +2313,7 @@ static void RemoveNonValidAttrAtIndex(LLVMContext &Ctx, AttrHolder &AH,    if (!R.empty())      AH.setAttributes(AH.getAttributes().removeAttributes( -        Ctx, Index, AttributeSet::get(Ctx, Index, R))); +        Ctx, Index, AttributeList::get(Ctx, Index, R)));  }  void @@ -2316,7 +2325,7 @@ RewriteStatepointsForGC::stripNonValidAttributesFromPrototype(Function &F) {        RemoveNonValidAttrAtIndex(Ctx, F, A.getArgNo() + 1);    if (isa<PointerType>(F.getReturnType())) -    RemoveNonValidAttrAtIndex(Ctx, F, AttributeSet::ReturnIndex); +    RemoveNonValidAttrAtIndex(Ctx, F, AttributeList::ReturnIndex);  }  void RewriteStatepointsForGC::stripNonValidAttributesFromBody(Function &F) { @@ -2351,7 +2360,7 @@ void RewriteStatepointsForGC::stripNonValidAttributesFromBody(Function &F) {          if (isa<PointerType>(CS.getArgument(i)->getType()))            RemoveNonValidAttrAtIndex(Ctx, CS, i + 1);        if (isa<PointerType>(CS.getType())) -        RemoveNonValidAttrAtIndex(Ctx, CS, AttributeSet::ReturnIndex); +        RemoveNonValidAttrAtIndex(Ctx, CS, AttributeList::ReturnIndex);      }    }  } diff --git a/lib/Transforms/Scalar/SCCP.cpp b/lib/Transforms/Scalar/SCCP.cpp index ede381c4c243..8908dae2f545 100644 --- a/lib/Transforms/Scalar/SCCP.cpp +++ b/lib/Transforms/Scalar/SCCP.cpp @@ -140,6 +140,14 @@ public:      return nullptr;    } +  /// getBlockAddress - If this is a constant with a BlockAddress value, return +  /// it, otherwise return null. +  BlockAddress *getBlockAddress() const { +    if (isConstant()) +      return dyn_cast<BlockAddress>(getConstant()); +    return nullptr; +  } +    void markForcedConstant(Constant *V) {      assert(isUnknown() && "Can't force a defined value!");      Val.setInt(forcedconstant); @@ -306,20 +314,14 @@ public:      return MRVFunctionsTracked;    } -  void markOverdefined(Value *V) { -    assert(!V->getType()->isStructTy() && -           "structs should use markAnythingOverdefined"); -    markOverdefined(ValueState[V], V); -  } - -  /// markAnythingOverdefined - Mark the specified value overdefined.  This +  /// markOverdefined - Mark the specified value overdefined.  This    /// works with both scalars and structs. -  void markAnythingOverdefined(Value *V) { +  void markOverdefined(Value *V) {      if (auto *STy = dyn_cast<StructType>(V->getType()))        for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i)          markOverdefined(getStructValueState(V, i), V);      else -      markOverdefined(V); +      markOverdefined(ValueState[V], V);    }    // isStructLatticeConstant - Return true if all the lattice values @@ -513,12 +515,12 @@ private:    void visitCmpInst(CmpInst &I);    void visitExtractValueInst(ExtractValueInst &EVI);    void visitInsertValueInst(InsertValueInst &IVI); -  void visitLandingPadInst(LandingPadInst &I) { markAnythingOverdefined(&I); } +  void visitLandingPadInst(LandingPadInst &I) { markOverdefined(&I); }    void visitFuncletPadInst(FuncletPadInst &FPI) { -    markAnythingOverdefined(&FPI); +    markOverdefined(&FPI);    }    void visitCatchSwitchInst(CatchSwitchInst &CPI) { -    markAnythingOverdefined(&CPI); +    markOverdefined(&CPI);      visitTerminatorInst(CPI);    } @@ -538,16 +540,16 @@ private:    void visitUnreachableInst(TerminatorInst &I) { /*returns void*/ }    void visitFenceInst     (FenceInst &I) { /*returns void*/ }    void visitAtomicCmpXchgInst(AtomicCmpXchgInst &I) { -    markAnythingOverdefined(&I); +    markOverdefined(&I);    }    void visitAtomicRMWInst (AtomicRMWInst &I) { markOverdefined(&I); }    void visitAllocaInst    (Instruction &I) { markOverdefined(&I); } -  void visitVAArgInst     (Instruction &I) { markAnythingOverdefined(&I); } +  void visitVAArgInst     (Instruction &I) { markOverdefined(&I); }    void visitInstruction(Instruction &I) {      // If a new instruction is added to LLVM that we don't handle.      DEBUG(dbgs() << "SCCP: Don't know how to handle: " << I << '\n'); -    markAnythingOverdefined(&I);   // Just in case +    markOverdefined(&I);   // Just in case    }  }; @@ -602,14 +604,36 @@ void SCCPSolver::getFeasibleSuccessors(TerminatorInst &TI,        return;      } -    Succs[SI->findCaseValue(CI).getSuccessorIndex()] = true; +    Succs[SI->findCaseValue(CI)->getSuccessorIndex()] = true;      return;    } -  // TODO: This could be improved if the operand is a [cast of a] BlockAddress. -  if (isa<IndirectBrInst>(&TI)) { -    // Just mark all destinations executable! -    Succs.assign(TI.getNumSuccessors(), true); +  // In case of indirect branch and its address is a blockaddress, we mark +  // the target as executable. +  if (auto *IBR = dyn_cast<IndirectBrInst>(&TI)) { +    // Casts are folded by visitCastInst. +    LatticeVal IBRValue = getValueState(IBR->getAddress()); +    BlockAddress *Addr = IBRValue.getBlockAddress(); +    if (!Addr) {   // Overdefined or unknown condition? +      // All destinations are executable! +      if (!IBRValue.isUnknown()) +        Succs.assign(TI.getNumSuccessors(), true); +      return; +    } + +    BasicBlock* T = Addr->getBasicBlock(); +    assert(Addr->getFunction() == T->getParent() && +           "Block address of a different function ?"); +    for (unsigned i = 0; i < IBR->getNumSuccessors(); ++i) { +      // This is the target. +      if (IBR->getDestination(i) == T) { +        Succs[i] = true; +        return; +      } +    } + +    // If we didn't find our destination in the IBR successor list, then we +    // have undefined behavior. Its ok to assume no successor is executable.      return;    } @@ -659,13 +683,21 @@ bool SCCPSolver::isEdgeFeasible(BasicBlock *From, BasicBlock *To) {      if (!CI)        return !SCValue.isUnknown(); -    return SI->findCaseValue(CI).getCaseSuccessor() == To; +    return SI->findCaseValue(CI)->getCaseSuccessor() == To;    } -  // Just mark all destinations executable! -  // TODO: This could be improved if the operand is a [cast of a] BlockAddress. -  if (isa<IndirectBrInst>(TI)) -    return true; +  // In case of indirect branch and its address is a blockaddress, we mark +  // the target as executable. +  if (auto *IBR = dyn_cast<IndirectBrInst>(TI)) { +    LatticeVal IBRValue = getValueState(IBR->getAddress()); +    BlockAddress *Addr = IBRValue.getBlockAddress(); + +    if (!Addr) +      return !IBRValue.isUnknown(); + +    // At this point, the indirectbr is branching on a blockaddress. +    return Addr->getBasicBlock() == To; +  }    DEBUG(dbgs() << "Unknown terminator instruction: " << *TI << '\n');    llvm_unreachable("SCCP: Don't know how to handle this terminator!"); @@ -693,7 +725,7 @@ void SCCPSolver::visitPHINode(PHINode &PN) {    // If this PN returns a struct, just mark the result overdefined.    // TODO: We could do a lot better than this if code actually uses this.    if (PN.getType()->isStructTy()) -    return markAnythingOverdefined(&PN); +    return markOverdefined(&PN);    if (getValueState(&PN).isOverdefined())      return;  // Quick exit @@ -803,7 +835,7 @@ void SCCPSolver::visitExtractValueInst(ExtractValueInst &EVI) {    // If this returns a struct, mark all elements over defined, we don't track    // structs in structs.    if (EVI.getType()->isStructTy()) -    return markAnythingOverdefined(&EVI); +    return markOverdefined(&EVI);    // If this is extracting from more than one level of struct, we don't know.    if (EVI.getNumIndices() != 1) @@ -828,7 +860,7 @@ void SCCPSolver::visitInsertValueInst(InsertValueInst &IVI) {    // If this has more than one index, we can't handle it, drive all results to    // undef.    if (IVI.getNumIndices() != 1) -    return markAnythingOverdefined(&IVI); +    return markOverdefined(&IVI);    Value *Aggr = IVI.getAggregateOperand();    unsigned Idx = *IVI.idx_begin(); @@ -857,7 +889,7 @@ void SCCPSolver::visitSelectInst(SelectInst &I) {    // If this select returns a struct, just mark the result overdefined.    // TODO: We could do a lot better than this if code actually uses this.    if (I.getType()->isStructTy()) -    return markAnythingOverdefined(&I); +    return markOverdefined(&I);    LatticeVal CondValue = getValueState(I.getCondition());    if (CondValue.isUnknown()) @@ -910,9 +942,16 @@ void SCCPSolver::visitBinaryOperator(Instruction &I) {    // Otherwise, one of our operands is overdefined.  Try to produce something    // better than overdefined with some tricks. - -  // If this is an AND or OR with 0 or -1, it doesn't matter that the other -  // operand is overdefined. +  // If this is 0 / Y, it doesn't matter that the second operand is +  // overdefined, and we can replace it with zero. +  if (I.getOpcode() == Instruction::UDiv || I.getOpcode() == Instruction::SDiv) +    if (V1State.isConstant() && V1State.getConstant()->isNullValue()) +      return markConstant(IV, &I, V1State.getConstant()); + +  // If this is: +  // -> AND/MUL with 0 +  // -> OR with -1 +  // it doesn't matter that the other operand is overdefined.    if (I.getOpcode() == Instruction::And || I.getOpcode() == Instruction::Mul ||        I.getOpcode() == Instruction::Or) {      LatticeVal *NonOverdefVal = nullptr; @@ -1021,7 +1060,7 @@ void SCCPSolver::visitStoreInst(StoreInst &SI) {  void SCCPSolver::visitLoadInst(LoadInst &I) {    // If this load is of a struct, just mark the result overdefined.    if (I.getType()->isStructTy()) -    return markAnythingOverdefined(&I); +    return markOverdefined(&I);    LatticeVal PtrVal = getValueState(I.getOperand(0));    if (PtrVal.isUnknown()) return;   // The pointer is not resolved yet! @@ -1107,7 +1146,7 @@ CallOverdefined:      }      // Otherwise, we don't know anything about this call, mark it overdefined. -    return markAnythingOverdefined(I); +    return markOverdefined(I);    }    // If this is a local function that doesn't have its address taken, mark its @@ -1483,6 +1522,31 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) {        return true;      } +   if (auto *IBR = dyn_cast<IndirectBrInst>(TI)) { +      // Indirect branch with no successor ?. Its ok to assume it branches +      // to no target. +      if (IBR->getNumSuccessors() < 1) +        continue; + +      if (!getValueState(IBR->getAddress()).isUnknown()) +        continue; + +      // If the input to SCCP is actually branch on undef, fix the undef to +      // the first successor of the indirect branch. +      if (isa<UndefValue>(IBR->getAddress())) { +        IBR->setAddress(BlockAddress::get(IBR->getSuccessor(0))); +        markEdgeExecutable(&BB, IBR->getSuccessor(0)); +        return true; +      } + +      // Otherwise, it is a branch on a symbolic value which is currently +      // considered to be undef.  Handle this by forcing the input value to the +      // branch to the first successor. +      markForcedConstant(IBR->getAddress(), +                         BlockAddress::get(IBR->getSuccessor(0))); +      return true; +    } +      if (auto *SI = dyn_cast<SwitchInst>(TI)) {        if (!SI->getNumCases() || !getValueState(SI->getCondition()).isUnknown())          continue; @@ -1490,12 +1554,12 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) {        // If the input to SCCP is actually switch on undef, fix the undef to        // the first constant.        if (isa<UndefValue>(SI->getCondition())) { -        SI->setCondition(SI->case_begin().getCaseValue()); -        markEdgeExecutable(&BB, SI->case_begin().getCaseSuccessor()); +        SI->setCondition(SI->case_begin()->getCaseValue()); +        markEdgeExecutable(&BB, SI->case_begin()->getCaseSuccessor());          return true;        } -      markForcedConstant(SI->getCondition(), SI->case_begin().getCaseValue()); +      markForcedConstant(SI->getCondition(), SI->case_begin()->getCaseValue());        return true;      }    } @@ -1545,7 +1609,7 @@ static bool runSCCP(Function &F, const DataLayout &DL,    // Mark all arguments to the function as being overdefined.    for (Argument &AI : F.args()) -    Solver.markAnythingOverdefined(&AI); +    Solver.markOverdefined(&AI);    // Solve for constants.    bool ResolvedUndefs = true; @@ -1728,7 +1792,7 @@ static bool runIPSCCP(Module &M, const DataLayout &DL,      // Assume nothing about the incoming arguments.      for (Argument &AI : F.args()) -      Solver.markAnythingOverdefined(&AI); +      Solver.markOverdefined(&AI);    }    // Loop over global variables.  We inform the solver about any internal global @@ -1817,32 +1881,9 @@ static bool runIPSCCP(Module &M, const DataLayout &DL,          if (!I) continue;          bool Folded = ConstantFoldTerminator(I->getParent()); -        if (!Folded) { -          // The constant folder may not have been able to fold the terminator -          // if this is a branch or switch on undef.  Fold it manually as a -          // branch to the first successor. -#ifndef NDEBUG -          if (auto *BI = dyn_cast<BranchInst>(I)) { -            assert(BI->isConditional() && isa<UndefValue>(BI->getCondition()) && -                   "Branch should be foldable!"); -          } else if (auto *SI = dyn_cast<SwitchInst>(I)) { -            assert(isa<UndefValue>(SI->getCondition()) && "Switch should fold"); -          } else { -            llvm_unreachable("Didn't fold away reference to block!"); -          } -#endif - -          // Make this an uncond branch to the first successor. -          TerminatorInst *TI = I->getParent()->getTerminator(); -          BranchInst::Create(TI->getSuccessor(0), TI); - -          // Remove entries in successor phi nodes to remove edges. -          for (unsigned i = 1, e = TI->getNumSuccessors(); i != e; ++i) -            TI->getSuccessor(i)->removePredecessor(TI->getParent()); - -          // Remove the old terminator. -          TI->eraseFromParent(); -        } +        assert(Folded && +              "Expect TermInst on constantint or blockaddress to be folded"); +        (void) Folded;        }        // Finally, delete the basic block. diff --git a/lib/Transforms/Scalar/SROA.cpp b/lib/Transforms/Scalar/SROA.cpp index bfcb15530ef5..d01e91a7f235 100644 --- a/lib/Transforms/Scalar/SROA.cpp +++ b/lib/Transforms/Scalar/SROA.cpp @@ -1825,6 +1825,7 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {      // Rank the remaining candidate vector types. This is easy because we know      // they're all integer vectors. We sort by ascending number of elements.      auto RankVectorTypes = [&DL](VectorType *RHSTy, VectorType *LHSTy) { +      (void)DL;        assert(DL.getTypeSizeInBits(RHSTy) == DL.getTypeSizeInBits(LHSTy) &&               "Cannot have vector types of different sizes!");        assert(RHSTy->getElementType()->isIntegerTy() && @@ -2294,7 +2295,8 @@ private:  #endif      return getAdjustedPtr(IRB, DL, &NewAI, -                          APInt(DL.getPointerSizeInBits(), Offset), PointerTy, +                          APInt(DL.getPointerTypeSizeInBits(PointerTy), Offset), +                          PointerTy,  #ifndef NDEBUG                            Twine(OldName) + "."  #else @@ -2369,6 +2371,8 @@ private:      Value *OldOp = LI.getOperand(0);      assert(OldOp == OldPtr); +    unsigned AS = LI.getPointerAddressSpace(); +      Type *TargetTy = IsSplit ? Type::getIntNTy(LI.getContext(), SliceSize * 8)                               : LI.getType();      const bool IsLoadPastEnd = DL.getTypeStoreSize(TargetTy) > SliceSize; @@ -2387,6 +2391,10 @@ private:                                                LI.isVolatile(), LI.getName());        if (LI.isVolatile())          NewLI->setAtomic(LI.getOrdering(), LI.getSynchScope()); + +      // Try to preserve nonnull metadata +      if (TargetTy->isPointerTy()) +        NewLI->copyMetadata(LI, LLVMContext::MD_nonnull);        V = NewLI;        // If this is an integer load past the end of the slice (which means the @@ -2401,7 +2409,7 @@ private:                                  "endian_shift");            }      } else { -      Type *LTy = TargetTy->getPointerTo(); +      Type *LTy = TargetTy->getPointerTo(AS);        LoadInst *NewLI = IRB.CreateAlignedLoad(getNewAllocaSlicePtr(IRB, LTy),                                                getSliceAlign(TargetTy),                                                LI.isVolatile(), LI.getName()); @@ -2429,7 +2437,7 @@ private:        // the computed value, and then replace the placeholder with LI, leaving        // LI only used for this computation.        Value *Placeholder = -          new LoadInst(UndefValue::get(LI.getType()->getPointerTo())); +          new LoadInst(UndefValue::get(LI.getType()->getPointerTo(AS)));        V = insertInteger(DL, IRB, Placeholder, V, NewBeginOffset - BeginOffset,                          "insert");        LI.replaceAllUsesWith(V); @@ -2542,7 +2550,8 @@ private:        NewSI = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment(),                                       SI.isVolatile());      } else { -      Value *NewPtr = getNewAllocaSlicePtr(IRB, V->getType()->getPointerTo()); +      unsigned AS = SI.getPointerAddressSpace(); +      Value *NewPtr = getNewAllocaSlicePtr(IRB, V->getType()->getPointerTo(AS));        NewSI = IRB.CreateAlignedStore(V, NewPtr, getSliceAlign(V->getType()),                                       SI.isVolatile());      } @@ -3857,7 +3866,7 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS,      if (Alignment <= DL.getABITypeAlignment(SliceTy))        Alignment = 0;      NewAI = new AllocaInst( -        SliceTy, nullptr, Alignment, +      SliceTy, AI.getType()->getAddressSpace(), nullptr, Alignment,          AI.getName() + ".sroa." + Twine(P.begin() - AS.begin()), &AI);      ++NumNewAllocas;    } @@ -4184,7 +4193,7 @@ bool SROA::promoteAllocas(Function &F) {    NumPromoted += PromotableAllocas.size();    DEBUG(dbgs() << "Promoting allocas with mem2reg...\n"); -  PromoteMemToReg(PromotableAllocas, *DT, nullptr, AC); +  PromoteMemToReg(PromotableAllocas, *DT, AC);    PromotableAllocas.clear();    return true;  } @@ -4234,9 +4243,8 @@ PreservedAnalyses SROA::runImpl(Function &F, DominatorTree &RunDT,    if (!Changed)      return PreservedAnalyses::all(); -  // FIXME: Even when promoting allocas we should preserve some abstract set of -  // CFG-specific analyses.    PreservedAnalyses PA; +  PA.preserveSet<CFGAnalyses>();    PA.preserve<GlobalsAA>();    return PA;  } diff --git a/lib/Transforms/Scalar/Scalar.cpp b/lib/Transforms/Scalar/Scalar.cpp index afe7483006ae..00e3c95f6f06 100644 --- a/lib/Transforms/Scalar/Scalar.cpp +++ b/lib/Transforms/Scalar/Scalar.cpp @@ -43,13 +43,14 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) {    initializeDSELegacyPassPass(Registry);    initializeGuardWideningLegacyPassPass(Registry);    initializeGVNLegacyPassPass(Registry); -  initializeNewGVNPass(Registry); +  initializeNewGVNLegacyPassPass(Registry);    initializeEarlyCSELegacyPassPass(Registry);    initializeEarlyCSEMemSSALegacyPassPass(Registry);    initializeGVNHoistLegacyPassPass(Registry);    initializeFlattenCFGPassPass(Registry);    initializeInductiveRangeCheckEliminationPass(Registry);    initializeIndVarSimplifyLegacyPassPass(Registry); +  initializeInferAddressSpacesPass(Registry);    initializeJumpThreadingPass(Registry);    initializeLegacyLICMPassPass(Registry);    initializeLegacyLoopSinkPassPass(Registry); @@ -58,6 +59,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) {    initializeLoopAccessLegacyAnalysisPass(Registry);    initializeLoopInstSimplifyLegacyPassPass(Registry);    initializeLoopInterchangePass(Registry); +  initializeLoopPredicationLegacyPassPass(Registry);    initializeLoopRotateLegacyPassPass(Registry);    initializeLoopStrengthReducePass(Registry);    initializeLoopRerollPass(Registry); @@ -79,6 +81,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) {    initializeIPSCCPLegacyPassPass(Registry);    initializeSROALegacyPassPass(Registry);    initializeCFGSimplifyPassPass(Registry); +  initializeLateCFGSimplifyPassPass(Registry);    initializeStructurizeCFGPass(Registry);    initializeSinkingLegacyPassPass(Registry);    initializeTailCallElimPass(Registry); @@ -115,6 +118,10 @@ void LLVMAddCFGSimplificationPass(LLVMPassManagerRef PM) {    unwrap(PM)->add(createCFGSimplificationPass());  } +void LLVMAddLateCFGSimplificationPass(LLVMPassManagerRef PM) { +  unwrap(PM)->add(createLateCFGSimplificationPass()); +} +  void LLVMAddDeadStoreEliminationPass(LLVMPassManagerRef PM) {    unwrap(PM)->add(createDeadStoreEliminationPass());  } diff --git a/lib/Transforms/Scalar/Scalarizer.cpp b/lib/Transforms/Scalar/Scalarizer.cpp index 39969e27367f..c0c09a7e43fe 100644 --- a/lib/Transforms/Scalar/Scalarizer.cpp +++ b/lib/Transforms/Scalar/Scalarizer.cpp @@ -520,12 +520,25 @@ bool Scalarizer::visitGetElementPtrInst(GetElementPtrInst &GEPI) {    unsigned NumElems = VT->getNumElements();    unsigned NumIndices = GEPI.getNumIndices(); -  Scatterer Base = scatter(&GEPI, GEPI.getOperand(0)); +  // The base pointer might be scalar even if it's a vector GEP. In those cases, +  // splat the pointer into a vector value, and scatter that vector. +  Value *Op0 = GEPI.getOperand(0); +  if (!Op0->getType()->isVectorTy()) +    Op0 = Builder.CreateVectorSplat(NumElems, Op0); +  Scatterer Base = scatter(&GEPI, Op0);    SmallVector<Scatterer, 8> Ops;    Ops.resize(NumIndices); -  for (unsigned I = 0; I < NumIndices; ++I) -    Ops[I] = scatter(&GEPI, GEPI.getOperand(I + 1)); +  for (unsigned I = 0; I < NumIndices; ++I) { +    Value *Op = GEPI.getOperand(I + 1); + +    // The indices might be scalars even if it's a vector GEP. In those cases, +    // splat the scalar into a vector value, and scatter that vector. +    if (!Op->getType()->isVectorTy()) +      Op = Builder.CreateVectorSplat(NumElems, Op); + +    Ops[I] = scatter(&GEPI, Op); +  }    ValueVector Res;    Res.resize(NumElems); diff --git a/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/lib/Transforms/Scalar/SimplifyCFGPass.cpp index f2723bd7af82..8754c714c5b2 100644 --- a/lib/Transforms/Scalar/SimplifyCFGPass.cpp +++ b/lib/Transforms/Scalar/SimplifyCFGPass.cpp @@ -130,7 +130,8 @@ static bool mergeEmptyReturnBlocks(Function &F) {  /// iterating until no more changes are made.  static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI,                                     AssumptionCache *AC, -                                   unsigned BonusInstThreshold) { +                                   unsigned BonusInstThreshold, +                                   bool LateSimplifyCFG) {    bool Changed = false;    bool LocalChange = true; @@ -145,7 +146,7 @@ static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI,      // Loop over all of the basic blocks and remove them if they are unneeded.      for (Function::iterator BBIt = F.begin(); BBIt != F.end(); ) { -      if (SimplifyCFG(&*BBIt++, TTI, BonusInstThreshold, AC, &LoopHeaders)) { +      if (SimplifyCFG(&*BBIt++, TTI, BonusInstThreshold, AC, &LoopHeaders, LateSimplifyCFG)) {          LocalChange = true;          ++NumSimpl;        } @@ -156,10 +157,12 @@ static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI,  }  static bool simplifyFunctionCFG(Function &F, const TargetTransformInfo &TTI, -                                AssumptionCache *AC, int BonusInstThreshold) { +                                AssumptionCache *AC, int BonusInstThreshold, +                                bool LateSimplifyCFG) {    bool EverChanged = removeUnreachableBlocks(F);    EverChanged |= mergeEmptyReturnBlocks(F); -  EverChanged |= iterativelySimplifyCFG(F, TTI, AC, BonusInstThreshold); +  EverChanged |= iterativelySimplifyCFG(F, TTI, AC, BonusInstThreshold, +                                        LateSimplifyCFG);    // If neither pass changed anything, we're done.    if (!EverChanged) return false; @@ -173,7 +176,8 @@ static bool simplifyFunctionCFG(Function &F, const TargetTransformInfo &TTI,      return true;    do { -    EverChanged = iterativelySimplifyCFG(F, TTI, AC, BonusInstThreshold); +    EverChanged = iterativelySimplifyCFG(F, TTI, AC, BonusInstThreshold, +                                         LateSimplifyCFG);      EverChanged |= removeUnreachableBlocks(F);    } while (EverChanged); @@ -181,17 +185,19 @@ static bool simplifyFunctionCFG(Function &F, const TargetTransformInfo &TTI,  }  SimplifyCFGPass::SimplifyCFGPass() -    : BonusInstThreshold(UserBonusInstThreshold) {} +    : BonusInstThreshold(UserBonusInstThreshold), +      LateSimplifyCFG(true) {} -SimplifyCFGPass::SimplifyCFGPass(int BonusInstThreshold) -    : BonusInstThreshold(BonusInstThreshold) {} +SimplifyCFGPass::SimplifyCFGPass(int BonusInstThreshold, bool LateSimplifyCFG) +    : BonusInstThreshold(BonusInstThreshold), +      LateSimplifyCFG(LateSimplifyCFG) {}  PreservedAnalyses SimplifyCFGPass::run(Function &F,                                         FunctionAnalysisManager &AM) {    auto &TTI = AM.getResult<TargetIRAnalysis>(F);    auto &AC = AM.getResult<AssumptionAnalysis>(F); -  if (!simplifyFunctionCFG(F, TTI, &AC, BonusInstThreshold)) +  if (!simplifyFunctionCFG(F, TTI, &AC, BonusInstThreshold, LateSimplifyCFG))      return PreservedAnalyses::all();    PreservedAnalyses PA;    PA.preserve<GlobalsAA>(); @@ -199,16 +205,17 @@ PreservedAnalyses SimplifyCFGPass::run(Function &F,  }  namespace { -struct CFGSimplifyPass : public FunctionPass { -  static char ID; // Pass identification, replacement for typeid +struct BaseCFGSimplifyPass : public FunctionPass {    unsigned BonusInstThreshold;    std::function<bool(const Function &)> PredicateFtor; +  bool LateSimplifyCFG; -  CFGSimplifyPass(int T = -1, -                  std::function<bool(const Function &)> Ftor = nullptr) -      : FunctionPass(ID), PredicateFtor(std::move(Ftor)) { +  BaseCFGSimplifyPass(int T, bool LateSimplifyCFG, +                      std::function<bool(const Function &)> Ftor, +                      char &ID) +      : FunctionPass(ID), PredicateFtor(std::move(Ftor)), +        LateSimplifyCFG(LateSimplifyCFG) {      BonusInstThreshold = (T == -1) ? UserBonusInstThreshold : unsigned(T); -    initializeCFGSimplifyPassPass(*PassRegistry::getPassRegistry());    }    bool runOnFunction(Function &F) override {      if (skipFunction(F) || (PredicateFtor && !PredicateFtor(F))) @@ -218,7 +225,7 @@ struct CFGSimplifyPass : public FunctionPass {          &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);      const TargetTransformInfo &TTI =          getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); -    return simplifyFunctionCFG(F, TTI, AC, BonusInstThreshold); +    return simplifyFunctionCFG(F, TTI, AC, BonusInstThreshold, LateSimplifyCFG);    }    void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -227,6 +234,26 @@ struct CFGSimplifyPass : public FunctionPass {      AU.addPreserved<GlobalsAAWrapperPass>();    }  }; + +struct CFGSimplifyPass : public BaseCFGSimplifyPass { +  static char ID; // Pass identification, replacement for typeid + +  CFGSimplifyPass(int T = -1, +                  std::function<bool(const Function &)> Ftor = nullptr) +                  : BaseCFGSimplifyPass(T, false, Ftor, ID) { +    initializeCFGSimplifyPassPass(*PassRegistry::getPassRegistry()); +  } +}; + +struct LateCFGSimplifyPass : public BaseCFGSimplifyPass { +  static char ID; // Pass identification, replacement for typeid + +  LateCFGSimplifyPass(int T = -1, +                      std::function<bool(const Function &)> Ftor = nullptr) +                      : BaseCFGSimplifyPass(T, true, Ftor, ID) { +    initializeLateCFGSimplifyPassPass(*PassRegistry::getPassRegistry()); +  } +};  }  char CFGSimplifyPass::ID = 0; @@ -237,9 +264,24 @@ INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)  INITIALIZE_PASS_END(CFGSimplifyPass, "simplifycfg", "Simplify the CFG", false,                      false) +char LateCFGSimplifyPass::ID = 0; +INITIALIZE_PASS_BEGIN(LateCFGSimplifyPass, "latesimplifycfg", +                      "Simplify the CFG more aggressively", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_END(LateCFGSimplifyPass, "latesimplifycfg", +                    "Simplify the CFG more aggressively", false, false) +  // Public interface to the CFGSimplification pass  FunctionPass *  llvm::createCFGSimplificationPass(int Threshold, -                                  std::function<bool(const Function &)> Ftor) { +    std::function<bool(const Function &)> Ftor) {    return new CFGSimplifyPass(Threshold, std::move(Ftor));  } + +// Public interface to the LateCFGSimplification pass +FunctionPass * +llvm::createLateCFGSimplificationPass(int Threshold,  +                                  std::function<bool(const Function &)> Ftor) { +  return new LateCFGSimplifyPass(Threshold, std::move(Ftor)); +} diff --git a/lib/Transforms/Scalar/Sink.cpp b/lib/Transforms/Scalar/Sink.cpp index c3f14a0f4b1e..102e9eaeab77 100644 --- a/lib/Transforms/Scalar/Sink.cpp +++ b/lib/Transforms/Scalar/Sink.cpp @@ -164,13 +164,14 @@ static bool SinkInstruction(Instruction *Inst,    // Instructions can only be sunk if all their uses are in blocks    // dominated by one of the successors. -  // Look at all the postdominators and see if we can sink it in one. +  // Look at all the dominated blocks and see if we can sink it in one.    DomTreeNode *DTN = DT.getNode(Inst->getParent());    for (DomTreeNode::iterator I = DTN->begin(), E = DTN->end();        I != E && SuccToSinkTo == nullptr; ++I) {      BasicBlock *Candidate = (*I)->getBlock(); -    if ((*I)->getIDom()->getBlock() == Inst->getParent() && -        IsAcceptableTarget(Inst, Candidate, DT, LI)) +    // A node always immediate-dominates its children on the dominator +    // tree. +    if (IsAcceptableTarget(Inst, Candidate, DT, LI))        SuccToSinkTo = Candidate;    } @@ -262,9 +263,8 @@ PreservedAnalyses SinkingPass::run(Function &F, FunctionAnalysisManager &AM) {    if (!iterativelySinkInstructions(F, DT, LI, AA))      return PreservedAnalyses::all(); -  auto PA = PreservedAnalyses(); -  PA.preserve<DominatorTreeAnalysis>(); -  PA.preserve<LoopAnalysis>(); +  PreservedAnalyses PA; +  PA.preserveSet<CFGAnalyses>();    return PA;  } diff --git a/lib/Transforms/Utils/AddDiscriminators.cpp b/lib/Transforms/Utils/AddDiscriminators.cpp index 2e95926c0b3f..4c9746b8c691 100644 --- a/lib/Transforms/Utils/AddDiscriminators.cpp +++ b/lib/Transforms/Utils/AddDiscriminators.cpp @@ -102,6 +102,10 @@ FunctionPass *llvm::createAddDiscriminatorsPass() {    return new AddDiscriminatorsLegacyPass();  } +static bool shouldHaveDiscriminator(const Instruction *I) { +  return !isa<IntrinsicInst>(I) || isa<MemIntrinsic>(I); +} +  /// \brief Assign DWARF discriminators.  ///  /// To assign discriminators, we examine the boundaries of every @@ -176,7 +180,13 @@ static bool addDiscriminators(Function &F) {    // discriminator for this instruction.    for (BasicBlock &B : F) {      for (auto &I : B.getInstList()) { -      if (isa<IntrinsicInst>(&I)) +      // Not all intrinsic calls should have a discriminator. +      // We want to avoid a non-deterministic assignment of discriminators at +      // different debug levels. We still allow discriminators on memory +      // intrinsic calls because those can be early expanded by SROA into +      // pairs of loads and stores, and the expanded load/store instructions +      // should have a valid discriminator. +      if (!shouldHaveDiscriminator(&I))          continue;        const DILocation *DIL = I.getDebugLoc();        if (!DIL) @@ -190,8 +200,8 @@ static bool addDiscriminators(Function &F) {        // discriminator is needed to distinguish both instructions.        // Only the lowest 7 bits are used to represent a discriminator to fit        // it in 1 byte ULEB128 representation. -      unsigned Discriminator = (R.second ? ++LDM[L] : LDM[L]) & 0x7f; -      I.setDebugLoc(DIL->cloneWithDiscriminator(Discriminator)); +      unsigned Discriminator = R.second ? ++LDM[L] : LDM[L]; +      I.setDebugLoc(DIL->setBaseDiscriminator(Discriminator));        DEBUG(dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":"                     << DIL->getColumn() << ":" << Discriminator << " " << I                     << "\n"); @@ -207,6 +217,10 @@ static bool addDiscriminators(Function &F) {      LocationSet CallLocations;      for (auto &I : B.getInstList()) {        CallInst *Current = dyn_cast<CallInst>(&I); +      // We bypass intrinsic calls for the following two reasons: +      //  1) We want to avoid a non-deterministic assigment of +      //     discriminators. +      //  2) We want to minimize the number of base discriminators used.        if (!Current || isa<IntrinsicInst>(&I))          continue; @@ -216,8 +230,8 @@ static bool addDiscriminators(Function &F) {        Location L =            std::make_pair(CurrentDIL->getFilename(), CurrentDIL->getLine());        if (!CallLocations.insert(L).second) { -        Current->setDebugLoc( -            CurrentDIL->cloneWithDiscriminator((++LDM[L]) & 0x7f)); +        unsigned Discriminator = ++LDM[L]; +        Current->setDebugLoc(CurrentDIL->setBaseDiscriminator(Discriminator));          Changed = true;        }      } diff --git a/lib/Transforms/Utils/BasicBlockUtils.cpp b/lib/Transforms/Utils/BasicBlockUtils.cpp index b90349d3cdad..22af21d55c01 100644 --- a/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -438,7 +438,7 @@ BasicBlock *llvm::SplitBlockPredecessors(BasicBlock *BB,    // The new block unconditionally branches to the old block.    BranchInst *BI = BranchInst::Create(BB, NewBB); -  BI->setDebugLoc(BB->getFirstNonPHI()->getDebugLoc()); +  BI->setDebugLoc(BB->getFirstNonPHIOrDbg()->getDebugLoc());    // Move the edges from Preds to point to NewBB instead of BB.    for (unsigned i = 0, e = Preds.size(); i != e; ++i) { @@ -646,9 +646,10 @@ llvm::SplitBlockAndInsertIfThen(Value *Cond, Instruction *SplitBefore,    }    if (LI) { -    Loop *L = LI->getLoopFor(Head); -    L->addBasicBlockToLoop(ThenBlock, *LI); -    L->addBasicBlockToLoop(Tail, *LI); +    if (Loop *L = LI->getLoopFor(Head)) { +      L->addBasicBlockToLoop(ThenBlock, *LI); +      L->addBasicBlockToLoop(Tail, *LI); +    }    }    return CheckTerm; diff --git a/lib/Transforms/Utils/BuildLibCalls.cpp b/lib/Transforms/Utils/BuildLibCalls.cpp index e61b04fbdd57..6cd9f1614991 100644 --- a/lib/Transforms/Utils/BuildLibCalls.cpp +++ b/lib/Transforms/Utils/BuildLibCalls.cpp @@ -96,9 +96,9 @@ static bool setDoesNotAlias(Function &F, unsigned n) {  }  static bool setNonNull(Function &F, unsigned n) { -  assert((n != AttributeSet::ReturnIndex || -          F.getReturnType()->isPointerTy()) && -         "nonnull applies only to pointers"); +  assert( +      (n != AttributeList::ReturnIndex || F.getReturnType()->isPointerTy()) && +      "nonnull applies only to pointers");    if (F.getAttributes().hasAttribute(n, Attribute::NonNull))      return false;    F.addAttribute(n, Attribute::NonNull); @@ -107,255 +107,255 @@ static bool setNonNull(Function &F, unsigned n) {  }  bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { -  LibFunc::Func TheLibFunc; +  LibFunc TheLibFunc;    if (!(TLI.getLibFunc(F, TheLibFunc) && TLI.has(TheLibFunc)))      return false;    bool Changed = false;    switch (TheLibFunc) { -  case LibFunc::strlen: +  case LibFunc_strlen:      Changed |= setOnlyReadsMemory(F);      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      return Changed; -  case LibFunc::strchr: -  case LibFunc::strrchr: +  case LibFunc_strchr: +  case LibFunc_strrchr:      Changed |= setOnlyReadsMemory(F);      Changed |= setDoesNotThrow(F);      return Changed; -  case LibFunc::strtol: -  case LibFunc::strtod: -  case LibFunc::strtof: -  case LibFunc::strtoul: -  case LibFunc::strtoll: -  case LibFunc::strtold: -  case LibFunc::strtoull: +  case LibFunc_strtol: +  case LibFunc_strtod: +  case LibFunc_strtof: +  case LibFunc_strtoul: +  case LibFunc_strtoll: +  case LibFunc_strtold: +  case LibFunc_strtoull:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::strcpy: -  case LibFunc::stpcpy: -  case LibFunc::strcat: -  case LibFunc::strncat: -  case LibFunc::strncpy: -  case LibFunc::stpncpy: +  case LibFunc_strcpy: +  case LibFunc_stpcpy: +  case LibFunc_strcat: +  case LibFunc_strncat: +  case LibFunc_strncpy: +  case LibFunc_stpncpy:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 2);      return Changed; -  case LibFunc::strxfrm: +  case LibFunc_strxfrm:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 2);      return Changed; -  case LibFunc::strcmp:      // 0,1 -  case LibFunc::strspn:      // 0,1 -  case LibFunc::strncmp:     // 0,1 -  case LibFunc::strcspn:     // 0,1 -  case LibFunc::strcoll:     // 0,1 -  case LibFunc::strcasecmp:  // 0,1 -  case LibFunc::strncasecmp: // +  case LibFunc_strcmp:      // 0,1 +  case LibFunc_strspn:      // 0,1 +  case LibFunc_strncmp:     // 0,1 +  case LibFunc_strcspn:     // 0,1 +  case LibFunc_strcoll:     // 0,1 +  case LibFunc_strcasecmp:  // 0,1 +  case LibFunc_strncasecmp: //      Changed |= setOnlyReadsMemory(F);      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 2);      return Changed; -  case LibFunc::strstr: -  case LibFunc::strpbrk: +  case LibFunc_strstr: +  case LibFunc_strpbrk:      Changed |= setOnlyReadsMemory(F);      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 2);      return Changed; -  case LibFunc::strtok: -  case LibFunc::strtok_r: +  case LibFunc_strtok: +  case LibFunc_strtok_r:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 2);      return Changed; -  case LibFunc::scanf: +  case LibFunc_scanf:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::setbuf: -  case LibFunc::setvbuf: +  case LibFunc_setbuf: +  case LibFunc_setvbuf:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      return Changed; -  case LibFunc::strdup: -  case LibFunc::strndup: +  case LibFunc_strdup: +  case LibFunc_strndup:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotAlias(F, 0);      Changed |= setDoesNotCapture(F, 1);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::stat: -  case LibFunc::statvfs: +  case LibFunc_stat: +  case LibFunc_statvfs:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::sscanf: +  case LibFunc_sscanf:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 1);      Changed |= setOnlyReadsMemory(F, 2);      return Changed; -  case LibFunc::sprintf: +  case LibFunc_sprintf:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 2);      return Changed; -  case LibFunc::snprintf: +  case LibFunc_snprintf:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 3);      Changed |= setOnlyReadsMemory(F, 3);      return Changed; -  case LibFunc::setitimer: +  case LibFunc_setitimer:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 2);      Changed |= setDoesNotCapture(F, 3);      Changed |= setOnlyReadsMemory(F, 2);      return Changed; -  case LibFunc::system: +  case LibFunc_system:      // May throw; "system" is a valid pthread cancellation point.      Changed |= setDoesNotCapture(F, 1);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::malloc: +  case LibFunc_malloc:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotAlias(F, 0);      return Changed; -  case LibFunc::memcmp: +  case LibFunc_memcmp:      Changed |= setOnlyReadsMemory(F);      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 2);      return Changed; -  case LibFunc::memchr: -  case LibFunc::memrchr: +  case LibFunc_memchr: +  case LibFunc_memrchr:      Changed |= setOnlyReadsMemory(F);      Changed |= setDoesNotThrow(F);      return Changed; -  case LibFunc::modf: -  case LibFunc::modff: -  case LibFunc::modfl: +  case LibFunc_modf: +  case LibFunc_modff: +  case LibFunc_modfl:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 2);      return Changed; -  case LibFunc::memcpy: -  case LibFunc::mempcpy: -  case LibFunc::memccpy: -  case LibFunc::memmove: +  case LibFunc_memcpy: +  case LibFunc_mempcpy: +  case LibFunc_memccpy: +  case LibFunc_memmove:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 2);      return Changed; -  case LibFunc::memcpy_chk: +  case LibFunc_memcpy_chk:      Changed |= setDoesNotThrow(F);      return Changed; -  case LibFunc::memalign: +  case LibFunc_memalign:      Changed |= setDoesNotAlias(F, 0);      return Changed; -  case LibFunc::mkdir: +  case LibFunc_mkdir:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::mktime: +  case LibFunc_mktime:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      return Changed; -  case LibFunc::realloc: +  case LibFunc_realloc:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotAlias(F, 0);      Changed |= setDoesNotCapture(F, 1);      return Changed; -  case LibFunc::read: +  case LibFunc_read:      // May throw; "read" is a valid pthread cancellation point.      Changed |= setDoesNotCapture(F, 2);      return Changed; -  case LibFunc::rewind: +  case LibFunc_rewind:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      return Changed; -  case LibFunc::rmdir: -  case LibFunc::remove: -  case LibFunc::realpath: +  case LibFunc_rmdir: +  case LibFunc_remove: +  case LibFunc_realpath:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::rename: +  case LibFunc_rename:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 1);      Changed |= setOnlyReadsMemory(F, 2);      return Changed; -  case LibFunc::readlink: +  case LibFunc_readlink:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::write: +  case LibFunc_write:      // May throw; "write" is a valid pthread cancellation point.      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 2);      return Changed; -  case LibFunc::bcopy: +  case LibFunc_bcopy:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::bcmp: +  case LibFunc_bcmp:      Changed |= setDoesNotThrow(F);      Changed |= setOnlyReadsMemory(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 2);      return Changed; -  case LibFunc::bzero: +  case LibFunc_bzero:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      return Changed; -  case LibFunc::calloc: +  case LibFunc_calloc:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotAlias(F, 0);      return Changed; -  case LibFunc::chmod: -  case LibFunc::chown: +  case LibFunc_chmod: +  case LibFunc_chown:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::ctermid: -  case LibFunc::clearerr: -  case LibFunc::closedir: +  case LibFunc_ctermid: +  case LibFunc_clearerr: +  case LibFunc_closedir:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      return Changed; -  case LibFunc::atoi: -  case LibFunc::atol: -  case LibFunc::atof: -  case LibFunc::atoll: +  case LibFunc_atoi: +  case LibFunc_atol: +  case LibFunc_atof: +  case LibFunc_atoll:      Changed |= setDoesNotThrow(F);      Changed |= setOnlyReadsMemory(F);      Changed |= setDoesNotCapture(F, 1);      return Changed; -  case LibFunc::access: +  case LibFunc_access:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::fopen: +  case LibFunc_fopen:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotAlias(F, 0);      Changed |= setDoesNotCapture(F, 1); @@ -363,150 +363,150 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) {      Changed |= setOnlyReadsMemory(F, 1);      Changed |= setOnlyReadsMemory(F, 2);      return Changed; -  case LibFunc::fdopen: +  case LibFunc_fdopen:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotAlias(F, 0);      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 2);      return Changed; -  case LibFunc::feof: -  case LibFunc::free: -  case LibFunc::fseek: -  case LibFunc::ftell: -  case LibFunc::fgetc: -  case LibFunc::fseeko: -  case LibFunc::ftello: -  case LibFunc::fileno: -  case LibFunc::fflush: -  case LibFunc::fclose: -  case LibFunc::fsetpos: -  case LibFunc::flockfile: -  case LibFunc::funlockfile: -  case LibFunc::ftrylockfile: +  case LibFunc_feof: +  case LibFunc_free: +  case LibFunc_fseek: +  case LibFunc_ftell: +  case LibFunc_fgetc: +  case LibFunc_fseeko: +  case LibFunc_ftello: +  case LibFunc_fileno: +  case LibFunc_fflush: +  case LibFunc_fclose: +  case LibFunc_fsetpos: +  case LibFunc_flockfile: +  case LibFunc_funlockfile: +  case LibFunc_ftrylockfile:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      return Changed; -  case LibFunc::ferror: +  case LibFunc_ferror:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setOnlyReadsMemory(F);      return Changed; -  case LibFunc::fputc: -  case LibFunc::fstat: -  case LibFunc::frexp: -  case LibFunc::frexpf: -  case LibFunc::frexpl: -  case LibFunc::fstatvfs: +  case LibFunc_fputc: +  case LibFunc_fstat: +  case LibFunc_frexp: +  case LibFunc_frexpf: +  case LibFunc_frexpl: +  case LibFunc_fstatvfs:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 2);      return Changed; -  case LibFunc::fgets: +  case LibFunc_fgets:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 3);      return Changed; -  case LibFunc::fread: +  case LibFunc_fread:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 4);      return Changed; -  case LibFunc::fwrite: +  case LibFunc_fwrite:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 4);      // FIXME: readonly #1?      return Changed; -  case LibFunc::fputs: +  case LibFunc_fputs:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::fscanf: -  case LibFunc::fprintf: +  case LibFunc_fscanf: +  case LibFunc_fprintf:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 2);      return Changed; -  case LibFunc::fgetpos: +  case LibFunc_fgetpos:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 2);      return Changed; -  case LibFunc::getc: -  case LibFunc::getlogin_r: -  case LibFunc::getc_unlocked: +  case LibFunc_getc: +  case LibFunc_getlogin_r: +  case LibFunc_getc_unlocked:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      return Changed; -  case LibFunc::getenv: +  case LibFunc_getenv:      Changed |= setDoesNotThrow(F);      Changed |= setOnlyReadsMemory(F);      Changed |= setDoesNotCapture(F, 1);      return Changed; -  case LibFunc::gets: -  case LibFunc::getchar: +  case LibFunc_gets: +  case LibFunc_getchar:      Changed |= setDoesNotThrow(F);      return Changed; -  case LibFunc::getitimer: +  case LibFunc_getitimer:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 2);      return Changed; -  case LibFunc::getpwnam: +  case LibFunc_getpwnam:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::ungetc: +  case LibFunc_ungetc:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 2);      return Changed; -  case LibFunc::uname: +  case LibFunc_uname:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      return Changed; -  case LibFunc::unlink: +  case LibFunc_unlink:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::unsetenv: +  case LibFunc_unsetenv:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::utime: -  case LibFunc::utimes: +  case LibFunc_utime: +  case LibFunc_utimes:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 1);      Changed |= setOnlyReadsMemory(F, 2);      return Changed; -  case LibFunc::putc: +  case LibFunc_putc:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 2);      return Changed; -  case LibFunc::puts: -  case LibFunc::printf: -  case LibFunc::perror: +  case LibFunc_puts: +  case LibFunc_printf: +  case LibFunc_perror:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::pread: +  case LibFunc_pread:      // May throw; "pread" is a valid pthread cancellation point.      Changed |= setDoesNotCapture(F, 2);      return Changed; -  case LibFunc::pwrite: +  case LibFunc_pwrite:      // May throw; "pwrite" is a valid pthread cancellation point.      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 2);      return Changed; -  case LibFunc::putchar: +  case LibFunc_putchar:      Changed |= setDoesNotThrow(F);      return Changed; -  case LibFunc::popen: +  case LibFunc_popen:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotAlias(F, 0);      Changed |= setDoesNotCapture(F, 1); @@ -514,132 +514,132 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) {      Changed |= setOnlyReadsMemory(F, 1);      Changed |= setOnlyReadsMemory(F, 2);      return Changed; -  case LibFunc::pclose: +  case LibFunc_pclose:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      return Changed; -  case LibFunc::vscanf: +  case LibFunc_vscanf:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::vsscanf: +  case LibFunc_vsscanf:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 1);      Changed |= setOnlyReadsMemory(F, 2);      return Changed; -  case LibFunc::vfscanf: +  case LibFunc_vfscanf:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 2);      return Changed; -  case LibFunc::valloc: +  case LibFunc_valloc:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotAlias(F, 0);      return Changed; -  case LibFunc::vprintf: +  case LibFunc_vprintf:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::vfprintf: -  case LibFunc::vsprintf: +  case LibFunc_vfprintf: +  case LibFunc_vsprintf:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 2);      return Changed; -  case LibFunc::vsnprintf: +  case LibFunc_vsnprintf:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 3);      Changed |= setOnlyReadsMemory(F, 3);      return Changed; -  case LibFunc::open: +  case LibFunc_open:      // May throw; "open" is a valid pthread cancellation point.      Changed |= setDoesNotCapture(F, 1);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::opendir: +  case LibFunc_opendir:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotAlias(F, 0);      Changed |= setDoesNotCapture(F, 1);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::tmpfile: +  case LibFunc_tmpfile:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotAlias(F, 0);      return Changed; -  case LibFunc::times: +  case LibFunc_times:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      return Changed; -  case LibFunc::htonl: -  case LibFunc::htons: -  case LibFunc::ntohl: -  case LibFunc::ntohs: +  case LibFunc_htonl: +  case LibFunc_htons: +  case LibFunc_ntohl: +  case LibFunc_ntohs:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotAccessMemory(F);      return Changed; -  case LibFunc::lstat: +  case LibFunc_lstat:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::lchown: +  case LibFunc_lchown:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::qsort: +  case LibFunc_qsort:      // May throw; places call through function pointer.      Changed |= setDoesNotCapture(F, 4);      return Changed; -  case LibFunc::dunder_strdup: -  case LibFunc::dunder_strndup: +  case LibFunc_dunder_strdup: +  case LibFunc_dunder_strndup:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotAlias(F, 0);      Changed |= setDoesNotCapture(F, 1);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::dunder_strtok_r: +  case LibFunc_dunder_strtok_r:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 2);      return Changed; -  case LibFunc::under_IO_getc: +  case LibFunc_under_IO_getc:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      return Changed; -  case LibFunc::under_IO_putc: +  case LibFunc_under_IO_putc:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 2);      return Changed; -  case LibFunc::dunder_isoc99_scanf: +  case LibFunc_dunder_isoc99_scanf:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::stat64: -  case LibFunc::lstat64: -  case LibFunc::statvfs64: +  case LibFunc_stat64: +  case LibFunc_lstat64: +  case LibFunc_statvfs64:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::dunder_isoc99_sscanf: +  case LibFunc_dunder_isoc99_sscanf:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 1);      Changed |= setOnlyReadsMemory(F, 2);      return Changed; -  case LibFunc::fopen64: +  case LibFunc_fopen64:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotAlias(F, 0);      Changed |= setDoesNotCapture(F, 1); @@ -647,26 +647,26 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) {      Changed |= setOnlyReadsMemory(F, 1);      Changed |= setOnlyReadsMemory(F, 2);      return Changed; -  case LibFunc::fseeko64: -  case LibFunc::ftello64: +  case LibFunc_fseeko64: +  case LibFunc_ftello64:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 1);      return Changed; -  case LibFunc::tmpfile64: +  case LibFunc_tmpfile64:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotAlias(F, 0);      return Changed; -  case LibFunc::fstat64: -  case LibFunc::fstatvfs64: +  case LibFunc_fstat64: +  case LibFunc_fstatvfs64:      Changed |= setDoesNotThrow(F);      Changed |= setDoesNotCapture(F, 2);      return Changed; -  case LibFunc::open64: +  case LibFunc_open64:      // May throw; "open" is a valid pthread cancellation point.      Changed |= setDoesNotCapture(F, 1);      Changed |= setOnlyReadsMemory(F, 1);      return Changed; -  case LibFunc::gettimeofday: +  case LibFunc_gettimeofday:      // Currently some platforms have the restrict keyword on the arguments to      // gettimeofday. To be conservative, do not add noalias to gettimeofday's      // arguments. @@ -674,29 +674,29 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) {      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 2);      return Changed; -  case LibFunc::Znwj: // new(unsigned int) -  case LibFunc::Znwm: // new(unsigned long) -  case LibFunc::Znaj: // new[](unsigned int) -  case LibFunc::Znam: // new[](unsigned long) -  case LibFunc::msvc_new_int: // new(unsigned int) -  case LibFunc::msvc_new_longlong: // new(unsigned long long) -  case LibFunc::msvc_new_array_int: // new[](unsigned int) -  case LibFunc::msvc_new_array_longlong: // new[](unsigned long long) +  case LibFunc_Znwj: // new(unsigned int) +  case LibFunc_Znwm: // new(unsigned long) +  case LibFunc_Znaj: // new[](unsigned int) +  case LibFunc_Znam: // new[](unsigned long) +  case LibFunc_msvc_new_int: // new(unsigned int) +  case LibFunc_msvc_new_longlong: // new(unsigned long long) +  case LibFunc_msvc_new_array_int: // new[](unsigned int) +  case LibFunc_msvc_new_array_longlong: // new[](unsigned long long)      // Operator new always returns a nonnull noalias pointer -    Changed |= setNonNull(F, AttributeSet::ReturnIndex); -    Changed |= setDoesNotAlias(F, AttributeSet::ReturnIndex); +    Changed |= setNonNull(F, AttributeList::ReturnIndex); +    Changed |= setDoesNotAlias(F, AttributeList::ReturnIndex);      return Changed;    //TODO: add LibFunc entries for: -  //case LibFunc::memset_pattern4: -  //case LibFunc::memset_pattern8: -  case LibFunc::memset_pattern16: +  //case LibFunc_memset_pattern4: +  //case LibFunc_memset_pattern8: +  case LibFunc_memset_pattern16:      Changed |= setOnlyAccessesArgMemory(F);      Changed |= setDoesNotCapture(F, 1);      Changed |= setDoesNotCapture(F, 2);      Changed |= setOnlyReadsMemory(F, 2);      return Changed;    // int __nvvm_reflect(const char *) -  case LibFunc::nvvm_reflect: +  case LibFunc_nvvm_reflect:      Changed |= setDoesNotAccessMemory(F);      Changed |= setDoesNotThrow(F);      return Changed; @@ -717,13 +717,13 @@ Value *llvm::castToCStr(Value *V, IRBuilder<> &B) {  Value *llvm::emitStrLen(Value *Ptr, IRBuilder<> &B, const DataLayout &DL,                          const TargetLibraryInfo *TLI) { -  if (!TLI->has(LibFunc::strlen)) +  if (!TLI->has(LibFunc_strlen))      return nullptr;    Module *M = B.GetInsertBlock()->getModule();    LLVMContext &Context = B.GetInsertBlock()->getContext();    Constant *StrLen = M->getOrInsertFunction("strlen", DL.getIntPtrType(Context), -                                            B.getInt8PtrTy(), nullptr); +                                            B.getInt8PtrTy());    inferLibFuncAttributes(*M->getFunction("strlen"), *TLI);    CallInst *CI = B.CreateCall(StrLen, castToCStr(Ptr, B), "strlen");    if (const Function *F = dyn_cast<Function>(StrLen->stripPointerCasts())) @@ -734,14 +734,14 @@ Value *llvm::emitStrLen(Value *Ptr, IRBuilder<> &B, const DataLayout &DL,  Value *llvm::emitStrChr(Value *Ptr, char C, IRBuilder<> &B,                          const TargetLibraryInfo *TLI) { -  if (!TLI->has(LibFunc::strchr)) +  if (!TLI->has(LibFunc_strchr))      return nullptr;    Module *M = B.GetInsertBlock()->getModule();    Type *I8Ptr = B.getInt8PtrTy();    Type *I32Ty = B.getInt32Ty();    Constant *StrChr = -      M->getOrInsertFunction("strchr", I8Ptr, I8Ptr, I32Ty, nullptr); +      M->getOrInsertFunction("strchr", I8Ptr, I8Ptr, I32Ty);    inferLibFuncAttributes(*M->getFunction("strchr"), *TLI);    CallInst *CI = B.CreateCall(        StrChr, {castToCStr(Ptr, B), ConstantInt::get(I32Ty, C)}, "strchr"); @@ -752,14 +752,14 @@ Value *llvm::emitStrChr(Value *Ptr, char C, IRBuilder<> &B,  Value *llvm::emitStrNCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilder<> &B,                           const DataLayout &DL, const TargetLibraryInfo *TLI) { -  if (!TLI->has(LibFunc::strncmp)) +  if (!TLI->has(LibFunc_strncmp))      return nullptr;    Module *M = B.GetInsertBlock()->getModule();    LLVMContext &Context = B.GetInsertBlock()->getContext();    Value *StrNCmp = M->getOrInsertFunction("strncmp", B.getInt32Ty(),                                            B.getInt8PtrTy(), B.getInt8PtrTy(), -                                          DL.getIntPtrType(Context), nullptr); +                                          DL.getIntPtrType(Context));    inferLibFuncAttributes(*M->getFunction("strncmp"), *TLI);    CallInst *CI = B.CreateCall(        StrNCmp, {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, "strncmp"); @@ -772,12 +772,12 @@ Value *llvm::emitStrNCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilder<> &B,  Value *llvm::emitStrCpy(Value *Dst, Value *Src, IRBuilder<> &B,                          const TargetLibraryInfo *TLI, StringRef Name) { -  if (!TLI->has(LibFunc::strcpy)) +  if (!TLI->has(LibFunc_strcpy))      return nullptr;    Module *M = B.GetInsertBlock()->getModule();    Type *I8Ptr = B.getInt8PtrTy(); -  Value *StrCpy = M->getOrInsertFunction(Name, I8Ptr, I8Ptr, I8Ptr, nullptr); +  Value *StrCpy = M->getOrInsertFunction(Name, I8Ptr, I8Ptr, I8Ptr);    inferLibFuncAttributes(*M->getFunction(Name), *TLI);    CallInst *CI =        B.CreateCall(StrCpy, {castToCStr(Dst, B), castToCStr(Src, B)}, Name); @@ -788,13 +788,13 @@ Value *llvm::emitStrCpy(Value *Dst, Value *Src, IRBuilder<> &B,  Value *llvm::emitStrNCpy(Value *Dst, Value *Src, Value *Len, IRBuilder<> &B,                           const TargetLibraryInfo *TLI, StringRef Name) { -  if (!TLI->has(LibFunc::strncpy)) +  if (!TLI->has(LibFunc_strncpy))      return nullptr;    Module *M = B.GetInsertBlock()->getModule();    Type *I8Ptr = B.getInt8PtrTy();    Value *StrNCpy = M->getOrInsertFunction(Name, I8Ptr, I8Ptr, I8Ptr, -                                          Len->getType(), nullptr); +                                          Len->getType());    inferLibFuncAttributes(*M->getFunction(Name), *TLI);    CallInst *CI = B.CreateCall(        StrNCpy, {castToCStr(Dst, B), castToCStr(Src, B), Len}, "strncpy"); @@ -806,18 +806,18 @@ Value *llvm::emitStrNCpy(Value *Dst, Value *Src, Value *Len, IRBuilder<> &B,  Value *llvm::emitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize,                             IRBuilder<> &B, const DataLayout &DL,                             const TargetLibraryInfo *TLI) { -  if (!TLI->has(LibFunc::memcpy_chk)) +  if (!TLI->has(LibFunc_memcpy_chk))      return nullptr;    Module *M = B.GetInsertBlock()->getModule(); -  AttributeSet AS; -  AS = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, -                         Attribute::NoUnwind); +  AttributeList AS; +  AS = AttributeList::get(M->getContext(), AttributeList::FunctionIndex, +                          Attribute::NoUnwind);    LLVMContext &Context = B.GetInsertBlock()->getContext();    Value *MemCpy = M->getOrInsertFunction( -      "__memcpy_chk", AttributeSet::get(M->getContext(), AS), B.getInt8PtrTy(), +      "__memcpy_chk", AttributeList::get(M->getContext(), AS), B.getInt8PtrTy(),        B.getInt8PtrTy(), B.getInt8PtrTy(), DL.getIntPtrType(Context), -      DL.getIntPtrType(Context), nullptr); +      DL.getIntPtrType(Context));    Dst = castToCStr(Dst, B);    Src = castToCStr(Src, B);    CallInst *CI = B.CreateCall(MemCpy, {Dst, Src, Len, ObjSize}); @@ -828,14 +828,14 @@ Value *llvm::emitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize,  Value *llvm::emitMemChr(Value *Ptr, Value *Val, Value *Len, IRBuilder<> &B,                          const DataLayout &DL, const TargetLibraryInfo *TLI) { -  if (!TLI->has(LibFunc::memchr)) +  if (!TLI->has(LibFunc_memchr))      return nullptr;    Module *M = B.GetInsertBlock()->getModule();    LLVMContext &Context = B.GetInsertBlock()->getContext();    Value *MemChr = M->getOrInsertFunction("memchr", B.getInt8PtrTy(),                                           B.getInt8PtrTy(), B.getInt32Ty(), -                                         DL.getIntPtrType(Context), nullptr); +                                         DL.getIntPtrType(Context));    inferLibFuncAttributes(*M->getFunction("memchr"), *TLI);    CallInst *CI = B.CreateCall(MemChr, {castToCStr(Ptr, B), Val, Len}, "memchr"); @@ -847,14 +847,14 @@ Value *llvm::emitMemChr(Value *Ptr, Value *Val, Value *Len, IRBuilder<> &B,  Value *llvm::emitMemCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilder<> &B,                          const DataLayout &DL, const TargetLibraryInfo *TLI) { -  if (!TLI->has(LibFunc::memcmp)) +  if (!TLI->has(LibFunc_memcmp))      return nullptr;    Module *M = B.GetInsertBlock()->getModule();    LLVMContext &Context = B.GetInsertBlock()->getContext();    Value *MemCmp = M->getOrInsertFunction("memcmp", B.getInt32Ty(),                                           B.getInt8PtrTy(), B.getInt8PtrTy(), -                                         DL.getIntPtrType(Context), nullptr); +                                         DL.getIntPtrType(Context));    inferLibFuncAttributes(*M->getFunction("memcmp"), *TLI);    CallInst *CI = B.CreateCall(        MemCmp, {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, "memcmp"); @@ -881,13 +881,13 @@ static void appendTypeSuffix(Value *Op, StringRef &Name,  }  Value *llvm::emitUnaryFloatFnCall(Value *Op, StringRef Name, IRBuilder<> &B, -                                  const AttributeSet &Attrs) { +                                  const AttributeList &Attrs) {    SmallString<20> NameBuffer;    appendTypeSuffix(Op, Name, NameBuffer);    Module *M = B.GetInsertBlock()->getModule();    Value *Callee = M->getOrInsertFunction(Name, Op->getType(), -                                         Op->getType(), nullptr); +                                         Op->getType());    CallInst *CI = B.CreateCall(Callee, Op, Name);    CI->setAttributes(Attrs);    if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts())) @@ -897,13 +897,13 @@ Value *llvm::emitUnaryFloatFnCall(Value *Op, StringRef Name, IRBuilder<> &B,  }  Value *llvm::emitBinaryFloatFnCall(Value *Op1, Value *Op2, StringRef Name, -                                  IRBuilder<> &B, const AttributeSet &Attrs) { +                                   IRBuilder<> &B, const AttributeList &Attrs) {    SmallString<20> NameBuffer;    appendTypeSuffix(Op1, Name, NameBuffer);    Module *M = B.GetInsertBlock()->getModule();    Value *Callee = M->getOrInsertFunction(Name, Op1->getType(), Op1->getType(), -                                         Op2->getType(), nullptr); +                                         Op2->getType());    CallInst *CI = B.CreateCall(Callee, {Op1, Op2}, Name);    CI->setAttributes(Attrs);    if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts())) @@ -914,12 +914,12 @@ Value *llvm::emitBinaryFloatFnCall(Value *Op1, Value *Op2, StringRef Name,  Value *llvm::emitPutChar(Value *Char, IRBuilder<> &B,                           const TargetLibraryInfo *TLI) { -  if (!TLI->has(LibFunc::putchar)) +  if (!TLI->has(LibFunc_putchar))      return nullptr;    Module *M = B.GetInsertBlock()->getModule(); -  Value *PutChar = M->getOrInsertFunction("putchar", B.getInt32Ty(), -                                          B.getInt32Ty(), nullptr); +  Value *PutChar = M->getOrInsertFunction("putchar", B.getInt32Ty(), B.getInt32Ty()); +  inferLibFuncAttributes(*M->getFunction("putchar"), *TLI);    CallInst *CI = B.CreateCall(PutChar,                                B.CreateIntCast(Char,                                B.getInt32Ty(), @@ -934,12 +934,12 @@ Value *llvm::emitPutChar(Value *Char, IRBuilder<> &B,  Value *llvm::emitPutS(Value *Str, IRBuilder<> &B,                        const TargetLibraryInfo *TLI) { -  if (!TLI->has(LibFunc::puts)) +  if (!TLI->has(LibFunc_puts))      return nullptr;    Module *M = B.GetInsertBlock()->getModule();    Value *PutS = -      M->getOrInsertFunction("puts", B.getInt32Ty(), B.getInt8PtrTy(), nullptr); +      M->getOrInsertFunction("puts", B.getInt32Ty(), B.getInt8PtrTy());    inferLibFuncAttributes(*M->getFunction("puts"), *TLI);    CallInst *CI = B.CreateCall(PutS, castToCStr(Str, B), "puts");    if (const Function *F = dyn_cast<Function>(PutS->stripPointerCasts())) @@ -949,12 +949,12 @@ Value *llvm::emitPutS(Value *Str, IRBuilder<> &B,  Value *llvm::emitFPutC(Value *Char, Value *File, IRBuilder<> &B,                         const TargetLibraryInfo *TLI) { -  if (!TLI->has(LibFunc::fputc)) +  if (!TLI->has(LibFunc_fputc))      return nullptr;    Module *M = B.GetInsertBlock()->getModule();    Constant *F = M->getOrInsertFunction("fputc", B.getInt32Ty(), B.getInt32Ty(), -                                       File->getType(), nullptr); +                                       File->getType());    if (File->getType()->isPointerTy())      inferLibFuncAttributes(*M->getFunction("fputc"), *TLI);    Char = B.CreateIntCast(Char, B.getInt32Ty(), /*isSigned*/true, @@ -968,13 +968,13 @@ Value *llvm::emitFPutC(Value *Char, Value *File, IRBuilder<> &B,  Value *llvm::emitFPutS(Value *Str, Value *File, IRBuilder<> &B,                         const TargetLibraryInfo *TLI) { -  if (!TLI->has(LibFunc::fputs)) +  if (!TLI->has(LibFunc_fputs))      return nullptr;    Module *M = B.GetInsertBlock()->getModule(); -  StringRef FPutsName = TLI->getName(LibFunc::fputs); +  StringRef FPutsName = TLI->getName(LibFunc_fputs);    Constant *F = M->getOrInsertFunction( -      FPutsName, B.getInt32Ty(), B.getInt8PtrTy(), File->getType(), nullptr); +      FPutsName, B.getInt32Ty(), B.getInt8PtrTy(), File->getType());    if (File->getType()->isPointerTy())      inferLibFuncAttributes(*M->getFunction(FPutsName), *TLI);    CallInst *CI = B.CreateCall(F, {castToCStr(Str, B), File}, "fputs"); @@ -986,16 +986,16 @@ Value *llvm::emitFPutS(Value *Str, Value *File, IRBuilder<> &B,  Value *llvm::emitFWrite(Value *Ptr, Value *Size, Value *File, IRBuilder<> &B,                          const DataLayout &DL, const TargetLibraryInfo *TLI) { -  if (!TLI->has(LibFunc::fwrite)) +  if (!TLI->has(LibFunc_fwrite))      return nullptr;    Module *M = B.GetInsertBlock()->getModule();    LLVMContext &Context = B.GetInsertBlock()->getContext(); -  StringRef FWriteName = TLI->getName(LibFunc::fwrite); +  StringRef FWriteName = TLI->getName(LibFunc_fwrite);    Constant *F = M->getOrInsertFunction(        FWriteName, DL.getIntPtrType(Context), B.getInt8PtrTy(), -      DL.getIntPtrType(Context), DL.getIntPtrType(Context), File->getType(), -      nullptr); +      DL.getIntPtrType(Context), DL.getIntPtrType(Context), File->getType()); +    if (File->getType()->isPointerTy())      inferLibFuncAttributes(*M->getFunction(FWriteName), *TLI);    CallInst *CI = diff --git a/lib/Transforms/Utils/BypassSlowDivision.cpp b/lib/Transforms/Utils/BypassSlowDivision.cpp index bc2cef26edcb..1cfe3bd53648 100644 --- a/lib/Transforms/Utils/BypassSlowDivision.cpp +++ b/lib/Transforms/Utils/BypassSlowDivision.cpp @@ -17,6 +17,8 @@  #include "llvm/Transforms/Utils/BypassSlowDivision.h"  #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Analysis/ValueTracking.h"  #include "llvm/IR/Function.h"  #include "llvm/IR/IRBuilder.h"  #include "llvm/IR/Instructions.h" @@ -36,12 +38,21 @@ namespace {        : SignedOp(InSignedOp), Dividend(InDividend), Divisor(InDivisor) {}    }; -  struct DivPhiNodes { -    PHINode *Quotient; -    PHINode *Remainder; +  struct QuotRemPair { +    Value *Quotient; +    Value *Remainder; -    DivPhiNodes(PHINode *InQuotient, PHINode *InRemainder) -      : Quotient(InQuotient), Remainder(InRemainder) {} +    QuotRemPair(Value *InQuotient, Value *InRemainder) +        : Quotient(InQuotient), Remainder(InRemainder) {} +  }; + +  /// A quotient and remainder, plus a BB from which they logically "originate". +  /// If you use Quotient or Remainder in a Phi node, you should use BB as its +  /// corresponding predecessor. +  struct QuotRemWithBB { +    BasicBlock *BB = nullptr; +    Value *Quotient = nullptr; +    Value *Remainder = nullptr;    };  } @@ -69,159 +80,376 @@ namespace llvm {      }    }; -  typedef DenseMap<DivOpInfo, DivPhiNodes> DivCacheTy; +  typedef DenseMap<DivOpInfo, QuotRemPair> DivCacheTy; +  typedef DenseMap<unsigned, unsigned> BypassWidthsTy; +  typedef SmallPtrSet<Instruction *, 4> VisitedSetTy;  } -// insertFastDiv - Substitutes the div/rem instruction with code that checks the -// value of the operands and uses a shorter-faster div/rem instruction when -// possible and the longer-slower div/rem instruction otherwise. -static bool insertFastDiv(Instruction *I, IntegerType *BypassType, -                          bool UseDivOp, bool UseSignedOp, -                          DivCacheTy &PerBBDivCache) { -  Function *F = I->getParent()->getParent(); -  // Get instruction operands -  Value *Dividend = I->getOperand(0); -  Value *Divisor = I->getOperand(1); +namespace { +enum ValueRange { +  /// Operand definitely fits into BypassType. No runtime checks are needed. +  VALRNG_KNOWN_SHORT, +  /// A runtime check is required, as value range is unknown. +  VALRNG_UNKNOWN, +  /// Operand is unlikely to fit into BypassType. The bypassing should be +  /// disabled. +  VALRNG_LIKELY_LONG +}; + +class FastDivInsertionTask { +  bool IsValidTask = false; +  Instruction *SlowDivOrRem = nullptr; +  IntegerType *BypassType = nullptr; +  BasicBlock *MainBB = nullptr; + +  bool isHashLikeValue(Value *V, VisitedSetTy &Visited); +  ValueRange getValueRange(Value *Op, VisitedSetTy &Visited); +  QuotRemWithBB createSlowBB(BasicBlock *Successor); +  QuotRemWithBB createFastBB(BasicBlock *Successor); +  QuotRemPair createDivRemPhiNodes(QuotRemWithBB &LHS, QuotRemWithBB &RHS, +                                   BasicBlock *PhiBB); +  Value *insertOperandRuntimeCheck(Value *Op1, Value *Op2); +  Optional<QuotRemPair> insertFastDivAndRem(); + +  bool isSignedOp() { +    return SlowDivOrRem->getOpcode() == Instruction::SDiv || +           SlowDivOrRem->getOpcode() == Instruction::SRem; +  } +  bool isDivisionOp() { +    return SlowDivOrRem->getOpcode() == Instruction::SDiv || +           SlowDivOrRem->getOpcode() == Instruction::UDiv; +  } +  Type *getSlowType() { return SlowDivOrRem->getType(); } + +public: +  FastDivInsertionTask(Instruction *I, const BypassWidthsTy &BypassWidths); +  Value *getReplacement(DivCacheTy &Cache); +}; +} // anonymous namespace + +FastDivInsertionTask::FastDivInsertionTask(Instruction *I, +                                           const BypassWidthsTy &BypassWidths) { +  switch (I->getOpcode()) { +  case Instruction::UDiv: +  case Instruction::SDiv: +  case Instruction::URem: +  case Instruction::SRem: +    SlowDivOrRem = I; +    break; +  default: +    // I is not a div/rem operation. +    return; +  } -  if (isa<ConstantInt>(Divisor)) { -    // Division by a constant should have been been solved and replaced earlier -    // in the pipeline. -    return false; +  // Skip division on vector types. Only optimize integer instructions. +  IntegerType *SlowType = dyn_cast<IntegerType>(SlowDivOrRem->getType()); +  if (!SlowType) +    return; + +  // Skip if this bitwidth is not bypassed. +  auto BI = BypassWidths.find(SlowType->getBitWidth()); +  if (BI == BypassWidths.end()) +    return; + +  // Get type for div/rem instruction with bypass bitwidth. +  IntegerType *BT = IntegerType::get(I->getContext(), BI->second); +  BypassType = BT; + +  // The original basic block. +  MainBB = I->getParent(); + +  // The instruction is indeed a slow div or rem operation. +  IsValidTask = true; +} + +/// Reuses previously-computed dividend or remainder from the current BB if +/// operands and operation are identical. Otherwise calls insertFastDivAndRem to +/// perform the optimization and caches the resulting dividend and remainder. +/// If no replacement can be generated, nullptr is returned. +Value *FastDivInsertionTask::getReplacement(DivCacheTy &Cache) { +  // First, make sure that the task is valid. +  if (!IsValidTask) +    return nullptr; + +  // Then, look for a value in Cache. +  Value *Dividend = SlowDivOrRem->getOperand(0); +  Value *Divisor = SlowDivOrRem->getOperand(1); +  DivOpInfo Key(isSignedOp(), Dividend, Divisor); +  auto CacheI = Cache.find(Key); + +  if (CacheI == Cache.end()) { +    // If previous instance does not exist, try to insert fast div. +    Optional<QuotRemPair> OptResult = insertFastDivAndRem(); +    // Bail out if insertFastDivAndRem has failed. +    if (!OptResult) +      return nullptr; +    CacheI = Cache.insert({Key, *OptResult}).first;    } -  // If the numerator is a constant, bail if it doesn't fit into BypassType. -  if (ConstantInt *ConstDividend = dyn_cast<ConstantInt>(Dividend)) -    if (ConstDividend->getValue().getActiveBits() > BypassType->getBitWidth()) +  QuotRemPair &Value = CacheI->second; +  return isDivisionOp() ? Value.Quotient : Value.Remainder; +} + +/// \brief Check if a value looks like a hash. +/// +/// The routine is expected to detect values computed using the most common hash +/// algorithms. Typically, hash computations end with one of the following +/// instructions: +/// +/// 1) MUL with a constant wider than BypassType +/// 2) XOR instruction +/// +/// And even if we are wrong and the value is not a hash, it is still quite +/// unlikely that such values will fit into BypassType. +/// +/// To detect string hash algorithms like FNV we have to look through PHI-nodes. +/// It is implemented as a depth-first search for values that look neither long +/// nor hash-like. +bool FastDivInsertionTask::isHashLikeValue(Value *V, VisitedSetTy &Visited) { +  Instruction *I = dyn_cast<Instruction>(V); +  if (!I) +    return false; + +  switch (I->getOpcode()) { +  case Instruction::Xor: +    return true; +  case Instruction::Mul: { +    // After Constant Hoisting pass, long constants may be represented as +    // bitcast instructions. As a result, some constants may look like an +    // instruction at first, and an additional check is necessary to find out if +    // an operand is actually a constant. +    Value *Op1 = I->getOperand(1); +    ConstantInt *C = dyn_cast<ConstantInt>(Op1); +    if (!C && isa<BitCastInst>(Op1)) +      C = dyn_cast<ConstantInt>(cast<BitCastInst>(Op1)->getOperand(0)); +    return C && C->getValue().getMinSignedBits() > BypassType->getBitWidth(); +  } +  case Instruction::PHI: { +    // Stop IR traversal in case of a crazy input code. This limits recursion +    // depth. +    if (Visited.size() >= 16)        return false; +    // Do not visit nodes that have been visited already. We return true because +    // it means that we couldn't find any value that doesn't look hash-like. +    if (Visited.find(I) != Visited.end()) +      return true; +    Visited.insert(I); +    return llvm::all_of(cast<PHINode>(I)->incoming_values(), [&](Value *V) { +      // Ignore undef values as they probably don't affect the division +      // operands. +      return getValueRange(V, Visited) == VALRNG_LIKELY_LONG || +             isa<UndefValue>(V); +    }); +  } +  default: +    return false; +  } +} + +/// Check if an integer value fits into our bypass type. +ValueRange FastDivInsertionTask::getValueRange(Value *V, +                                               VisitedSetTy &Visited) { +  unsigned ShortLen = BypassType->getBitWidth(); +  unsigned LongLen = V->getType()->getIntegerBitWidth(); + +  assert(LongLen > ShortLen && "Value type must be wider than BypassType"); +  unsigned HiBits = LongLen - ShortLen; + +  const DataLayout &DL = SlowDivOrRem->getModule()->getDataLayout(); +  APInt Zeros(LongLen, 0), Ones(LongLen, 0); -  // Basic Block is split before divide -  BasicBlock *MainBB = &*I->getParent(); -  BasicBlock *SuccessorBB = MainBB->splitBasicBlock(I); - -  // Add new basic block for slow divide operation -  BasicBlock *SlowBB = -      BasicBlock::Create(F->getContext(), "", MainBB->getParent(), SuccessorBB); -  SlowBB->moveBefore(SuccessorBB); -  IRBuilder<> SlowBuilder(SlowBB, SlowBB->begin()); -  Value *SlowQuotientV; -  Value *SlowRemainderV; -  if (UseSignedOp) { -    SlowQuotientV = SlowBuilder.CreateSDiv(Dividend, Divisor); -    SlowRemainderV = SlowBuilder.CreateSRem(Dividend, Divisor); +  computeKnownBits(V, Zeros, Ones, DL); + +  if (Zeros.countLeadingOnes() >= HiBits) +    return VALRNG_KNOWN_SHORT; + +  if (Ones.countLeadingZeros() < HiBits) +    return VALRNG_LIKELY_LONG; + +  // Long integer divisions are often used in hashtable implementations. It's +  // not worth bypassing such divisions because hash values are extremely +  // unlikely to have enough leading zeros. The call below tries to detect +  // values that are unlikely to fit BypassType (including hashes). +  if (isHashLikeValue(V, Visited)) +    return VALRNG_LIKELY_LONG; + +  return VALRNG_UNKNOWN; +} + +/// Add new basic block for slow div and rem operations and put it before +/// SuccessorBB. +QuotRemWithBB FastDivInsertionTask::createSlowBB(BasicBlock *SuccessorBB) { +  QuotRemWithBB DivRemPair; +  DivRemPair.BB = BasicBlock::Create(MainBB->getParent()->getContext(), "", +                                     MainBB->getParent(), SuccessorBB); +  IRBuilder<> Builder(DivRemPair.BB, DivRemPair.BB->begin()); + +  Value *Dividend = SlowDivOrRem->getOperand(0); +  Value *Divisor = SlowDivOrRem->getOperand(1); + +  if (isSignedOp()) { +    DivRemPair.Quotient = Builder.CreateSDiv(Dividend, Divisor); +    DivRemPair.Remainder = Builder.CreateSRem(Dividend, Divisor);    } else { -    SlowQuotientV = SlowBuilder.CreateUDiv(Dividend, Divisor); -    SlowRemainderV = SlowBuilder.CreateURem(Dividend, Divisor); +    DivRemPair.Quotient = Builder.CreateUDiv(Dividend, Divisor); +    DivRemPair.Remainder = Builder.CreateURem(Dividend, Divisor);    } -  SlowBuilder.CreateBr(SuccessorBB); - -  // Add new basic block for fast divide operation -  BasicBlock *FastBB = -      BasicBlock::Create(F->getContext(), "", MainBB->getParent(), SuccessorBB); -  FastBB->moveBefore(SlowBB); -  IRBuilder<> FastBuilder(FastBB, FastBB->begin()); -  Value *ShortDivisorV = FastBuilder.CreateCast(Instruction::Trunc, Divisor, -                                                BypassType); -  Value *ShortDividendV = FastBuilder.CreateCast(Instruction::Trunc, Dividend, -                                                 BypassType); - -  // udiv/urem because optimization only handles positive numbers -  Value *ShortQuotientV = FastBuilder.CreateUDiv(ShortDividendV, ShortDivisorV); -  Value *ShortRemainderV = FastBuilder.CreateURem(ShortDividendV, -                                                  ShortDivisorV); -  Value *FastQuotientV = FastBuilder.CreateCast(Instruction::ZExt, -                                                ShortQuotientV, -                                                Dividend->getType()); -  Value *FastRemainderV = FastBuilder.CreateCast(Instruction::ZExt, -                                                 ShortRemainderV, -                                                 Dividend->getType()); -  FastBuilder.CreateBr(SuccessorBB); - -  // Phi nodes for result of div and rem -  IRBuilder<> SuccessorBuilder(SuccessorBB, SuccessorBB->begin()); -  PHINode *QuoPhi = SuccessorBuilder.CreatePHI(I->getType(), 2); -  QuoPhi->addIncoming(SlowQuotientV, SlowBB); -  QuoPhi->addIncoming(FastQuotientV, FastBB); -  PHINode *RemPhi = SuccessorBuilder.CreatePHI(I->getType(), 2); -  RemPhi->addIncoming(SlowRemainderV, SlowBB); -  RemPhi->addIncoming(FastRemainderV, FastBB); - -  // Replace I with appropriate phi node -  if (UseDivOp) -    I->replaceAllUsesWith(QuoPhi); -  else -    I->replaceAllUsesWith(RemPhi); -  I->eraseFromParent(); -  // Combine operands into a single value with OR for value testing below -  MainBB->getInstList().back().eraseFromParent(); -  IRBuilder<> MainBuilder(MainBB, MainBB->end()); +  Builder.CreateBr(SuccessorBB); +  return DivRemPair; +} + +/// Add new basic block for fast div and rem operations and put it before +/// SuccessorBB. +QuotRemWithBB FastDivInsertionTask::createFastBB(BasicBlock *SuccessorBB) { +  QuotRemWithBB DivRemPair; +  DivRemPair.BB = BasicBlock::Create(MainBB->getParent()->getContext(), "", +                                     MainBB->getParent(), SuccessorBB); +  IRBuilder<> Builder(DivRemPair.BB, DivRemPair.BB->begin()); + +  Value *Dividend = SlowDivOrRem->getOperand(0); +  Value *Divisor = SlowDivOrRem->getOperand(1); +  Value *ShortDivisorV = +      Builder.CreateCast(Instruction::Trunc, Divisor, BypassType); +  Value *ShortDividendV = +      Builder.CreateCast(Instruction::Trunc, Dividend, BypassType); + +  // udiv/urem because this optimization only handles positive numbers. +  Value *ShortQV = Builder.CreateUDiv(ShortDividendV, ShortDivisorV); +  Value *ShortRV = Builder.CreateURem(ShortDividendV, ShortDivisorV); +  DivRemPair.Quotient = +      Builder.CreateCast(Instruction::ZExt, ShortQV, getSlowType()); +  DivRemPair.Remainder = +      Builder.CreateCast(Instruction::ZExt, ShortRV, getSlowType()); +  Builder.CreateBr(SuccessorBB); + +  return DivRemPair; +} -  // We should have bailed out above if the divisor is a constant, but the -  // dividend may still be a constant.  Set OrV to our non-constant operands -  // OR'ed together. -  assert(!isa<ConstantInt>(Divisor)); +/// Creates Phi nodes for result of Div and Rem. +QuotRemPair FastDivInsertionTask::createDivRemPhiNodes(QuotRemWithBB &LHS, +                                                       QuotRemWithBB &RHS, +                                                       BasicBlock *PhiBB) { +  IRBuilder<> Builder(PhiBB, PhiBB->begin()); +  PHINode *QuoPhi = Builder.CreatePHI(getSlowType(), 2); +  QuoPhi->addIncoming(LHS.Quotient, LHS.BB); +  QuoPhi->addIncoming(RHS.Quotient, RHS.BB); +  PHINode *RemPhi = Builder.CreatePHI(getSlowType(), 2); +  RemPhi->addIncoming(LHS.Remainder, LHS.BB); +  RemPhi->addIncoming(RHS.Remainder, RHS.BB); +  return QuotRemPair(QuoPhi, RemPhi); +} + +/// Creates a runtime check to test whether both the divisor and dividend fit +/// into BypassType. The check is inserted at the end of MainBB. True return +/// value means that the operands fit. Either of the operands may be NULL if it +/// doesn't need a runtime check. +Value *FastDivInsertionTask::insertOperandRuntimeCheck(Value *Op1, Value *Op2) { +  assert((Op1 || Op2) && "Nothing to check"); +  IRBuilder<> Builder(MainBB, MainBB->end());    Value *OrV; -  if (!isa<ConstantInt>(Dividend)) -    OrV = MainBuilder.CreateOr(Dividend, Divisor); +  if (Op1 && Op2) +    OrV = Builder.CreateOr(Op1, Op2);    else -    OrV = Divisor; +    OrV = Op1 ? Op1 : Op2;    // BitMask is inverted to check if the operands are    // larger than the bypass type    uint64_t BitMask = ~BypassType->getBitMask(); -  Value *AndV = MainBuilder.CreateAnd(OrV, BitMask); - -  // Compare operand values and branch -  Value *ZeroV = ConstantInt::getSigned(Dividend->getType(), 0); -  Value *CmpV = MainBuilder.CreateICmpEQ(AndV, ZeroV); -  MainBuilder.CreateCondBr(CmpV, FastBB, SlowBB); - -  // Cache phi nodes to be used later in place of other instances -  // of div or rem with the same sign, dividend, and divisor -  DivOpInfo Key(UseSignedOp, Dividend, Divisor); -  DivPhiNodes Value(QuoPhi, RemPhi); -  PerBBDivCache.insert(std::pair<DivOpInfo, DivPhiNodes>(Key, Value)); -  return true; +  Value *AndV = Builder.CreateAnd(OrV, BitMask); + +  // Compare operand values +  Value *ZeroV = ConstantInt::getSigned(getSlowType(), 0); +  return Builder.CreateICmpEQ(AndV, ZeroV);  } -// reuseOrInsertFastDiv - Reuses previously computed dividend or remainder from -// the current BB if operands and operation are identical. Otherwise calls -// insertFastDiv to perform the optimization and caches the resulting dividend -// and remainder. -static bool reuseOrInsertFastDiv(Instruction *I, IntegerType *BypassType, -                                 bool UseDivOp, bool UseSignedOp, -                                 DivCacheTy &PerBBDivCache) { -  // Get instruction operands -  DivOpInfo Key(UseSignedOp, I->getOperand(0), I->getOperand(1)); -  DivCacheTy::iterator CacheI = PerBBDivCache.find(Key); - -  if (CacheI == PerBBDivCache.end()) { -    // If previous instance does not exist, insert fast div -    return insertFastDiv(I, BypassType, UseDivOp, UseSignedOp, PerBBDivCache); +/// Substitutes the div/rem instruction with code that checks the value of the +/// operands and uses a shorter-faster div/rem instruction when possible. +Optional<QuotRemPair> FastDivInsertionTask::insertFastDivAndRem() { +  Value *Dividend = SlowDivOrRem->getOperand(0); +  Value *Divisor = SlowDivOrRem->getOperand(1); + +  if (isa<ConstantInt>(Divisor)) { +    // Keep division by a constant for DAGCombiner. +    return None;    } -  // Replace operation value with previously generated phi node -  DivPhiNodes &Value = CacheI->second; -  if (UseDivOp) { -    // Replace all uses of div instruction with quotient phi node -    I->replaceAllUsesWith(Value.Quotient); +  VisitedSetTy SetL; +  ValueRange DividendRange = getValueRange(Dividend, SetL); +  if (DividendRange == VALRNG_LIKELY_LONG) +    return None; + +  VisitedSetTy SetR; +  ValueRange DivisorRange = getValueRange(Divisor, SetR); +  if (DivisorRange == VALRNG_LIKELY_LONG) +    return None; + +  bool DividendShort = (DividendRange == VALRNG_KNOWN_SHORT); +  bool DivisorShort = (DivisorRange == VALRNG_KNOWN_SHORT); + +  if (DividendShort && DivisorShort) { +    // If both operands are known to be short then just replace the long +    // division with a short one in-place. + +    IRBuilder<> Builder(SlowDivOrRem); +    Value *TruncDividend = Builder.CreateTrunc(Dividend, BypassType); +    Value *TruncDivisor = Builder.CreateTrunc(Divisor, BypassType); +    Value *TruncDiv = Builder.CreateUDiv(TruncDividend, TruncDivisor); +    Value *TruncRem = Builder.CreateURem(TruncDividend, TruncDivisor); +    Value *ExtDiv = Builder.CreateZExt(TruncDiv, getSlowType()); +    Value *ExtRem = Builder.CreateZExt(TruncRem, getSlowType()); +    return QuotRemPair(ExtDiv, ExtRem); +  } else if (DividendShort && !isSignedOp()) { +    // If the division is unsigned and Dividend is known to be short, then +    // either +    // 1) Divisor is less or equal to Dividend, and the result can be computed +    //    with a short division. +    // 2) Divisor is greater than Dividend. In this case, no division is needed +    //    at all: The quotient is 0 and the remainder is equal to Dividend. +    // +    // So instead of checking at runtime whether Divisor fits into BypassType, +    // we emit a runtime check to differentiate between these two cases. This +    // lets us entirely avoid a long div. + +    // Split the basic block before the div/rem. +    BasicBlock *SuccessorBB = MainBB->splitBasicBlock(SlowDivOrRem); +    // Remove the unconditional branch from MainBB to SuccessorBB. +    MainBB->getInstList().back().eraseFromParent(); +    QuotRemWithBB Long; +    Long.BB = MainBB; +    Long.Quotient = ConstantInt::get(getSlowType(), 0); +    Long.Remainder = Dividend; +    QuotRemWithBB Fast = createFastBB(SuccessorBB); +    QuotRemPair Result = createDivRemPhiNodes(Fast, Long, SuccessorBB); +    IRBuilder<> Builder(MainBB, MainBB->end()); +    Value *CmpV = Builder.CreateICmpUGE(Dividend, Divisor); +    Builder.CreateCondBr(CmpV, Fast.BB, SuccessorBB); +    return Result;    } else { -    // Replace all uses of rem instruction with remainder phi node -    I->replaceAllUsesWith(Value.Remainder); +    // General case. Create both slow and fast div/rem pairs and choose one of +    // them at runtime. + +    // Split the basic block before the div/rem. +    BasicBlock *SuccessorBB = MainBB->splitBasicBlock(SlowDivOrRem); +    // Remove the unconditional branch from MainBB to SuccessorBB. +    MainBB->getInstList().back().eraseFromParent(); +    QuotRemWithBB Fast = createFastBB(SuccessorBB); +    QuotRemWithBB Slow = createSlowBB(SuccessorBB); +    QuotRemPair Result = createDivRemPhiNodes(Fast, Slow, SuccessorBB); +    Value *CmpV = insertOperandRuntimeCheck(DividendShort ? nullptr : Dividend, +                                            DivisorShort ? nullptr : Divisor); +    IRBuilder<> Builder(MainBB, MainBB->end()); +    Builder.CreateCondBr(CmpV, Fast.BB, Slow.BB); +    return Result;    } - -  // Remove redundant operation -  I->eraseFromParent(); -  return true;  } -// bypassSlowDivision - This optimization identifies DIV instructions in a BB -// that can be profitably bypassed and carried out with a shorter, faster -// divide. -bool llvm::bypassSlowDivision( -    BasicBlock *BB, const DenseMap<unsigned int, unsigned int> &BypassWidths) { -  DivCacheTy DivCache; +/// This optimization identifies DIV/REM instructions in a BB that can be +/// profitably bypassed and carried out with a shorter, faster divide. +bool llvm::bypassSlowDivision(BasicBlock *BB, +                              const BypassWidthsTy &BypassWidths) { +  DivCacheTy PerBBDivCache;    bool MadeChange = false;    Instruction* Next = &*BB->begin(); @@ -231,42 +459,20 @@ bool llvm::bypassSlowDivision(      Instruction* I = Next;      Next = Next->getNextNode(); -    // Get instruction details -    unsigned Opcode = I->getOpcode(); -    bool UseDivOp = Opcode == Instruction::SDiv || Opcode == Instruction::UDiv; -    bool UseRemOp = Opcode == Instruction::SRem || Opcode == Instruction::URem; -    bool UseSignedOp = Opcode == Instruction::SDiv || -                       Opcode == Instruction::SRem; - -    // Only optimize div or rem ops -    if (!UseDivOp && !UseRemOp) -      continue; - -    // Skip division on vector types, only optimize integer instructions -    if (!I->getType()->isIntegerTy()) -      continue; - -    // Get bitwidth of div/rem instruction -    IntegerType *T = cast<IntegerType>(I->getType()); -    unsigned int bitwidth = T->getBitWidth(); - -    // Continue if bitwidth is not bypassed -    DenseMap<unsigned int, unsigned int>::const_iterator BI = BypassWidths.find(bitwidth); -    if (BI == BypassWidths.end()) -      continue; - -    // Get type for div/rem instruction with bypass bitwidth -    IntegerType *BT = IntegerType::get(I->getContext(), BI->second); - -    MadeChange |= reuseOrInsertFastDiv(I, BT, UseDivOp, UseSignedOp, DivCache); +    FastDivInsertionTask Task(I, BypassWidths); +    if (Value *Replacement = Task.getReplacement(PerBBDivCache)) { +      I->replaceAllUsesWith(Replacement); +      I->eraseFromParent(); +      MadeChange = true; +    }    }    // Above we eagerly create divs and rems, as pairs, so that we can efficiently    // create divrem machine instructions.  Now erase any unused divs / rems so we    // don't leave extra instructions sitting around. -  for (auto &KV : DivCache) -    for (Instruction *Phi : {KV.second.Quotient, KV.second.Remainder}) -      RecursivelyDeleteTriviallyDeadInstructions(Phi); +  for (auto &KV : PerBBDivCache) +    for (Value *V : {KV.second.Quotient, KV.second.Remainder}) +      RecursivelyDeleteTriviallyDeadInstructions(V);    return MadeChange;  } diff --git a/lib/Transforms/Utils/CMakeLists.txt b/lib/Transforms/Utils/CMakeLists.txt index 69889ec72f90..7a21c03da221 100644 --- a/lib/Transforms/Utils/CMakeLists.txt +++ b/lib/Transforms/Utils/CMakeLists.txt @@ -31,12 +31,13 @@ add_llvm_library(LLVMTransformUtils    LoopUtils.cpp    LoopVersioning.cpp    LowerInvoke.cpp +  LowerMemIntrinsics.cpp    LowerSwitch.cpp    Mem2Reg.cpp -  MemorySSA.cpp    MetaRenamer.cpp    ModuleUtils.cpp    NameAnonGlobals.cpp +  PredicateInfo.cpp    PromoteMemoryToRegister.cpp    StripGCRelocates.cpp    SSAUpdater.cpp @@ -51,6 +52,7 @@ add_llvm_library(LLVMTransformUtils    UnifyFunctionExitNodes.cpp    Utils.cpp    ValueMapper.cpp +  VNCoercion.cpp    ADDITIONAL_HEADER_DIRS    ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms diff --git a/lib/Transforms/Utils/CloneFunction.cpp b/lib/Transforms/Utils/CloneFunction.cpp index 4d33e22fecfb..385c12302e04 100644 --- a/lib/Transforms/Utils/CloneFunction.cpp +++ b/lib/Transforms/Utils/CloneFunction.cpp @@ -90,9 +90,9 @@ void llvm::CloneFunctionInto(Function *NewFunc, const Function *OldFunc,      assert(VMap.count(&I) && "No mapping from source argument specified!");  #endif -  // Copy all attributes other than those stored in the AttributeSet.  We need -  // to remap the parameter indices of the AttributeSet. -  AttributeSet NewAttrs = NewFunc->getAttributes(); +  // Copy all attributes other than those stored in the AttributeList.  We need +  // to remap the parameter indices of the AttributeList. +  AttributeList NewAttrs = NewFunc->getAttributes();    NewFunc->copyAttributesFrom(OldFunc);    NewFunc->setAttributes(NewAttrs); @@ -103,22 +103,20 @@ void llvm::CloneFunctionInto(Function *NewFunc, const Function *OldFunc,                   ModuleLevelChanges ? RF_None : RF_NoModuleLevelChanges,                   TypeMapper, Materializer)); -  AttributeSet OldAttrs = OldFunc->getAttributes(); +  SmallVector<AttributeSet, 4> NewArgAttrs(NewFunc->arg_size()); +  AttributeList OldAttrs = OldFunc->getAttributes(); +    // Clone any argument attributes that are present in the VMap. -  for (const Argument &OldArg : OldFunc->args()) +  for (const Argument &OldArg : OldFunc->args()) {      if (Argument *NewArg = dyn_cast<Argument>(VMap[&OldArg])) { -      AttributeSet attrs = -          OldAttrs.getParamAttributes(OldArg.getArgNo() + 1); -      if (attrs.getNumSlots() > 0) -        NewArg->addAttr(attrs); +      NewArgAttrs[NewArg->getArgNo()] = +          OldAttrs.getParamAttributes(OldArg.getArgNo());      } +  }    NewFunc->setAttributes( -      NewFunc->getAttributes() -          .addAttributes(NewFunc->getContext(), AttributeSet::ReturnIndex, -                         OldAttrs.getRetAttributes()) -          .addAttributes(NewFunc->getContext(), AttributeSet::FunctionIndex, -                         OldAttrs.getFnAttributes())); +      AttributeList::get(NewFunc->getContext(), OldAttrs.getFnAttributes(), +                         OldAttrs.getRetAttributes(), NewArgAttrs));    SmallVector<std::pair<unsigned, MDNode *>, 1> MDs;    OldFunc->getAllMetadata(MDs); @@ -353,7 +351,7 @@ void PruningFunctionCloner::CloneBlock(const BasicBlock *BB,        Cond = dyn_cast_or_null<ConstantInt>(V);      }      if (Cond) {     // Constant fold to uncond branch! -      SwitchInst::ConstCaseIt Case = SI->findCaseValue(Cond); +      SwitchInst::ConstCaseHandle Case = *SI->findCaseValue(Cond);        BasicBlock *Dest = const_cast<BasicBlock*>(Case.getCaseSuccessor());        VMap[OldTI] = BranchInst::Create(Dest, NewBB);        ToClone.push_back(Dest); @@ -747,3 +745,40 @@ Loop *llvm::cloneLoopWithPreheader(BasicBlock *Before, BasicBlock *LoopDomBB,    return NewLoop;  } + +/// \brief Duplicate non-Phi instructions from the beginning of block up to +/// StopAt instruction into a split block between BB and its predecessor. +BasicBlock * +llvm::DuplicateInstructionsInSplitBetween(BasicBlock *BB, BasicBlock *PredBB, +                                          Instruction *StopAt, +                                          ValueToValueMapTy &ValueMapping) { +  // We are going to have to map operands from the original BB block to the new +  // copy of the block 'NewBB'.  If there are PHI nodes in BB, evaluate them to +  // account for entry from PredBB. +  BasicBlock::iterator BI = BB->begin(); +  for (; PHINode *PN = dyn_cast<PHINode>(BI); ++BI) +    ValueMapping[PN] = PN->getIncomingValueForBlock(PredBB); + +  BasicBlock *NewBB = SplitEdge(PredBB, BB); +  NewBB->setName(PredBB->getName() + ".split"); +  Instruction *NewTerm = NewBB->getTerminator(); + +  // Clone the non-phi instructions of BB into NewBB, keeping track of the +  // mapping and using it to remap operands in the cloned instructions. +  for (; StopAt != &*BI; ++BI) { +    Instruction *New = BI->clone(); +    New->setName(BI->getName()); +    New->insertBefore(NewTerm); +    ValueMapping[&*BI] = New; + +    // Remap operands to patch up intra-block references. +    for (unsigned i = 0, e = New->getNumOperands(); i != e; ++i) +      if (Instruction *Inst = dyn_cast<Instruction>(New->getOperand(i))) { +        auto I = ValueMapping.find(Inst); +        if (I != ValueMapping.end()) +          New->setOperand(i, I->second); +      } +  } + +  return NewBB; +} diff --git a/lib/Transforms/Utils/CloneModule.cpp b/lib/Transforms/Utils/CloneModule.cpp index 7ebeb615d248..4e9d67252d6c 100644 --- a/lib/Transforms/Utils/CloneModule.cpp +++ b/lib/Transforms/Utils/CloneModule.cpp @@ -20,6 +20,15 @@  #include "llvm-c/Core.h"  using namespace llvm; +static void copyComdat(GlobalObject *Dst, const GlobalObject *Src) { +  const Comdat *SC = Src->getComdat(); +  if (!SC) +    return; +  Comdat *DC = Dst->getParent()->getOrInsertComdat(SC->getName()); +  DC->setSelectionKind(SC->getSelectionKind()); +  Dst->setComdat(DC); +} +  /// This is not as easy as it might seem because we have to worry about making  /// copies of global variables and functions, and making their (initializers and  /// references, respectively) refer to the right globals. @@ -124,6 +133,8 @@ std::unique_ptr<Module> llvm::CloneModule(      I->getAllMetadata(MDs);      for (auto MD : MDs)        GV->addMetadata(MD.first, *MapMetadata(MD.second, VMap)); + +    copyComdat(GV, &*I);    }    // Similarly, copy over function bodies now... @@ -153,6 +164,8 @@ std::unique_ptr<Module> llvm::CloneModule(      if (I.hasPersonalityFn())        F->setPersonalityFn(MapValue(I.getPersonalityFn(), VMap)); + +    copyComdat(F, &I);    }    // And aliases diff --git a/lib/Transforms/Utils/CodeExtractor.cpp b/lib/Transforms/Utils/CodeExtractor.cpp index c514c9c9cd4a..644d93b727b3 100644 --- a/lib/Transforms/Utils/CodeExtractor.cpp +++ b/lib/Transforms/Utils/CodeExtractor.cpp @@ -362,9 +362,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,    //  "target-features" attribute allowing it to be lowered.    // FIXME: This should be changed to check to see if a specific    //           attribute can not be inherited. -  AttributeSet OldFnAttrs = oldFunction->getAttributes().getFnAttributes(); -  AttrBuilder AB(OldFnAttrs, AttributeSet::FunctionIndex); -  for (auto Attr : AB.td_attrs()) +  AttrBuilder AB(oldFunction->getAttributes().getFnAttributes()); +  for (const auto &Attr : AB.td_attrs())      newFunction->addFnAttr(Attr.first, Attr.second);    newFunction->getBasicBlockList().push_back(newRootNode); @@ -440,8 +439,10 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer,    // Emit a call to the new function, passing in: *pointer to struct (if    // aggregating parameters), or plan inputs and allocated memory for outputs    std::vector<Value*> params, StructValues, ReloadOutputs, Reloads; -   -  LLVMContext &Context = newFunction->getContext(); + +  Module *M = newFunction->getParent(); +  LLVMContext &Context = M->getContext(); +  const DataLayout &DL = M->getDataLayout();    // Add inputs as params, or to be filled into the struct    for (Value *input : inputs) @@ -456,8 +457,9 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer,        StructValues.push_back(output);      } else {        AllocaInst *alloca = -          new AllocaInst(output->getType(), nullptr, output->getName() + ".loc", -                         &codeReplacer->getParent()->front().front()); +        new AllocaInst(output->getType(), DL.getAllocaAddrSpace(), +                       nullptr, output->getName() + ".loc", +                       &codeReplacer->getParent()->front().front());        ReloadOutputs.push_back(alloca);        params.push_back(alloca);      } @@ -473,7 +475,8 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer,      // Allocate a struct at the beginning of this function      StructArgTy = StructType::get(newFunction->getContext(), ArgTypes); -    Struct = new AllocaInst(StructArgTy, nullptr, "structArg", +    Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr, +                            "structArg",                              &codeReplacer->getParent()->front().front());      params.push_back(Struct); diff --git a/lib/Transforms/Utils/DemoteRegToStack.cpp b/lib/Transforms/Utils/DemoteRegToStack.cpp index 75a1dde57c4c..0eee6e19efac 100644 --- a/lib/Transforms/Utils/DemoteRegToStack.cpp +++ b/lib/Transforms/Utils/DemoteRegToStack.cpp @@ -28,15 +28,17 @@ AllocaInst *llvm::DemoteRegToStack(Instruction &I, bool VolatileLoads,      return nullptr;    } +  Function *F = I.getParent()->getParent(); +  const DataLayout &DL = F->getParent()->getDataLayout(); +    // Create a stack slot to hold the value.    AllocaInst *Slot;    if (AllocaPoint) { -    Slot = new AllocaInst(I.getType(), nullptr, +    Slot = new AllocaInst(I.getType(), DL.getAllocaAddrSpace(), nullptr,                            I.getName()+".reg2mem", AllocaPoint);    } else { -    Function *F = I.getParent()->getParent(); -    Slot = new AllocaInst(I.getType(), nullptr, I.getName() + ".reg2mem", -                          &F->getEntryBlock().front()); +    Slot = new AllocaInst(I.getType(), DL.getAllocaAddrSpace(), nullptr, +                          I.getName() + ".reg2mem", &F->getEntryBlock().front());    }    // We cannot demote invoke instructions to the stack if their normal edge @@ -110,14 +112,17 @@ AllocaInst *llvm::DemotePHIToStack(PHINode *P, Instruction *AllocaPoint) {      return nullptr;    } +  const DataLayout &DL = P->getModule()->getDataLayout(); +    // Create a stack slot to hold the value.    AllocaInst *Slot;    if (AllocaPoint) { -    Slot = new AllocaInst(P->getType(), nullptr, +    Slot = new AllocaInst(P->getType(), DL.getAllocaAddrSpace(), nullptr,                            P->getName()+".reg2mem", AllocaPoint);    } else {      Function *F = P->getParent()->getParent(); -    Slot = new AllocaInst(P->getType(), nullptr, P->getName() + ".reg2mem", +    Slot = new AllocaInst(P->getType(), DL.getAllocaAddrSpace(), nullptr, +                          P->getName() + ".reg2mem",                            &F->getEntryBlock().front());    } diff --git a/lib/Transforms/Utils/Evaluator.cpp b/lib/Transforms/Utils/Evaluator.cpp index 4adf1754253d..59f176e2f231 100644 --- a/lib/Transforms/Utils/Evaluator.cpp +++ b/lib/Transforms/Utils/Evaluator.cpp @@ -16,6 +16,7 @@  #include "llvm/IR/BasicBlock.h"  #include "llvm/IR/CallSite.h"  #include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h"  #include "llvm/IR/DerivedTypes.h"  #include "llvm/IR/DiagnosticPrinter.h"  #include "llvm/IR/GlobalVariable.h" @@ -486,7 +487,7 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst,          ConstantInt *Val =            dyn_cast<ConstantInt>(getVal(SI->getCondition()));          if (!Val) return false;  // Cannot determine. -        NextBB = SI->findCaseValue(Val).getCaseSuccessor(); +        NextBB = SI->findCaseValue(Val)->getCaseSuccessor();        } else if (IndirectBrInst *IBI = dyn_cast<IndirectBrInst>(CurInst)) {          Value *Val = getVal(IBI->getAddress())->stripPointerCasts();          if (BlockAddress *BA = dyn_cast<BlockAddress>(Val)) diff --git a/lib/Transforms/Utils/FunctionComparator.cpp b/lib/Transforms/Utils/FunctionComparator.cpp index 81a7c4ceffab..73a0b2737e95 100644 --- a/lib/Transforms/Utils/FunctionComparator.cpp +++ b/lib/Transforms/Utils/FunctionComparator.cpp @@ -74,14 +74,14 @@ int FunctionComparator::cmpMem(StringRef L, StringRef R) const {    return L.compare(R);  } -int FunctionComparator::cmpAttrs(const AttributeSet L, -                                 const AttributeSet R) const { +int FunctionComparator::cmpAttrs(const AttributeList L, +                                 const AttributeList R) const {    if (int Res = cmpNumbers(L.getNumSlots(), R.getNumSlots()))      return Res;    for (unsigned i = 0, e = L.getNumSlots(); i != e; ++i) { -    AttributeSet::iterator LI = L.begin(i), LE = L.end(i), RI = R.begin(i), -                           RE = R.end(i); +    AttributeList::iterator LI = L.begin(i), LE = L.end(i), RI = R.begin(i), +                            RE = R.end(i);      for (; LI != LE && RI != RE; ++LI, ++RI) {        Attribute LA = *LI;        Attribute RA = *RI; diff --git a/lib/Transforms/Utils/FunctionImportUtils.cpp b/lib/Transforms/Utils/FunctionImportUtils.cpp index 9844190ef84a..b00f4b14068a 100644 --- a/lib/Transforms/Utils/FunctionImportUtils.cpp +++ b/lib/Transforms/Utils/FunctionImportUtils.cpp @@ -21,11 +21,11 @@ using namespace llvm;  /// Checks if we should import SGV as a definition, otherwise import as a  /// declaration.  bool FunctionImportGlobalProcessing::doImportAsDefinition( -    const GlobalValue *SGV, DenseSet<const GlobalValue *> *GlobalsToImport) { +    const GlobalValue *SGV, SetVector<GlobalValue *> *GlobalsToImport) {    // For alias, we tie the definition to the base object. Extract it and recurse    if (auto *GA = dyn_cast<GlobalAlias>(SGV)) { -    if (GA->hasWeakAnyLinkage()) +    if (GA->isInterposable())        return false;      const GlobalObject *GO = GA->getBaseObject();      if (!GO->hasLinkOnceODRLinkage()) @@ -34,7 +34,7 @@ bool FunctionImportGlobalProcessing::doImportAsDefinition(          GO, GlobalsToImport);    }    // Only import the globals requested for importing. -  if (GlobalsToImport->count(SGV)) +  if (GlobalsToImport->count(const_cast<GlobalValue *>(SGV)))      return true;    // Otherwise no.    return false; @@ -57,7 +57,8 @@ bool FunctionImportGlobalProcessing::shouldPromoteLocalToGlobal(      return false;    if (isPerformingImport()) { -    assert((!GlobalsToImport->count(SGV) || !isNonRenamableLocal(*SGV)) && +    assert((!GlobalsToImport->count(const_cast<GlobalValue *>(SGV)) || +            !isNonRenamableLocal(*SGV)) &&             "Attempting to promote non-renamable local");      // We don't know for sure yet if we are importing this value (as either      // a reference or a def), since we are simply walking all values in the @@ -254,9 +255,8 @@ bool FunctionImportGlobalProcessing::run() {    return false;  } -bool llvm::renameModuleForThinLTO( -    Module &M, const ModuleSummaryIndex &Index, -    DenseSet<const GlobalValue *> *GlobalsToImport) { +bool llvm::renameModuleForThinLTO(Module &M, const ModuleSummaryIndex &Index, +                                  SetVector<GlobalValue *> *GlobalsToImport) {    FunctionImportGlobalProcessing ThinLTOProcessing(M, Index, GlobalsToImport);    return ThinLTOProcessing.run();  } diff --git a/lib/Transforms/Utils/GlobalStatus.cpp b/lib/Transforms/Utils/GlobalStatus.cpp index 74ebcda8355c..ba4b78ac758a 100644 --- a/lib/Transforms/Utils/GlobalStatus.cpp +++ b/lib/Transforms/Utils/GlobalStatus.cpp @@ -10,9 +10,22 @@  #include "llvm/ADT/SmallPtrSet.h"  #include "llvm/IR/BasicBlock.h"  #include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/GlobalValue.h"  #include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h"  #include "llvm/IR/IntrinsicInst.h"  #include "llvm/Transforms/Utils/GlobalStatus.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/AtomicOrdering.h" +#include "llvm/Support/Casting.h" +#include <algorithm> +#include <cassert>  using namespace llvm; @@ -175,13 +188,9 @@ static bool analyzeGlobalAux(const Value *V, GlobalStatus &GS,    return false;  } +GlobalStatus::GlobalStatus() = default; +  bool GlobalStatus::analyzeGlobal(const Value *V, GlobalStatus &GS) {    SmallPtrSet<const PHINode *, 16> PhiUsers;    return analyzeGlobalAux(V, GS, PhiUsers);  } - -GlobalStatus::GlobalStatus() -    : IsCompared(false), IsLoaded(false), StoredType(NotStored), -      StoredOnceValue(nullptr), AccessingFunction(nullptr), -      HasMultipleAccessingFunctions(false), HasNonInstructionUser(false), -      Ordering(AtomicOrdering::NotAtomic) {} diff --git a/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp b/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp index ed018bb73107..b8c12ad5ea84 100644 --- a/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp +++ b/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp @@ -62,6 +62,8 @@ void ImportedFunctionsInliningStatistics::recordInline(const Function &Caller,  void ImportedFunctionsInliningStatistics::setModuleInfo(const Module &M) {    ModuleName = M.getName();    for (const auto &F : M.functions()) { +    if (F.isDeclaration()) +      continue;      AllFunctions++;      ImportedFunctions += int(F.getMetadata("thinlto_src_module") != nullptr);    } diff --git a/lib/Transforms/Utils/InlineFunction.cpp b/lib/Transforms/Utils/InlineFunction.cpp index a40079ca8e76..5d6fbc3325ff 100644 --- a/lib/Transforms/Utils/InlineFunction.cpp +++ b/lib/Transforms/Utils/InlineFunction.cpp @@ -20,10 +20,12 @@  #include "llvm/ADT/StringExtras.h"  #include "llvm/Analysis/AliasAnalysis.h"  #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BlockFrequencyInfo.h"  #include "llvm/Analysis/CallGraph.h"  #include "llvm/Analysis/CaptureTracking.h"  #include "llvm/Analysis/EHPersonalities.h"  #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/ProfileSummaryInfo.h"  #include "llvm/Analysis/ValueTracking.h"  #include "llvm/IR/Attributes.h"  #include "llvm/IR/CallSite.h" @@ -40,8 +42,8 @@  #include "llvm/IR/Intrinsics.h"  #include "llvm/IR/MDBuilder.h"  #include "llvm/IR/Module.h" -#include "llvm/Transforms/Utils/Local.h"  #include "llvm/Support/CommandLine.h" +#include "llvm/Transforms/Utils/Local.h"  #include <algorithm>  using namespace llvm; @@ -1107,26 +1109,23 @@ static void AddAlignmentAssumptions(CallSite CS, InlineFunctionInfo &IFI) {    bool DTCalculated = false;    Function *CalledFunc = CS.getCalledFunction(); -  for (Function::arg_iterator I = CalledFunc->arg_begin(), -                              E = CalledFunc->arg_end(); -       I != E; ++I) { -    unsigned Align = I->getType()->isPointerTy() ? I->getParamAlignment() : 0; -    if (Align && !I->hasByValOrInAllocaAttr() && !I->hasNUses(0)) { +  for (Argument &Arg : CalledFunc->args()) { +    unsigned Align = Arg.getType()->isPointerTy() ? Arg.getParamAlignment() : 0; +    if (Align && !Arg.hasByValOrInAllocaAttr() && !Arg.hasNUses(0)) {        if (!DTCalculated) { -        DT.recalculate(const_cast<Function&>(*CS.getInstruction()->getParent() -                                               ->getParent())); +        DT.recalculate(*CS.getCaller());          DTCalculated = true;        }        // If we can already prove the asserted alignment in the context of the        // caller, then don't bother inserting the assumption. -      Value *Arg = CS.getArgument(I->getArgNo()); -      if (getKnownAlignment(Arg, DL, CS.getInstruction(), AC, &DT) >= Align) +      Value *ArgVal = CS.getArgument(Arg.getArgNo()); +      if (getKnownAlignment(ArgVal, DL, CS.getInstruction(), AC, &DT) >= Align)          continue; -      CallInst *NewAssumption = IRBuilder<>(CS.getInstruction()) -                                    .CreateAlignmentAssumption(DL, Arg, Align); -      AC->registerAssumption(NewAssumption); +      CallInst *NewAsmp = IRBuilder<>(CS.getInstruction()) +                              .CreateAlignmentAssumption(DL, ArgVal, Align); +      AC->registerAssumption(NewAsmp);      }    }  } @@ -1140,7 +1139,7 @@ static void UpdateCallGraphAfterInlining(CallSite CS,                                           ValueToValueMapTy &VMap,                                           InlineFunctionInfo &IFI) {    CallGraph &CG = *IFI.CG; -  const Function *Caller = CS.getInstruction()->getParent()->getParent(); +  const Function *Caller = CS.getCaller();    const Function *Callee = CS.getCalledFunction();    CallGraphNode *CalleeNode = CG[Callee];    CallGraphNode *CallerNode = CG[Caller]; @@ -1225,7 +1224,8 @@ static Value *HandleByValArgument(Value *Arg, Instruction *TheCall,    PointerType *ArgTy = cast<PointerType>(Arg->getType());    Type *AggTy = ArgTy->getElementType(); -  Function *Caller = TheCall->getParent()->getParent(); +  Function *Caller = TheCall->getFunction(); +  const DataLayout &DL = Caller->getParent()->getDataLayout();    // If the called function is readonly, then it could not mutate the caller's    // copy of the byval'd memory.  In this case, it is safe to elide the copy and @@ -1239,31 +1239,30 @@ static Value *HandleByValArgument(Value *Arg, Instruction *TheCall,      AssumptionCache *AC =          IFI.GetAssumptionCache ? &(*IFI.GetAssumptionCache)(*Caller) : nullptr; -    const DataLayout &DL = Caller->getParent()->getDataLayout();      // If the pointer is already known to be sufficiently aligned, or if we can      // round it up to a larger alignment, then we don't need a temporary.      if (getOrEnforceKnownAlignment(Arg, ByValAlignment, DL, TheCall, AC) >=          ByValAlignment)        return Arg; -     +      // Otherwise, we have to make a memcpy to get a safe alignment.  This is bad      // for code quality, but rarely happens and is required for correctness.    }    // Create the alloca.  If we have DataLayout, use nice alignment. -  unsigned Align = -      Caller->getParent()->getDataLayout().getPrefTypeAlignment(AggTy); +  unsigned Align = DL.getPrefTypeAlignment(AggTy);    // If the byval had an alignment specified, we *must* use at least that    // alignment, as it is required by the byval argument (and uses of the    // pointer inside the callee).    Align = std::max(Align, ByValAlignment); -   -  Value *NewAlloca = new AllocaInst(AggTy, nullptr, Align, Arg->getName(),  + +  Value *NewAlloca = new AllocaInst(AggTy, DL.getAllocaAddrSpace(), +                                    nullptr, Align, Arg->getName(),                                      &*Caller->begin()->begin());    IFI.StaticAllocas.push_back(cast<AllocaInst>(NewAlloca)); -   +    // Uses of the argument in the function should use our new alloca    // instead.    return NewAlloca; @@ -1393,6 +1392,89 @@ static void fixupLineNumbers(Function *Fn, Function::iterator FI,      }    }  } +/// Update the block frequencies of the caller after a callee has been inlined. +/// +/// Each block cloned into the caller has its block frequency scaled by the +/// ratio of CallSiteFreq/CalleeEntryFreq. This ensures that the cloned copy of +/// callee's entry block gets the same frequency as the callsite block and the +/// relative frequencies of all cloned blocks remain the same after cloning. +static void updateCallerBFI(BasicBlock *CallSiteBlock, +                            const ValueToValueMapTy &VMap, +                            BlockFrequencyInfo *CallerBFI, +                            BlockFrequencyInfo *CalleeBFI, +                            const BasicBlock &CalleeEntryBlock) { +  SmallPtrSet<BasicBlock *, 16> ClonedBBs; +  for (auto const &Entry : VMap) { +    if (!isa<BasicBlock>(Entry.first) || !Entry.second) +      continue; +    auto *OrigBB = cast<BasicBlock>(Entry.first); +    auto *ClonedBB = cast<BasicBlock>(Entry.second); +    uint64_t Freq = CalleeBFI->getBlockFreq(OrigBB).getFrequency(); +    if (!ClonedBBs.insert(ClonedBB).second) { +      // Multiple blocks in the callee might get mapped to one cloned block in +      // the caller since we prune the callee as we clone it. When that happens, +      // we want to use the maximum among the original blocks' frequencies. +      uint64_t NewFreq = CallerBFI->getBlockFreq(ClonedBB).getFrequency(); +      if (NewFreq > Freq) +        Freq = NewFreq; +    } +    CallerBFI->setBlockFreq(ClonedBB, Freq); +  } +  BasicBlock *EntryClone = cast<BasicBlock>(VMap.lookup(&CalleeEntryBlock)); +  CallerBFI->setBlockFreqAndScale( +      EntryClone, CallerBFI->getBlockFreq(CallSiteBlock).getFrequency(), +      ClonedBBs); +} + +/// Update the branch metadata for cloned call instructions. +static void updateCallProfile(Function *Callee, const ValueToValueMapTy &VMap, +                              const Optional<uint64_t> &CalleeEntryCount, +                              const Instruction *TheCall) { +  if (!CalleeEntryCount.hasValue() || CalleeEntryCount.getValue() < 1) +    return; +  Optional<uint64_t> CallSiteCount = +      ProfileSummaryInfo::getProfileCount(TheCall, nullptr); +  uint64_t CallCount = +      std::min(CallSiteCount.hasValue() ? CallSiteCount.getValue() : 0, +               CalleeEntryCount.getValue()); + +  for (auto const &Entry : VMap) +    if (isa<CallInst>(Entry.first)) +      if (auto *CI = dyn_cast_or_null<CallInst>(Entry.second)) +        CI->updateProfWeight(CallCount, CalleeEntryCount.getValue()); +  for (BasicBlock &BB : *Callee) +    // No need to update the callsite if it is pruned during inlining. +    if (VMap.count(&BB)) +      for (Instruction &I : BB) +        if (CallInst *CI = dyn_cast<CallInst>(&I)) +          CI->updateProfWeight(CalleeEntryCount.getValue() - CallCount, +                               CalleeEntryCount.getValue()); +} + +/// Update the entry count of callee after inlining. +/// +/// The callsite's block count is subtracted from the callee's function entry +/// count. +static void updateCalleeCount(BlockFrequencyInfo *CallerBFI, BasicBlock *CallBB, +                              Instruction *CallInst, Function *Callee) { +  // If the callee has a original count of N, and the estimated count of +  // callsite is M, the new callee count is set to N - M. M is estimated from +  // the caller's entry count, its entry block frequency and the block frequency +  // of the callsite. +  Optional<uint64_t> CalleeCount = Callee->getEntryCount(); +  if (!CalleeCount.hasValue()) +    return; +  Optional<uint64_t> CallCount = +      ProfileSummaryInfo::getProfileCount(CallInst, CallerBFI); +  if (!CallCount.hasValue()) +    return; +  // Since CallSiteCount is an estimate, it could exceed the original callee +  // count and has to be set to 0. +  if (CallCount.getValue() > CalleeCount.getValue()) +    Callee->setEntryCount(0); +  else +    Callee->setEntryCount(CalleeCount.getValue() - CallCount.getValue()); +}  /// This function inlines the called function into the basic block of the  /// caller. This returns false if it is not possible to inline this call. @@ -1405,13 +1487,13 @@ static void fixupLineNumbers(Function *Fn, Function::iterator FI,  bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI,                            AAResults *CalleeAAR, bool InsertLifetime) {    Instruction *TheCall = CS.getInstruction(); -  assert(TheCall->getParent() && TheCall->getParent()->getParent() && -         "Instruction not in function!"); +  assert(TheCall->getParent() && TheCall->getFunction() +         && "Instruction not in function!");    // If IFI has any state in it, zap it before we fill it in.    IFI.reset(); -   -  const Function *CalledFunc = CS.getCalledFunction(); + +  Function *CalledFunc = CS.getCalledFunction();    if (!CalledFunc ||              // Can't inline external function or indirect        CalledFunc->isDeclaration() || // call, or call to a vararg function!        CalledFunc->getFunctionType()->isVarArg()) return false; @@ -1548,7 +1630,7 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI,      // matches up the formal to the actual argument values.      CallSite::arg_iterator AI = CS.arg_begin();      unsigned ArgNo = 0; -    for (Function::const_arg_iterator I = CalledFunc->arg_begin(), +    for (Function::arg_iterator I = CalledFunc->arg_begin(),           E = CalledFunc->arg_end(); I != E; ++I, ++AI, ++ArgNo) {        Value *ActualArg = *AI; @@ -1578,10 +1660,18 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI,      CloneAndPruneFunctionInto(Caller, CalledFunc, VMap,                                /*ModuleLevelChanges=*/false, Returns, ".i",                                &InlinedFunctionInfo, TheCall); -      // Remember the first block that is newly cloned over.      FirstNewBlock = LastBlock; ++FirstNewBlock; +    if (IFI.CallerBFI != nullptr && IFI.CalleeBFI != nullptr) +      // Update the BFI of blocks cloned into the caller. +      updateCallerBFI(OrigBB, VMap, IFI.CallerBFI, IFI.CalleeBFI, +                      CalledFunc->front()); + +    updateCallProfile(CalledFunc, VMap, CalledFunc->getEntryCount(), TheCall); +    // Update the profile count of callee. +    updateCalleeCount(IFI.CallerBFI, OrigBB, TheCall, CalledFunc); +      // Inject byval arguments initialization.      for (std::pair<Value*, Value*> &Init : ByValInit)        HandleByValArgumentInit(Init.first, Init.second, Caller->getParent(), @@ -2087,6 +2177,12 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI,                                            CalledFunc->getName() + ".exit");    } +  if (IFI.CallerBFI) { +    // Copy original BB's block frequency to AfterCallBB +    IFI.CallerBFI->setBlockFreq( +        AfterCallBB, IFI.CallerBFI->getBlockFreq(OrigBB).getFrequency()); +  } +    // Change the branch that used to go to AfterCallBB to branch to the first    // basic block of the inlined function.    // diff --git a/lib/Transforms/Utils/LCSSA.cpp b/lib/Transforms/Utils/LCSSA.cpp index 68c6b74d5e5b..49b4bd92faf4 100644 --- a/lib/Transforms/Utils/LCSSA.cpp +++ b/lib/Transforms/Utils/LCSSA.cpp @@ -87,7 +87,8 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist,      Instruction *I = Worklist.pop_back_val();      BasicBlock *InstBB = I->getParent();      Loop *L = LI.getLoopFor(InstBB); -    if (!LoopExitBlocks.count(L))    +    assert(L && "Instruction belongs to a BB that's not part of a loop"); +    if (!LoopExitBlocks.count(L))        L->getExitBlocks(LoopExitBlocks[L]);      assert(LoopExitBlocks.count(L));      const SmallVectorImpl<BasicBlock *> &ExitBlocks = LoopExitBlocks[L]; @@ -105,7 +106,7 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist,      for (Use &U : I->uses()) {        Instruction *User = cast<Instruction>(U.getUser());        BasicBlock *UserBB = User->getParent(); -      if (PHINode *PN = dyn_cast<PHINode>(User)) +      if (auto *PN = dyn_cast<PHINode>(User))          UserBB = PN->getIncomingBlock(U);        if (InstBB != UserBB && !L->contains(UserBB)) @@ -123,7 +124,7 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist,      // DomBB dominates the value, so adjust DomBB to the normal destination      // block, which is effectively where the value is first usable.      BasicBlock *DomBB = InstBB; -    if (InvokeInst *Inv = dyn_cast<InvokeInst>(I)) +    if (auto *Inv = dyn_cast<InvokeInst>(I))        DomBB = Inv->getNormalDest();      DomTreeNode *DomNode = DT.getNode(DomBB); @@ -188,7 +189,7 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist,        // block.        Instruction *User = cast<Instruction>(UseToRewrite->getUser());        BasicBlock *UserBB = User->getParent(); -      if (PHINode *PN = dyn_cast<PHINode>(User)) +      if (auto *PN = dyn_cast<PHINode>(User))          UserBB = PN->getIncomingBlock(*UseToRewrite);        if (isa<PHINode>(UserBB->begin()) && isExitBlock(UserBB, ExitBlocks)) { @@ -237,40 +238,75 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist,    return Changed;  } -/// Return true if the specified block dominates at least -/// one of the blocks in the specified list. -static bool -blockDominatesAnExit(BasicBlock *BB, -                     DominatorTree &DT, -                     const SmallVectorImpl<BasicBlock *> &ExitBlocks) { -  DomTreeNode *DomNode = DT.getNode(BB); -  return any_of(ExitBlocks, [&](BasicBlock *EB) { -    return DT.dominates(DomNode, DT.getNode(EB)); -  }); +// Compute the set of BasicBlocks in the loop `L` dominating at least one exit. +static void computeBlocksDominatingExits( +    Loop &L, DominatorTree &DT, SmallVector<BasicBlock *, 8> &ExitBlocks, +    SmallPtrSet<BasicBlock *, 8> &BlocksDominatingExits) { +  SmallVector<BasicBlock *, 8> BBWorklist; + +  // We start from the exit blocks, as every block trivially dominates itself +  // (not strictly). +  for (BasicBlock *BB : ExitBlocks) +    BBWorklist.push_back(BB); + +  while (!BBWorklist.empty()) { +    BasicBlock *BB = BBWorklist.pop_back_val(); + +    // Check if this is a loop header. If this is the case, we're done. +    if (L.getHeader() == BB) +      continue; + +    // Otherwise, add its immediate predecessor in the dominator tree to the +    // worklist, unless we visited it already. +    BasicBlock *IDomBB = DT.getNode(BB)->getIDom()->getBlock(); + +    // Exit blocks can have an immediate dominator not beloinging to the +    // loop. For an exit block to be immediately dominated by another block +    // outside the loop, it implies not all paths from that dominator, to the +    // exit block, go through the loop. +    // Example: +    // +    // |---- A +    // |     | +    // |     B<-- +    // |     |  | +    // |---> C -- +    //       | +    //       D +    // +    // C is the exit block of the loop and it's immediately dominated by A, +    // which doesn't belong to the loop. +    if (!L.contains(IDomBB)) +      continue; + +    if (BlocksDominatingExits.insert(IDomBB).second) +      BBWorklist.push_back(IDomBB); +  }  }  bool llvm::formLCSSA(Loop &L, DominatorTree &DT, LoopInfo *LI,                       ScalarEvolution *SE) {    bool Changed = false; -  // Get the set of exiting blocks.    SmallVector<BasicBlock *, 8> ExitBlocks;    L.getExitBlocks(ExitBlocks); -    if (ExitBlocks.empty())      return false; +  SmallPtrSet<BasicBlock *, 8> BlocksDominatingExits; + +  // We want to avoid use-scanning leveraging dominance informations. +  // If a block doesn't dominate any of the loop exits, the none of the values +  // defined in the loop can be used outside. +  // We compute the set of blocks fullfilling the conditions in advance +  // walking the dominator tree upwards until we hit a loop header. +  computeBlocksDominatingExits(L, DT, ExitBlocks, BlocksDominatingExits); +    SmallVector<Instruction *, 8> Worklist;    // Look at all the instructions in the loop, checking to see if they have uses    // outside the loop.  If so, put them into the worklist to rewrite those uses. -  for (BasicBlock *BB : L.blocks()) { -    // For large loops, avoid use-scanning by using dominance information:  In -    // particular, if a block does not dominate any of the loop exits, then none -    // of the values defined in the block could be used outside the loop. -    if (!blockDominatesAnExit(BB, DT, ExitBlocks)) -      continue; - +  for (BasicBlock *BB : BlocksDominatingExits) {      for (Instruction &I : *BB) {        // Reject two common cases fast: instructions with no uses (like stores)        // and instructions with one use that is in the same block as this. @@ -395,8 +431,8 @@ PreservedAnalyses LCSSAPass::run(Function &F, FunctionAnalysisManager &AM) {    if (!formLCSSAOnAllLoops(&LI, DT, SE))      return PreservedAnalyses::all(); -  // FIXME: This should also 'preserve the CFG'.    PreservedAnalyses PA; +  PA.preserveSet<CFGAnalyses>();    PA.preserve<BasicAA>();    PA.preserve<GlobalsAA>();    PA.preserve<SCEVAA>(); diff --git a/lib/Transforms/Utils/LibCallsShrinkWrap.cpp b/lib/Transforms/Utils/LibCallsShrinkWrap.cpp index d97cd7582eaa..fe93d6927c63 100644 --- a/lib/Transforms/Utils/LibCallsShrinkWrap.cpp +++ b/lib/Transforms/Utils/LibCallsShrinkWrap.cpp @@ -100,12 +100,12 @@ private:    bool perform(CallInst *CI);    void checkCandidate(CallInst &CI);    void shrinkWrapCI(CallInst *CI, Value *Cond); -  bool performCallDomainErrorOnly(CallInst *CI, const LibFunc::Func &Func); -  bool performCallErrors(CallInst *CI, const LibFunc::Func &Func); -  bool performCallRangeErrorOnly(CallInst *CI, const LibFunc::Func &Func); -  Value *generateOneRangeCond(CallInst *CI, const LibFunc::Func &Func); -  Value *generateTwoRangeCond(CallInst *CI, const LibFunc::Func &Func); -  Value *generateCondForPow(CallInst *CI, const LibFunc::Func &Func); +  bool performCallDomainErrorOnly(CallInst *CI, const LibFunc &Func); +  bool performCallErrors(CallInst *CI, const LibFunc &Func); +  bool performCallRangeErrorOnly(CallInst *CI, const LibFunc &Func); +  Value *generateOneRangeCond(CallInst *CI, const LibFunc &Func); +  Value *generateTwoRangeCond(CallInst *CI, const LibFunc &Func); +  Value *generateCondForPow(CallInst *CI, const LibFunc &Func);    // Create an OR of two conditions.    Value *createOrCond(CallInst *CI, CmpInst::Predicate Cmp, float Val, @@ -141,44 +141,44 @@ private:  // Perform the transformation to calls with errno set by domain error.  bool LibCallsShrinkWrap::performCallDomainErrorOnly(CallInst *CI, -                                                    const LibFunc::Func &Func) { +                                                    const LibFunc &Func) {    Value *Cond = nullptr;    switch (Func) { -  case LibFunc::acos:  // DomainError: (x < -1 || x > 1) -  case LibFunc::acosf: // Same as acos -  case LibFunc::acosl: // Same as acos -  case LibFunc::asin:  // DomainError: (x < -1 || x > 1) -  case LibFunc::asinf: // Same as asin -  case LibFunc::asinl: // Same as asin +  case LibFunc_acos:  // DomainError: (x < -1 || x > 1) +  case LibFunc_acosf: // Same as acos +  case LibFunc_acosl: // Same as acos +  case LibFunc_asin:  // DomainError: (x < -1 || x > 1) +  case LibFunc_asinf: // Same as asin +  case LibFunc_asinl: // Same as asin    {      ++NumWrappedTwoCond;      Cond = createOrCond(CI, CmpInst::FCMP_OLT, -1.0f, CmpInst::FCMP_OGT, 1.0f);      break;    } -  case LibFunc::cos:  // DomainError: (x == +inf || x == -inf) -  case LibFunc::cosf: // Same as cos -  case LibFunc::cosl: // Same as cos -  case LibFunc::sin:  // DomainError: (x == +inf || x == -inf) -  case LibFunc::sinf: // Same as sin -  case LibFunc::sinl: // Same as sin +  case LibFunc_cos:  // DomainError: (x == +inf || x == -inf) +  case LibFunc_cosf: // Same as cos +  case LibFunc_cosl: // Same as cos +  case LibFunc_sin:  // DomainError: (x == +inf || x == -inf) +  case LibFunc_sinf: // Same as sin +  case LibFunc_sinl: // Same as sin    {      ++NumWrappedTwoCond;      Cond = createOrCond(CI, CmpInst::FCMP_OEQ, INFINITY, CmpInst::FCMP_OEQ,                          -INFINITY);      break;    } -  case LibFunc::acosh:  // DomainError: (x < 1) -  case LibFunc::acoshf: // Same as acosh -  case LibFunc::acoshl: // Same as acosh +  case LibFunc_acosh:  // DomainError: (x < 1) +  case LibFunc_acoshf: // Same as acosh +  case LibFunc_acoshl: // Same as acosh    {      ++NumWrappedOneCond;      Cond = createCond(CI, CmpInst::FCMP_OLT, 1.0f);      break;    } -  case LibFunc::sqrt:  // DomainError: (x < 0) -  case LibFunc::sqrtf: // Same as sqrt -  case LibFunc::sqrtl: // Same as sqrt +  case LibFunc_sqrt:  // DomainError: (x < 0) +  case LibFunc_sqrtf: // Same as sqrt +  case LibFunc_sqrtl: // Same as sqrt    {      ++NumWrappedOneCond;      Cond = createCond(CI, CmpInst::FCMP_OLT, 0.0f); @@ -193,31 +193,31 @@ bool LibCallsShrinkWrap::performCallDomainErrorOnly(CallInst *CI,  // Perform the transformation to calls with errno set by range error.  bool LibCallsShrinkWrap::performCallRangeErrorOnly(CallInst *CI, -                                                   const LibFunc::Func &Func) { +                                                   const LibFunc &Func) {    Value *Cond = nullptr;    switch (Func) { -  case LibFunc::cosh: -  case LibFunc::coshf: -  case LibFunc::coshl: -  case LibFunc::exp: -  case LibFunc::expf: -  case LibFunc::expl: -  case LibFunc::exp10: -  case LibFunc::exp10f: -  case LibFunc::exp10l: -  case LibFunc::exp2: -  case LibFunc::exp2f: -  case LibFunc::exp2l: -  case LibFunc::sinh: -  case LibFunc::sinhf: -  case LibFunc::sinhl: { +  case LibFunc_cosh: +  case LibFunc_coshf: +  case LibFunc_coshl: +  case LibFunc_exp: +  case LibFunc_expf: +  case LibFunc_expl: +  case LibFunc_exp10: +  case LibFunc_exp10f: +  case LibFunc_exp10l: +  case LibFunc_exp2: +  case LibFunc_exp2f: +  case LibFunc_exp2l: +  case LibFunc_sinh: +  case LibFunc_sinhf: +  case LibFunc_sinhl: {      Cond = generateTwoRangeCond(CI, Func);      break;    } -  case LibFunc::expm1:  // RangeError: (709, inf) -  case LibFunc::expm1f: // RangeError: (88, inf) -  case LibFunc::expm1l: // RangeError: (11356, inf) +  case LibFunc_expm1:  // RangeError: (709, inf) +  case LibFunc_expm1f: // RangeError: (88, inf) +  case LibFunc_expm1l: // RangeError: (11356, inf)    {      Cond = generateOneRangeCond(CI, Func);      break; @@ -231,15 +231,15 @@ bool LibCallsShrinkWrap::performCallRangeErrorOnly(CallInst *CI,  // Perform the transformation to calls with errno set by combination of errors.  bool LibCallsShrinkWrap::performCallErrors(CallInst *CI, -                                           const LibFunc::Func &Func) { +                                           const LibFunc &Func) {    Value *Cond = nullptr;    switch (Func) { -  case LibFunc::atanh:  // DomainError: (x < -1 || x > 1) +  case LibFunc_atanh:  // DomainError: (x < -1 || x > 1)                          // PoleError:   (x == -1 || x == 1)                          // Overall Cond: (x <= -1 || x >= 1) -  case LibFunc::atanhf: // Same as atanh -  case LibFunc::atanhl: // Same as atanh +  case LibFunc_atanhf: // Same as atanh +  case LibFunc_atanhl: // Same as atanh    {      if (!LibCallsShrinkWrapDoDomainError || !LibCallsShrinkWrapDoPoleError)        return false; @@ -247,20 +247,20 @@ bool LibCallsShrinkWrap::performCallErrors(CallInst *CI,      Cond = createOrCond(CI, CmpInst::FCMP_OLE, -1.0f, CmpInst::FCMP_OGE, 1.0f);      break;    } -  case LibFunc::log:    // DomainError: (x < 0) +  case LibFunc_log:    // DomainError: (x < 0)                          // PoleError:   (x == 0)                          // Overall Cond: (x <= 0) -  case LibFunc::logf:   // Same as log -  case LibFunc::logl:   // Same as log -  case LibFunc::log10:  // Same as log -  case LibFunc::log10f: // Same as log -  case LibFunc::log10l: // Same as log -  case LibFunc::log2:   // Same as log -  case LibFunc::log2f:  // Same as log -  case LibFunc::log2l:  // Same as log -  case LibFunc::logb:   // Same as log -  case LibFunc::logbf:  // Same as log -  case LibFunc::logbl:  // Same as log +  case LibFunc_logf:   // Same as log +  case LibFunc_logl:   // Same as log +  case LibFunc_log10:  // Same as log +  case LibFunc_log10f: // Same as log +  case LibFunc_log10l: // Same as log +  case LibFunc_log2:   // Same as log +  case LibFunc_log2f:  // Same as log +  case LibFunc_log2l:  // Same as log +  case LibFunc_logb:   // Same as log +  case LibFunc_logbf:  // Same as log +  case LibFunc_logbl:  // Same as log    {      if (!LibCallsShrinkWrapDoDomainError || !LibCallsShrinkWrapDoPoleError)        return false; @@ -268,11 +268,11 @@ bool LibCallsShrinkWrap::performCallErrors(CallInst *CI,      Cond = createCond(CI, CmpInst::FCMP_OLE, 0.0f);      break;    } -  case LibFunc::log1p:  // DomainError: (x < -1) +  case LibFunc_log1p:  // DomainError: (x < -1)                          // PoleError:   (x == -1)                          // Overall Cond: (x <= -1) -  case LibFunc::log1pf: // Same as log1p -  case LibFunc::log1pl: // Same as log1p +  case LibFunc_log1pf: // Same as log1p +  case LibFunc_log1pl: // Same as log1p    {      if (!LibCallsShrinkWrapDoDomainError || !LibCallsShrinkWrapDoPoleError)        return false; @@ -280,11 +280,11 @@ bool LibCallsShrinkWrap::performCallErrors(CallInst *CI,      Cond = createCond(CI, CmpInst::FCMP_OLE, -1.0f);      break;    } -  case LibFunc::pow: // DomainError: x < 0 and y is noninteger +  case LibFunc_pow: // DomainError: x < 0 and y is noninteger                       // PoleError:   x == 0 and y < 0                       // RangeError:  overflow or underflow -  case LibFunc::powf: -  case LibFunc::powl: { +  case LibFunc_powf: +  case LibFunc_powl: {      if (!LibCallsShrinkWrapDoDomainError || !LibCallsShrinkWrapDoPoleError ||          !LibCallsShrinkWrapDoRangeError)        return false; @@ -313,7 +313,7 @@ void LibCallsShrinkWrap::checkCandidate(CallInst &CI) {    if (!CI.use_empty())      return; -  LibFunc::Func Func; +  LibFunc Func;    Function *Callee = CI.getCalledFunction();    if (!Callee)      return; @@ -333,16 +333,16 @@ void LibCallsShrinkWrap::checkCandidate(CallInst &CI) {  // Generate the upper bound condition for RangeError.  Value *LibCallsShrinkWrap::generateOneRangeCond(CallInst *CI, -                                                const LibFunc::Func &Func) { +                                                const LibFunc &Func) {    float UpperBound;    switch (Func) { -  case LibFunc::expm1: // RangeError: (709, inf) +  case LibFunc_expm1: // RangeError: (709, inf)      UpperBound = 709.0f;      break; -  case LibFunc::expm1f: // RangeError: (88, inf) +  case LibFunc_expm1f: // RangeError: (88, inf)      UpperBound = 88.0f;      break; -  case LibFunc::expm1l: // RangeError: (11356, inf) +  case LibFunc_expm1l: // RangeError: (11356, inf)      UpperBound = 11356.0f;      break;    default: @@ -355,57 +355,57 @@ Value *LibCallsShrinkWrap::generateOneRangeCond(CallInst *CI,  // Generate the lower and upper bound condition for RangeError.  Value *LibCallsShrinkWrap::generateTwoRangeCond(CallInst *CI, -                                                const LibFunc::Func &Func) { +                                                const LibFunc &Func) {    float UpperBound, LowerBound;    switch (Func) { -  case LibFunc::cosh: // RangeError: (x < -710 || x > 710) -  case LibFunc::sinh: // Same as cosh +  case LibFunc_cosh: // RangeError: (x < -710 || x > 710) +  case LibFunc_sinh: // Same as cosh      LowerBound = -710.0f;      UpperBound = 710.0f;      break; -  case LibFunc::coshf: // RangeError: (x < -89 || x > 89) -  case LibFunc::sinhf: // Same as coshf +  case LibFunc_coshf: // RangeError: (x < -89 || x > 89) +  case LibFunc_sinhf: // Same as coshf      LowerBound = -89.0f;      UpperBound = 89.0f;      break; -  case LibFunc::coshl: // RangeError: (x < -11357 || x > 11357) -  case LibFunc::sinhl: // Same as coshl +  case LibFunc_coshl: // RangeError: (x < -11357 || x > 11357) +  case LibFunc_sinhl: // Same as coshl      LowerBound = -11357.0f;      UpperBound = 11357.0f;      break; -  case LibFunc::exp: // RangeError: (x < -745 || x > 709) +  case LibFunc_exp: // RangeError: (x < -745 || x > 709)      LowerBound = -745.0f;      UpperBound = 709.0f;      break; -  case LibFunc::expf: // RangeError: (x < -103 || x > 88) +  case LibFunc_expf: // RangeError: (x < -103 || x > 88)      LowerBound = -103.0f;      UpperBound = 88.0f;      break; -  case LibFunc::expl: // RangeError: (x < -11399 || x > 11356) +  case LibFunc_expl: // RangeError: (x < -11399 || x > 11356)      LowerBound = -11399.0f;      UpperBound = 11356.0f;      break; -  case LibFunc::exp10: // RangeError: (x < -323 || x > 308) +  case LibFunc_exp10: // RangeError: (x < -323 || x > 308)      LowerBound = -323.0f;      UpperBound = 308.0f;      break; -  case LibFunc::exp10f: // RangeError: (x < -45 || x > 38) +  case LibFunc_exp10f: // RangeError: (x < -45 || x > 38)      LowerBound = -45.0f;      UpperBound = 38.0f;      break; -  case LibFunc::exp10l: // RangeError: (x < -4950 || x > 4932) +  case LibFunc_exp10l: // RangeError: (x < -4950 || x > 4932)      LowerBound = -4950.0f;      UpperBound = 4932.0f;      break; -  case LibFunc::exp2: // RangeError: (x < -1074 || x > 1023) +  case LibFunc_exp2: // RangeError: (x < -1074 || x > 1023)      LowerBound = -1074.0f;      UpperBound = 1023.0f;      break; -  case LibFunc::exp2f: // RangeError: (x < -149 || x > 127) +  case LibFunc_exp2f: // RangeError: (x < -149 || x > 127)      LowerBound = -149.0f;      UpperBound = 127.0f;      break; -  case LibFunc::exp2l: // RangeError: (x < -16445 || x > 11383) +  case LibFunc_exp2l: // RangeError: (x < -16445 || x > 11383)      LowerBound = -16445.0f;      UpperBound = 11383.0f;      break; @@ -434,9 +434,9 @@ Value *LibCallsShrinkWrap::generateTwoRangeCond(CallInst *CI,  // (i.e. we might invoke the calls that will not set the errno.).  //  Value *LibCallsShrinkWrap::generateCondForPow(CallInst *CI, -                                              const LibFunc::Func &Func) { -  // FIXME: LibFunc::powf and powl TBD. -  if (Func != LibFunc::pow) { +                                              const LibFunc &Func) { +  // FIXME: LibFunc_powf and powl TBD. +  if (Func != LibFunc_pow) {      DEBUG(dbgs() << "Not handled powf() and powl()\n");      return nullptr;    } @@ -516,7 +516,7 @@ void LibCallsShrinkWrap::shrinkWrapCI(CallInst *CI, Value *Cond) {  // Perform the transformation to a single candidate.  bool LibCallsShrinkWrap::perform(CallInst *CI) { -  LibFunc::Func Func; +  LibFunc Func;    Function *Callee = CI->getCalledFunction();    assert(Callee && "perform() should apply to a non-empty callee");    TLI.getLibFunc(*Callee, Func); diff --git a/lib/Transforms/Utils/Local.cpp b/lib/Transforms/Utils/Local.cpp index 6e4174aa0cda..18b29226c2ef 100644 --- a/lib/Transforms/Utils/Local.cpp +++ b/lib/Transforms/Utils/Local.cpp @@ -126,21 +126,20 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,      // If the default is unreachable, ignore it when searching for TheOnlyDest.      if (isa<UnreachableInst>(DefaultDest->getFirstNonPHIOrDbg()) &&          SI->getNumCases() > 0) { -      TheOnlyDest = SI->case_begin().getCaseSuccessor(); +      TheOnlyDest = SI->case_begin()->getCaseSuccessor();      }      // Figure out which case it goes to. -    for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); -         i != e; ++i) { +    for (auto i = SI->case_begin(), e = SI->case_end(); i != e;) {        // Found case matching a constant operand? -      if (i.getCaseValue() == CI) { -        TheOnlyDest = i.getCaseSuccessor(); +      if (i->getCaseValue() == CI) { +        TheOnlyDest = i->getCaseSuccessor();          break;        }        // Check to see if this branch is going to the same place as the default        // dest.  If so, eliminate it as an explicit compare. -      if (i.getCaseSuccessor() == DefaultDest) { +      if (i->getCaseSuccessor() == DefaultDest) {          MDNode *MD = SI->getMetadata(LLVMContext::MD_prof);          unsigned NCases = SI->getNumCases();          // Fold the case metadata into the default if there will be any branches @@ -154,7 +153,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,              Weights.push_back(CI->getValue().getZExtValue());            }            // Merge weight of this case to the default weight. -          unsigned idx = i.getCaseIndex(); +          unsigned idx = i->getCaseIndex();            Weights[0] += Weights[idx+1];            // Remove weight for this case.            std::swap(Weights[idx+1], Weights.back()); @@ -165,15 +164,19 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,          }          // Remove this entry.          DefaultDest->removePredecessor(SI->getParent()); -        SI->removeCase(i); -        --i; --e; +        i = SI->removeCase(i); +        e = SI->case_end();          continue;        }        // Otherwise, check to see if the switch only branches to one destination.        // We do this by reseting "TheOnlyDest" to null when we find two non-equal        // destinations. -      if (i.getCaseSuccessor() != TheOnlyDest) TheOnlyDest = nullptr; +      if (i->getCaseSuccessor() != TheOnlyDest) +        TheOnlyDest = nullptr; + +      // Increment this iterator as we haven't removed the case. +      ++i;      }      if (CI && !TheOnlyDest) { @@ -209,7 +212,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,      if (SI->getNumCases() == 1) {        // Otherwise, we can fold this switch into a conditional branch        // instruction if it has only one non-default destination. -      SwitchInst::CaseIt FirstCase = SI->case_begin(); +      auto FirstCase = *SI->case_begin();        Value *Cond = Builder.CreateICmpEQ(SI->getCondition(),            FirstCase.getCaseValue(), "cond"); @@ -287,7 +290,15 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,  ///  bool llvm::isInstructionTriviallyDead(Instruction *I,                                        const TargetLibraryInfo *TLI) { -  if (!I->use_empty() || isa<TerminatorInst>(I)) return false; +  if (!I->use_empty()) +    return false; +  return wouldInstructionBeTriviallyDead(I, TLI); +} + +bool llvm::wouldInstructionBeTriviallyDead(Instruction *I, +                                           const TargetLibraryInfo *TLI) { +  if (isa<TerminatorInst>(I)) +    return false;    // We don't want the landingpad-like instructions removed by anything this    // general. @@ -307,7 +318,8 @@ bool llvm::isInstructionTriviallyDead(Instruction *I,      return true;    } -  if (!I->mayHaveSideEffects()) return true; +  if (!I->mayHaveSideEffects()) +    return true;    // Special case intrinsics that "may have side effects" but can be deleted    // when dead. @@ -334,7 +346,8 @@ bool llvm::isInstructionTriviallyDead(Instruction *I,      }    } -  if (isAllocLikeFn(I, TLI)) return true; +  if (isAllocLikeFn(I, TLI)) +    return true;    if (CallInst *CI = isFreeCall(I, TLI))      if (Constant *C = dyn_cast<Constant>(CI->getArgOperand(0))) @@ -1075,11 +1088,11 @@ static bool PhiHasDebugValue(DILocalVariable *DIVar,    // Since we can't guarantee that the original dbg.declare instrinsic    // is removed by LowerDbgDeclare(), we need to make sure that we are    // not inserting the same dbg.value intrinsic over and over. -  DbgValueList DbgValues; -  FindAllocaDbgValues(DbgValues, APN); -  for (auto DVI : DbgValues) { -    assert (DVI->getValue() == APN); -    assert (DVI->getOffset() == 0); +  SmallVector<DbgValueInst *, 1> DbgValues; +  findDbgValues(DbgValues, APN); +  for (auto *DVI : DbgValues) { +    assert(DVI->getValue() == APN); +    assert(DVI->getOffset() == 0);      if ((DVI->getVariable() == DIVar) && (DVI->getExpression() == DIExpr))        return true;    } @@ -1241,9 +1254,7 @@ DbgDeclareInst *llvm::FindAllocaDbgDeclare(Value *V) {    return nullptr;  } -/// FindAllocaDbgValues - Finds the llvm.dbg.value intrinsics describing the -/// alloca 'V', if any. -void llvm::FindAllocaDbgValues(DbgValueList &DbgValues, Value *V) { +void llvm::findDbgValues(SmallVectorImpl<DbgValueInst *> &DbgValues, Value *V) {    if (auto *L = LocalAsMetadata::getIfExists(V))      if (auto *MDV = MetadataAsValue::getIfExists(V->getContext(), L))        for (User *U : MDV->users()) @@ -1251,36 +1262,32 @@ void llvm::FindAllocaDbgValues(DbgValueList &DbgValues, Value *V) {            DbgValues.push_back(DVI);  } -static void DIExprAddDeref(SmallVectorImpl<uint64_t> &Expr) { -  Expr.push_back(dwarf::DW_OP_deref); -} - -static void DIExprAddOffset(SmallVectorImpl<uint64_t> &Expr, int Offset) { +static void appendOffset(SmallVectorImpl<uint64_t> &Ops, int64_t Offset) {    if (Offset > 0) { -    Expr.push_back(dwarf::DW_OP_plus); -    Expr.push_back(Offset); +    Ops.push_back(dwarf::DW_OP_plus); +    Ops.push_back(Offset);    } else if (Offset < 0) { -    Expr.push_back(dwarf::DW_OP_minus); -    Expr.push_back(-Offset); +    Ops.push_back(dwarf::DW_OP_minus); +    Ops.push_back(-Offset);    }  } -static DIExpression *BuildReplacementDIExpr(DIBuilder &Builder, -                                            DIExpression *DIExpr, bool Deref, -                                            int Offset) { +/// Prepend \p DIExpr with a deref and offset operation. +static DIExpression *prependDIExpr(DIBuilder &Builder, DIExpression *DIExpr, +                                   bool Deref, int64_t Offset) {    if (!Deref && !Offset)      return DIExpr;    // Create a copy of the original DIDescriptor for user variable, prepending    // "deref" operation to a list of address elements, as new llvm.dbg.declare    // will take a value storing address of the memory for variable, not    // alloca itself. -  SmallVector<uint64_t, 4> NewDIExpr; +  SmallVector<uint64_t, 4> Ops;    if (Deref) -    DIExprAddDeref(NewDIExpr); -  DIExprAddOffset(NewDIExpr, Offset); +    Ops.push_back(dwarf::DW_OP_deref); +  appendOffset(Ops, Offset);    if (DIExpr) -    NewDIExpr.append(DIExpr->elements_begin(), DIExpr->elements_end()); -  return Builder.createExpression(NewDIExpr); +    Ops.append(DIExpr->elements_begin(), DIExpr->elements_end()); +  return Builder.createExpression(Ops);  }  bool llvm::replaceDbgDeclare(Value *Address, Value *NewAddress, @@ -1294,7 +1301,7 @@ bool llvm::replaceDbgDeclare(Value *Address, Value *NewAddress,    auto *DIExpr = DDI->getExpression();    assert(DIVar && "Missing variable"); -  DIExpr = BuildReplacementDIExpr(Builder, DIExpr, Deref, Offset); +  DIExpr = prependDIExpr(Builder, DIExpr, Deref, Offset);    // Insert llvm.dbg.declare immediately after the original alloca, and remove    // old llvm.dbg.declare. @@ -1326,11 +1333,11 @@ static void replaceOneDbgValueForAlloca(DbgValueInst *DVI, Value *NewAddress,    // Insert the offset immediately after the first deref.    // We could just change the offset argument of dbg.value, but it's unsigned...    if (Offset) { -    SmallVector<uint64_t, 4> NewDIExpr; -    DIExprAddDeref(NewDIExpr); -    DIExprAddOffset(NewDIExpr, Offset); -    NewDIExpr.append(DIExpr->elements_begin() + 1, DIExpr->elements_end()); -    DIExpr = Builder.createExpression(NewDIExpr); +    SmallVector<uint64_t, 4> Ops; +    Ops.push_back(dwarf::DW_OP_deref); +    appendOffset(Ops, Offset); +    Ops.append(DIExpr->elements_begin() + 1, DIExpr->elements_end()); +    DIExpr = Builder.createExpression(Ops);    }    Builder.insertDbgValueIntrinsic(NewAddress, DVI->getOffset(), DIVar, DIExpr, @@ -1349,6 +1356,53 @@ void llvm::replaceDbgValueForAlloca(AllocaInst *AI, Value *NewAllocaAddress,        }  } +void llvm::salvageDebugInfo(Instruction &I) { +  SmallVector<DbgValueInst *, 1> DbgValues; +  auto &M = *I.getModule(); + +  auto MDWrap = [&](Value *V) { +    return MetadataAsValue::get(I.getContext(), ValueAsMetadata::get(V)); +  }; + +  if (isa<BitCastInst>(&I)) { +    findDbgValues(DbgValues, &I); +    for (auto *DVI : DbgValues) { +      // Bitcasts are entirely irrelevant for debug info. Rewrite the dbg.value +      // to use the cast's source. +      DVI->setOperand(0, MDWrap(I.getOperand(0))); +      DEBUG(dbgs() << "SALVAGE: " << *DVI << '\n'); +    } +  } else if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { +    findDbgValues(DbgValues, &I); +    for (auto *DVI : DbgValues) { +      unsigned BitWidth = +          M.getDataLayout().getPointerSizeInBits(GEP->getPointerAddressSpace()); +      APInt Offset(BitWidth, 0); +      // Rewrite a constant GEP into a DIExpression. +      if (GEP->accumulateConstantOffset(M.getDataLayout(), Offset)) { +        auto *DIExpr = DVI->getExpression(); +        DIBuilder DIB(M, /*AllowUnresolved*/ false); +        // GEP offsets are i32 and thus alwaus fit into an int64_t. +        DIExpr = prependDIExpr(DIB, DIExpr, NoDeref, Offset.getSExtValue()); +        DVI->setOperand(0, MDWrap(I.getOperand(0))); +        DVI->setOperand(3, MetadataAsValue::get(I.getContext(), DIExpr)); +        DEBUG(dbgs() << "SALVAGE: " << *DVI << '\n'); +      } +    } +  } else if (isa<LoadInst>(&I)) { +    findDbgValues(DbgValues, &I); +    for (auto *DVI : DbgValues) { +      // Rewrite the load into DW_OP_deref. +      auto *DIExpr = DVI->getExpression(); +      DIBuilder DIB(M, /*AllowUnresolved*/ false); +      DIExpr = prependDIExpr(DIB, DIExpr, WithDeref, 0); +      DVI->setOperand(0, MDWrap(I.getOperand(0))); +      DVI->setOperand(3, MetadataAsValue::get(I.getContext(), DIExpr)); +      DEBUG(dbgs() << "SALVAGE:  " << *DVI << '\n'); +    } +  } +} +  unsigned llvm::removeAllNonTerminatorAndEHPadInstructions(BasicBlock *BB) {    unsigned NumDeadInst = 0;    // Delete the instructions backwards, as it has a reduced likelihood of @@ -2068,9 +2122,9 @@ bool llvm::recognizeBSwapOrBitReverseIdiom(  void llvm::maybeMarkSanitizerLibraryCallNoBuiltin(      CallInst *CI, const TargetLibraryInfo *TLI) {    Function *F = CI->getCalledFunction(); -  LibFunc::Func Func; +  LibFunc Func;    if (F && !F->hasLocalLinkage() && F->hasName() &&        TLI->getLibFunc(F->getName(), Func) && TLI->hasOptimizedCodeGen(Func) &&        !F->doesNotAccessMemory()) -    CI->addAttribute(AttributeSet::FunctionIndex, Attribute::NoBuiltin); +    CI->addAttribute(AttributeList::FunctionIndex, Attribute::NoBuiltin);  } diff --git a/lib/Transforms/Utils/LoopSimplify.cpp b/lib/Transforms/Utils/LoopSimplify.cpp index 00cda2af00c6..e7ba19665d59 100644 --- a/lib/Transforms/Utils/LoopSimplify.cpp +++ b/lib/Transforms/Utils/LoopSimplify.cpp @@ -645,14 +645,7 @@ ReprocessLoop:    // loop-invariant instructions out of the way to open up more    // opportunities, and the disadvantage of having the responsibility    // to preserve dominator information. -  bool UniqueExit = true; -  if (!ExitBlocks.empty()) -    for (unsigned i = 1, e = ExitBlocks.size(); i != e; ++i) -      if (ExitBlocks[i] != ExitBlocks[0]) { -        UniqueExit = false; -        break; -      } -  if (UniqueExit) { +  if (ExitBlockSet.size() == 1) {      for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {        BasicBlock *ExitingBlock = ExitingBlocks[i];        if (!ExitingBlock->getSinglePredecessor()) continue; @@ -735,6 +728,17 @@ bool llvm::simplifyLoop(Loop *L, DominatorTree *DT, LoopInfo *LI,                          bool PreserveLCSSA) {    bool Changed = false; +#ifndef NDEBUG +  // If we're asked to preserve LCSSA, the loop nest needs to start in LCSSA +  // form. +  if (PreserveLCSSA) { +    assert(DT && "DT not available."); +    assert(LI && "LI not available."); +    assert(L->isRecursivelyLCSSAForm(*DT, *LI) && +           "Requested to preserve LCSSA, but it's already broken."); +  } +#endif +    // Worklist maintains our depth-first queue of loops in this nest to process.    SmallVector<Loop *, 4> Worklist;    Worklist.push_back(L); @@ -814,15 +818,6 @@ bool LoopSimplify::runOnFunction(Function &F) {        &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);    bool PreserveLCSSA = mustPreserveAnalysisID(LCSSAID); -#ifndef NDEBUG -  if (PreserveLCSSA) { -    assert(DT && "DT not available."); -    assert(LI && "LI not available."); -    bool InLCSSA = all_of( -        *LI, [&](Loop *L) { return L->isRecursivelyLCSSAForm(*DT, *LI); }); -    assert(InLCSSA && "Requested to preserve LCSSA, but it's already broken."); -  } -#endif    // Simplify each loop nest in the function.    for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) @@ -846,17 +841,14 @@ PreservedAnalyses LoopSimplifyPass::run(Function &F,    ScalarEvolution *SE = AM.getCachedResult<ScalarEvolutionAnalysis>(F);    AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F); -  // FIXME: This pass should verify that the loops on which it's operating -  // are in canonical SSA form, and that the pass itself preserves this form. +  // Note that we don't preserve LCSSA in the new PM, if you need it run LCSSA +  // after simplifying the loops.    for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) -    Changed |= simplifyLoop(*I, DT, LI, SE, AC, true /* PreserveLCSSA */); - -  // FIXME: We need to invalidate this to avoid PR28400. Is there a better -  // solution? -  AM.invalidate<ScalarEvolutionAnalysis>(F); +    Changed |= simplifyLoop(*I, DT, LI, SE, AC, /*PreserveLCSSA*/ false);    if (!Changed)      return PreservedAnalyses::all(); +    PreservedAnalyses PA;    PA.preserve<DominatorTreeAnalysis>();    PA.preserve<LoopAnalysis>(); diff --git a/lib/Transforms/Utils/LoopUnroll.cpp b/lib/Transforms/Utils/LoopUnroll.cpp index e346ebd6a000..3c669ce644e2 100644 --- a/lib/Transforms/Utils/LoopUnroll.cpp +++ b/lib/Transforms/Utils/LoopUnroll.cpp @@ -27,6 +27,7 @@  #include "llvm/Analysis/ScalarEvolution.h"  #include "llvm/IR/BasicBlock.h"  #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfoMetadata.h"  #include "llvm/IR/Dominators.h"  #include "llvm/IR/IntrinsicInst.h"  #include "llvm/IR/LLVMContext.h" @@ -51,6 +52,16 @@ UnrollRuntimeEpilog("unroll-runtime-epilog", cl::init(false), cl::Hidden,                      cl::desc("Allow runtime unrolled loops to be unrolled "                               "with epilog instead of prolog.")); +static cl::opt<bool> +UnrollVerifyDomtree("unroll-verify-domtree", cl::Hidden, +                    cl::desc("Verify domtree after unrolling"), +#ifdef NDEBUG +    cl::init(false) +#else +    cl::init(true) +#endif +                    ); +  /// Convert the instruction operands from referencing the current values into  /// those specified by VMap.  static inline void remapInstruction(Instruction *I, @@ -205,6 +216,45 @@ const Loop* llvm::addClonedBlockToLoopInfo(BasicBlock *OriginalBB,    }  } +/// The function chooses which type of unroll (epilog or prolog) is more +/// profitabale. +/// Epilog unroll is more profitable when there is PHI that starts from +/// constant.  In this case epilog will leave PHI start from constant, +/// but prolog will convert it to non-constant. +/// +/// loop: +///   PN = PHI [I, Latch], [CI, PreHeader] +///   I = foo(PN) +///   ... +/// +/// Epilog unroll case. +/// loop: +///   PN = PHI [I2, Latch], [CI, PreHeader] +///   I1 = foo(PN) +///   I2 = foo(I1) +///   ... +/// Prolog unroll case. +///   NewPN = PHI [PrologI, Prolog], [CI, PreHeader] +/// loop: +///   PN = PHI [I2, Latch], [NewPN, PreHeader] +///   I1 = foo(PN) +///   I2 = foo(I1) +///   ... +/// +static bool isEpilogProfitable(Loop *L) { +  BasicBlock *PreHeader = L->getLoopPreheader(); +  BasicBlock *Header = L->getHeader(); +  assert(PreHeader && Header); +  for (Instruction &BBI : *Header) { +    PHINode *PN = dyn_cast<PHINode>(&BBI); +    if (!PN) +      break; +    if (isa<ConstantInt>(PN->getIncomingValueForBlock(PreHeader))) +      return true; +  } +  return false; +} +  /// Unroll the given loop by Count. The loop must be in LCSSA form. Returns true  /// if unrolling was successful, or false if the loop was unmodified. Unrolling  /// can only fail when the loop's latch block is not terminated by a conditional @@ -296,8 +346,10 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,      Count = TripCount;    // Don't enter the unroll code if there is nothing to do. -  if (TripCount == 0 && Count < 2 && PeelCount == 0) +  if (TripCount == 0 && Count < 2 && PeelCount == 0) { +    DEBUG(dbgs() << "Won't unroll; almost nothing to do\n");      return false; +  }    assert(Count > 0);    assert(TripMultiple > 0); @@ -330,7 +382,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,           "and peeling for the same loop");    if (PeelCount) -    peelLoop(L, PeelCount, LI, SE, DT, PreserveLCSSA); +    peelLoop(L, PeelCount, LI, SE, DT, AC, PreserveLCSSA);    // Loops containing convergent instructions must have a count that divides    // their TripMultiple. @@ -346,14 +398,22 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,                 "convergent operation.");        }); +  bool EpilogProfitability = +      UnrollRuntimeEpilog.getNumOccurrences() ? UnrollRuntimeEpilog +                                              : isEpilogProfitable(L); +    if (RuntimeTripCount && TripMultiple % Count != 0 &&        !UnrollRuntimeLoopRemainder(L, Count, AllowExpensiveTripCount, -                                  UnrollRuntimeEpilog, LI, SE, DT,  +                                  EpilogProfitability, LI, SE, DT,                                    PreserveLCSSA)) {      if (Force)        RuntimeTripCount = false; -    else +    else { +      DEBUG( +          dbgs() << "Wont unroll; remainder loop could not be generated" +                    "when assuming runtime trip count\n");        return false; +    }    }    // Notify ScalarEvolution that the loop will be substantially changed, @@ -446,6 +506,12 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,    for (Loop *SubLoop : *L)      LoopsToSimplify.insert(SubLoop); +  if (Header->getParent()->isDebugInfoForProfiling()) +    for (BasicBlock *BB : L->getBlocks()) +      for (Instruction &I : *BB) +        if (const DILocation *DIL = I.getDebugLoc()) +          I.setDebugLoc(DIL->cloneWithDuplicationFactor(Count)); +    for (unsigned It = 1; It != Count; ++It) {      std::vector<BasicBlock*> NewBlocks;      SmallDenseMap<const Loop *, Loop *, 4> NewLoops; @@ -456,19 +522,16 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,        BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It));        Header->getParent()->getBasicBlockList().push_back(New); +      assert((*BB != Header || LI->getLoopFor(*BB) == L) && +             "Header should not be in a sub-loop");        // Tell LI about New. -      if (*BB == Header) { -        assert(LI->getLoopFor(*BB) == L && "Header should not be in a sub-loop"); -        L->addBasicBlockToLoop(New, *LI); -      } else { -        const Loop *OldLoop = addClonedBlockToLoopInfo(*BB, New, LI, NewLoops); -        if (OldLoop) { -          LoopsToSimplify.insert(NewLoops[OldLoop]); +      const Loop *OldLoop = addClonedBlockToLoopInfo(*BB, New, LI, NewLoops); +      if (OldLoop) { +        LoopsToSimplify.insert(NewLoops[OldLoop]); -          // Forget the old loop, since its inputs may have changed. -          if (SE) -            SE->forgetLoop(OldLoop); -        } +        // Forget the old loop, since its inputs may have changed. +        if (SE) +          SE->forgetLoop(OldLoop);        }        if (*BB == Header) @@ -615,14 +678,11 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,        Term->eraseFromParent();      }    } +    // Update dominators of blocks we might reach through exits.    // Immediate dominator of such block might change, because we add more    // routes which can lead to the exit: we can now reach it from the copied -  // iterations too. Thus, the new idom of the block will be the nearest -  // common dominator of the previous idom and common dominator of all copies of -  // the previous idom. This is equivalent to the nearest common dominator of -  // the previous idom and the first latch, which dominates all copies of the -  // previous idom. +  // iterations too.    if (DT && Count > 1) {      for (auto *BB : OriginalLoopBlocks) {        auto *BBDomNode = DT->getNode(BB); @@ -632,12 +692,38 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,          if (!L->contains(ChildBB))            ChildrenToUpdate.push_back(ChildBB);        } -      BasicBlock *NewIDom = DT->findNearestCommonDominator(BB, Latches[0]); +      BasicBlock *NewIDom; +      if (BB == LatchBlock) { +        // The latch is special because we emit unconditional branches in +        // some cases where the original loop contained a conditional branch. +        // Since the latch is always at the bottom of the loop, if the latch +        // dominated an exit before unrolling, the new dominator of that exit +        // must also be a latch.  Specifically, the dominator is the first +        // latch which ends in a conditional branch, or the last latch if +        // there is no such latch. +        NewIDom = Latches.back(); +        for (BasicBlock *IterLatch : Latches) { +          TerminatorInst *Term = IterLatch->getTerminator(); +          if (isa<BranchInst>(Term) && cast<BranchInst>(Term)->isConditional()) { +            NewIDom = IterLatch; +            break; +          } +        } +      } else { +        // The new idom of the block will be the nearest common dominator +        // of all copies of the previous idom. This is equivalent to the +        // nearest common dominator of the previous idom and the first latch, +        // which dominates all copies of the previous idom. +        NewIDom = DT->findNearestCommonDominator(BB, LatchBlock); +      }        for (auto *ChildBB : ChildrenToUpdate)          DT->changeImmediateDominator(ChildBB, NewIDom);      }    } +  if (DT && UnrollVerifyDomtree) +    DT->verifyDomTree(); +    // Merge adjacent basic blocks, if possible.    SmallPtrSet<Loop *, 4> ForgottenLoops;    for (BasicBlock *Latch : Latches) { @@ -655,13 +741,6 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,      }    } -  // FIXME: We only preserve DT info for complete unrolling now. Incrementally -  // updating domtree after partial loop unrolling should also be easy. -  if (DT && !CompletelyUnroll) -    DT->recalculate(*L->getHeader()->getParent()); -  else if (DT) -    DEBUG(DT->verifyDomTree()); -    // Simplify any new induction variables in the partially unrolled loop.    if (SE && !CompletelyUnroll && Count > 1) {      SmallVector<WeakVH, 16> DeadInsts; @@ -721,29 +800,29 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,    // at least one layer outside of the loop that was unrolled so that any    // changes to the parent loop exposed by the unrolling are considered.    if (DT) { -    if (!OuterL && !CompletelyUnroll) -      OuterL = L;      if (OuterL) {        // OuterL includes all loops for which we can break loop-simplify, so        // it's sufficient to simplify only it (it'll recursively simplify inner        // loops too). +      if (NeedToFixLCSSA) { +        // LCSSA must be performed on the outermost affected loop. The unrolled +        // loop's last loop latch is guaranteed to be in the outermost loop +        // after LoopInfo's been updated by markAsRemoved. +        Loop *LatchLoop = LI->getLoopFor(Latches.back()); +        Loop *FixLCSSALoop = OuterL; +        if (!FixLCSSALoop->contains(LatchLoop)) +          while (FixLCSSALoop->getParentLoop() != LatchLoop) +            FixLCSSALoop = FixLCSSALoop->getParentLoop(); + +        formLCSSARecursively(*FixLCSSALoop, *DT, LI, SE); +      } else if (PreserveLCSSA) { +        assert(OuterL->isLCSSAForm(*DT) && +               "Loops should be in LCSSA form after loop-unroll."); +      } +        // TODO: That potentially might be compile-time expensive. We should try        // to fix the loop-simplified form incrementally.        simplifyLoop(OuterL, DT, LI, SE, AC, PreserveLCSSA); - -      // LCSSA must be performed on the outermost affected loop. The unrolled -      // loop's last loop latch is guaranteed to be in the outermost loop after -      // LoopInfo's been updated by markAsRemoved. -      Loop *LatchLoop = LI->getLoopFor(Latches.back()); -      if (!OuterL->contains(LatchLoop)) -        while (OuterL->getParentLoop() != LatchLoop) -          OuterL = OuterL->getParentLoop(); - -      if (NeedToFixLCSSA) -        formLCSSARecursively(*OuterL, *DT, LI, SE); -      else -        assert(OuterL->isLCSSAForm(*DT) && -               "Loops should be in LCSSA form after loop-unroll.");      } else {        // Simplify loops for which we might've broken loop-simplify form.        for (Loop *SubLoop : LoopsToSimplify) diff --git a/lib/Transforms/Utils/LoopUnrollPeel.cpp b/lib/Transforms/Utils/LoopUnrollPeel.cpp index 842cf31f2e3d..73c14f5606b7 100644 --- a/lib/Transforms/Utils/LoopUnrollPeel.cpp +++ b/lib/Transforms/Utils/LoopUnrollPeel.cpp @@ -28,6 +28,7 @@  #include "llvm/Transforms/Scalar.h"  #include "llvm/Transforms/Utils/BasicBlockUtils.h"  #include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/LoopSimplify.h"  #include "llvm/Transforms/Utils/LoopUtils.h"  #include "llvm/Transforms/Utils/UnrollLoop.h"  #include <algorithm> @@ -55,12 +56,20 @@ static bool canPeel(Loop *L) {    if (!L->getExitingBlock() || !L->getUniqueExitBlock())      return false; +  // Don't try to peel loops where the latch is not the exiting block. +  // This can be an indication of two different things: +  // 1) The loop is not rotated. +  // 2) The loop contains irreducible control flow that involves the latch. +  if (L->getLoopLatch() != L->getExitingBlock()) +    return false; +    return true;  }  // Return the number of iterations we want to peel off.  void llvm::computePeelCount(Loop *L, unsigned LoopSize, -                            TargetTransformInfo::UnrollingPreferences &UP) { +                            TargetTransformInfo::UnrollingPreferences &UP, +                            unsigned &TripCount) {    UP.PeelCount = 0;    if (!canPeel(L))      return; @@ -69,6 +78,39 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize,    if (!L->empty())      return; +  // Try to find a Phi node that has the same loop invariant as an input from +  // its only back edge. If there is such Phi, peeling 1 iteration from the +  // loop is profitable, because starting from 2nd iteration we will have an +  // invariant instead of this Phi. +  if (LoopSize <= UP.Threshold) { +    BasicBlock *BackEdge = L->getLoopLatch(); +    assert(BackEdge && "Loop is not in simplified form?"); +    BasicBlock *Header = L->getHeader(); +    // Iterate over Phis to find one with invariant input on back edge. +    bool FoundCandidate = false; +    PHINode *Phi; +    for (auto BI = Header->begin(); isa<PHINode>(&*BI); ++BI) { +      Phi = cast<PHINode>(&*BI); +      Value *Input = Phi->getIncomingValueForBlock(BackEdge); +      if (L->isLoopInvariant(Input)) { +        FoundCandidate = true; +        break; +      } +    } +    if (FoundCandidate) { +      DEBUG(dbgs() << "Peel one iteration to get rid of " << *Phi +                   << " because starting from 2nd iteration it is always" +                   << " an invariant\n"); +      UP.PeelCount = 1; +      return; +    } +  } + +  // Bail if we know the statically calculated trip count. +  // In this case we rather prefer partial unrolling. +  if (TripCount) +    return; +    // If the user provided a peel count, use that.    bool UserPeelCount = UnrollForcePeelCount.getNumOccurrences() > 0;    if (UserPeelCount) { @@ -164,7 +206,8 @@ static void cloneLoopBlocks(Loop *L, unsigned IterNumber, BasicBlock *InsertTop,                              BasicBlock *InsertBot, BasicBlock *Exit,                              SmallVectorImpl<BasicBlock *> &NewBlocks,                              LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap, -                            ValueToValueMapTy &LVMap, LoopInfo *LI) { +                            ValueToValueMapTy &LVMap, DominatorTree *DT, +                            LoopInfo *LI) {    BasicBlock *Header = L->getHeader();    BasicBlock *Latch = L->getLoopLatch(); @@ -185,6 +228,17 @@ static void cloneLoopBlocks(Loop *L, unsigned IterNumber, BasicBlock *InsertTop,        ParentLoop->addBasicBlockToLoop(NewBB, *LI);      VMap[*BB] = NewBB; + +    // If dominator tree is available, insert nodes to represent cloned blocks. +    if (DT) { +      if (Header == *BB) +        DT->addNewBlock(NewBB, InsertTop); +      else { +        DomTreeNode *IDom = DT->getNode(*BB)->getIDom(); +        // VMap must contain entry for IDom, as the iteration order is RPO. +        DT->addNewBlock(NewBB, cast<BasicBlock>(VMap[IDom->getBlock()])); +      } +    }    }    // Hook-up the control flow for the newly inserted blocks. @@ -198,11 +252,13 @@ static void cloneLoopBlocks(Loop *L, unsigned IterNumber, BasicBlock *InsertTop,    // The backedge now goes to the "bottom", which is either the loop's real    // header (for the last peeled iteration) or the copied header of the next    // iteration (for every other iteration) -  BranchInst *LatchBR = -      cast<BranchInst>(cast<BasicBlock>(VMap[Latch])->getTerminator()); +  BasicBlock *NewLatch = cast<BasicBlock>(VMap[Latch]); +  BranchInst *LatchBR = cast<BranchInst>(NewLatch->getTerminator());    unsigned HeaderIdx = (LatchBR->getSuccessor(0) == Header ? 0 : 1);    LatchBR->setSuccessor(HeaderIdx, InsertBot);    LatchBR->setSuccessor(1 - HeaderIdx, Exit); +  if (DT) +    DT->changeImmediateDominator(InsertBot, NewLatch);    // The new copy of the loop body starts with a bunch of PHI nodes    // that pick an incoming value from either the preheader, or the previous @@ -257,7 +313,7 @@ static void cloneLoopBlocks(Loop *L, unsigned IterNumber, BasicBlock *InsertTop,  /// optimizations.  bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,                      ScalarEvolution *SE, DominatorTree *DT, -                    bool PreserveLCSSA) { +                    AssumptionCache *AC, bool PreserveLCSSA) {    if (!canPeel(L))      return false; @@ -358,7 +414,24 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,        CurHeaderWeight = 1;      cloneLoopBlocks(L, Iter, InsertTop, InsertBot, Exit, -                    NewBlocks, LoopBlocks, VMap, LVMap, LI); +                    NewBlocks, LoopBlocks, VMap, LVMap, DT, LI); + +    // Remap to use values from the current iteration instead of the +    // previous one. +    remapInstructionsInBlocks(NewBlocks, VMap); + +    if (DT) { +      // Latches of the cloned loops dominate over the loop exit, so idom of the +      // latter is the first cloned loop body, as original PreHeader dominates +      // the original loop body. +      if (Iter == 0) +        DT->changeImmediateDominator(Exit, cast<BasicBlock>(LVMap[Latch])); +#ifndef NDEBUG +      if (VerifyDomInfo) +        DT->verifyDomTree(); +#endif +    } +      updateBranchWeights(InsertBot, cast<BranchInst>(VMap[LatchBR]), Iter,                          PeelCount, ExitWeight); @@ -369,10 +442,6 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,      F->getBasicBlockList().splice(InsertTop->getIterator(),                                    F->getBasicBlockList(),                                    NewBlocks[0]->getIterator(), F->end()); - -    // Remap to use values from the current iteration instead of the -    // previous one. -    remapInstructionsInBlocks(NewBlocks, VMap);    }    // Now adjust the phi nodes in the loop header to get their initial values @@ -405,9 +474,16 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,    }    // If the loop is nested, we changed the parent loop, update SE. -  if (Loop *ParentLoop = L->getParentLoop()) +  if (Loop *ParentLoop = L->getParentLoop()) {      SE->forgetLoop(ParentLoop); +    // FIXME: Incrementally update loop-simplify +    simplifyLoop(ParentLoop, DT, LI, SE, AC, PreserveLCSSA); +  } else { +    // FIXME: Incrementally update loop-simplify +    simplifyLoop(L, DT, LI, SE, AC, PreserveLCSSA); +  } +    NumPeeled++;    return true; diff --git a/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/lib/Transforms/Utils/LoopUnrollRuntime.cpp index d3ea1564115b..85db734fb182 100644 --- a/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -146,6 +146,8 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count,    // Add the branch to the exit block (around the unrolled loop)    B.CreateCondBr(BrLoopExit, Exit, NewPreHeader);    InsertPt->eraseFromParent(); +  if (DT) +    DT->changeImmediateDominator(Exit, PrologExit);  }  /// Connect the unrolling epilog code to the original loop. @@ -260,13 +262,20 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,    IRBuilder<> B(InsertPt);    Value *BrLoopExit = B.CreateIsNotNull(ModVal, "lcmp.mod");    assert(Exit && "Loop must have a single exit block only"); -  // Split the exit to maintain loop canonicalization guarantees +  // Split the epilogue exit to maintain loop canonicalization guarantees    SmallVector<BasicBlock*, 4> Preds(predecessors(Exit));    SplitBlockPredecessors(Exit, Preds, ".epilog-lcssa", DT, LI,                           PreserveLCSSA);    // Add the branch to the exit block (around the unrolling loop)    B.CreateCondBr(BrLoopExit, EpilogPreHeader, Exit);    InsertPt->eraseFromParent(); +  if (DT) +    DT->changeImmediateDominator(Exit, NewExit); + +  // Split the main loop exit to maintain canonicalization guarantees. +  SmallVector<BasicBlock*, 4> NewExitPreds{Latch}; +  SplitBlockPredecessors(NewExit, NewExitPreds, ".loopexit", DT, LI, +                         PreserveLCSSA);  }  /// Create a clone of the blocks in a loop and connect them together. @@ -284,27 +293,17 @@ static void CloneLoopBlocks(Loop *L, Value *NewIter,                              BasicBlock *Preheader,                              std::vector<BasicBlock *> &NewBlocks,                              LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap, -                            LoopInfo *LI) { +                            DominatorTree *DT, LoopInfo *LI) {    StringRef suffix = UseEpilogRemainder ? "epil" : "prol";    BasicBlock *Header = L->getHeader();    BasicBlock *Latch = L->getLoopLatch();    Function *F = Header->getParent();    LoopBlocksDFS::RPOIterator BlockBegin = LoopBlocks.beginRPO();    LoopBlocksDFS::RPOIterator BlockEnd = LoopBlocks.endRPO(); -  Loop *NewLoop = nullptr;    Loop *ParentLoop = L->getParentLoop(); -  if (CreateRemainderLoop) { -    NewLoop = new Loop(); -    if (ParentLoop) -      ParentLoop->addChildLoop(NewLoop); -    else -      LI->addTopLevelLoop(NewLoop); -  } -    NewLoopsMap NewLoops; -  if (NewLoop) -    NewLoops[L] = NewLoop; -  else if (ParentLoop) +  NewLoops[ParentLoop] = ParentLoop; +  if (!CreateRemainderLoop)      NewLoops[L] = ParentLoop;    // For each block in the original loop, create a new copy, @@ -312,7 +311,7 @@ static void CloneLoopBlocks(Loop *L, Value *NewIter,    for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {      BasicBlock *NewBB = CloneBasicBlock(*BB, VMap, "." + suffix, F);      NewBlocks.push_back(NewBB); -    +      // If we're unrolling the outermost loop, there's no remainder loop,      // and this block isn't in a nested loop, then the new block is not      // in any loop. Otherwise, add it to loopinfo. @@ -326,6 +325,17 @@ static void CloneLoopBlocks(Loop *L, Value *NewIter,        InsertTop->getTerminator()->setSuccessor(0, NewBB);      } +    if (DT) { +      if (Header == *BB) { +        // The header is dominated by the preheader. +        DT->addNewBlock(NewBB, InsertTop); +      } else { +        // Copy information from original loop to unrolled loop. +        BasicBlock *IDomBB = DT->getNode(*BB)->getIDom()->getBlock(); +        DT->addNewBlock(NewBB, cast<BasicBlock>(VMap[IDomBB])); +      } +    } +      if (Latch == *BB) {        // For the last block, if CreateRemainderLoop is false, create a direct        // jump to InsertBot. If not, create a loop back to cloned head. @@ -376,7 +386,9 @@ static void CloneLoopBlocks(Loop *L, Value *NewIter,          NewPHI->setIncomingValue(idx, V);      }    } -  if (NewLoop) { +  if (CreateRemainderLoop) { +    Loop *NewLoop = NewLoops[L]; +    assert(NewLoop && "L should have been cloned");      // Add unroll disable metadata to disable future unrolling for this loop.      SmallVector<Metadata *, 4> MDs;      // Reserve first location for self reference to the LoopID metadata node. @@ -599,6 +611,12 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count,    // Branch to either remainder (extra iterations) loop or unrolling loop.    B.CreateCondBr(BranchVal, RemainderLoop, UnrollingLoop);    PreHeaderBR->eraseFromParent(); +  if (DT) { +    if (UseEpilogRemainder) +      DT->changeImmediateDominator(NewExit, PreHeader); +    else +      DT->changeImmediateDominator(PrologExit, PreHeader); +  }    Function *F = Header->getParent();    // Get an ordered list of blocks in the loop to help with the ordering of the    // cloned blocks in the prolog/epilog code @@ -623,7 +641,7 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count,    BasicBlock *InsertBot = UseEpilogRemainder ? Exit : PrologExit;    BasicBlock *InsertTop = UseEpilogRemainder ? EpilogPreHeader : PrologPreHeader;    CloneLoopBlocks(L, ModVal, CreateRemainderLoop, UseEpilogRemainder, InsertTop, -                  InsertBot, NewPreHeader, NewBlocks, LoopBlocks, VMap, LI); +                  InsertBot, NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, LI);    // Insert the cloned blocks into the function.    F->getBasicBlockList().splice(InsertBot->getIterator(), diff --git a/lib/Transforms/Utils/LoopUtils.cpp b/lib/Transforms/Utils/LoopUtils.cpp index c8efa9efc7f3..175d013a011d 100644 --- a/lib/Transforms/Utils/LoopUtils.cpp +++ b/lib/Transforms/Utils/LoopUtils.cpp @@ -230,8 +230,9 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurrenceKind Kind,    //      - PHI:    //        - All uses of the PHI must be the reduction (safe).    //        - Otherwise, not safe. -  //  - By one instruction outside of the loop (safe). -  //  - By further instructions outside of the loop (not safe). +  //  - By instructions outside of the loop (safe). +  //      * One value may have several outside users, but all outside +  //        uses must be of the same value.    //  - By an instruction that is not part of the reduction (not safe).    //    This is either:    //      * An instruction type other than PHI or the reduction operation. @@ -297,10 +298,15 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurrenceKind Kind,        // Check if we found the exit user.        BasicBlock *Parent = UI->getParent();        if (!TheLoop->contains(Parent)) { -        // Exit if you find multiple outside users or if the header phi node is -        // being used. In this case the user uses the value of the previous -        // iteration, in which case we would loose "VF-1" iterations of the -        // reduction operation if we vectorize. +        // If we already know this instruction is used externally, move on to +        // the next user. +        if (ExitInstruction == Cur) +          continue; + +        // Exit if you find multiple values used outside or if the header phi +        // node is being used. In this case the user uses the value of the +        // previous iteration, in which case we would loose "VF-1" iterations of +        // the reduction operation if we vectorize.          if (ExitInstruction != nullptr || Cur == Phi)            return false; @@ -547,13 +553,14 @@ bool RecurrenceDescriptor::isFirstOrderRecurrence(PHINode *Phi, Loop *TheLoop,    if (!Previous || !TheLoop->contains(Previous) || isa<PHINode>(Previous))      return false; -  // Ensure every user of the phi node is dominated by the previous value. The -  // dominance requirement ensures the loop vectorizer will not need to +  // Ensure every user of the phi node is dominated by the previous value. +  // The dominance requirement ensures the loop vectorizer will not need to    // vectorize the initial value prior to the first iteration of the loop.    for (User *U : Phi->users()) -    if (auto *I = dyn_cast<Instruction>(U)) +    if (auto *I = dyn_cast<Instruction>(U)) {        if (!DT->dominates(Previous, I))          return false; +    }    return true;  } diff --git a/lib/Transforms/Utils/LowerMemIntrinsics.cpp b/lib/Transforms/Utils/LowerMemIntrinsics.cpp new file mode 100644 index 000000000000..c7cb561b5e21 --- /dev/null +++ b/lib/Transforms/Utils/LowerMemIntrinsics.cpp @@ -0,0 +1,231 @@ +//===- LowerMemIntrinsics.cpp ----------------------------------*- C++ -*--===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/LowerMemIntrinsics.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IRBuilder.h" + +using namespace llvm; + +void llvm::createMemCpyLoop(Instruction *InsertBefore, +                            Value *SrcAddr, Value *DstAddr, Value *CopyLen, +                            unsigned SrcAlign, unsigned DestAlign, +                            bool SrcIsVolatile, bool DstIsVolatile) { +  Type *TypeOfCopyLen = CopyLen->getType(); + +  BasicBlock *OrigBB = InsertBefore->getParent(); +  Function *F = OrigBB->getParent(); +  BasicBlock *NewBB = +    InsertBefore->getParent()->splitBasicBlock(InsertBefore, "split"); +  BasicBlock *LoopBB = BasicBlock::Create(F->getContext(), "loadstoreloop", +                                          F, NewBB); + +  OrigBB->getTerminator()->setSuccessor(0, LoopBB); +  IRBuilder<> Builder(OrigBB->getTerminator()); + +  // SrcAddr and DstAddr are expected to be pointer types, +  // so no check is made here. +  unsigned SrcAS = cast<PointerType>(SrcAddr->getType())->getAddressSpace(); +  unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace(); + +  // Cast pointers to (char *) +  SrcAddr = Builder.CreateBitCast(SrcAddr, Builder.getInt8PtrTy(SrcAS)); +  DstAddr = Builder.CreateBitCast(DstAddr, Builder.getInt8PtrTy(DstAS)); + +  IRBuilder<> LoopBuilder(LoopBB); +  PHINode *LoopIndex = LoopBuilder.CreatePHI(TypeOfCopyLen, 0); +  LoopIndex->addIncoming(ConstantInt::get(TypeOfCopyLen, 0), OrigBB); + +  // load from SrcAddr+LoopIndex +  // TODO: we can leverage the align parameter of llvm.memcpy for more efficient +  // word-sized loads and stores. +  Value *Element = +    LoopBuilder.CreateLoad(LoopBuilder.CreateInBoundsGEP( +                             LoopBuilder.getInt8Ty(), SrcAddr, LoopIndex), +                           SrcIsVolatile); +  // store at DstAddr+LoopIndex +  LoopBuilder.CreateStore(Element, +                          LoopBuilder.CreateInBoundsGEP(LoopBuilder.getInt8Ty(), +                                                        DstAddr, LoopIndex), +                          DstIsVolatile); + +  // The value for LoopIndex coming from backedge is (LoopIndex + 1) +  Value *NewIndex = +    LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(TypeOfCopyLen, 1)); +  LoopIndex->addIncoming(NewIndex, LoopBB); + +  LoopBuilder.CreateCondBr(LoopBuilder.CreateICmpULT(NewIndex, CopyLen), LoopBB, +                           NewBB); +} + +// Lower memmove to IR. memmove is required to correctly copy overlapping memory +// regions; therefore, it has to check the relative positions of the source and +// destination pointers and choose the copy direction accordingly. +// +// The code below is an IR rendition of this C function: +// +// void* memmove(void* dst, const void* src, size_t n) { +//   unsigned char* d = dst; +//   const unsigned char* s = src; +//   if (s < d) { +//     // copy backwards +//     while (n--) { +//       d[n] = s[n]; +//     } +//   } else { +//     // copy forward +//     for (size_t i = 0; i < n; ++i) { +//       d[i] = s[i]; +//     } +//   } +//   return dst; +// } +static void createMemMoveLoop(Instruction *InsertBefore, +                              Value *SrcAddr, Value *DstAddr, Value *CopyLen, +                              unsigned SrcAlign, unsigned DestAlign, +                              bool SrcIsVolatile, bool DstIsVolatile) { +  Type *TypeOfCopyLen = CopyLen->getType(); +  BasicBlock *OrigBB = InsertBefore->getParent(); +  Function *F = OrigBB->getParent(); + +  // Create the a comparison of src and dst, based on which we jump to either +  // the forward-copy part of the function (if src >= dst) or the backwards-copy +  // part (if src < dst). +  // SplitBlockAndInsertIfThenElse conveniently creates the basic if-then-else +  // structure. Its block terminators (unconditional branches) are replaced by +  // the appropriate conditional branches when the loop is built. +  ICmpInst *PtrCompare = new ICmpInst(InsertBefore, ICmpInst::ICMP_ULT, +                                      SrcAddr, DstAddr, "compare_src_dst"); +  TerminatorInst *ThenTerm, *ElseTerm; +  SplitBlockAndInsertIfThenElse(PtrCompare, InsertBefore, &ThenTerm, +                                &ElseTerm); + +  // Each part of the function consists of two blocks: +  //   copy_backwards:        used to skip the loop when n == 0 +  //   copy_backwards_loop:   the actual backwards loop BB +  //   copy_forward:          used to skip the loop when n == 0 +  //   copy_forward_loop:     the actual forward loop BB +  BasicBlock *CopyBackwardsBB = ThenTerm->getParent(); +  CopyBackwardsBB->setName("copy_backwards"); +  BasicBlock *CopyForwardBB = ElseTerm->getParent(); +  CopyForwardBB->setName("copy_forward"); +  BasicBlock *ExitBB = InsertBefore->getParent(); +  ExitBB->setName("memmove_done"); + +  // Initial comparison of n == 0 that lets us skip the loops altogether. Shared +  // between both backwards and forward copy clauses. +  ICmpInst *CompareN = +      new ICmpInst(OrigBB->getTerminator(), ICmpInst::ICMP_EQ, CopyLen, +                   ConstantInt::get(TypeOfCopyLen, 0), "compare_n_to_0"); + +  // Copying backwards. +  BasicBlock *LoopBB = +    BasicBlock::Create(F->getContext(), "copy_backwards_loop", F, CopyForwardBB); +  IRBuilder<> LoopBuilder(LoopBB); +  PHINode *LoopPhi = LoopBuilder.CreatePHI(TypeOfCopyLen, 0); +  Value *IndexPtr = LoopBuilder.CreateSub( +      LoopPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_ptr"); +  Value *Element = LoopBuilder.CreateLoad( +      LoopBuilder.CreateInBoundsGEP(SrcAddr, IndexPtr), "element"); +  LoopBuilder.CreateStore(Element, +                          LoopBuilder.CreateInBoundsGEP(DstAddr, IndexPtr)); +  LoopBuilder.CreateCondBr( +      LoopBuilder.CreateICmpEQ(IndexPtr, ConstantInt::get(TypeOfCopyLen, 0)), +      ExitBB, LoopBB); +  LoopPhi->addIncoming(IndexPtr, LoopBB); +  LoopPhi->addIncoming(CopyLen, CopyBackwardsBB); +  BranchInst::Create(ExitBB, LoopBB, CompareN, ThenTerm); +  ThenTerm->eraseFromParent(); + +  // Copying forward. +  BasicBlock *FwdLoopBB = +    BasicBlock::Create(F->getContext(), "copy_forward_loop", F, ExitBB); +  IRBuilder<> FwdLoopBuilder(FwdLoopBB); +  PHINode *FwdCopyPhi = FwdLoopBuilder.CreatePHI(TypeOfCopyLen, 0, "index_ptr"); +  Value *FwdElement = FwdLoopBuilder.CreateLoad( +      FwdLoopBuilder.CreateInBoundsGEP(SrcAddr, FwdCopyPhi), "element"); +  FwdLoopBuilder.CreateStore( +      FwdElement, FwdLoopBuilder.CreateInBoundsGEP(DstAddr, FwdCopyPhi)); +  Value *FwdIndexPtr = FwdLoopBuilder.CreateAdd( +      FwdCopyPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_increment"); +  FwdLoopBuilder.CreateCondBr(FwdLoopBuilder.CreateICmpEQ(FwdIndexPtr, CopyLen), +                              ExitBB, FwdLoopBB); +  FwdCopyPhi->addIncoming(FwdIndexPtr, FwdLoopBB); +  FwdCopyPhi->addIncoming(ConstantInt::get(TypeOfCopyLen, 0), CopyForwardBB); + +  BranchInst::Create(ExitBB, FwdLoopBB, CompareN, ElseTerm); +  ElseTerm->eraseFromParent(); +} + +static void createMemSetLoop(Instruction *InsertBefore, +                             Value *DstAddr, Value *CopyLen, Value *SetValue, +                             unsigned Align, bool IsVolatile) { +  BasicBlock *OrigBB = InsertBefore->getParent(); +  Function *F = OrigBB->getParent(); +  BasicBlock *NewBB = +      OrigBB->splitBasicBlock(InsertBefore, "split"); +  BasicBlock *LoopBB +    = BasicBlock::Create(F->getContext(), "loadstoreloop", F, NewBB); + +  OrigBB->getTerminator()->setSuccessor(0, LoopBB); +  IRBuilder<> Builder(OrigBB->getTerminator()); + +  // Cast pointer to the type of value getting stored +  unsigned dstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace(); +  DstAddr = Builder.CreateBitCast(DstAddr, +                                  PointerType::get(SetValue->getType(), dstAS)); + +  IRBuilder<> LoopBuilder(LoopBB); +  PHINode *LoopIndex = LoopBuilder.CreatePHI(CopyLen->getType(), 0); +  LoopIndex->addIncoming(ConstantInt::get(CopyLen->getType(), 0), OrigBB); + +  LoopBuilder.CreateStore( +      SetValue, +      LoopBuilder.CreateInBoundsGEP(SetValue->getType(), DstAddr, LoopIndex), +      IsVolatile); + +  Value *NewIndex = +      LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(CopyLen->getType(), 1)); +  LoopIndex->addIncoming(NewIndex, LoopBB); + +  LoopBuilder.CreateCondBr(LoopBuilder.CreateICmpULT(NewIndex, CopyLen), LoopBB, +                           NewBB); +} + +void llvm::expandMemCpyAsLoop(MemCpyInst *Memcpy) { +  createMemCpyLoop(/* InsertBefore */ Memcpy, +                   /* SrcAddr */ Memcpy->getRawSource(), +                   /* DstAddr */ Memcpy->getRawDest(), +                   /* CopyLen */ Memcpy->getLength(), +                   /* SrcAlign */ Memcpy->getAlignment(), +                   /* DestAlign */ Memcpy->getAlignment(), +                   /* SrcIsVolatile */ Memcpy->isVolatile(), +                   /* DstIsVolatile */ Memcpy->isVolatile()); +} + +void llvm::expandMemMoveAsLoop(MemMoveInst *Memmove) { +  createMemMoveLoop(/* InsertBefore */ Memmove, +                    /* SrcAddr */ Memmove->getRawSource(), +                    /* DstAddr */ Memmove->getRawDest(), +                    /* CopyLen */ Memmove->getLength(), +                    /* SrcAlign */ Memmove->getAlignment(), +                    /* DestAlign */ Memmove->getAlignment(), +                    /* SrcIsVolatile */ Memmove->isVolatile(), +                    /* DstIsVolatile */ Memmove->isVolatile()); +} + +void llvm::expandMemSetAsLoop(MemSetInst *Memset) { +  createMemSetLoop(/* InsertBefore */ Memset, +                   /* DstAddr */ Memset->getRawDest(), +                   /* CopyLen */ Memset->getLength(), +                   /* SetValue */ Memset->getValue(), +                   /* Alignment */ Memset->getAlignment(), +                   Memset->isVolatile()); +} diff --git a/lib/Transforms/Utils/LowerSwitch.cpp b/lib/Transforms/Utils/LowerSwitch.cpp index 75cd3bc8b2bf..b375d51005d5 100644 --- a/lib/Transforms/Utils/LowerSwitch.cpp +++ b/lib/Transforms/Utils/LowerSwitch.cpp @@ -356,10 +356,10 @@ unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) {    unsigned numCmps = 0;    // Start with "simple" cases -  for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; ++i) -    Cases.push_back(CaseRange(i.getCaseValue(), i.getCaseValue(), -                              i.getCaseSuccessor())); -   +  for (auto Case : SI->cases()) +    Cases.push_back(CaseRange(Case.getCaseValue(), Case.getCaseValue(), +                              Case.getCaseSuccessor())); +    std::sort(Cases.begin(), Cases.end(), CaseCmp());    // Merge case into clusters diff --git a/lib/Transforms/Utils/Mem2Reg.cpp b/lib/Transforms/Utils/Mem2Reg.cpp index 24b3b12930ac..b659a2e4463f 100644 --- a/lib/Transforms/Utils/Mem2Reg.cpp +++ b/lib/Transforms/Utils/Mem2Reg.cpp @@ -46,7 +46,7 @@ static bool promoteMemoryToRegister(Function &F, DominatorTree &DT,      if (Allocas.empty())        break; -    PromoteMemToReg(Allocas, DT, nullptr, &AC); +    PromoteMemToReg(Allocas, DT, &AC);      NumPromoted += Allocas.size();      Changed = true;    } @@ -59,8 +59,9 @@ PreservedAnalyses PromotePass::run(Function &F, FunctionAnalysisManager &AM) {    if (!promoteMemoryToRegister(F, DT, AC))      return PreservedAnalyses::all(); -  // FIXME: This should also 'preserve the CFG'. -  return PreservedAnalyses::none(); +  PreservedAnalyses PA; +  PA.preserveSet<CFGAnalyses>(); +  return PA;  }  namespace { diff --git a/lib/Transforms/Utils/MemorySSA.cpp b/lib/Transforms/Utils/MemorySSA.cpp deleted file mode 100644 index 1ce4225f09cc..000000000000 --- a/lib/Transforms/Utils/MemorySSA.cpp +++ /dev/null @@ -1,2305 +0,0 @@ -//===-- MemorySSA.cpp - Memory SSA Builder---------------------------===// -// -//                     The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------===// -// -// This file implements the MemorySSA class. -// -//===----------------------------------------------------------------===// -#include "llvm/Transforms/Utils/MemorySSA.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/DepthFirstIterator.h" -#include "llvm/ADT/GraphTraits.h" -#include "llvm/ADT/PostOrderIterator.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallBitVector.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/CFG.h" -#include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/Analysis/IteratedDominanceFrontier.h" -#include "llvm/Analysis/MemoryLocation.h" -#include "llvm/Analysis/PHITransAddr.h" -#include "llvm/IR/AssemblyAnnotationWriter.h" -#include "llvm/IR/DataLayout.h" -#include "llvm/IR/Dominators.h" -#include "llvm/IR/GlobalVariable.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Metadata.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/PatternMatch.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/FormattedStream.h" -#include "llvm/Transforms/Scalar.h" -#include <algorithm> - -#define DEBUG_TYPE "memoryssa" -using namespace llvm; -STATISTIC(NumClobberCacheLookups, "Number of Memory SSA version cache lookups"); -STATISTIC(NumClobberCacheHits, "Number of Memory SSA version cache hits"); -STATISTIC(NumClobberCacheInserts, "Number of MemorySSA version cache inserts"); - -INITIALIZE_PASS_BEGIN(MemorySSAWrapperPass, "memoryssa", "Memory SSA", false, -                      true) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_END(MemorySSAWrapperPass, "memoryssa", "Memory SSA", false, -                    true) - -INITIALIZE_PASS_BEGIN(MemorySSAPrinterLegacyPass, "print-memoryssa", -                      "Memory SSA Printer", false, false) -INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) -INITIALIZE_PASS_END(MemorySSAPrinterLegacyPass, "print-memoryssa", -                    "Memory SSA Printer", false, false) - -static cl::opt<unsigned> MaxCheckLimit( -    "memssa-check-limit", cl::Hidden, cl::init(100), -    cl::desc("The maximum number of stores/phis MemorySSA" -             "will consider trying to walk past (default = 100)")); - -static cl::opt<bool> -    VerifyMemorySSA("verify-memoryssa", cl::init(false), cl::Hidden, -                    cl::desc("Verify MemorySSA in legacy printer pass.")); - -namespace llvm { -/// \brief An assembly annotator class to print Memory SSA information in -/// comments. -class MemorySSAAnnotatedWriter : public AssemblyAnnotationWriter { -  friend class MemorySSA; -  const MemorySSA *MSSA; - -public: -  MemorySSAAnnotatedWriter(const MemorySSA *M) : MSSA(M) {} - -  virtual void emitBasicBlockStartAnnot(const BasicBlock *BB, -                                        formatted_raw_ostream &OS) { -    if (MemoryAccess *MA = MSSA->getMemoryAccess(BB)) -      OS << "; " << *MA << "\n"; -  } - -  virtual void emitInstructionAnnot(const Instruction *I, -                                    formatted_raw_ostream &OS) { -    if (MemoryAccess *MA = MSSA->getMemoryAccess(I)) -      OS << "; " << *MA << "\n"; -  } -}; -} - -namespace { -/// Our current alias analysis API differentiates heavily between calls and -/// non-calls, and functions called on one usually assert on the other. -/// This class encapsulates the distinction to simplify other code that wants -/// "Memory affecting instructions and related data" to use as a key. -/// For example, this class is used as a densemap key in the use optimizer. -class MemoryLocOrCall { -public: -  MemoryLocOrCall() : IsCall(false) {} -  MemoryLocOrCall(MemoryUseOrDef *MUD) -      : MemoryLocOrCall(MUD->getMemoryInst()) {} -  MemoryLocOrCall(const MemoryUseOrDef *MUD) -      : MemoryLocOrCall(MUD->getMemoryInst()) {} - -  MemoryLocOrCall(Instruction *Inst) { -    if (ImmutableCallSite(Inst)) { -      IsCall = true; -      CS = ImmutableCallSite(Inst); -    } else { -      IsCall = false; -      // There is no such thing as a memorylocation for a fence inst, and it is -      // unique in that regard. -      if (!isa<FenceInst>(Inst)) -        Loc = MemoryLocation::get(Inst); -    } -  } - -  explicit MemoryLocOrCall(const MemoryLocation &Loc) -      : IsCall(false), Loc(Loc) {} - -  bool IsCall; -  ImmutableCallSite getCS() const { -    assert(IsCall); -    return CS; -  } -  MemoryLocation getLoc() const { -    assert(!IsCall); -    return Loc; -  } - -  bool operator==(const MemoryLocOrCall &Other) const { -    if (IsCall != Other.IsCall) -      return false; - -    if (IsCall) -      return CS.getCalledValue() == Other.CS.getCalledValue(); -    return Loc == Other.Loc; -  } - -private: -  union { -      ImmutableCallSite CS; -      MemoryLocation Loc; -  }; -}; -} - -namespace llvm { -template <> struct DenseMapInfo<MemoryLocOrCall> { -  static inline MemoryLocOrCall getEmptyKey() { -    return MemoryLocOrCall(DenseMapInfo<MemoryLocation>::getEmptyKey()); -  } -  static inline MemoryLocOrCall getTombstoneKey() { -    return MemoryLocOrCall(DenseMapInfo<MemoryLocation>::getTombstoneKey()); -  } -  static unsigned getHashValue(const MemoryLocOrCall &MLOC) { -    if (MLOC.IsCall) -      return hash_combine(MLOC.IsCall, -                          DenseMapInfo<const Value *>::getHashValue( -                              MLOC.getCS().getCalledValue())); -    return hash_combine( -        MLOC.IsCall, DenseMapInfo<MemoryLocation>::getHashValue(MLOC.getLoc())); -  } -  static bool isEqual(const MemoryLocOrCall &LHS, const MemoryLocOrCall &RHS) { -    return LHS == RHS; -  } -}; - -enum class Reorderability { Always, IfNoAlias, Never }; - -/// This does one-way checks to see if Use could theoretically be hoisted above -/// MayClobber. This will not check the other way around. -/// -/// This assumes that, for the purposes of MemorySSA, Use comes directly after -/// MayClobber, with no potentially clobbering operations in between them. -/// (Where potentially clobbering ops are memory barriers, aliased stores, etc.) -static Reorderability getLoadReorderability(const LoadInst *Use, -                                            const LoadInst *MayClobber) { -  bool VolatileUse = Use->isVolatile(); -  bool VolatileClobber = MayClobber->isVolatile(); -  // Volatile operations may never be reordered with other volatile operations. -  if (VolatileUse && VolatileClobber) -    return Reorderability::Never; - -  // The lang ref allows reordering of volatile and non-volatile operations. -  // Whether an aliasing nonvolatile load and volatile load can be reordered, -  // though, is ambiguous. Because it may not be best to exploit this ambiguity, -  // we only allow volatile/non-volatile reordering if the volatile and -  // non-volatile operations don't alias. -  Reorderability Result = VolatileUse || VolatileClobber -                              ? Reorderability::IfNoAlias -                              : Reorderability::Always; - -  // If a load is seq_cst, it cannot be moved above other loads. If its ordering -  // is weaker, it can be moved above other loads. We just need to be sure that -  // MayClobber isn't an acquire load, because loads can't be moved above -  // acquire loads. -  // -  // Note that this explicitly *does* allow the free reordering of monotonic (or -  // weaker) loads of the same address. -  bool SeqCstUse = Use->getOrdering() == AtomicOrdering::SequentiallyConsistent; -  bool MayClobberIsAcquire = isAtLeastOrStrongerThan(MayClobber->getOrdering(), -                                                     AtomicOrdering::Acquire); -  if (SeqCstUse || MayClobberIsAcquire) -    return Reorderability::Never; -  return Result; -} - -static bool instructionClobbersQuery(MemoryDef *MD, -                                     const MemoryLocation &UseLoc, -                                     const Instruction *UseInst, -                                     AliasAnalysis &AA) { -  Instruction *DefInst = MD->getMemoryInst(); -  assert(DefInst && "Defining instruction not actually an instruction"); - -  if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(DefInst)) { -    // These intrinsics will show up as affecting memory, but they are just -    // markers. -    switch (II->getIntrinsicID()) { -    case Intrinsic::lifetime_start: -    case Intrinsic::lifetime_end: -    case Intrinsic::invariant_start: -    case Intrinsic::invariant_end: -    case Intrinsic::assume: -      return false; -    default: -      break; -    } -  } - -  ImmutableCallSite UseCS(UseInst); -  if (UseCS) { -    ModRefInfo I = AA.getModRefInfo(DefInst, UseCS); -    return I != MRI_NoModRef; -  } - -  if (auto *DefLoad = dyn_cast<LoadInst>(DefInst)) { -    if (auto *UseLoad = dyn_cast<LoadInst>(UseInst)) { -      switch (getLoadReorderability(UseLoad, DefLoad)) { -      case Reorderability::Always: -        return false; -      case Reorderability::Never: -        return true; -      case Reorderability::IfNoAlias: -        return !AA.isNoAlias(UseLoc, MemoryLocation::get(DefLoad)); -      } -    } -  } - -  return AA.getModRefInfo(DefInst, UseLoc) & MRI_Mod; -} - -static bool instructionClobbersQuery(MemoryDef *MD, const MemoryUseOrDef *MU, -                                     const MemoryLocOrCall &UseMLOC, -                                     AliasAnalysis &AA) { -  // FIXME: This is a temporary hack to allow a single instructionClobbersQuery -  // to exist while MemoryLocOrCall is pushed through places. -  if (UseMLOC.IsCall) -    return instructionClobbersQuery(MD, MemoryLocation(), MU->getMemoryInst(), -                                    AA); -  return instructionClobbersQuery(MD, UseMLOC.getLoc(), MU->getMemoryInst(), -                                  AA); -} - -// Return true when MD may alias MU, return false otherwise. -bool defClobbersUseOrDef(MemoryDef *MD, const MemoryUseOrDef *MU, -                         AliasAnalysis &AA) { -  return instructionClobbersQuery(MD, MU, MemoryLocOrCall(MU), AA); -} -} - -namespace { -struct UpwardsMemoryQuery { -  // True if our original query started off as a call -  bool IsCall; -  // The pointer location we started the query with. This will be empty if -  // IsCall is true. -  MemoryLocation StartingLoc; -  // This is the instruction we were querying about. -  const Instruction *Inst; -  // The MemoryAccess we actually got called with, used to test local domination -  const MemoryAccess *OriginalAccess; - -  UpwardsMemoryQuery() -      : IsCall(false), Inst(nullptr), OriginalAccess(nullptr) {} - -  UpwardsMemoryQuery(const Instruction *Inst, const MemoryAccess *Access) -      : IsCall(ImmutableCallSite(Inst)), Inst(Inst), OriginalAccess(Access) { -    if (!IsCall) -      StartingLoc = MemoryLocation::get(Inst); -  } -}; - -static bool lifetimeEndsAt(MemoryDef *MD, const MemoryLocation &Loc, -                           AliasAnalysis &AA) { -  Instruction *Inst = MD->getMemoryInst(); -  if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { -    switch (II->getIntrinsicID()) { -    case Intrinsic::lifetime_start: -    case Intrinsic::lifetime_end: -      return AA.isMustAlias(MemoryLocation(II->getArgOperand(1)), Loc); -    default: -      return false; -    } -  } -  return false; -} - -static bool isUseTriviallyOptimizableToLiveOnEntry(AliasAnalysis &AA, -                                                   const Instruction *I) { -  // If the memory can't be changed, then loads of the memory can't be -  // clobbered. -  // -  // FIXME: We should handle invariant groups, as well. It's a bit harder, -  // because we need to pay close attention to invariant group barriers. -  return isa<LoadInst>(I) && (I->getMetadata(LLVMContext::MD_invariant_load) || -                              AA.pointsToConstantMemory(I)); -} - -/// Cache for our caching MemorySSA walker. -class WalkerCache { -  DenseMap<ConstMemoryAccessPair, MemoryAccess *> Accesses; -  DenseMap<const MemoryAccess *, MemoryAccess *> Calls; - -public: -  MemoryAccess *lookup(const MemoryAccess *MA, const MemoryLocation &Loc, -                       bool IsCall) const { -    ++NumClobberCacheLookups; -    MemoryAccess *R = IsCall ? Calls.lookup(MA) : Accesses.lookup({MA, Loc}); -    if (R) -      ++NumClobberCacheHits; -    return R; -  } - -  bool insert(const MemoryAccess *MA, MemoryAccess *To, -              const MemoryLocation &Loc, bool IsCall) { -    // This is fine for Phis, since there are times where we can't optimize -    // them.  Making a def its own clobber is never correct, though. -    assert((MA != To || isa<MemoryPhi>(MA)) && -           "Something can't clobber itself!"); - -    ++NumClobberCacheInserts; -    bool Inserted; -    if (IsCall) -      Inserted = Calls.insert({MA, To}).second; -    else -      Inserted = Accesses.insert({{MA, Loc}, To}).second; - -    return Inserted; -  } - -  bool remove(const MemoryAccess *MA, const MemoryLocation &Loc, bool IsCall) { -    return IsCall ? Calls.erase(MA) : Accesses.erase({MA, Loc}); -  } - -  void clear() { -    Accesses.clear(); -    Calls.clear(); -  } - -  bool contains(const MemoryAccess *MA) const { -    for (auto &P : Accesses) -      if (P.first.first == MA || P.second == MA) -        return true; -    for (auto &P : Calls) -      if (P.first == MA || P.second == MA) -        return true; -    return false; -  } -}; - -/// Walks the defining uses of MemoryDefs. Stops after we hit something that has -/// no defining use (e.g. a MemoryPhi or liveOnEntry). Note that, when comparing -/// against a null def_chain_iterator, this will compare equal only after -/// walking said Phi/liveOnEntry. -struct def_chain_iterator -    : public iterator_facade_base<def_chain_iterator, std::forward_iterator_tag, -                                  MemoryAccess *> { -  def_chain_iterator() : MA(nullptr) {} -  def_chain_iterator(MemoryAccess *MA) : MA(MA) {} - -  MemoryAccess *operator*() const { return MA; } - -  def_chain_iterator &operator++() { -    // N.B. liveOnEntry has a null defining access. -    if (auto *MUD = dyn_cast<MemoryUseOrDef>(MA)) -      MA = MUD->getDefiningAccess(); -    else -      MA = nullptr; -    return *this; -  } - -  bool operator==(const def_chain_iterator &O) const { return MA == O.MA; } - -private: -  MemoryAccess *MA; -}; - -static iterator_range<def_chain_iterator> -def_chain(MemoryAccess *MA, MemoryAccess *UpTo = nullptr) { -#ifdef EXPENSIVE_CHECKS -  assert((!UpTo || find(def_chain(MA), UpTo) != def_chain_iterator()) && -         "UpTo isn't in the def chain!"); -#endif -  return make_range(def_chain_iterator(MA), def_chain_iterator(UpTo)); -} - -/// Verifies that `Start` is clobbered by `ClobberAt`, and that nothing -/// inbetween `Start` and `ClobberAt` can clobbers `Start`. -/// -/// This is meant to be as simple and self-contained as possible. Because it -/// uses no cache, etc., it can be relatively expensive. -/// -/// \param Start     The MemoryAccess that we want to walk from. -/// \param ClobberAt A clobber for Start. -/// \param StartLoc  The MemoryLocation for Start. -/// \param MSSA      The MemorySSA isntance that Start and ClobberAt belong to. -/// \param Query     The UpwardsMemoryQuery we used for our search. -/// \param AA        The AliasAnalysis we used for our search. -static void LLVM_ATTRIBUTE_UNUSED -checkClobberSanity(MemoryAccess *Start, MemoryAccess *ClobberAt, -                   const MemoryLocation &StartLoc, const MemorySSA &MSSA, -                   const UpwardsMemoryQuery &Query, AliasAnalysis &AA) { -  assert(MSSA.dominates(ClobberAt, Start) && "Clobber doesn't dominate start?"); - -  if (MSSA.isLiveOnEntryDef(Start)) { -    assert(MSSA.isLiveOnEntryDef(ClobberAt) && -           "liveOnEntry must clobber itself"); -    return; -  } - -  bool FoundClobber = false; -  DenseSet<MemoryAccessPair> VisitedPhis; -  SmallVector<MemoryAccessPair, 8> Worklist; -  Worklist.emplace_back(Start, StartLoc); -  // Walk all paths from Start to ClobberAt, while looking for clobbers. If one -  // is found, complain. -  while (!Worklist.empty()) { -    MemoryAccessPair MAP = Worklist.pop_back_val(); -    // All we care about is that nothing from Start to ClobberAt clobbers Start. -    // We learn nothing from revisiting nodes. -    if (!VisitedPhis.insert(MAP).second) -      continue; - -    for (MemoryAccess *MA : def_chain(MAP.first)) { -      if (MA == ClobberAt) { -        if (auto *MD = dyn_cast<MemoryDef>(MA)) { -          // instructionClobbersQuery isn't essentially free, so don't use `|=`, -          // since it won't let us short-circuit. -          // -          // Also, note that this can't be hoisted out of the `Worklist` loop, -          // since MD may only act as a clobber for 1 of N MemoryLocations. -          FoundClobber = -              FoundClobber || MSSA.isLiveOnEntryDef(MD) || -              instructionClobbersQuery(MD, MAP.second, Query.Inst, AA); -        } -        break; -      } - -      // We should never hit liveOnEntry, unless it's the clobber. -      assert(!MSSA.isLiveOnEntryDef(MA) && "Hit liveOnEntry before clobber?"); - -      if (auto *MD = dyn_cast<MemoryDef>(MA)) { -        (void)MD; -        assert(!instructionClobbersQuery(MD, MAP.second, Query.Inst, AA) && -               "Found clobber before reaching ClobberAt!"); -        continue; -      } - -      assert(isa<MemoryPhi>(MA)); -      Worklist.append(upward_defs_begin({MA, MAP.second}), upward_defs_end()); -    } -  } - -  // If ClobberAt is a MemoryPhi, we can assume something above it acted as a -  // clobber. Otherwise, `ClobberAt` should've acted as a clobber at some point. -  assert((isa<MemoryPhi>(ClobberAt) || FoundClobber) && -         "ClobberAt never acted as a clobber"); -} - -/// Our algorithm for walking (and trying to optimize) clobbers, all wrapped up -/// in one class. -class ClobberWalker { -  /// Save a few bytes by using unsigned instead of size_t. -  using ListIndex = unsigned; - -  /// Represents a span of contiguous MemoryDefs, potentially ending in a -  /// MemoryPhi. -  struct DefPath { -    MemoryLocation Loc; -    // Note that, because we always walk in reverse, Last will always dominate -    // First. Also note that First and Last are inclusive. -    MemoryAccess *First; -    MemoryAccess *Last; -    Optional<ListIndex> Previous; - -    DefPath(const MemoryLocation &Loc, MemoryAccess *First, MemoryAccess *Last, -            Optional<ListIndex> Previous) -        : Loc(Loc), First(First), Last(Last), Previous(Previous) {} - -    DefPath(const MemoryLocation &Loc, MemoryAccess *Init, -            Optional<ListIndex> Previous) -        : DefPath(Loc, Init, Init, Previous) {} -  }; - -  const MemorySSA &MSSA; -  AliasAnalysis &AA; -  DominatorTree &DT; -  WalkerCache &WC; -  UpwardsMemoryQuery *Query; -  bool UseCache; - -  // Phi optimization bookkeeping -  SmallVector<DefPath, 32> Paths; -  DenseSet<ConstMemoryAccessPair> VisitedPhis; -  DenseMap<const BasicBlock *, MemoryAccess *> WalkTargetCache; - -  void setUseCache(bool Use) { UseCache = Use; } -  bool shouldIgnoreCache() const { -    // UseCache will only be false when we're debugging, or when expensive -    // checks are enabled. In either case, we don't care deeply about speed. -    return LLVM_UNLIKELY(!UseCache); -  } - -  void addCacheEntry(const MemoryAccess *What, MemoryAccess *To, -                     const MemoryLocation &Loc) const { -// EXPENSIVE_CHECKS because most of these queries are redundant. -#ifdef EXPENSIVE_CHECKS -    assert(MSSA.dominates(To, What)); -#endif -    if (shouldIgnoreCache()) -      return; -    WC.insert(What, To, Loc, Query->IsCall); -  } - -  MemoryAccess *lookupCache(const MemoryAccess *MA, const MemoryLocation &Loc) { -    return shouldIgnoreCache() ? nullptr : WC.lookup(MA, Loc, Query->IsCall); -  } - -  void cacheDefPath(const DefPath &DN, MemoryAccess *Target) const { -    if (shouldIgnoreCache()) -      return; - -    for (MemoryAccess *MA : def_chain(DN.First, DN.Last)) -      addCacheEntry(MA, Target, DN.Loc); - -    // DefPaths only express the path we walked. So, DN.Last could either be a -    // thing we want to cache, or not. -    if (DN.Last != Target) -      addCacheEntry(DN.Last, Target, DN.Loc); -  } - -  /// Find the nearest def or phi that `From` can legally be optimized to. -  /// -  /// FIXME: Deduplicate this with MSSA::findDominatingDef. Ideally, MSSA should -  /// keep track of this information for us, and allow us O(1) lookups of this -  /// info. -  MemoryAccess *getWalkTarget(const MemoryPhi *From) { -    assert(From->getNumOperands() && "Phi with no operands?"); - -    BasicBlock *BB = From->getBlock(); -    auto At = WalkTargetCache.find(BB); -    if (At != WalkTargetCache.end()) -      return At->second; - -    SmallVector<const BasicBlock *, 8> ToCache; -    ToCache.push_back(BB); - -    MemoryAccess *Result = MSSA.getLiveOnEntryDef(); -    DomTreeNode *Node = DT.getNode(BB); -    while ((Node = Node->getIDom())) { -      auto At = WalkTargetCache.find(BB); -      if (At != WalkTargetCache.end()) { -        Result = At->second; -        break; -      } - -      auto *Accesses = MSSA.getBlockAccesses(Node->getBlock()); -      if (Accesses) { -        auto Iter = find_if(reverse(*Accesses), [](const MemoryAccess &MA) { -          return !isa<MemoryUse>(MA); -        }); -        if (Iter != Accesses->rend()) { -          Result = const_cast<MemoryAccess *>(&*Iter); -          break; -        } -      } - -      ToCache.push_back(Node->getBlock()); -    } - -    for (const BasicBlock *BB : ToCache) -      WalkTargetCache.insert({BB, Result}); -    return Result; -  } - -  /// Result of calling walkToPhiOrClobber. -  struct UpwardsWalkResult { -    /// The "Result" of the walk. Either a clobber, the last thing we walked, or -    /// both. -    MemoryAccess *Result; -    bool IsKnownClobber; -    bool FromCache; -  }; - -  /// Walk to the next Phi or Clobber in the def chain starting at Desc.Last. -  /// This will update Desc.Last as it walks. It will (optionally) also stop at -  /// StopAt. -  /// -  /// This does not test for whether StopAt is a clobber -  UpwardsWalkResult walkToPhiOrClobber(DefPath &Desc, -                                       MemoryAccess *StopAt = nullptr) { -    assert(!isa<MemoryUse>(Desc.Last) && "Uses don't exist in my world"); - -    for (MemoryAccess *Current : def_chain(Desc.Last)) { -      Desc.Last = Current; -      if (Current == StopAt) -        return {Current, false, false}; - -      if (auto *MD = dyn_cast<MemoryDef>(Current)) -        if (MSSA.isLiveOnEntryDef(MD) || -            instructionClobbersQuery(MD, Desc.Loc, Query->Inst, AA)) -          return {MD, true, false}; - -      // Cache checks must be done last, because if Current is a clobber, the -      // cache will contain the clobber for Current. -      if (MemoryAccess *MA = lookupCache(Current, Desc.Loc)) -        return {MA, true, true}; -    } - -    assert(isa<MemoryPhi>(Desc.Last) && -           "Ended at a non-clobber that's not a phi?"); -    return {Desc.Last, false, false}; -  } - -  void addSearches(MemoryPhi *Phi, SmallVectorImpl<ListIndex> &PausedSearches, -                   ListIndex PriorNode) { -    auto UpwardDefs = make_range(upward_defs_begin({Phi, Paths[PriorNode].Loc}), -                                 upward_defs_end()); -    for (const MemoryAccessPair &P : UpwardDefs) { -      PausedSearches.push_back(Paths.size()); -      Paths.emplace_back(P.second, P.first, PriorNode); -    } -  } - -  /// Represents a search that terminated after finding a clobber. This clobber -  /// may or may not be present in the path of defs from LastNode..SearchStart, -  /// since it may have been retrieved from cache. -  struct TerminatedPath { -    MemoryAccess *Clobber; -    ListIndex LastNode; -  }; - -  /// Get an access that keeps us from optimizing to the given phi. -  /// -  /// PausedSearches is an array of indices into the Paths array. Its incoming -  /// value is the indices of searches that stopped at the last phi optimization -  /// target. It's left in an unspecified state. -  /// -  /// If this returns None, NewPaused is a vector of searches that terminated -  /// at StopWhere. Otherwise, NewPaused is left in an unspecified state. -  Optional<TerminatedPath> -  getBlockingAccess(MemoryAccess *StopWhere, -                    SmallVectorImpl<ListIndex> &PausedSearches, -                    SmallVectorImpl<ListIndex> &NewPaused, -                    SmallVectorImpl<TerminatedPath> &Terminated) { -    assert(!PausedSearches.empty() && "No searches to continue?"); - -    // BFS vs DFS really doesn't make a difference here, so just do a DFS with -    // PausedSearches as our stack. -    while (!PausedSearches.empty()) { -      ListIndex PathIndex = PausedSearches.pop_back_val(); -      DefPath &Node = Paths[PathIndex]; - -      // If we've already visited this path with this MemoryLocation, we don't -      // need to do so again. -      // -      // NOTE: That we just drop these paths on the ground makes caching -      // behavior sporadic. e.g. given a diamond: -      //  A -      // B C -      //  D -      // -      // ...If we walk D, B, A, C, we'll only cache the result of phi -      // optimization for A, B, and D; C will be skipped because it dies here. -      // This arguably isn't the worst thing ever, since: -      //   - We generally query things in a top-down order, so if we got below D -      //     without needing cache entries for {C, MemLoc}, then chances are -      //     that those cache entries would end up ultimately unused. -      //   - We still cache things for A, so C only needs to walk up a bit. -      // If this behavior becomes problematic, we can fix without a ton of extra -      // work. -      if (!VisitedPhis.insert({Node.Last, Node.Loc}).second) -        continue; - -      UpwardsWalkResult Res = walkToPhiOrClobber(Node, /*StopAt=*/StopWhere); -      if (Res.IsKnownClobber) { -        assert(Res.Result != StopWhere || Res.FromCache); -        // If this wasn't a cache hit, we hit a clobber when walking. That's a -        // failure. -        TerminatedPath Term{Res.Result, PathIndex}; -        if (!Res.FromCache || !MSSA.dominates(Res.Result, StopWhere)) -          return Term; - -        // Otherwise, it's a valid thing to potentially optimize to. -        Terminated.push_back(Term); -        continue; -      } - -      if (Res.Result == StopWhere) { -        // We've hit our target. Save this path off for if we want to continue -        // walking. -        NewPaused.push_back(PathIndex); -        continue; -      } - -      assert(!MSSA.isLiveOnEntryDef(Res.Result) && "liveOnEntry is a clobber"); -      addSearches(cast<MemoryPhi>(Res.Result), PausedSearches, PathIndex); -    } - -    return None; -  } - -  template <typename T, typename Walker> -  struct generic_def_path_iterator -      : public iterator_facade_base<generic_def_path_iterator<T, Walker>, -                                    std::forward_iterator_tag, T *> { -    generic_def_path_iterator() : W(nullptr), N(None) {} -    generic_def_path_iterator(Walker *W, ListIndex N) : W(W), N(N) {} - -    T &operator*() const { return curNode(); } - -    generic_def_path_iterator &operator++() { -      N = curNode().Previous; -      return *this; -    } - -    bool operator==(const generic_def_path_iterator &O) const { -      if (N.hasValue() != O.N.hasValue()) -        return false; -      return !N.hasValue() || *N == *O.N; -    } - -  private: -    T &curNode() const { return W->Paths[*N]; } - -    Walker *W; -    Optional<ListIndex> N; -  }; - -  using def_path_iterator = generic_def_path_iterator<DefPath, ClobberWalker>; -  using const_def_path_iterator = -      generic_def_path_iterator<const DefPath, const ClobberWalker>; - -  iterator_range<def_path_iterator> def_path(ListIndex From) { -    return make_range(def_path_iterator(this, From), def_path_iterator()); -  } - -  iterator_range<const_def_path_iterator> const_def_path(ListIndex From) const { -    return make_range(const_def_path_iterator(this, From), -                      const_def_path_iterator()); -  } - -  struct OptznResult { -    /// The path that contains our result. -    TerminatedPath PrimaryClobber; -    /// The paths that we can legally cache back from, but that aren't -    /// necessarily the result of the Phi optimization. -    SmallVector<TerminatedPath, 4> OtherClobbers; -  }; - -  ListIndex defPathIndex(const DefPath &N) const { -    // The assert looks nicer if we don't need to do &N -    const DefPath *NP = &N; -    assert(!Paths.empty() && NP >= &Paths.front() && NP <= &Paths.back() && -           "Out of bounds DefPath!"); -    return NP - &Paths.front(); -  } - -  /// Try to optimize a phi as best as we can. Returns a SmallVector of Paths -  /// that act as legal clobbers. Note that this won't return *all* clobbers. -  /// -  /// Phi optimization algorithm tl;dr: -  ///   - Find the earliest def/phi, A, we can optimize to -  ///   - Find if all paths from the starting memory access ultimately reach A -  ///     - If not, optimization isn't possible. -  ///     - Otherwise, walk from A to another clobber or phi, A'. -  ///       - If A' is a def, we're done. -  ///       - If A' is a phi, try to optimize it. -  /// -  /// A path is a series of {MemoryAccess, MemoryLocation} pairs. A path -  /// terminates when a MemoryAccess that clobbers said MemoryLocation is found. -  OptznResult tryOptimizePhi(MemoryPhi *Phi, MemoryAccess *Start, -                             const MemoryLocation &Loc) { -    assert(Paths.empty() && VisitedPhis.empty() && -           "Reset the optimization state."); - -    Paths.emplace_back(Loc, Start, Phi, None); -    // Stores how many "valid" optimization nodes we had prior to calling -    // addSearches/getBlockingAccess. Necessary for caching if we had a blocker. -    auto PriorPathsSize = Paths.size(); - -    SmallVector<ListIndex, 16> PausedSearches; -    SmallVector<ListIndex, 8> NewPaused; -    SmallVector<TerminatedPath, 4> TerminatedPaths; - -    addSearches(Phi, PausedSearches, 0); - -    // Moves the TerminatedPath with the "most dominated" Clobber to the end of -    // Paths. -    auto MoveDominatedPathToEnd = [&](SmallVectorImpl<TerminatedPath> &Paths) { -      assert(!Paths.empty() && "Need a path to move"); -      auto Dom = Paths.begin(); -      for (auto I = std::next(Dom), E = Paths.end(); I != E; ++I) -        if (!MSSA.dominates(I->Clobber, Dom->Clobber)) -          Dom = I; -      auto Last = Paths.end() - 1; -      if (Last != Dom) -        std::iter_swap(Last, Dom); -    }; - -    MemoryPhi *Current = Phi; -    while (1) { -      assert(!MSSA.isLiveOnEntryDef(Current) && -             "liveOnEntry wasn't treated as a clobber?"); - -      MemoryAccess *Target = getWalkTarget(Current); -      // If a TerminatedPath doesn't dominate Target, then it wasn't a legal -      // optimization for the prior phi. -      assert(all_of(TerminatedPaths, [&](const TerminatedPath &P) { -        return MSSA.dominates(P.Clobber, Target); -      })); - -      // FIXME: This is broken, because the Blocker may be reported to be -      // liveOnEntry, and we'll happily wait for that to disappear (read: never) -      // For the moment, this is fine, since we do nothing with blocker info. -      if (Optional<TerminatedPath> Blocker = getBlockingAccess( -              Target, PausedSearches, NewPaused, TerminatedPaths)) { -        // Cache our work on the blocking node, since we know that's correct. -        cacheDefPath(Paths[Blocker->LastNode], Blocker->Clobber); - -        // Find the node we started at. We can't search based on N->Last, since -        // we may have gone around a loop with a different MemoryLocation. -        auto Iter = find_if(def_path(Blocker->LastNode), [&](const DefPath &N) { -          return defPathIndex(N) < PriorPathsSize; -        }); -        assert(Iter != def_path_iterator()); - -        DefPath &CurNode = *Iter; -        assert(CurNode.Last == Current); - -        // Two things: -        // A. We can't reliably cache all of NewPaused back. Consider a case -        //    where we have two paths in NewPaused; one of which can't optimize -        //    above this phi, whereas the other can. If we cache the second path -        //    back, we'll end up with suboptimal cache entries. We can handle -        //    cases like this a bit better when we either try to find all -        //    clobbers that block phi optimization, or when our cache starts -        //    supporting unfinished searches. -        // B. We can't reliably cache TerminatedPaths back here without doing -        //    extra checks; consider a case like: -        //       T -        //      / \ -        //     D   C -        //      \ / -        //       S -        //    Where T is our target, C is a node with a clobber on it, D is a -        //    diamond (with a clobber *only* on the left or right node, N), and -        //    S is our start. Say we walk to D, through the node opposite N -        //    (read: ignoring the clobber), and see a cache entry in the top -        //    node of D. That cache entry gets put into TerminatedPaths. We then -        //    walk up to C (N is later in our worklist), find the clobber, and -        //    quit. If we append TerminatedPaths to OtherClobbers, we'll cache -        //    the bottom part of D to the cached clobber, ignoring the clobber -        //    in N. Again, this problem goes away if we start tracking all -        //    blockers for a given phi optimization. -        TerminatedPath Result{CurNode.Last, defPathIndex(CurNode)}; -        return {Result, {}}; -      } - -      // If there's nothing left to search, then all paths led to valid clobbers -      // that we got from our cache; pick the nearest to the start, and allow -      // the rest to be cached back. -      if (NewPaused.empty()) { -        MoveDominatedPathToEnd(TerminatedPaths); -        TerminatedPath Result = TerminatedPaths.pop_back_val(); -        return {Result, std::move(TerminatedPaths)}; -      } - -      MemoryAccess *DefChainEnd = nullptr; -      SmallVector<TerminatedPath, 4> Clobbers; -      for (ListIndex Paused : NewPaused) { -        UpwardsWalkResult WR = walkToPhiOrClobber(Paths[Paused]); -        if (WR.IsKnownClobber) -          Clobbers.push_back({WR.Result, Paused}); -        else -          // Micro-opt: If we hit the end of the chain, save it. -          DefChainEnd = WR.Result; -      } - -      if (!TerminatedPaths.empty()) { -        // If we couldn't find the dominating phi/liveOnEntry in the above loop, -        // do it now. -        if (!DefChainEnd) -          for (MemoryAccess *MA : def_chain(Target)) -            DefChainEnd = MA; - -        // If any of the terminated paths don't dominate the phi we'll try to -        // optimize, we need to figure out what they are and quit. -        const BasicBlock *ChainBB = DefChainEnd->getBlock(); -        for (const TerminatedPath &TP : TerminatedPaths) { -          // Because we know that DefChainEnd is as "high" as we can go, we -          // don't need local dominance checks; BB dominance is sufficient. -          if (DT.dominates(ChainBB, TP.Clobber->getBlock())) -            Clobbers.push_back(TP); -        } -      } - -      // If we have clobbers in the def chain, find the one closest to Current -      // and quit. -      if (!Clobbers.empty()) { -        MoveDominatedPathToEnd(Clobbers); -        TerminatedPath Result = Clobbers.pop_back_val(); -        return {Result, std::move(Clobbers)}; -      } - -      assert(all_of(NewPaused, -                    [&](ListIndex I) { return Paths[I].Last == DefChainEnd; })); - -      // Because liveOnEntry is a clobber, this must be a phi. -      auto *DefChainPhi = cast<MemoryPhi>(DefChainEnd); - -      PriorPathsSize = Paths.size(); -      PausedSearches.clear(); -      for (ListIndex I : NewPaused) -        addSearches(DefChainPhi, PausedSearches, I); -      NewPaused.clear(); - -      Current = DefChainPhi; -    } -  } - -  /// Caches everything in an OptznResult. -  void cacheOptResult(const OptznResult &R) { -    if (R.OtherClobbers.empty()) { -      // If we're not going to be caching OtherClobbers, don't bother with -      // marking visited/etc. -      for (const DefPath &N : const_def_path(R.PrimaryClobber.LastNode)) -        cacheDefPath(N, R.PrimaryClobber.Clobber); -      return; -    } - -    // PrimaryClobber is our answer. If we can cache anything back, we need to -    // stop caching when we visit PrimaryClobber. -    SmallBitVector Visited(Paths.size()); -    for (const DefPath &N : const_def_path(R.PrimaryClobber.LastNode)) { -      Visited[defPathIndex(N)] = true; -      cacheDefPath(N, R.PrimaryClobber.Clobber); -    } - -    for (const TerminatedPath &P : R.OtherClobbers) { -      for (const DefPath &N : const_def_path(P.LastNode)) { -        ListIndex NIndex = defPathIndex(N); -        if (Visited[NIndex]) -          break; -        Visited[NIndex] = true; -        cacheDefPath(N, P.Clobber); -      } -    } -  } - -  void verifyOptResult(const OptznResult &R) const { -    assert(all_of(R.OtherClobbers, [&](const TerminatedPath &P) { -      return MSSA.dominates(P.Clobber, R.PrimaryClobber.Clobber); -    })); -  } - -  void resetPhiOptznState() { -    Paths.clear(); -    VisitedPhis.clear(); -  } - -public: -  ClobberWalker(const MemorySSA &MSSA, AliasAnalysis &AA, DominatorTree &DT, -                WalkerCache &WC) -      : MSSA(MSSA), AA(AA), DT(DT), WC(WC), UseCache(true) {} - -  void reset() { WalkTargetCache.clear(); } - -  /// Finds the nearest clobber for the given query, optimizing phis if -  /// possible. -  MemoryAccess *findClobber(MemoryAccess *Start, UpwardsMemoryQuery &Q, -                            bool UseWalkerCache = true) { -    setUseCache(UseWalkerCache); -    Query = &Q; - -    MemoryAccess *Current = Start; -    // This walker pretends uses don't exist. If we're handed one, silently grab -    // its def. (This has the nice side-effect of ensuring we never cache uses) -    if (auto *MU = dyn_cast<MemoryUse>(Start)) -      Current = MU->getDefiningAccess(); - -    DefPath FirstDesc(Q.StartingLoc, Current, Current, None); -    // Fast path for the overly-common case (no crazy phi optimization -    // necessary) -    UpwardsWalkResult WalkResult = walkToPhiOrClobber(FirstDesc); -    MemoryAccess *Result; -    if (WalkResult.IsKnownClobber) { -      cacheDefPath(FirstDesc, WalkResult.Result); -      Result = WalkResult.Result; -    } else { -      OptznResult OptRes = tryOptimizePhi(cast<MemoryPhi>(FirstDesc.Last), -                                          Current, Q.StartingLoc); -      verifyOptResult(OptRes); -      cacheOptResult(OptRes); -      resetPhiOptznState(); -      Result = OptRes.PrimaryClobber.Clobber; -    } - -#ifdef EXPENSIVE_CHECKS -    checkClobberSanity(Current, Result, Q.StartingLoc, MSSA, Q, AA); -#endif -    return Result; -  } - -  void verify(const MemorySSA *MSSA) { assert(MSSA == &this->MSSA); } -}; - -struct RenamePassData { -  DomTreeNode *DTN; -  DomTreeNode::const_iterator ChildIt; -  MemoryAccess *IncomingVal; - -  RenamePassData(DomTreeNode *D, DomTreeNode::const_iterator It, -                 MemoryAccess *M) -      : DTN(D), ChildIt(It), IncomingVal(M) {} -  void swap(RenamePassData &RHS) { -    std::swap(DTN, RHS.DTN); -    std::swap(ChildIt, RHS.ChildIt); -    std::swap(IncomingVal, RHS.IncomingVal); -  } -}; -} // anonymous namespace - -namespace llvm { -/// \brief A MemorySSAWalker that does AA walks and caching of lookups to -/// disambiguate accesses. -/// -/// FIXME: The current implementation of this can take quadratic space in rare -/// cases. This can be fixed, but it is something to note until it is fixed. -/// -/// In order to trigger this behavior, you need to store to N distinct locations -/// (that AA can prove don't alias), perform M stores to other memory -/// locations that AA can prove don't alias any of the initial N locations, and -/// then load from all of the N locations. In this case, we insert M cache -/// entries for each of the N loads. -/// -/// For example: -/// define i32 @foo() { -///   %a = alloca i32, align 4 -///   %b = alloca i32, align 4 -///   store i32 0, i32* %a, align 4 -///   store i32 0, i32* %b, align 4 -/// -///   ; Insert M stores to other memory that doesn't alias %a or %b here -/// -///   %c = load i32, i32* %a, align 4 ; Caches M entries in -///                                   ; CachedUpwardsClobberingAccess for the -///                                   ; MemoryLocation %a -///   %d = load i32, i32* %b, align 4 ; Caches M entries in -///                                   ; CachedUpwardsClobberingAccess for the -///                                   ; MemoryLocation %b -/// -///   ; For completeness' sake, loading %a or %b again would not cache *another* -///   ; M entries. -///   %r = add i32 %c, %d -///   ret i32 %r -/// } -class MemorySSA::CachingWalker final : public MemorySSAWalker { -  WalkerCache Cache; -  ClobberWalker Walker; -  bool AutoResetWalker; - -  MemoryAccess *getClobberingMemoryAccess(MemoryAccess *, UpwardsMemoryQuery &); -  void verifyRemoved(MemoryAccess *); - -public: -  CachingWalker(MemorySSA *, AliasAnalysis *, DominatorTree *); -  ~CachingWalker() override; - -  using MemorySSAWalker::getClobberingMemoryAccess; -  MemoryAccess *getClobberingMemoryAccess(MemoryAccess *) override; -  MemoryAccess *getClobberingMemoryAccess(MemoryAccess *, -                                          const MemoryLocation &) override; -  void invalidateInfo(MemoryAccess *) override; - -  /// Whether we call resetClobberWalker() after each time we *actually* walk to -  /// answer a clobber query. -  void setAutoResetWalker(bool AutoReset) { AutoResetWalker = AutoReset; } - -  /// Drop the walker's persistent data structures. At the moment, this means -  /// "drop the walker's cache of BasicBlocks -> -  /// earliest-MemoryAccess-we-can-optimize-to". This is necessary if we're -  /// going to have DT updates, if we remove MemoryAccesses, etc. -  void resetClobberWalker() { Walker.reset(); } - -  void verify(const MemorySSA *MSSA) override { -    MemorySSAWalker::verify(MSSA); -    Walker.verify(MSSA); -  } -}; - -/// \brief Rename a single basic block into MemorySSA form. -/// Uses the standard SSA renaming algorithm. -/// \returns The new incoming value. -MemoryAccess *MemorySSA::renameBlock(BasicBlock *BB, -                                     MemoryAccess *IncomingVal) { -  auto It = PerBlockAccesses.find(BB); -  // Skip most processing if the list is empty. -  if (It != PerBlockAccesses.end()) { -    AccessList *Accesses = It->second.get(); -    for (MemoryAccess &L : *Accesses) { -      if (MemoryUseOrDef *MUD = dyn_cast<MemoryUseOrDef>(&L)) { -        if (MUD->getDefiningAccess() == nullptr) -          MUD->setDefiningAccess(IncomingVal); -        if (isa<MemoryDef>(&L)) -          IncomingVal = &L; -      } else { -        IncomingVal = &L; -      } -    } -  } - -  // Pass through values to our successors -  for (const BasicBlock *S : successors(BB)) { -    auto It = PerBlockAccesses.find(S); -    // Rename the phi nodes in our successor block -    if (It == PerBlockAccesses.end() || !isa<MemoryPhi>(It->second->front())) -      continue; -    AccessList *Accesses = It->second.get(); -    auto *Phi = cast<MemoryPhi>(&Accesses->front()); -    Phi->addIncoming(IncomingVal, BB); -  } - -  return IncomingVal; -} - -/// \brief This is the standard SSA renaming algorithm. -/// -/// We walk the dominator tree in preorder, renaming accesses, and then filling -/// in phi nodes in our successors. -void MemorySSA::renamePass(DomTreeNode *Root, MemoryAccess *IncomingVal, -                           SmallPtrSet<BasicBlock *, 16> &Visited) { -  SmallVector<RenamePassData, 32> WorkStack; -  IncomingVal = renameBlock(Root->getBlock(), IncomingVal); -  WorkStack.push_back({Root, Root->begin(), IncomingVal}); -  Visited.insert(Root->getBlock()); - -  while (!WorkStack.empty()) { -    DomTreeNode *Node = WorkStack.back().DTN; -    DomTreeNode::const_iterator ChildIt = WorkStack.back().ChildIt; -    IncomingVal = WorkStack.back().IncomingVal; - -    if (ChildIt == Node->end()) { -      WorkStack.pop_back(); -    } else { -      DomTreeNode *Child = *ChildIt; -      ++WorkStack.back().ChildIt; -      BasicBlock *BB = Child->getBlock(); -      Visited.insert(BB); -      IncomingVal = renameBlock(BB, IncomingVal); -      WorkStack.push_back({Child, Child->begin(), IncomingVal}); -    } -  } -} - -/// \brief Compute dominator levels, used by the phi insertion algorithm above. -void MemorySSA::computeDomLevels(DenseMap<DomTreeNode *, unsigned> &DomLevels) { -  for (auto DFI = df_begin(DT->getRootNode()), DFE = df_end(DT->getRootNode()); -       DFI != DFE; ++DFI) -    DomLevels[*DFI] = DFI.getPathLength() - 1; -} - -/// \brief This handles unreachable block accesses by deleting phi nodes in -/// unreachable blocks, and marking all other unreachable MemoryAccess's as -/// being uses of the live on entry definition. -void MemorySSA::markUnreachableAsLiveOnEntry(BasicBlock *BB) { -  assert(!DT->isReachableFromEntry(BB) && -         "Reachable block found while handling unreachable blocks"); - -  // Make sure phi nodes in our reachable successors end up with a -  // LiveOnEntryDef for our incoming edge, even though our block is forward -  // unreachable.  We could just disconnect these blocks from the CFG fully, -  // but we do not right now. -  for (const BasicBlock *S : successors(BB)) { -    if (!DT->isReachableFromEntry(S)) -      continue; -    auto It = PerBlockAccesses.find(S); -    // Rename the phi nodes in our successor block -    if (It == PerBlockAccesses.end() || !isa<MemoryPhi>(It->second->front())) -      continue; -    AccessList *Accesses = It->second.get(); -    auto *Phi = cast<MemoryPhi>(&Accesses->front()); -    Phi->addIncoming(LiveOnEntryDef.get(), BB); -  } - -  auto It = PerBlockAccesses.find(BB); -  if (It == PerBlockAccesses.end()) -    return; - -  auto &Accesses = It->second; -  for (auto AI = Accesses->begin(), AE = Accesses->end(); AI != AE;) { -    auto Next = std::next(AI); -    // If we have a phi, just remove it. We are going to replace all -    // users with live on entry. -    if (auto *UseOrDef = dyn_cast<MemoryUseOrDef>(AI)) -      UseOrDef->setDefiningAccess(LiveOnEntryDef.get()); -    else -      Accesses->erase(AI); -    AI = Next; -  } -} - -MemorySSA::MemorySSA(Function &Func, AliasAnalysis *AA, DominatorTree *DT) -    : AA(AA), DT(DT), F(Func), LiveOnEntryDef(nullptr), Walker(nullptr), -      NextID(INVALID_MEMORYACCESS_ID) { -  buildMemorySSA(); -} - -MemorySSA::~MemorySSA() { -  // Drop all our references -  for (const auto &Pair : PerBlockAccesses) -    for (MemoryAccess &MA : *Pair.second) -      MA.dropAllReferences(); -} - -MemorySSA::AccessList *MemorySSA::getOrCreateAccessList(const BasicBlock *BB) { -  auto Res = PerBlockAccesses.insert(std::make_pair(BB, nullptr)); - -  if (Res.second) -    Res.first->second = make_unique<AccessList>(); -  return Res.first->second.get(); -} - -/// This class is a batch walker of all MemoryUse's in the program, and points -/// their defining access at the thing that actually clobbers them.  Because it -/// is a batch walker that touches everything, it does not operate like the -/// other walkers.  This walker is basically performing a top-down SSA renaming -/// pass, where the version stack is used as the cache.  This enables it to be -/// significantly more time and memory efficient than using the regular walker, -/// which is walking bottom-up. -class MemorySSA::OptimizeUses { -public: -  OptimizeUses(MemorySSA *MSSA, MemorySSAWalker *Walker, AliasAnalysis *AA, -               DominatorTree *DT) -      : MSSA(MSSA), Walker(Walker), AA(AA), DT(DT) { -    Walker = MSSA->getWalker(); -  } - -  void optimizeUses(); - -private: -  /// This represents where a given memorylocation is in the stack. -  struct MemlocStackInfo { -    // This essentially is keeping track of versions of the stack. Whenever -    // the stack changes due to pushes or pops, these versions increase. -    unsigned long StackEpoch; -    unsigned long PopEpoch; -    // This is the lower bound of places on the stack to check. It is equal to -    // the place the last stack walk ended. -    // Note: Correctness depends on this being initialized to 0, which densemap -    // does -    unsigned long LowerBound; -    const BasicBlock *LowerBoundBlock; -    // This is where the last walk for this memory location ended. -    unsigned long LastKill; -    bool LastKillValid; -  }; -  void optimizeUsesInBlock(const BasicBlock *, unsigned long &, unsigned long &, -                           SmallVectorImpl<MemoryAccess *> &, -                           DenseMap<MemoryLocOrCall, MemlocStackInfo> &); -  MemorySSA *MSSA; -  MemorySSAWalker *Walker; -  AliasAnalysis *AA; -  DominatorTree *DT; -}; - -/// Optimize the uses in a given block This is basically the SSA renaming -/// algorithm, with one caveat: We are able to use a single stack for all -/// MemoryUses.  This is because the set of *possible* reaching MemoryDefs is -/// the same for every MemoryUse.  The *actual* clobbering MemoryDef is just -/// going to be some position in that stack of possible ones. -/// -/// We track the stack positions that each MemoryLocation needs -/// to check, and last ended at.  This is because we only want to check the -/// things that changed since last time.  The same MemoryLocation should -/// get clobbered by the same store (getModRefInfo does not use invariantness or -/// things like this, and if they start, we can modify MemoryLocOrCall to -/// include relevant data) -void MemorySSA::OptimizeUses::optimizeUsesInBlock( -    const BasicBlock *BB, unsigned long &StackEpoch, unsigned long &PopEpoch, -    SmallVectorImpl<MemoryAccess *> &VersionStack, -    DenseMap<MemoryLocOrCall, MemlocStackInfo> &LocStackInfo) { - -  /// If no accesses, nothing to do. -  MemorySSA::AccessList *Accesses = MSSA->getWritableBlockAccesses(BB); -  if (Accesses == nullptr) -    return; - -  // Pop everything that doesn't dominate the current block off the stack, -  // increment the PopEpoch to account for this. -  while (!VersionStack.empty()) { -    BasicBlock *BackBlock = VersionStack.back()->getBlock(); -    if (DT->dominates(BackBlock, BB)) -      break; -    while (VersionStack.back()->getBlock() == BackBlock) -      VersionStack.pop_back(); -    ++PopEpoch; -  } -  for (MemoryAccess &MA : *Accesses) { -    auto *MU = dyn_cast<MemoryUse>(&MA); -    if (!MU) { -      VersionStack.push_back(&MA); -      ++StackEpoch; -      continue; -    } - -    if (isUseTriviallyOptimizableToLiveOnEntry(*AA, MU->getMemoryInst())) { -      MU->setDefiningAccess(MSSA->getLiveOnEntryDef(), true); -      continue; -    } - -    MemoryLocOrCall UseMLOC(MU); -    auto &LocInfo = LocStackInfo[UseMLOC]; -    // If the pop epoch changed, it means we've removed stuff from top of -    // stack due to changing blocks. We may have to reset the lower bound or -    // last kill info. -    if (LocInfo.PopEpoch != PopEpoch) { -      LocInfo.PopEpoch = PopEpoch; -      LocInfo.StackEpoch = StackEpoch; -      // If the lower bound was in something that no longer dominates us, we -      // have to reset it. -      // We can't simply track stack size, because the stack may have had -      // pushes/pops in the meantime. -      // XXX: This is non-optimal, but only is slower cases with heavily -      // branching dominator trees.  To get the optimal number of queries would -      // be to make lowerbound and lastkill a per-loc stack, and pop it until -      // the top of that stack dominates us.  This does not seem worth it ATM. -      // A much cheaper optimization would be to always explore the deepest -      // branch of the dominator tree first. This will guarantee this resets on -      // the smallest set of blocks. -      if (LocInfo.LowerBoundBlock && LocInfo.LowerBoundBlock != BB && -          !DT->dominates(LocInfo.LowerBoundBlock, BB)) { -        // Reset the lower bound of things to check. -        // TODO: Some day we should be able to reset to last kill, rather than -        // 0. -        LocInfo.LowerBound = 0; -        LocInfo.LowerBoundBlock = VersionStack[0]->getBlock(); -        LocInfo.LastKillValid = false; -      } -    } else if (LocInfo.StackEpoch != StackEpoch) { -      // If all that has changed is the StackEpoch, we only have to check the -      // new things on the stack, because we've checked everything before.  In -      // this case, the lower bound of things to check remains the same. -      LocInfo.PopEpoch = PopEpoch; -      LocInfo.StackEpoch = StackEpoch; -    } -    if (!LocInfo.LastKillValid) { -      LocInfo.LastKill = VersionStack.size() - 1; -      LocInfo.LastKillValid = true; -    } - -    // At this point, we should have corrected last kill and LowerBound to be -    // in bounds. -    assert(LocInfo.LowerBound < VersionStack.size() && -           "Lower bound out of range"); -    assert(LocInfo.LastKill < VersionStack.size() && -           "Last kill info out of range"); -    // In any case, the new upper bound is the top of the stack. -    unsigned long UpperBound = VersionStack.size() - 1; - -    if (UpperBound - LocInfo.LowerBound > MaxCheckLimit) { -      DEBUG(dbgs() << "MemorySSA skipping optimization of " << *MU << " (" -                   << *(MU->getMemoryInst()) << ")" -                   << " because there are " << UpperBound - LocInfo.LowerBound -                   << " stores to disambiguate\n"); -      // Because we did not walk, LastKill is no longer valid, as this may -      // have been a kill. -      LocInfo.LastKillValid = false; -      continue; -    } -    bool FoundClobberResult = false; -    while (UpperBound > LocInfo.LowerBound) { -      if (isa<MemoryPhi>(VersionStack[UpperBound])) { -        // For phis, use the walker, see where we ended up, go there -        Instruction *UseInst = MU->getMemoryInst(); -        MemoryAccess *Result = Walker->getClobberingMemoryAccess(UseInst); -        // We are guaranteed to find it or something is wrong -        while (VersionStack[UpperBound] != Result) { -          assert(UpperBound != 0); -          --UpperBound; -        } -        FoundClobberResult = true; -        break; -      } - -      MemoryDef *MD = cast<MemoryDef>(VersionStack[UpperBound]); -      // If the lifetime of the pointer ends at this instruction, it's live on -      // entry. -      if (!UseMLOC.IsCall && lifetimeEndsAt(MD, UseMLOC.getLoc(), *AA)) { -        // Reset UpperBound to liveOnEntryDef's place in the stack -        UpperBound = 0; -        FoundClobberResult = true; -        break; -      } -      if (instructionClobbersQuery(MD, MU, UseMLOC, *AA)) { -        FoundClobberResult = true; -        break; -      } -      --UpperBound; -    } -    // At the end of this loop, UpperBound is either a clobber, or lower bound -    // PHI walking may cause it to be < LowerBound, and in fact, < LastKill. -    if (FoundClobberResult || UpperBound < LocInfo.LastKill) { -      MU->setDefiningAccess(VersionStack[UpperBound], true); -      // We were last killed now by where we got to -      LocInfo.LastKill = UpperBound; -    } else { -      // Otherwise, we checked all the new ones, and now we know we can get to -      // LastKill. -      MU->setDefiningAccess(VersionStack[LocInfo.LastKill], true); -    } -    LocInfo.LowerBound = VersionStack.size() - 1; -    LocInfo.LowerBoundBlock = BB; -  } -} - -/// Optimize uses to point to their actual clobbering definitions. -void MemorySSA::OptimizeUses::optimizeUses() { - -  // We perform a non-recursive top-down dominator tree walk -  struct StackInfo { -    const DomTreeNode *Node; -    DomTreeNode::const_iterator Iter; -  }; - -  SmallVector<MemoryAccess *, 16> VersionStack; -  SmallVector<StackInfo, 16> DomTreeWorklist; -  DenseMap<MemoryLocOrCall, MemlocStackInfo> LocStackInfo; -  VersionStack.push_back(MSSA->getLiveOnEntryDef()); - -  unsigned long StackEpoch = 1; -  unsigned long PopEpoch = 1; -  for (const auto *DomNode : depth_first(DT->getRootNode())) -    optimizeUsesInBlock(DomNode->getBlock(), StackEpoch, PopEpoch, VersionStack, -                        LocStackInfo); -} - -void MemorySSA::placePHINodes( -    const SmallPtrSetImpl<BasicBlock *> &DefiningBlocks, -    const DenseMap<const BasicBlock *, unsigned int> &BBNumbers) { -  // Determine where our MemoryPhi's should go -  ForwardIDFCalculator IDFs(*DT); -  IDFs.setDefiningBlocks(DefiningBlocks); -  SmallVector<BasicBlock *, 32> IDFBlocks; -  IDFs.calculate(IDFBlocks); - -  std::sort(IDFBlocks.begin(), IDFBlocks.end(), -            [&BBNumbers](const BasicBlock *A, const BasicBlock *B) { -              return BBNumbers.lookup(A) < BBNumbers.lookup(B); -            }); - -  // Now place MemoryPhi nodes. -  for (auto &BB : IDFBlocks) { -    // Insert phi node -    AccessList *Accesses = getOrCreateAccessList(BB); -    MemoryPhi *Phi = new MemoryPhi(BB->getContext(), BB, NextID++); -    ValueToMemoryAccess[BB] = Phi; -    // Phi's always are placed at the front of the block. -    Accesses->push_front(Phi); -  } -} - -void MemorySSA::buildMemorySSA() { -  // We create an access to represent "live on entry", for things like -  // arguments or users of globals, where the memory they use is defined before -  // the beginning of the function. We do not actually insert it into the IR. -  // We do not define a live on exit for the immediate uses, and thus our -  // semantics do *not* imply that something with no immediate uses can simply -  // be removed. -  BasicBlock &StartingPoint = F.getEntryBlock(); -  LiveOnEntryDef = make_unique<MemoryDef>(F.getContext(), nullptr, nullptr, -                                          &StartingPoint, NextID++); -  DenseMap<const BasicBlock *, unsigned int> BBNumbers; -  unsigned NextBBNum = 0; - -  // We maintain lists of memory accesses per-block, trading memory for time. We -  // could just look up the memory access for every possible instruction in the -  // stream. -  SmallPtrSet<BasicBlock *, 32> DefiningBlocks; -  SmallPtrSet<BasicBlock *, 32> DefUseBlocks; -  // Go through each block, figure out where defs occur, and chain together all -  // the accesses. -  for (BasicBlock &B : F) { -    BBNumbers[&B] = NextBBNum++; -    bool InsertIntoDef = false; -    AccessList *Accesses = nullptr; -    for (Instruction &I : B) { -      MemoryUseOrDef *MUD = createNewAccess(&I); -      if (!MUD) -        continue; -      InsertIntoDef |= isa<MemoryDef>(MUD); - -      if (!Accesses) -        Accesses = getOrCreateAccessList(&B); -      Accesses->push_back(MUD); -    } -    if (InsertIntoDef) -      DefiningBlocks.insert(&B); -    if (Accesses) -      DefUseBlocks.insert(&B); -  } -  placePHINodes(DefiningBlocks, BBNumbers); - -  // Now do regular SSA renaming on the MemoryDef/MemoryUse. Visited will get -  // filled in with all blocks. -  SmallPtrSet<BasicBlock *, 16> Visited; -  renamePass(DT->getRootNode(), LiveOnEntryDef.get(), Visited); - -  CachingWalker *Walker = getWalkerImpl(); - -  // We're doing a batch of updates; don't drop useful caches between them. -  Walker->setAutoResetWalker(false); -  OptimizeUses(this, Walker, AA, DT).optimizeUses(); -  Walker->setAutoResetWalker(true); -  Walker->resetClobberWalker(); - -  // Mark the uses in unreachable blocks as live on entry, so that they go -  // somewhere. -  for (auto &BB : F) -    if (!Visited.count(&BB)) -      markUnreachableAsLiveOnEntry(&BB); -} - -MemorySSAWalker *MemorySSA::getWalker() { return getWalkerImpl(); } - -MemorySSA::CachingWalker *MemorySSA::getWalkerImpl() { -  if (Walker) -    return Walker.get(); - -  Walker = make_unique<CachingWalker>(this, AA, DT); -  return Walker.get(); -} - -MemoryPhi *MemorySSA::createMemoryPhi(BasicBlock *BB) { -  assert(!getMemoryAccess(BB) && "MemoryPhi already exists for this BB"); -  AccessList *Accesses = getOrCreateAccessList(BB); -  MemoryPhi *Phi = new MemoryPhi(BB->getContext(), BB, NextID++); -  ValueToMemoryAccess[BB] = Phi; -  // Phi's always are placed at the front of the block. -  Accesses->push_front(Phi); -  BlockNumberingValid.erase(BB); -  return Phi; -} - -MemoryUseOrDef *MemorySSA::createDefinedAccess(Instruction *I, -                                               MemoryAccess *Definition) { -  assert(!isa<PHINode>(I) && "Cannot create a defined access for a PHI"); -  MemoryUseOrDef *NewAccess = createNewAccess(I); -  assert( -      NewAccess != nullptr && -      "Tried to create a memory access for a non-memory touching instruction"); -  NewAccess->setDefiningAccess(Definition); -  return NewAccess; -} - -MemoryAccess *MemorySSA::createMemoryAccessInBB(Instruction *I, -                                                MemoryAccess *Definition, -                                                const BasicBlock *BB, -                                                InsertionPlace Point) { -  MemoryUseOrDef *NewAccess = createDefinedAccess(I, Definition); -  auto *Accesses = getOrCreateAccessList(BB); -  if (Point == Beginning) { -    // It goes after any phi nodes -    auto AI = find_if( -        *Accesses, [](const MemoryAccess &MA) { return !isa<MemoryPhi>(MA); }); - -    Accesses->insert(AI, NewAccess); -  } else { -    Accesses->push_back(NewAccess); -  } -  BlockNumberingValid.erase(BB); -  return NewAccess; -} - -MemoryUseOrDef *MemorySSA::createMemoryAccessBefore(Instruction *I, -                                                    MemoryAccess *Definition, -                                                    MemoryUseOrDef *InsertPt) { -  assert(I->getParent() == InsertPt->getBlock() && -         "New and old access must be in the same block"); -  MemoryUseOrDef *NewAccess = createDefinedAccess(I, Definition); -  auto *Accesses = getOrCreateAccessList(InsertPt->getBlock()); -  Accesses->insert(AccessList::iterator(InsertPt), NewAccess); -  BlockNumberingValid.erase(InsertPt->getBlock()); -  return NewAccess; -} - -MemoryUseOrDef *MemorySSA::createMemoryAccessAfter(Instruction *I, -                                                   MemoryAccess *Definition, -                                                   MemoryAccess *InsertPt) { -  assert(I->getParent() == InsertPt->getBlock() && -         "New and old access must be in the same block"); -  MemoryUseOrDef *NewAccess = createDefinedAccess(I, Definition); -  auto *Accesses = getOrCreateAccessList(InsertPt->getBlock()); -  Accesses->insertAfter(AccessList::iterator(InsertPt), NewAccess); -  BlockNumberingValid.erase(InsertPt->getBlock()); -  return NewAccess; -} - -void MemorySSA::spliceMemoryAccessAbove(MemoryDef *Where, -                                        MemoryUseOrDef *What) { -  assert(What != getLiveOnEntryDef() && -         Where != getLiveOnEntryDef() && "Can't splice (above) LOE."); -  assert(dominates(Where, What) && "Only upwards splices are permitted."); - -  if (Where == What) -    return; -  if (isa<MemoryDef>(What)) { -    // TODO: possibly use removeMemoryAccess' more efficient RAUW -    What->replaceAllUsesWith(What->getDefiningAccess()); -    What->setDefiningAccess(Where->getDefiningAccess()); -    Where->setDefiningAccess(What); -  } -  AccessList *Src = getWritableBlockAccesses(What->getBlock()); -  AccessList *Dest = getWritableBlockAccesses(Where->getBlock()); -  Dest->splice(AccessList::iterator(Where), *Src, What); - -  BlockNumberingValid.erase(What->getBlock()); -  if (What->getBlock() != Where->getBlock()) -    BlockNumberingValid.erase(Where->getBlock()); -} - -/// \brief Helper function to create new memory accesses -MemoryUseOrDef *MemorySSA::createNewAccess(Instruction *I) { -  // The assume intrinsic has a control dependency which we model by claiming -  // that it writes arbitrarily. Ignore that fake memory dependency here. -  // FIXME: Replace this special casing with a more accurate modelling of -  // assume's control dependency. -  if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) -    if (II->getIntrinsicID() == Intrinsic::assume) -      return nullptr; - -  // Find out what affect this instruction has on memory. -  ModRefInfo ModRef = AA->getModRefInfo(I); -  bool Def = bool(ModRef & MRI_Mod); -  bool Use = bool(ModRef & MRI_Ref); - -  // It's possible for an instruction to not modify memory at all. During -  // construction, we ignore them. -  if (!Def && !Use) -    return nullptr; - -  assert((Def || Use) && -         "Trying to create a memory access with a non-memory instruction"); - -  MemoryUseOrDef *MUD; -  if (Def) -    MUD = new MemoryDef(I->getContext(), nullptr, I, I->getParent(), NextID++); -  else -    MUD = new MemoryUse(I->getContext(), nullptr, I, I->getParent()); -  ValueToMemoryAccess[I] = MUD; -  return MUD; -} - -MemoryAccess *MemorySSA::findDominatingDef(BasicBlock *UseBlock, -                                           enum InsertionPlace Where) { -  // Handle the initial case -  if (Where == Beginning) -    // The only thing that could define us at the beginning is a phi node -    if (MemoryPhi *Phi = getMemoryAccess(UseBlock)) -      return Phi; - -  DomTreeNode *CurrNode = DT->getNode(UseBlock); -  // Need to be defined by our dominator -  if (Where == Beginning) -    CurrNode = CurrNode->getIDom(); -  Where = End; -  while (CurrNode) { -    auto It = PerBlockAccesses.find(CurrNode->getBlock()); -    if (It != PerBlockAccesses.end()) { -      auto &Accesses = It->second; -      for (MemoryAccess &RA : reverse(*Accesses)) { -        if (isa<MemoryDef>(RA) || isa<MemoryPhi>(RA)) -          return &RA; -      } -    } -    CurrNode = CurrNode->getIDom(); -  } -  return LiveOnEntryDef.get(); -} - -/// \brief Returns true if \p Replacer dominates \p Replacee . -bool MemorySSA::dominatesUse(const MemoryAccess *Replacer, -                             const MemoryAccess *Replacee) const { -  if (isa<MemoryUseOrDef>(Replacee)) -    return DT->dominates(Replacer->getBlock(), Replacee->getBlock()); -  const auto *MP = cast<MemoryPhi>(Replacee); -  // For a phi node, the use occurs in the predecessor block of the phi node. -  // Since we may occur multiple times in the phi node, we have to check each -  // operand to ensure Replacer dominates each operand where Replacee occurs. -  for (const Use &Arg : MP->operands()) { -    if (Arg.get() != Replacee && -        !DT->dominates(Replacer->getBlock(), MP->getIncomingBlock(Arg))) -      return false; -  } -  return true; -} - -/// \brief If all arguments of a MemoryPHI are defined by the same incoming -/// argument, return that argument. -static MemoryAccess *onlySingleValue(MemoryPhi *MP) { -  MemoryAccess *MA = nullptr; - -  for (auto &Arg : MP->operands()) { -    if (!MA) -      MA = cast<MemoryAccess>(Arg); -    else if (MA != Arg) -      return nullptr; -  } -  return MA; -} - -/// \brief Properly remove \p MA from all of MemorySSA's lookup tables. -/// -/// Because of the way the intrusive list and use lists work, it is important to -/// do removal in the right order. -void MemorySSA::removeFromLookups(MemoryAccess *MA) { -  assert(MA->use_empty() && -         "Trying to remove memory access that still has uses"); -  BlockNumbering.erase(MA); -  if (MemoryUseOrDef *MUD = dyn_cast<MemoryUseOrDef>(MA)) -    MUD->setDefiningAccess(nullptr); -  // Invalidate our walker's cache if necessary -  if (!isa<MemoryUse>(MA)) -    Walker->invalidateInfo(MA); -  // The call below to erase will destroy MA, so we can't change the order we -  // are doing things here -  Value *MemoryInst; -  if (MemoryUseOrDef *MUD = dyn_cast<MemoryUseOrDef>(MA)) { -    MemoryInst = MUD->getMemoryInst(); -  } else { -    MemoryInst = MA->getBlock(); -  } -  auto VMA = ValueToMemoryAccess.find(MemoryInst); -  if (VMA->second == MA) -    ValueToMemoryAccess.erase(VMA); - -  auto AccessIt = PerBlockAccesses.find(MA->getBlock()); -  std::unique_ptr<AccessList> &Accesses = AccessIt->second; -  Accesses->erase(MA); -  if (Accesses->empty()) -    PerBlockAccesses.erase(AccessIt); -} - -void MemorySSA::removeMemoryAccess(MemoryAccess *MA) { -  assert(!isLiveOnEntryDef(MA) && "Trying to remove the live on entry def"); -  // We can only delete phi nodes if they have no uses, or we can replace all -  // uses with a single definition. -  MemoryAccess *NewDefTarget = nullptr; -  if (MemoryPhi *MP = dyn_cast<MemoryPhi>(MA)) { -    // Note that it is sufficient to know that all edges of the phi node have -    // the same argument.  If they do, by the definition of dominance frontiers -    // (which we used to place this phi), that argument must dominate this phi, -    // and thus, must dominate the phi's uses, and so we will not hit the assert -    // below. -    NewDefTarget = onlySingleValue(MP); -    assert((NewDefTarget || MP->use_empty()) && -           "We can't delete this memory phi"); -  } else { -    NewDefTarget = cast<MemoryUseOrDef>(MA)->getDefiningAccess(); -  } - -  // Re-point the uses at our defining access -  if (!MA->use_empty()) { -    // Reset optimized on users of this store, and reset the uses. -    // A few notes: -    // 1. This is a slightly modified version of RAUW to avoid walking the -    // uses twice here. -    // 2. If we wanted to be complete, we would have to reset the optimized -    // flags on users of phi nodes if doing the below makes a phi node have all -    // the same arguments. Instead, we prefer users to removeMemoryAccess those -    // phi nodes, because doing it here would be N^3. -    if (MA->hasValueHandle()) -      ValueHandleBase::ValueIsRAUWd(MA, NewDefTarget); -    // Note: We assume MemorySSA is not used in metadata since it's not really -    // part of the IR. - -    while (!MA->use_empty()) { -      Use &U = *MA->use_begin(); -      if (MemoryUse *MU = dyn_cast<MemoryUse>(U.getUser())) -        MU->resetOptimized(); -      U.set(NewDefTarget); -    } -  } - -  // The call below to erase will destroy MA, so we can't change the order we -  // are doing things here -  removeFromLookups(MA); -} - -void MemorySSA::print(raw_ostream &OS) const { -  MemorySSAAnnotatedWriter Writer(this); -  F.print(OS, &Writer); -} - -void MemorySSA::dump() const { -  MemorySSAAnnotatedWriter Writer(this); -  F.print(dbgs(), &Writer); -} - -void MemorySSA::verifyMemorySSA() const { -  verifyDefUses(F); -  verifyDomination(F); -  verifyOrdering(F); -  Walker->verify(this); -} - -/// \brief Verify that the order and existence of MemoryAccesses matches the -/// order and existence of memory affecting instructions. -void MemorySSA::verifyOrdering(Function &F) const { -  // Walk all the blocks, comparing what the lookups think and what the access -  // lists think, as well as the order in the blocks vs the order in the access -  // lists. -  SmallVector<MemoryAccess *, 32> ActualAccesses; -  for (BasicBlock &B : F) { -    const AccessList *AL = getBlockAccesses(&B); -    MemoryAccess *Phi = getMemoryAccess(&B); -    if (Phi) -      ActualAccesses.push_back(Phi); -    for (Instruction &I : B) { -      MemoryAccess *MA = getMemoryAccess(&I); -      assert((!MA || AL) && "We have memory affecting instructions " -                            "in this block but they are not in the " -                            "access list"); -      if (MA) -        ActualAccesses.push_back(MA); -    } -    // Either we hit the assert, really have no accesses, or we have both -    // accesses and an access list -    if (!AL) -      continue; -    assert(AL->size() == ActualAccesses.size() && -           "We don't have the same number of accesses in the block as on the " -           "access list"); -    auto ALI = AL->begin(); -    auto AAI = ActualAccesses.begin(); -    while (ALI != AL->end() && AAI != ActualAccesses.end()) { -      assert(&*ALI == *AAI && "Not the same accesses in the same order"); -      ++ALI; -      ++AAI; -    } -    ActualAccesses.clear(); -  } -} - -/// \brief Verify the domination properties of MemorySSA by checking that each -/// definition dominates all of its uses. -void MemorySSA::verifyDomination(Function &F) const { -#ifndef NDEBUG -  for (BasicBlock &B : F) { -    // Phi nodes are attached to basic blocks -    if (MemoryPhi *MP = getMemoryAccess(&B)) -      for (const Use &U : MP->uses()) -        assert(dominates(MP, U) && "Memory PHI does not dominate it's uses"); - -    for (Instruction &I : B) { -      MemoryAccess *MD = dyn_cast_or_null<MemoryDef>(getMemoryAccess(&I)); -      if (!MD) -        continue; - -      for (const Use &U : MD->uses()) -        assert(dominates(MD, U) && "Memory Def does not dominate it's uses"); -    } -  } -#endif -} - -/// \brief Verify the def-use lists in MemorySSA, by verifying that \p Use -/// appears in the use list of \p Def. - -void MemorySSA::verifyUseInDefs(MemoryAccess *Def, MemoryAccess *Use) const { -#ifndef NDEBUG -  // The live on entry use may cause us to get a NULL def here -  if (!Def) -    assert(isLiveOnEntryDef(Use) && -           "Null def but use not point to live on entry def"); -  else -    assert(is_contained(Def->users(), Use) && -           "Did not find use in def's use list"); -#endif -} - -/// \brief Verify the immediate use information, by walking all the memory -/// accesses and verifying that, for each use, it appears in the -/// appropriate def's use list -void MemorySSA::verifyDefUses(Function &F) const { -  for (BasicBlock &B : F) { -    // Phi nodes are attached to basic blocks -    if (MemoryPhi *Phi = getMemoryAccess(&B)) { -      assert(Phi->getNumOperands() == static_cast<unsigned>(std::distance( -                                          pred_begin(&B), pred_end(&B))) && -             "Incomplete MemoryPhi Node"); -      for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I) -        verifyUseInDefs(Phi->getIncomingValue(I), Phi); -    } - -    for (Instruction &I : B) { -      if (MemoryUseOrDef *MA = getMemoryAccess(&I)) { -        verifyUseInDefs(MA->getDefiningAccess(), MA); -      } -    } -  } -} - -MemoryUseOrDef *MemorySSA::getMemoryAccess(const Instruction *I) const { -  return cast_or_null<MemoryUseOrDef>(ValueToMemoryAccess.lookup(I)); -} - -MemoryPhi *MemorySSA::getMemoryAccess(const BasicBlock *BB) const { -  return cast_or_null<MemoryPhi>(ValueToMemoryAccess.lookup(cast<Value>(BB))); -} - -/// Perform a local numbering on blocks so that instruction ordering can be -/// determined in constant time. -/// TODO: We currently just number in order.  If we numbered by N, we could -/// allow at least N-1 sequences of insertBefore or insertAfter (and at least -/// log2(N) sequences of mixed before and after) without needing to invalidate -/// the numbering. -void MemorySSA::renumberBlock(const BasicBlock *B) const { -  // The pre-increment ensures the numbers really start at 1. -  unsigned long CurrentNumber = 0; -  const AccessList *AL = getBlockAccesses(B); -  assert(AL != nullptr && "Asking to renumber an empty block"); -  for (const auto &I : *AL) -    BlockNumbering[&I] = ++CurrentNumber; -  BlockNumberingValid.insert(B); -} - -/// \brief Determine, for two memory accesses in the same block, -/// whether \p Dominator dominates \p Dominatee. -/// \returns True if \p Dominator dominates \p Dominatee. -bool MemorySSA::locallyDominates(const MemoryAccess *Dominator, -                                 const MemoryAccess *Dominatee) const { - -  const BasicBlock *DominatorBlock = Dominator->getBlock(); - -  assert((DominatorBlock == Dominatee->getBlock()) && -         "Asking for local domination when accesses are in different blocks!"); -  // A node dominates itself. -  if (Dominatee == Dominator) -    return true; - -  // When Dominatee is defined on function entry, it is not dominated by another -  // memory access. -  if (isLiveOnEntryDef(Dominatee)) -    return false; - -  // When Dominator is defined on function entry, it dominates the other memory -  // access. -  if (isLiveOnEntryDef(Dominator)) -    return true; - -  if (!BlockNumberingValid.count(DominatorBlock)) -    renumberBlock(DominatorBlock); - -  unsigned long DominatorNum = BlockNumbering.lookup(Dominator); -  // All numbers start with 1 -  assert(DominatorNum != 0 && "Block was not numbered properly"); -  unsigned long DominateeNum = BlockNumbering.lookup(Dominatee); -  assert(DominateeNum != 0 && "Block was not numbered properly"); -  return DominatorNum < DominateeNum; -} - -bool MemorySSA::dominates(const MemoryAccess *Dominator, -                          const MemoryAccess *Dominatee) const { -  if (Dominator == Dominatee) -    return true; - -  if (isLiveOnEntryDef(Dominatee)) -    return false; - -  if (Dominator->getBlock() != Dominatee->getBlock()) -    return DT->dominates(Dominator->getBlock(), Dominatee->getBlock()); -  return locallyDominates(Dominator, Dominatee); -} - -bool MemorySSA::dominates(const MemoryAccess *Dominator, -                          const Use &Dominatee) const { -  if (MemoryPhi *MP = dyn_cast<MemoryPhi>(Dominatee.getUser())) { -    BasicBlock *UseBB = MP->getIncomingBlock(Dominatee); -    // The def must dominate the incoming block of the phi. -    if (UseBB != Dominator->getBlock()) -      return DT->dominates(Dominator->getBlock(), UseBB); -    // If the UseBB and the DefBB are the same, compare locally. -    return locallyDominates(Dominator, cast<MemoryAccess>(Dominatee)); -  } -  // If it's not a PHI node use, the normal dominates can already handle it. -  return dominates(Dominator, cast<MemoryAccess>(Dominatee.getUser())); -} - -const static char LiveOnEntryStr[] = "liveOnEntry"; - -void MemoryDef::print(raw_ostream &OS) const { -  MemoryAccess *UO = getDefiningAccess(); - -  OS << getID() << " = MemoryDef("; -  if (UO && UO->getID()) -    OS << UO->getID(); -  else -    OS << LiveOnEntryStr; -  OS << ')'; -} - -void MemoryPhi::print(raw_ostream &OS) const { -  bool First = true; -  OS << getID() << " = MemoryPhi("; -  for (const auto &Op : operands()) { -    BasicBlock *BB = getIncomingBlock(Op); -    MemoryAccess *MA = cast<MemoryAccess>(Op); -    if (!First) -      OS << ','; -    else -      First = false; - -    OS << '{'; -    if (BB->hasName()) -      OS << BB->getName(); -    else -      BB->printAsOperand(OS, false); -    OS << ','; -    if (unsigned ID = MA->getID()) -      OS << ID; -    else -      OS << LiveOnEntryStr; -    OS << '}'; -  } -  OS << ')'; -} - -MemoryAccess::~MemoryAccess() {} - -void MemoryUse::print(raw_ostream &OS) const { -  MemoryAccess *UO = getDefiningAccess(); -  OS << "MemoryUse("; -  if (UO && UO->getID()) -    OS << UO->getID(); -  else -    OS << LiveOnEntryStr; -  OS << ')'; -} - -void MemoryAccess::dump() const { -  print(dbgs()); -  dbgs() << "\n"; -} - -char MemorySSAPrinterLegacyPass::ID = 0; - -MemorySSAPrinterLegacyPass::MemorySSAPrinterLegacyPass() : FunctionPass(ID) { -  initializeMemorySSAPrinterLegacyPassPass(*PassRegistry::getPassRegistry()); -} - -void MemorySSAPrinterLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const { -  AU.setPreservesAll(); -  AU.addRequired<MemorySSAWrapperPass>(); -  AU.addPreserved<MemorySSAWrapperPass>(); -} - -bool MemorySSAPrinterLegacyPass::runOnFunction(Function &F) { -  auto &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA(); -  MSSA.print(dbgs()); -  if (VerifyMemorySSA) -    MSSA.verifyMemorySSA(); -  return false; -} - -AnalysisKey MemorySSAAnalysis::Key; - -MemorySSAAnalysis::Result MemorySSAAnalysis::run(Function &F, -                                                 FunctionAnalysisManager &AM) { -  auto &DT = AM.getResult<DominatorTreeAnalysis>(F); -  auto &AA = AM.getResult<AAManager>(F); -  return MemorySSAAnalysis::Result(make_unique<MemorySSA>(F, &AA, &DT)); -} - -PreservedAnalyses MemorySSAPrinterPass::run(Function &F, -                                            FunctionAnalysisManager &AM) { -  OS << "MemorySSA for function: " << F.getName() << "\n"; -  AM.getResult<MemorySSAAnalysis>(F).getMSSA().print(OS); - -  return PreservedAnalyses::all(); -} - -PreservedAnalyses MemorySSAVerifierPass::run(Function &F, -                                             FunctionAnalysisManager &AM) { -  AM.getResult<MemorySSAAnalysis>(F).getMSSA().verifyMemorySSA(); - -  return PreservedAnalyses::all(); -} - -char MemorySSAWrapperPass::ID = 0; - -MemorySSAWrapperPass::MemorySSAWrapperPass() : FunctionPass(ID) { -  initializeMemorySSAWrapperPassPass(*PassRegistry::getPassRegistry()); -} - -void MemorySSAWrapperPass::releaseMemory() { MSSA.reset(); } - -void MemorySSAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { -  AU.setPreservesAll(); -  AU.addRequiredTransitive<DominatorTreeWrapperPass>(); -  AU.addRequiredTransitive<AAResultsWrapperPass>(); -} - -bool MemorySSAWrapperPass::runOnFunction(Function &F) { -  auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); -  auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); -  MSSA.reset(new MemorySSA(F, &AA, &DT)); -  return false; -} - -void MemorySSAWrapperPass::verifyAnalysis() const { MSSA->verifyMemorySSA(); } - -void MemorySSAWrapperPass::print(raw_ostream &OS, const Module *M) const { -  MSSA->print(OS); -} - -MemorySSAWalker::MemorySSAWalker(MemorySSA *M) : MSSA(M) {} - -MemorySSA::CachingWalker::CachingWalker(MemorySSA *M, AliasAnalysis *A, -                                        DominatorTree *D) -    : MemorySSAWalker(M), Walker(*M, *A, *D, Cache), AutoResetWalker(true) {} - -MemorySSA::CachingWalker::~CachingWalker() {} - -void MemorySSA::CachingWalker::invalidateInfo(MemoryAccess *MA) { -  // TODO: We can do much better cache invalidation with differently stored -  // caches.  For now, for MemoryUses, we simply remove them -  // from the cache, and kill the entire call/non-call cache for everything -  // else.  The problem is for phis or defs, currently we'd need to follow use -  // chains down and invalidate anything below us in the chain that currently -  // terminates at this access. - -  // See if this is a MemoryUse, if so, just remove the cached info. MemoryUse -  // is by definition never a barrier, so nothing in the cache could point to -  // this use. In that case, we only need invalidate the info for the use -  // itself. - -  if (MemoryUse *MU = dyn_cast<MemoryUse>(MA)) { -    UpwardsMemoryQuery Q(MU->getMemoryInst(), MU); -    Cache.remove(MU, Q.StartingLoc, Q.IsCall); -    MU->resetOptimized(); -  } else { -    // If it is not a use, the best we can do right now is destroy the cache. -    Cache.clear(); -  } - -#ifdef EXPENSIVE_CHECKS -  verifyRemoved(MA); -#endif -} - -/// \brief Walk the use-def chains starting at \p MA and find -/// the MemoryAccess that actually clobbers Loc. -/// -/// \returns our clobbering memory access -MemoryAccess *MemorySSA::CachingWalker::getClobberingMemoryAccess( -    MemoryAccess *StartingAccess, UpwardsMemoryQuery &Q) { -  MemoryAccess *New = Walker.findClobber(StartingAccess, Q); -#ifdef EXPENSIVE_CHECKS -  MemoryAccess *NewNoCache = -      Walker.findClobber(StartingAccess, Q, /*UseWalkerCache=*/false); -  assert(NewNoCache == New && "Cache made us hand back a different result?"); -#endif -  if (AutoResetWalker) -    resetClobberWalker(); -  return New; -} - -MemoryAccess *MemorySSA::CachingWalker::getClobberingMemoryAccess( -    MemoryAccess *StartingAccess, const MemoryLocation &Loc) { -  if (isa<MemoryPhi>(StartingAccess)) -    return StartingAccess; - -  auto *StartingUseOrDef = cast<MemoryUseOrDef>(StartingAccess); -  if (MSSA->isLiveOnEntryDef(StartingUseOrDef)) -    return StartingUseOrDef; - -  Instruction *I = StartingUseOrDef->getMemoryInst(); - -  // Conservatively, fences are always clobbers, so don't perform the walk if we -  // hit a fence. -  if (!ImmutableCallSite(I) && I->isFenceLike()) -    return StartingUseOrDef; - -  UpwardsMemoryQuery Q; -  Q.OriginalAccess = StartingUseOrDef; -  Q.StartingLoc = Loc; -  Q.Inst = I; -  Q.IsCall = false; - -  if (auto *CacheResult = Cache.lookup(StartingUseOrDef, Loc, Q.IsCall)) -    return CacheResult; - -  // Unlike the other function, do not walk to the def of a def, because we are -  // handed something we already believe is the clobbering access. -  MemoryAccess *DefiningAccess = isa<MemoryUse>(StartingUseOrDef) -                                     ? StartingUseOrDef->getDefiningAccess() -                                     : StartingUseOrDef; - -  MemoryAccess *Clobber = getClobberingMemoryAccess(DefiningAccess, Q); -  DEBUG(dbgs() << "Starting Memory SSA clobber for " << *I << " is "); -  DEBUG(dbgs() << *StartingUseOrDef << "\n"); -  DEBUG(dbgs() << "Final Memory SSA clobber for " << *I << " is "); -  DEBUG(dbgs() << *Clobber << "\n"); -  return Clobber; -} - -MemoryAccess * -MemorySSA::CachingWalker::getClobberingMemoryAccess(MemoryAccess *MA) { -  auto *StartingAccess = dyn_cast<MemoryUseOrDef>(MA); -  // If this is a MemoryPhi, we can't do anything. -  if (!StartingAccess) -    return MA; - -  // If this is an already optimized use or def, return the optimized result. -  // Note: Currently, we do not store the optimized def result because we'd need -  // a separate field, since we can't use it as the defining access. -  if (MemoryUse *MU = dyn_cast<MemoryUse>(StartingAccess)) -    if (MU->isOptimized()) -      return MU->getDefiningAccess(); - -  const Instruction *I = StartingAccess->getMemoryInst(); -  UpwardsMemoryQuery Q(I, StartingAccess); -  // We can't sanely do anything with a fences, they conservatively -  // clobber all memory, and have no locations to get pointers from to -  // try to disambiguate. -  if (!Q.IsCall && I->isFenceLike()) -    return StartingAccess; - -  if (auto *CacheResult = Cache.lookup(StartingAccess, Q.StartingLoc, Q.IsCall)) -    return CacheResult; - -  if (isUseTriviallyOptimizableToLiveOnEntry(*MSSA->AA, I)) { -    MemoryAccess *LiveOnEntry = MSSA->getLiveOnEntryDef(); -    Cache.insert(StartingAccess, LiveOnEntry, Q.StartingLoc, Q.IsCall); -    if (MemoryUse *MU = dyn_cast<MemoryUse>(StartingAccess)) -      MU->setDefiningAccess(LiveOnEntry, true); -    return LiveOnEntry; -  } - -  // Start with the thing we already think clobbers this location -  MemoryAccess *DefiningAccess = StartingAccess->getDefiningAccess(); - -  // At this point, DefiningAccess may be the live on entry def. -  // If it is, we will not get a better result. -  if (MSSA->isLiveOnEntryDef(DefiningAccess)) -    return DefiningAccess; - -  MemoryAccess *Result = getClobberingMemoryAccess(DefiningAccess, Q); -  DEBUG(dbgs() << "Starting Memory SSA clobber for " << *I << " is "); -  DEBUG(dbgs() << *DefiningAccess << "\n"); -  DEBUG(dbgs() << "Final Memory SSA clobber for " << *I << " is "); -  DEBUG(dbgs() << *Result << "\n"); -  if (MemoryUse *MU = dyn_cast<MemoryUse>(StartingAccess)) -    MU->setDefiningAccess(Result, true); - -  return Result; -} - -// Verify that MA doesn't exist in any of the caches. -void MemorySSA::CachingWalker::verifyRemoved(MemoryAccess *MA) { -  assert(!Cache.contains(MA) && "Found removed MemoryAccess in cache."); -} - -MemoryAccess * -DoNothingMemorySSAWalker::getClobberingMemoryAccess(MemoryAccess *MA) { -  if (auto *Use = dyn_cast<MemoryUseOrDef>(MA)) -    return Use->getDefiningAccess(); -  return MA; -} - -MemoryAccess *DoNothingMemorySSAWalker::getClobberingMemoryAccess( -    MemoryAccess *StartingAccess, const MemoryLocation &) { -  if (auto *Use = dyn_cast<MemoryUseOrDef>(StartingAccess)) -    return Use->getDefiningAccess(); -  return StartingAccess; -} -} // namespace llvm diff --git a/lib/Transforms/Utils/MetaRenamer.cpp b/lib/Transforms/Utils/MetaRenamer.cpp index c999bd008fef..481c6aa29c3a 100644 --- a/lib/Transforms/Utils/MetaRenamer.cpp +++ b/lib/Transforms/Utils/MetaRenamer.cpp @@ -16,6 +16,7 @@  #include "llvm/Transforms/IPO.h"  #include "llvm/ADT/STLExtras.h"  #include "llvm/ADT/SmallString.h" +#include "llvm/Analysis/TargetLibraryInfo.h"  #include "llvm/IR/DerivedTypes.h"  #include "llvm/IR/Function.h"  #include "llvm/IR/Module.h" @@ -67,6 +68,7 @@ namespace {      }      void getAnalysisUsage(AnalysisUsage &AU) const override { +      AU.addRequired<TargetLibraryInfoWrapperPass>();        AU.setPreservesAll();      } @@ -110,9 +112,15 @@ namespace {        }        // Rename all functions +      const TargetLibraryInfo &TLI = +          getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();        for (auto &F : M) {          StringRef Name = F.getName(); -        if (Name.startswith("llvm.") || (!Name.empty() && Name[0] == 1)) +        LibFunc Tmp; +        // Leave library functions alone because their presence or absence could +        // affect the behavior of other passes. +        if (Name.startswith("llvm.") || (!Name.empty() && Name[0] == 1) || +            TLI.getLibFunc(F, Tmp))            continue;          F.setName(renamer.newName()); @@ -139,8 +147,11 @@ namespace {  }  char MetaRenamer::ID = 0; -INITIALIZE_PASS(MetaRenamer, "metarenamer",  -                "Assign new names to everything", false, false) +INITIALIZE_PASS_BEGIN(MetaRenamer, "metarenamer", +                      "Assign new names to everything", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(MetaRenamer, "metarenamer", +                    "Assign new names to everything", false, false)  //===----------------------------------------------------------------------===//  //  // MetaRenamer - Rename everything with metasyntactic names. diff --git a/lib/Transforms/Utils/ModuleUtils.cpp b/lib/Transforms/Utils/ModuleUtils.cpp index 0d623df77a67..dbe42c201dd4 100644 --- a/lib/Transforms/Utils/ModuleUtils.cpp +++ b/lib/Transforms/Utils/ModuleUtils.cpp @@ -130,13 +130,25 @@ void llvm::appendToCompilerUsed(Module &M, ArrayRef<GlobalValue *> Values) {  Function *llvm::checkSanitizerInterfaceFunction(Constant *FuncOrBitcast) {    if (isa<Function>(FuncOrBitcast))      return cast<Function>(FuncOrBitcast); -  FuncOrBitcast->dump(); +  FuncOrBitcast->print(errs()); +  errs() << '\n';    std::string Err;    raw_string_ostream Stream(Err);    Stream << "Sanitizer interface function redefined: " << *FuncOrBitcast;    report_fatal_error(Err);  } +Function *llvm::declareSanitizerInitFunction(Module &M, StringRef InitName, +                                             ArrayRef<Type *> InitArgTypes) { +  assert(!InitName.empty() && "Expected init function name"); +  Function *F = checkSanitizerInterfaceFunction(M.getOrInsertFunction( +      InitName, +      FunctionType::get(Type::getVoidTy(M.getContext()), InitArgTypes, false), +      AttributeList())); +  F->setLinkage(Function::ExternalLinkage); +  return F; +} +  std::pair<Function *, Function *> llvm::createSanitizerCtorAndInitFunctions(      Module &M, StringRef CtorName, StringRef InitName,      ArrayRef<Type *> InitArgTypes, ArrayRef<Value *> InitArgs, @@ -144,22 +156,19 @@ std::pair<Function *, Function *> llvm::createSanitizerCtorAndInitFunctions(    assert(!InitName.empty() && "Expected init function name");    assert(InitArgs.size() == InitArgTypes.size() &&           "Sanitizer's init function expects different number of arguments"); +  Function *InitFunction = +      declareSanitizerInitFunction(M, InitName, InitArgTypes);    Function *Ctor = Function::Create(        FunctionType::get(Type::getVoidTy(M.getContext()), false),        GlobalValue::InternalLinkage, CtorName, &M);    BasicBlock *CtorBB = BasicBlock::Create(M.getContext(), "", Ctor);    IRBuilder<> IRB(ReturnInst::Create(M.getContext(), CtorBB)); -  Function *InitFunction = -      checkSanitizerInterfaceFunction(M.getOrInsertFunction( -          InitName, FunctionType::get(IRB.getVoidTy(), InitArgTypes, false), -          AttributeSet())); -  InitFunction->setLinkage(Function::ExternalLinkage);    IRB.CreateCall(InitFunction, InitArgs);    if (!VersionCheckName.empty()) {      Function *VersionCheckFunction =          checkSanitizerInterfaceFunction(M.getOrInsertFunction(              VersionCheckName, FunctionType::get(IRB.getVoidTy(), {}, false), -            AttributeSet())); +            AttributeList()));      IRB.CreateCall(VersionCheckFunction, {});    }    return std::make_pair(Ctor, InitFunction); diff --git a/lib/Transforms/Utils/PredicateInfo.cpp b/lib/Transforms/Utils/PredicateInfo.cpp new file mode 100644 index 000000000000..8877aeafecde --- /dev/null +++ b/lib/Transforms/Utils/PredicateInfo.cpp @@ -0,0 +1,782 @@ +//===-- PredicateInfo.cpp - PredicateInfo Builder--------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------===// +// +// This file implements the PredicateInfo class. +// +//===----------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/PredicateInfo.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/OrderedBasicBlock.h" +#include "llvm/IR/AssemblyAnnotationWriter.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugCounter.h" +#include "llvm/Support/FormattedStream.h" +#include "llvm/Transforms/Scalar.h" +#include <algorithm> +#define DEBUG_TYPE "predicateinfo" +using namespace llvm; +using namespace PatternMatch; +using namespace llvm::PredicateInfoClasses; + +INITIALIZE_PASS_BEGIN(PredicateInfoPrinterLegacyPass, "print-predicateinfo", +                      "PredicateInfo Printer", false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_END(PredicateInfoPrinterLegacyPass, "print-predicateinfo", +                    "PredicateInfo Printer", false, false) +static cl::opt<bool> VerifyPredicateInfo( +    "verify-predicateinfo", cl::init(false), cl::Hidden, +    cl::desc("Verify PredicateInfo in legacy printer pass.")); +namespace { +DEBUG_COUNTER(RenameCounter, "predicateinfo-rename", +              "Controls which variables are renamed with predicateinfo") +// Given a predicate info that is a type of branching terminator, get the +// branching block. +const BasicBlock *getBranchBlock(const PredicateBase *PB) { +  assert(isa<PredicateWithEdge>(PB) && +         "Only branches and switches should have PHIOnly defs that " +         "require branch blocks."); +  return cast<PredicateWithEdge>(PB)->From; +} + +// Given a predicate info that is a type of branching terminator, get the +// branching terminator. +static Instruction *getBranchTerminator(const PredicateBase *PB) { +  assert(isa<PredicateWithEdge>(PB) && +         "Not a predicate info type we know how to get a terminator from."); +  return cast<PredicateWithEdge>(PB)->From->getTerminator(); +} + +// Given a predicate info that is a type of branching terminator, get the +// edge this predicate info represents +const std::pair<BasicBlock *, BasicBlock *> +getBlockEdge(const PredicateBase *PB) { +  assert(isa<PredicateWithEdge>(PB) && +         "Not a predicate info type we know how to get an edge from."); +  const auto *PEdge = cast<PredicateWithEdge>(PB); +  return std::make_pair(PEdge->From, PEdge->To); +} +} + +namespace llvm { +namespace PredicateInfoClasses { +enum LocalNum { +  // Operations that must appear first in the block. +  LN_First, +  // Operations that are somewhere in the middle of the block, and are sorted on +  // demand. +  LN_Middle, +  // Operations that must appear last in a block, like successor phi node uses. +  LN_Last +}; + +// Associate global and local DFS info with defs and uses, so we can sort them +// into a global domination ordering. +struct ValueDFS { +  int DFSIn = 0; +  int DFSOut = 0; +  unsigned int LocalNum = LN_Middle; +  // Only one of Def or Use will be set. +  Value *Def = nullptr; +  Use *U = nullptr; +  // Neither PInfo nor EdgeOnly participate in the ordering +  PredicateBase *PInfo = nullptr; +  bool EdgeOnly = false; +}; + +// This compares ValueDFS structures, creating OrderedBasicBlocks where +// necessary to compare uses/defs in the same block.  Doing so allows us to walk +// the minimum number of instructions necessary to compute our def/use ordering. +struct ValueDFS_Compare { +  DenseMap<const BasicBlock *, std::unique_ptr<OrderedBasicBlock>> &OBBMap; +  ValueDFS_Compare( +      DenseMap<const BasicBlock *, std::unique_ptr<OrderedBasicBlock>> &OBBMap) +      : OBBMap(OBBMap) {} +  bool operator()(const ValueDFS &A, const ValueDFS &B) const { +    if (&A == &B) +      return false; +    // The only case we can't directly compare them is when they in the same +    // block, and both have localnum == middle.  In that case, we have to use +    // comesbefore to see what the real ordering is, because they are in the +    // same basic block. + +    bool SameBlock = std::tie(A.DFSIn, A.DFSOut) == std::tie(B.DFSIn, B.DFSOut); + +    // We want to put the def that will get used for a given set of phi uses, +    // before those phi uses. +    // So we sort by edge, then by def. +    // Note that only phi nodes uses and defs can come last. +    if (SameBlock && A.LocalNum == LN_Last && B.LocalNum == LN_Last) +      return comparePHIRelated(A, B); + +    if (!SameBlock || A.LocalNum != LN_Middle || B.LocalNum != LN_Middle) +      return std::tie(A.DFSIn, A.DFSOut, A.LocalNum, A.Def, A.U) < +             std::tie(B.DFSIn, B.DFSOut, B.LocalNum, B.Def, B.U); +    return localComesBefore(A, B); +  } + +  // For a phi use, or a non-materialized def, return the edge it represents. +  const std::pair<BasicBlock *, BasicBlock *> +  getBlockEdge(const ValueDFS &VD) const { +    if (!VD.Def && VD.U) { +      auto *PHI = cast<PHINode>(VD.U->getUser()); +      return std::make_pair(PHI->getIncomingBlock(*VD.U), PHI->getParent()); +    } +    // This is really a non-materialized def. +    return ::getBlockEdge(VD.PInfo); +  } + +  // For two phi related values, return the ordering. +  bool comparePHIRelated(const ValueDFS &A, const ValueDFS &B) const { +    auto &ABlockEdge = getBlockEdge(A); +    auto &BBlockEdge = getBlockEdge(B); +    // Now sort by block edge and then defs before uses. +    return std::tie(ABlockEdge, A.Def, A.U) < std::tie(BBlockEdge, B.Def, B.U); +  } + +  // Get the definition of an instruction that occurs in the middle of a block. +  Value *getMiddleDef(const ValueDFS &VD) const { +    if (VD.Def) +      return VD.Def; +    // It's possible for the defs and uses to be null.  For branches, the local +    // numbering will say the placed predicaeinfos should go first (IE +    // LN_beginning), so we won't be in this function. For assumes, we will end +    // up here, beause we need to order the def we will place relative to the +    // assume.  So for the purpose of ordering, we pretend the def is the assume +    // because that is where we will insert the info. +    if (!VD.U) { +      assert(VD.PInfo && +             "No def, no use, and no predicateinfo should not occur"); +      assert(isa<PredicateAssume>(VD.PInfo) && +             "Middle of block should only occur for assumes"); +      return cast<PredicateAssume>(VD.PInfo)->AssumeInst; +    } +    return nullptr; +  } + +  // Return either the Def, if it's not null, or the user of the Use, if the def +  // is null. +  const Instruction *getDefOrUser(const Value *Def, const Use *U) const { +    if (Def) +      return cast<Instruction>(Def); +    return cast<Instruction>(U->getUser()); +  } + +  // This performs the necessary local basic block ordering checks to tell +  // whether A comes before B, where both are in the same basic block. +  bool localComesBefore(const ValueDFS &A, const ValueDFS &B) const { +    auto *ADef = getMiddleDef(A); +    auto *BDef = getMiddleDef(B); + +    // See if we have real values or uses. If we have real values, we are +    // guaranteed they are instructions or arguments. No matter what, we are +    // guaranteed they are in the same block if they are instructions. +    auto *ArgA = dyn_cast_or_null<Argument>(ADef); +    auto *ArgB = dyn_cast_or_null<Argument>(BDef); + +    if (ArgA && !ArgB) +      return true; +    if (ArgB && !ArgA) +      return false; +    if (ArgA && ArgB) +      return ArgA->getArgNo() < ArgB->getArgNo(); + +    auto *AInst = getDefOrUser(ADef, A.U); +    auto *BInst = getDefOrUser(BDef, B.U); + +    auto *BB = AInst->getParent(); +    auto LookupResult = OBBMap.find(BB); +    if (LookupResult != OBBMap.end()) +      return LookupResult->second->dominates(AInst, BInst); + +    auto Result = OBBMap.insert({BB, make_unique<OrderedBasicBlock>(BB)}); +    return Result.first->second->dominates(AInst, BInst); +  } +}; + +} // namespace PredicateInfoClasses + +bool PredicateInfo::stackIsInScope(const ValueDFSStack &Stack, +                                   const ValueDFS &VDUse) const { +  if (Stack.empty()) +    return false; +  // If it's a phi only use, make sure it's for this phi node edge, and that the +  // use is in a phi node.  If it's anything else, and the top of the stack is +  // EdgeOnly, we need to pop the stack.  We deliberately sort phi uses next to +  // the defs they must go with so that we can know it's time to pop the stack +  // when we hit the end of the phi uses for a given def. +  if (Stack.back().EdgeOnly) { +    if (!VDUse.U) +      return false; +    auto *PHI = dyn_cast<PHINode>(VDUse.U->getUser()); +    if (!PHI) +      return false; +    // Check edge +    BasicBlock *EdgePred = PHI->getIncomingBlock(*VDUse.U); +    if (EdgePred != getBranchBlock(Stack.back().PInfo)) +      return false; + +    // Use dominates, which knows how to handle edge dominance. +    return DT.dominates(getBlockEdge(Stack.back().PInfo), *VDUse.U); +  } + +  return (VDUse.DFSIn >= Stack.back().DFSIn && +          VDUse.DFSOut <= Stack.back().DFSOut); +} + +void PredicateInfo::popStackUntilDFSScope(ValueDFSStack &Stack, +                                          const ValueDFS &VD) { +  while (!Stack.empty() && !stackIsInScope(Stack, VD)) +    Stack.pop_back(); +} + +// Convert the uses of Op into a vector of uses, associating global and local +// DFS info with each one. +void PredicateInfo::convertUsesToDFSOrdered( +    Value *Op, SmallVectorImpl<ValueDFS> &DFSOrderedSet) { +  for (auto &U : Op->uses()) { +    if (auto *I = dyn_cast<Instruction>(U.getUser())) { +      ValueDFS VD; +      // Put the phi node uses in the incoming block. +      BasicBlock *IBlock; +      if (auto *PN = dyn_cast<PHINode>(I)) { +        IBlock = PN->getIncomingBlock(U); +        // Make phi node users appear last in the incoming block +        // they are from. +        VD.LocalNum = LN_Last; +      } else { +        // If it's not a phi node use, it is somewhere in the middle of the +        // block. +        IBlock = I->getParent(); +        VD.LocalNum = LN_Middle; +      } +      DomTreeNode *DomNode = DT.getNode(IBlock); +      // It's possible our use is in an unreachable block. Skip it if so. +      if (!DomNode) +        continue; +      VD.DFSIn = DomNode->getDFSNumIn(); +      VD.DFSOut = DomNode->getDFSNumOut(); +      VD.U = &U; +      DFSOrderedSet.push_back(VD); +    } +  } +} + +// Collect relevant operations from Comparison that we may want to insert copies +// for. +void collectCmpOps(CmpInst *Comparison, SmallVectorImpl<Value *> &CmpOperands) { +  auto *Op0 = Comparison->getOperand(0); +  auto *Op1 = Comparison->getOperand(1); +  if (Op0 == Op1) +    return; +  CmpOperands.push_back(Comparison); +  // Only want real values, not constants.  Additionally, operands with one use +  // are only being used in the comparison, which means they will not be useful +  // for us to consider for predicateinfo. +  // +  if ((isa<Instruction>(Op0) || isa<Argument>(Op0)) && !Op0->hasOneUse()) +    CmpOperands.push_back(Op0); +  if ((isa<Instruction>(Op1) || isa<Argument>(Op1)) && !Op1->hasOneUse()) +    CmpOperands.push_back(Op1); +} + +// Add Op, PB to the list of value infos for Op, and mark Op to be renamed. +void PredicateInfo::addInfoFor(SmallPtrSetImpl<Value *> &OpsToRename, Value *Op, +                               PredicateBase *PB) { +  OpsToRename.insert(Op); +  auto &OperandInfo = getOrCreateValueInfo(Op); +  AllInfos.push_back(PB); +  OperandInfo.Infos.push_back(PB); +} + +// Process an assume instruction and place relevant operations we want to rename +// into OpsToRename. +void PredicateInfo::processAssume(IntrinsicInst *II, BasicBlock *AssumeBB, +                                  SmallPtrSetImpl<Value *> &OpsToRename) { +  // See if we have a comparison we support +  SmallVector<Value *, 8> CmpOperands; +  SmallVector<Value *, 2> ConditionsToProcess; +  CmpInst::Predicate Pred; +  Value *Operand = II->getOperand(0); +  if (m_c_And(m_Cmp(Pred, m_Value(), m_Value()), +              m_Cmp(Pred, m_Value(), m_Value())) +          .match(II->getOperand(0))) { +    ConditionsToProcess.push_back(cast<BinaryOperator>(Operand)->getOperand(0)); +    ConditionsToProcess.push_back(cast<BinaryOperator>(Operand)->getOperand(1)); +    ConditionsToProcess.push_back(Operand); +  } else if (isa<CmpInst>(Operand)) { + +    ConditionsToProcess.push_back(Operand); +  } +  for (auto Cond : ConditionsToProcess) { +    if (auto *Cmp = dyn_cast<CmpInst>(Cond)) { +      collectCmpOps(Cmp, CmpOperands); +      // Now add our copy infos for our operands +      for (auto *Op : CmpOperands) { +        auto *PA = new PredicateAssume(Op, II, Cmp); +        addInfoFor(OpsToRename, Op, PA); +      } +      CmpOperands.clear(); +    } else if (auto *BinOp = dyn_cast<BinaryOperator>(Cond)) { +      // Otherwise, it should be an AND. +      assert(BinOp->getOpcode() == Instruction::And && +             "Should have been an AND"); +      auto *PA = new PredicateAssume(BinOp, II, BinOp); +      addInfoFor(OpsToRename, BinOp, PA); +    } else { +      llvm_unreachable("Unknown type of condition"); +    } +  } +} + +// Process a block terminating branch, and place relevant operations to be +// renamed into OpsToRename. +void PredicateInfo::processBranch(BranchInst *BI, BasicBlock *BranchBB, +                                  SmallPtrSetImpl<Value *> &OpsToRename) { +  BasicBlock *FirstBB = BI->getSuccessor(0); +  BasicBlock *SecondBB = BI->getSuccessor(1); +  SmallVector<BasicBlock *, 2> SuccsToProcess; +  SuccsToProcess.push_back(FirstBB); +  SuccsToProcess.push_back(SecondBB); +  SmallVector<Value *, 2> ConditionsToProcess; + +  auto InsertHelper = [&](Value *Op, bool isAnd, bool isOr, Value *Cond) { +    for (auto *Succ : SuccsToProcess) { +      // Don't try to insert on a self-edge. This is mainly because we will +      // eliminate during renaming anyway. +      if (Succ == BranchBB) +        continue; +      bool TakenEdge = (Succ == FirstBB); +      // For and, only insert on the true edge +      // For or, only insert on the false edge +      if ((isAnd && !TakenEdge) || (isOr && TakenEdge)) +        continue; +      PredicateBase *PB = +          new PredicateBranch(Op, BranchBB, Succ, Cond, TakenEdge); +      addInfoFor(OpsToRename, Op, PB); +      if (!Succ->getSinglePredecessor()) +        EdgeUsesOnly.insert({BranchBB, Succ}); +    } +  }; + +  // Match combinations of conditions. +  CmpInst::Predicate Pred; +  bool isAnd = false; +  bool isOr = false; +  SmallVector<Value *, 8> CmpOperands; +  if (match(BI->getCondition(), m_And(m_Cmp(Pred, m_Value(), m_Value()), +                                      m_Cmp(Pred, m_Value(), m_Value()))) || +      match(BI->getCondition(), m_Or(m_Cmp(Pred, m_Value(), m_Value()), +                                     m_Cmp(Pred, m_Value(), m_Value())))) { +    auto *BinOp = cast<BinaryOperator>(BI->getCondition()); +    if (BinOp->getOpcode() == Instruction::And) +      isAnd = true; +    else if (BinOp->getOpcode() == Instruction::Or) +      isOr = true; +    ConditionsToProcess.push_back(BinOp->getOperand(0)); +    ConditionsToProcess.push_back(BinOp->getOperand(1)); +    ConditionsToProcess.push_back(BI->getCondition()); +  } else if (isa<CmpInst>(BI->getCondition())) { +    ConditionsToProcess.push_back(BI->getCondition()); +  } +  for (auto Cond : ConditionsToProcess) { +    if (auto *Cmp = dyn_cast<CmpInst>(Cond)) { +      collectCmpOps(Cmp, CmpOperands); +      // Now add our copy infos for our operands +      for (auto *Op : CmpOperands) +        InsertHelper(Op, isAnd, isOr, Cmp); +    } else if (auto *BinOp = dyn_cast<BinaryOperator>(Cond)) { +      // This must be an AND or an OR. +      assert((BinOp->getOpcode() == Instruction::And || +              BinOp->getOpcode() == Instruction::Or) && +             "Should have been an AND or an OR"); +      // The actual value of the binop is not subject to the same restrictions +      // as the comparison. It's either true or false on the true/false branch. +      InsertHelper(BinOp, false, false, BinOp); +    } else { +      llvm_unreachable("Unknown type of condition"); +    } +    CmpOperands.clear(); +  } +} +// Process a block terminating switch, and place relevant operations to be +// renamed into OpsToRename. +void PredicateInfo::processSwitch(SwitchInst *SI, BasicBlock *BranchBB, +                                  SmallPtrSetImpl<Value *> &OpsToRename) { +  Value *Op = SI->getCondition(); +  if ((!isa<Instruction>(Op) && !isa<Argument>(Op)) || Op->hasOneUse()) +    return; + +  // Remember how many outgoing edges there are to every successor. +  SmallDenseMap<BasicBlock *, unsigned, 16> SwitchEdges; +  for (unsigned i = 0, e = SI->getNumSuccessors(); i != e; ++i) { +    BasicBlock *TargetBlock = SI->getSuccessor(i); +    ++SwitchEdges[TargetBlock]; +  } + +  // Now propagate info for each case value +  for (auto C : SI->cases()) { +    BasicBlock *TargetBlock = C.getCaseSuccessor(); +    if (SwitchEdges.lookup(TargetBlock) == 1) { +      PredicateSwitch *PS = new PredicateSwitch( +          Op, SI->getParent(), TargetBlock, C.getCaseValue(), SI); +      addInfoFor(OpsToRename, Op, PS); +      if (!TargetBlock->getSinglePredecessor()) +        EdgeUsesOnly.insert({BranchBB, TargetBlock}); +    } +  } +} + +// Build predicate info for our function +void PredicateInfo::buildPredicateInfo() { +  DT.updateDFSNumbers(); +  // Collect operands to rename from all conditional branch terminators, as well +  // as assume statements. +  SmallPtrSet<Value *, 8> OpsToRename; +  for (auto DTN : depth_first(DT.getRootNode())) { +    BasicBlock *BranchBB = DTN->getBlock(); +    if (auto *BI = dyn_cast<BranchInst>(BranchBB->getTerminator())) { +      if (!BI->isConditional()) +        continue; +      processBranch(BI, BranchBB, OpsToRename); +    } else if (auto *SI = dyn_cast<SwitchInst>(BranchBB->getTerminator())) { +      processSwitch(SI, BranchBB, OpsToRename); +    } +  } +  for (auto &Assume : AC.assumptions()) { +    if (auto *II = dyn_cast_or_null<IntrinsicInst>(Assume)) +      processAssume(II, II->getParent(), OpsToRename); +  } +  // Now rename all our operations. +  renameUses(OpsToRename); +} + +// Given the renaming stack, make all the operands currently on the stack real +// by inserting them into the IR.  Return the last operation's value. +Value *PredicateInfo::materializeStack(unsigned int &Counter, +                                       ValueDFSStack &RenameStack, +                                       Value *OrigOp) { +  // Find the first thing we have to materialize +  auto RevIter = RenameStack.rbegin(); +  for (; RevIter != RenameStack.rend(); ++RevIter) +    if (RevIter->Def) +      break; + +  size_t Start = RevIter - RenameStack.rbegin(); +  // The maximum number of things we should be trying to materialize at once +  // right now is 4, depending on if we had an assume, a branch, and both used +  // and of conditions. +  for (auto RenameIter = RenameStack.end() - Start; +       RenameIter != RenameStack.end(); ++RenameIter) { +    auto *Op = +        RenameIter == RenameStack.begin() ? OrigOp : (RenameIter - 1)->Def; +    ValueDFS &Result = *RenameIter; +    auto *ValInfo = Result.PInfo; +    // For edge predicates, we can just place the operand in the block before +    // the terminator.  For assume, we have to place it right before the assume +    // to ensure we dominate all of our uses.  Always insert right before the +    // relevant instruction (terminator, assume), so that we insert in proper +    // order in the case of multiple predicateinfo in the same block. +    if (isa<PredicateWithEdge>(ValInfo)) { +      IRBuilder<> B(getBranchTerminator(ValInfo)); +      Function *IF = Intrinsic::getDeclaration( +          F.getParent(), Intrinsic::ssa_copy, Op->getType()); +      CallInst *PIC = +          B.CreateCall(IF, Op, Op->getName() + "." + Twine(Counter++)); +      PredicateMap.insert({PIC, ValInfo}); +      Result.Def = PIC; +    } else { +      auto *PAssume = dyn_cast<PredicateAssume>(ValInfo); +      assert(PAssume && +             "Should not have gotten here without it being an assume"); +      IRBuilder<> B(PAssume->AssumeInst); +      Function *IF = Intrinsic::getDeclaration( +          F.getParent(), Intrinsic::ssa_copy, Op->getType()); +      CallInst *PIC = B.CreateCall(IF, Op); +      PredicateMap.insert({PIC, ValInfo}); +      Result.Def = PIC; +    } +  } +  return RenameStack.back().Def; +} + +// Instead of the standard SSA renaming algorithm, which is O(Number of +// instructions), and walks the entire dominator tree, we walk only the defs + +// uses.  The standard SSA renaming algorithm does not really rely on the +// dominator tree except to order the stack push/pops of the renaming stacks, so +// that defs end up getting pushed before hitting the correct uses.  This does +// not require the dominator tree, only the *order* of the dominator tree. The +// complete and correct ordering of the defs and uses, in dominator tree is +// contained in the DFS numbering of the dominator tree. So we sort the defs and +// uses into the DFS ordering, and then just use the renaming stack as per +// normal, pushing when we hit a def (which is a predicateinfo instruction), +// popping when we are out of the dfs scope for that def, and replacing any uses +// with top of stack if it exists.  In order to handle liveness without +// propagating liveness info, we don't actually insert the predicateinfo +// instruction def until we see a use that it would dominate.  Once we see such +// a use, we materialize the predicateinfo instruction in the right place and +// use it. +// +// TODO: Use this algorithm to perform fast single-variable renaming in +// promotememtoreg and memoryssa. +void PredicateInfo::renameUses(SmallPtrSetImpl<Value *> &OpsToRename) { +  ValueDFS_Compare Compare(OBBMap); +  // Compute liveness, and rename in O(uses) per Op. +  for (auto *Op : OpsToRename) { +    unsigned Counter = 0; +    SmallVector<ValueDFS, 16> OrderedUses; +    const auto &ValueInfo = getValueInfo(Op); +    // Insert the possible copies into the def/use list. +    // They will become real copies if we find a real use for them, and never +    // created otherwise. +    for (auto &PossibleCopy : ValueInfo.Infos) { +      ValueDFS VD; +      // Determine where we are going to place the copy by the copy type. +      // The predicate info for branches always come first, they will get +      // materialized in the split block at the top of the block. +      // The predicate info for assumes will be somewhere in the middle, +      // it will get materialized in front of the assume. +      if (const auto *PAssume = dyn_cast<PredicateAssume>(PossibleCopy)) { +        VD.LocalNum = LN_Middle; +        DomTreeNode *DomNode = DT.getNode(PAssume->AssumeInst->getParent()); +        if (!DomNode) +          continue; +        VD.DFSIn = DomNode->getDFSNumIn(); +        VD.DFSOut = DomNode->getDFSNumOut(); +        VD.PInfo = PossibleCopy; +        OrderedUses.push_back(VD); +      } else if (isa<PredicateWithEdge>(PossibleCopy)) { +        // If we can only do phi uses, we treat it like it's in the branch +        // block, and handle it specially. We know that it goes last, and only +        // dominate phi uses. +        auto BlockEdge = getBlockEdge(PossibleCopy); +        if (EdgeUsesOnly.count(BlockEdge)) { +          VD.LocalNum = LN_Last; +          auto *DomNode = DT.getNode(BlockEdge.first); +          if (DomNode) { +            VD.DFSIn = DomNode->getDFSNumIn(); +            VD.DFSOut = DomNode->getDFSNumOut(); +            VD.PInfo = PossibleCopy; +            VD.EdgeOnly = true; +            OrderedUses.push_back(VD); +          } +        } else { +          // Otherwise, we are in the split block (even though we perform +          // insertion in the branch block). +          // Insert a possible copy at the split block and before the branch. +          VD.LocalNum = LN_First; +          auto *DomNode = DT.getNode(BlockEdge.second); +          if (DomNode) { +            VD.DFSIn = DomNode->getDFSNumIn(); +            VD.DFSOut = DomNode->getDFSNumOut(); +            VD.PInfo = PossibleCopy; +            OrderedUses.push_back(VD); +          } +        } +      } +    } + +    convertUsesToDFSOrdered(Op, OrderedUses); +    std::sort(OrderedUses.begin(), OrderedUses.end(), Compare); +    SmallVector<ValueDFS, 8> RenameStack; +    // For each use, sorted into dfs order, push values and replaces uses with +    // top of stack, which will represent the reaching def. +    for (auto &VD : OrderedUses) { +      // We currently do not materialize copy over copy, but we should decide if +      // we want to. +      bool PossibleCopy = VD.PInfo != nullptr; +      if (RenameStack.empty()) { +        DEBUG(dbgs() << "Rename Stack is empty\n"); +      } else { +        DEBUG(dbgs() << "Rename Stack Top DFS numbers are (" +                     << RenameStack.back().DFSIn << "," +                     << RenameStack.back().DFSOut << ")\n"); +      } + +      DEBUG(dbgs() << "Current DFS numbers are (" << VD.DFSIn << "," +                   << VD.DFSOut << ")\n"); + +      bool ShouldPush = (VD.Def || PossibleCopy); +      bool OutOfScope = !stackIsInScope(RenameStack, VD); +      if (OutOfScope || ShouldPush) { +        // Sync to our current scope. +        popStackUntilDFSScope(RenameStack, VD); +        if (ShouldPush) { +          RenameStack.push_back(VD); +        } +      } +      // If we get to this point, and the stack is empty we must have a use +      // with no renaming needed, just skip it. +      if (RenameStack.empty()) +        continue; +      // Skip values, only want to rename the uses +      if (VD.Def || PossibleCopy) +        continue; +      if (!DebugCounter::shouldExecute(RenameCounter)) { +        DEBUG(dbgs() << "Skipping execution due to debug counter\n"); +        continue; +      } +      ValueDFS &Result = RenameStack.back(); + +      // If the possible copy dominates something, materialize our stack up to +      // this point. This ensures every comparison that affects our operation +      // ends up with predicateinfo. +      if (!Result.Def) +        Result.Def = materializeStack(Counter, RenameStack, Op); + +      DEBUG(dbgs() << "Found replacement " << *Result.Def << " for " +                   << *VD.U->get() << " in " << *(VD.U->getUser()) << "\n"); +      assert(DT.dominates(cast<Instruction>(Result.Def), *VD.U) && +             "Predicateinfo def should have dominated this use"); +      VD.U->set(Result.Def); +    } +  } +} + +PredicateInfo::ValueInfo &PredicateInfo::getOrCreateValueInfo(Value *Operand) { +  auto OIN = ValueInfoNums.find(Operand); +  if (OIN == ValueInfoNums.end()) { +    // This will grow it +    ValueInfos.resize(ValueInfos.size() + 1); +    // This will use the new size and give us a 0 based number of the info +    auto InsertResult = ValueInfoNums.insert({Operand, ValueInfos.size() - 1}); +    assert(InsertResult.second && "Value info number already existed?"); +    return ValueInfos[InsertResult.first->second]; +  } +  return ValueInfos[OIN->second]; +} + +const PredicateInfo::ValueInfo & +PredicateInfo::getValueInfo(Value *Operand) const { +  auto OINI = ValueInfoNums.lookup(Operand); +  assert(OINI != 0 && "Operand was not really in the Value Info Numbers"); +  assert(OINI < ValueInfos.size() && +         "Value Info Number greater than size of Value Info Table"); +  return ValueInfos[OINI]; +} + +PredicateInfo::PredicateInfo(Function &F, DominatorTree &DT, +                             AssumptionCache &AC) +    : F(F), DT(DT), AC(AC) { +  // Push an empty operand info so that we can detect 0 as not finding one +  ValueInfos.resize(1); +  buildPredicateInfo(); +} + +PredicateInfo::~PredicateInfo() {} + +void PredicateInfo::verifyPredicateInfo() const {} + +char PredicateInfoPrinterLegacyPass::ID = 0; + +PredicateInfoPrinterLegacyPass::PredicateInfoPrinterLegacyPass() +    : FunctionPass(ID) { +  initializePredicateInfoPrinterLegacyPassPass( +      *PassRegistry::getPassRegistry()); +} + +void PredicateInfoPrinterLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +  AU.addRequiredTransitive<DominatorTreeWrapperPass>(); +  AU.addRequired<AssumptionCacheTracker>(); +} + +bool PredicateInfoPrinterLegacyPass::runOnFunction(Function &F) { +  auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); +  auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); +  auto PredInfo = make_unique<PredicateInfo>(F, DT, AC); +  PredInfo->print(dbgs()); +  if (VerifyPredicateInfo) +    PredInfo->verifyPredicateInfo(); +  return false; +} + +PreservedAnalyses PredicateInfoPrinterPass::run(Function &F, +                                                FunctionAnalysisManager &AM) { +  auto &DT = AM.getResult<DominatorTreeAnalysis>(F); +  auto &AC = AM.getResult<AssumptionAnalysis>(F); +  OS << "PredicateInfo for function: " << F.getName() << "\n"; +  make_unique<PredicateInfo>(F, DT, AC)->print(OS); + +  return PreservedAnalyses::all(); +} + +/// \brief An assembly annotator class to print PredicateInfo information in +/// comments. +class PredicateInfoAnnotatedWriter : public AssemblyAnnotationWriter { +  friend class PredicateInfo; +  const PredicateInfo *PredInfo; + +public: +  PredicateInfoAnnotatedWriter(const PredicateInfo *M) : PredInfo(M) {} + +  virtual void emitBasicBlockStartAnnot(const BasicBlock *BB, +                                        formatted_raw_ostream &OS) {} + +  virtual void emitInstructionAnnot(const Instruction *I, +                                    formatted_raw_ostream &OS) { +    if (const auto *PI = PredInfo->getPredicateInfoFor(I)) { +      OS << "; Has predicate info\n"; +      if (const auto *PB = dyn_cast<PredicateBranch>(PI)) { +        OS << "; branch predicate info { TrueEdge: " << PB->TrueEdge +           << " Comparison:" << *PB->Condition << " Edge: ["; +        PB->From->printAsOperand(OS); +        OS << ","; +        PB->To->printAsOperand(OS); +        OS << "] }\n"; +      } else if (const auto *PS = dyn_cast<PredicateSwitch>(PI)) { +        OS << "; switch predicate info { CaseValue: " << *PS->CaseValue +           << " Switch:" << *PS->Switch << " Edge: ["; +        PS->From->printAsOperand(OS); +        OS << ","; +        PS->To->printAsOperand(OS); +        OS << "] }\n"; +      } else if (const auto *PA = dyn_cast<PredicateAssume>(PI)) { +        OS << "; assume predicate info {" +           << " Comparison:" << *PA->Condition << " }\n"; +      } +    } +  } +}; + +void PredicateInfo::print(raw_ostream &OS) const { +  PredicateInfoAnnotatedWriter Writer(this); +  F.print(OS, &Writer); +} + +void PredicateInfo::dump() const { +  PredicateInfoAnnotatedWriter Writer(this); +  F.print(dbgs(), &Writer); +} + +PreservedAnalyses PredicateInfoVerifierPass::run(Function &F, +                                                 FunctionAnalysisManager &AM) { +  auto &DT = AM.getResult<DominatorTreeAnalysis>(F); +  auto &AC = AM.getResult<AssumptionAnalysis>(F); +  make_unique<PredicateInfo>(F, DT, AC)->verifyPredicateInfo(); + +  return PreservedAnalyses::all(); +} +} diff --git a/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/lib/Transforms/Utils/PromoteMemoryToRegister.cpp index 35faa6f65efd..a33b85c4ee69 100644 --- a/lib/Transforms/Utils/PromoteMemoryToRegister.cpp +++ b/lib/Transforms/Utils/PromoteMemoryToRegister.cpp @@ -15,7 +15,6 @@  //  //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Utils/PromoteMemToReg.h"  #include "llvm/ADT/ArrayRef.h"  #include "llvm/ADT/DenseMap.h"  #include "llvm/ADT/STLExtras.h" @@ -23,6 +22,7 @@  #include "llvm/ADT/SmallVector.h"  #include "llvm/ADT/Statistic.h"  #include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/AssumptionCache.h"  #include "llvm/Analysis/InstructionSimplify.h"  #include "llvm/Analysis/IteratedDominanceFrontier.h"  #include "llvm/Analysis/ValueTracking.h" @@ -38,6 +38,7 @@  #include "llvm/IR/Metadata.h"  #include "llvm/IR/Module.h"  #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/PromoteMemToReg.h"  #include <algorithm>  using namespace llvm; @@ -225,9 +226,6 @@ struct PromoteMem2Reg {    DominatorTree &DT;    DIBuilder DIB; -  /// An AliasSetTracker object to update.  If null, don't update it. -  AliasSetTracker *AST; -    /// A cache of @llvm.assume intrinsics used by SimplifyInstruction.    AssumptionCache *AC; @@ -269,10 +267,10 @@ struct PromoteMem2Reg {  public:    PromoteMem2Reg(ArrayRef<AllocaInst *> Allocas, DominatorTree &DT, -                 AliasSetTracker *AST, AssumptionCache *AC) +                 AssumptionCache *AC)        : Allocas(Allocas.begin(), Allocas.end()), DT(DT),          DIB(*DT.getRoot()->getParent()->getParent(), /*AllowUnresolved*/ false), -        AST(AST), AC(AC) {} +        AC(AC) {}    void run(); @@ -301,6 +299,18 @@ private:  } // end of anonymous namespace +/// Given a LoadInst LI this adds assume(LI != null) after it. +static void addAssumeNonNull(AssumptionCache *AC, LoadInst *LI) { +  Function *AssumeIntrinsic = +      Intrinsic::getDeclaration(LI->getModule(), Intrinsic::assume); +  ICmpInst *LoadNotNull = new ICmpInst(ICmpInst::ICMP_NE, LI, +                                       Constant::getNullValue(LI->getType())); +  LoadNotNull->insertAfter(LI); +  CallInst *CI = CallInst::Create(AssumeIntrinsic, {LoadNotNull}); +  CI->insertAfter(LoadNotNull); +  AC->registerAssumption(CI); +} +  static void removeLifetimeIntrinsicUsers(AllocaInst *AI) {    // Knowing that this alloca is promotable, we know that it's safe to kill all    // instructions except for load and store. @@ -334,9 +344,8 @@ static void removeLifetimeIntrinsicUsers(AllocaInst *AI) {  /// and thus must be phi-ed with undef. We fall back to the standard alloca  /// promotion algorithm in that case.  static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, -                                     LargeBlockInfo &LBI, -                                     DominatorTree &DT, -                                     AliasSetTracker *AST) { +                                     LargeBlockInfo &LBI, DominatorTree &DT, +                                     AssumptionCache *AC) {    StoreInst *OnlyStore = Info.OnlyStore;    bool StoringGlobalVal = !isa<Instruction>(OnlyStore->getOperand(0));    BasicBlock *StoreBB = OnlyStore->getParent(); @@ -387,9 +396,15 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info,      // code.      if (ReplVal == LI)        ReplVal = UndefValue::get(LI->getType()); + +    // If the load was marked as nonnull we don't want to lose +    // that information when we erase this Load. So we preserve +    // it with an assume. +    if (AC && LI->getMetadata(LLVMContext::MD_nonnull) && +        !llvm::isKnownNonNullAt(ReplVal, LI, &DT)) +      addAssumeNonNull(AC, LI); +      LI->replaceAllUsesWith(ReplVal); -    if (AST && LI->getType()->isPointerTy()) -      AST->deleteValue(LI);      LI->eraseFromParent();      LBI.deleteValue(LI);    } @@ -410,8 +425,6 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info,    Info.OnlyStore->eraseFromParent();    LBI.deleteValue(Info.OnlyStore); -  if (AST) -    AST->deleteValue(AI);    AI->eraseFromParent();    LBI.deleteValue(AI);    return true; @@ -435,7 +448,8 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info,  ///  }  static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info,                                       LargeBlockInfo &LBI, -                                     AliasSetTracker *AST) { +                                     DominatorTree &DT, +                                     AssumptionCache *AC) {    // The trickiest case to handle is when we have large blocks. Because of this,    // this code is optimized assuming that large blocks happen.  This does not    // significantly pessimize the small block case.  This uses LargeBlockInfo to @@ -476,13 +490,18 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info,          // There is no store before this load, bail out (load may be affected          // by the following stores - see main comment).          return false; -    } -    else +    } else {        // Otherwise, there was a store before this load, the load takes its value. -      LI->replaceAllUsesWith(std::prev(I)->second->getOperand(0)); +      // Note, if the load was marked as nonnull we don't want to lose that +      // information when we erase it. So we preserve it with an assume. +      Value *ReplVal = std::prev(I)->second->getOperand(0); +      if (AC && LI->getMetadata(LLVMContext::MD_nonnull) && +          !llvm::isKnownNonNullAt(ReplVal, LI, &DT)) +        addAssumeNonNull(AC, LI); + +      LI->replaceAllUsesWith(ReplVal); +    } -    if (AST && LI->getType()->isPointerTy()) -      AST->deleteValue(LI);      LI->eraseFromParent();      LBI.deleteValue(LI);    } @@ -499,8 +518,6 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info,      LBI.deleteValue(SI);    } -  if (AST) -    AST->deleteValue(AI);    AI->eraseFromParent();    LBI.deleteValue(AI); @@ -517,8 +534,6 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info,  void PromoteMem2Reg::run() {    Function &F = *DT.getRoot()->getParent(); -  if (AST) -    PointerAllocaValues.resize(Allocas.size());    AllocaDbgDeclares.resize(Allocas.size());    AllocaInfo Info; @@ -536,8 +551,6 @@ void PromoteMem2Reg::run() {      if (AI->use_empty()) {        // If there are no uses of the alloca, just delete it now. -      if (AST) -        AST->deleteValue(AI);        AI->eraseFromParent();        // Remove the alloca from the Allocas list, since it has been processed @@ -553,7 +566,7 @@ void PromoteMem2Reg::run() {      // If there is only a single store to this value, replace any loads of      // it that are directly dominated by the definition with the value stored.      if (Info.DefiningBlocks.size() == 1) { -      if (rewriteSingleStoreAlloca(AI, Info, LBI, DT, AST)) { +      if (rewriteSingleStoreAlloca(AI, Info, LBI, DT, AC)) {          // The alloca has been processed, move on.          RemoveFromAllocasList(AllocaNum);          ++NumSingleStore; @@ -564,7 +577,7 @@ void PromoteMem2Reg::run() {      // If the alloca is only read and written in one basic block, just perform a      // linear sweep over the block to eliminate it.      if (Info.OnlyUsedInOneBlock && -        promoteSingleBlockAlloca(AI, Info, LBI, AST)) { +        promoteSingleBlockAlloca(AI, Info, LBI, DT, AC)) {        // The alloca has been processed, move on.        RemoveFromAllocasList(AllocaNum);        continue; @@ -578,11 +591,6 @@ void PromoteMem2Reg::run() {          BBNumbers[&BB] = ID++;      } -    // If we have an AST to keep updated, remember some pointer value that is -    // stored into the alloca. -    if (AST) -      PointerAllocaValues[AllocaNum] = Info.AllocaPointerVal; -      // Remember the dbg.declare intrinsic describing this alloca, if any.      if (Info.DbgDeclare)        AllocaDbgDeclares[AllocaNum] = Info.DbgDeclare; @@ -662,8 +670,6 @@ void PromoteMem2Reg::run() {      // tree. Just delete the users now.      if (!A->use_empty())        A->replaceAllUsesWith(UndefValue::get(A->getType())); -    if (AST) -      AST->deleteValue(A);      A->eraseFromParent();    } @@ -694,8 +700,6 @@ void PromoteMem2Reg::run() {        // If this PHI node merges one value and/or undefs, get the value.        if (Value *V = SimplifyInstruction(PN, DL, nullptr, &DT, AC)) { -        if (AST && PN->getType()->isPointerTy()) -          AST->deleteValue(PN);          PN->replaceAllUsesWith(V);          PN->eraseFromParent();          NewPhiNodes.erase(I++); @@ -863,10 +867,6 @@ bool PromoteMem2Reg::QueuePhiNode(BasicBlock *BB, unsigned AllocaNo,                         &BB->front());    ++NumPHIInsert;    PhiToAllocaMap[PN] = AllocaNo; - -  if (AST && PN->getType()->isPointerTy()) -    AST->copyValue(PointerAllocaValues[AllocaNo], PN); -    return true;  } @@ -940,10 +940,15 @@ NextIteration:        Value *V = IncomingVals[AI->second]; +      // If the load was marked as nonnull we don't want to lose +      // that information when we erase this Load. So we preserve +      // it with an assume. +      if (AC && LI->getMetadata(LLVMContext::MD_nonnull) && +          !llvm::isKnownNonNullAt(V, LI, &DT)) +        addAssumeNonNull(AC, LI); +        // Anything using the load now uses the current value.        LI->replaceAllUsesWith(V); -      if (AST && LI->getType()->isPointerTy()) -        AST->deleteValue(LI);        BB->getInstList().erase(LI);      } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {        // Delete this instruction and mark the name as the current holder of the @@ -987,10 +992,10 @@ NextIteration:  }  void llvm::PromoteMemToReg(ArrayRef<AllocaInst *> Allocas, DominatorTree &DT, -                           AliasSetTracker *AST, AssumptionCache *AC) { +                           AssumptionCache *AC) {    // If there is nothing to do, bail out...    if (Allocas.empty())      return; -  PromoteMem2Reg(Allocas, DT, AST, AC).run(); +  PromoteMem2Reg(Allocas, DT, AC).run();  } diff --git a/lib/Transforms/Utils/SSAUpdater.cpp b/lib/Transforms/Utils/SSAUpdater.cpp index 8e93ee757a15..8b6a2c3766d2 100644 --- a/lib/Transforms/Utils/SSAUpdater.cpp +++ b/lib/Transforms/Utils/SSAUpdater.cpp @@ -11,20 +11,29 @@  //  //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Utils/SSAUpdater.h"  #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h"  #include "llvm/ADT/TinyPtrVector.h"  #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/IR/BasicBlock.h"  #include "llvm/IR/CFG.h"  #include "llvm/IR/Constants.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/Instruction.h"  #include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h"  #include "llvm/IR/Module.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Support/Casting.h"  #include "llvm/Support/Debug.h"  #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/SSAUpdater.h"  #include "llvm/Transforms/Utils/SSAUpdaterImpl.h" +#include <cassert> +#include <utility>  using namespace llvm; @@ -36,7 +45,7 @@ static AvailableValsTy &getAvailableVals(void *AV) {  }  SSAUpdater::SSAUpdater(SmallVectorImpl<PHINode*> *NewPHI) -  : AV(nullptr), ProtoType(nullptr), ProtoName(), InsertedPHIs(NewPHI) {} +  : InsertedPHIs(NewPHI) {}  SSAUpdater::~SSAUpdater() {    delete static_cast<AvailableValsTy*>(AV); @@ -205,6 +214,7 @@ void SSAUpdater::RewriteUseAfterInsertions(Use &U) {  }  namespace llvm { +  template<>  class SSAUpdaterTraits<SSAUpdater> {  public: @@ -230,6 +240,7 @@ public:      PHI_iterator &operator++() { ++idx; return *this; }       bool operator==(const PHI_iterator& x) const { return idx == x.idx; }      bool operator!=(const PHI_iterator& x) const { return !operator==(x); } +      Value *getIncomingValue() { return PHI->getIncomingValue(idx); }      BasicBlock *getIncomingBlock() { return PHI->getIncomingBlock(idx); }    }; @@ -303,7 +314,7 @@ public:    }  }; -} // End llvm namespace +} // end namespace llvm  /// Check to see if AvailableVals has an entry for the specified BB and if so,  /// return it.  If not, construct SSA form by first calculating the required @@ -337,14 +348,12 @@ LoadAndStorePromoter(ArrayRef<const Instruction*> Insts,    SSA.Initialize(SomeVal->getType(), BaseName);  } -  void LoadAndStorePromoter::  run(const SmallVectorImpl<Instruction*> &Insts) const { -      // First step: bucket up uses of the alloca by the block they occur in.    // This is important because we have to handle multiple defs/uses in a block    // ourselves: SSAUpdater is purely for cross-block references. -  DenseMap<BasicBlock*, TinyPtrVector<Instruction*> > UsesByBlock; +  DenseMap<BasicBlock*, TinyPtrVector<Instruction*>> UsesByBlock;    for (Instruction *User : Insts)      UsesByBlock[User->getParent()].push_back(User); diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp index 7b0bddbbb831..127a44df5344 100644 --- a/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/lib/Transforms/Utils/SimplifyCFG.cpp @@ -22,6 +22,7 @@  #include "llvm/ADT/SmallVector.h"  #include "llvm/ADT/Statistic.h"  #include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/AssumptionCache.h"  #include "llvm/Analysis/ConstantFolding.h"  #include "llvm/Analysis/EHPersonalities.h"  #include "llvm/Analysis/InstructionSimplify.h" @@ -169,6 +170,8 @@ class SimplifyCFGOpt {    unsigned BonusInstThreshold;    AssumptionCache *AC;    SmallPtrSetImpl<BasicBlock *> *LoopHeaders; +  // See comments in SimplifyCFGOpt::SimplifySwitch. +  bool LateSimplifyCFG;    Value *isValueEqualityComparison(TerminatorInst *TI);    BasicBlock *GetValueEqualityComparisonCases(        TerminatorInst *TI, std::vector<ValueEqualityComparisonCase> &Cases); @@ -192,9 +195,10 @@ class SimplifyCFGOpt {  public:    SimplifyCFGOpt(const TargetTransformInfo &TTI, const DataLayout &DL,                   unsigned BonusInstThreshold, AssumptionCache *AC, -                 SmallPtrSetImpl<BasicBlock *> *LoopHeaders) +                 SmallPtrSetImpl<BasicBlock *> *LoopHeaders, +                 bool LateSimplifyCFG)        : TTI(TTI), DL(DL), BonusInstThreshold(BonusInstThreshold), AC(AC), -        LoopHeaders(LoopHeaders) {} +        LoopHeaders(LoopHeaders), LateSimplifyCFG(LateSimplifyCFG) {}    bool run(BasicBlock *BB);  }; @@ -710,10 +714,9 @@ BasicBlock *SimplifyCFGOpt::GetValueEqualityComparisonCases(      TerminatorInst *TI, std::vector<ValueEqualityComparisonCase> &Cases) {    if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) {      Cases.reserve(SI->getNumCases()); -    for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; -         ++i) -      Cases.push_back( -          ValueEqualityComparisonCase(i.getCaseValue(), i.getCaseSuccessor())); +    for (auto Case : SI->cases()) +      Cases.push_back(ValueEqualityComparisonCase(Case.getCaseValue(), +                                                  Case.getCaseSuccessor()));      return SI->getDefaultDest();    } @@ -846,12 +849,12 @@ bool SimplifyCFGOpt::SimplifyEqualityComparisonWithOnlyPredecessor(        }      for (SwitchInst::CaseIt i = SI->case_end(), e = SI->case_begin(); i != e;) {        --i; -      if (DeadCases.count(i.getCaseValue())) { +      if (DeadCases.count(i->getCaseValue())) {          if (HasWeight) { -          std::swap(Weights[i.getCaseIndex() + 1], Weights.back()); +          std::swap(Weights[i->getCaseIndex() + 1], Weights.back());            Weights.pop_back();          } -        i.getCaseSuccessor()->removePredecessor(TI->getParent()); +        i->getCaseSuccessor()->removePredecessor(TI->getParent());          SI->removeCase(i);        }      } @@ -996,8 +999,7 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI,        SmallSetVector<BasicBlock*, 4> FailBlocks;        if (!SafeToMergeTerminators(TI, PTI, &FailBlocks)) {          for (auto *Succ : FailBlocks) { -          std::vector<BasicBlock*> Blocks = { TI->getParent() }; -          if (!SplitBlockPredecessors(Succ, Blocks, ".fold.split")) +          if (!SplitBlockPredecessors(Succ, TI->getParent(), ".fold.split"))              return false;          }        } @@ -1280,7 +1282,7 @@ static bool HoistThenElseCodeToIf(BranchInst *BI,      if (!isa<CallInst>(I1))        I1->setDebugLoc(            DILocation::getMergedLocation(I1->getDebugLoc(), I2->getDebugLoc())); -  +      I2->eraseFromParent();      Changed = true; @@ -1472,29 +1474,28 @@ static bool canSinkInstructions(        return false;    } +  // Because SROA can't handle speculating stores of selects, try not +  // to sink loads or stores of allocas when we'd have to create a PHI for +  // the address operand. Also, because it is likely that loads or stores +  // of allocas will disappear when Mem2Reg/SROA is run, don't sink them. +  // This can cause code churn which can have unintended consequences down +  // the line - see https://llvm.org/bugs/show_bug.cgi?id=30244. +  // FIXME: This is a workaround for a deficiency in SROA - see +  // https://llvm.org/bugs/show_bug.cgi?id=30188 +  if (isa<StoreInst>(I0) && any_of(Insts, [](const Instruction *I) { +        return isa<AllocaInst>(I->getOperand(1)); +      })) +    return false; +  if (isa<LoadInst>(I0) && any_of(Insts, [](const Instruction *I) { +        return isa<AllocaInst>(I->getOperand(0)); +      })) +    return false; +    for (unsigned OI = 0, OE = I0->getNumOperands(); OI != OE; ++OI) {      if (I0->getOperand(OI)->getType()->isTokenTy())        // Don't touch any operand of token type.        return false; -    // Because SROA can't handle speculating stores of selects, try not -    // to sink loads or stores of allocas when we'd have to create a PHI for -    // the address operand. Also, because it is likely that loads or stores -    // of allocas will disappear when Mem2Reg/SROA is run, don't sink them. -    // This can cause code churn which can have unintended consequences down -    // the line - see https://llvm.org/bugs/show_bug.cgi?id=30244. -    // FIXME: This is a workaround for a deficiency in SROA - see -    // https://llvm.org/bugs/show_bug.cgi?id=30188 -    if (OI == 1 && isa<StoreInst>(I0) && -        any_of(Insts, [](const Instruction *I) { -          return isa<AllocaInst>(I->getOperand(1)); -        })) -      return false; -    if (OI == 0 && isa<LoadInst>(I0) && any_of(Insts, [](const Instruction *I) { -          return isa<AllocaInst>(I->getOperand(0)); -        })) -      return false; -      auto SameAsI0 = [&I0, OI](const Instruction *I) {        assert(I->getNumOperands() == I0->getNumOperands());        return I->getOperand(OI) == I0->getOperand(OI); @@ -1546,7 +1547,7 @@ static bool sinkLastInstruction(ArrayRef<BasicBlock*> Blocks) {          }))        return false;    } -   +    // We don't need to do any more checking here; canSinkLastInstruction should    // have done it all for us.    SmallVector<Value*, 4> NewOperands; @@ -1653,7 +1654,7 @@ namespace {      bool isValid() const {        return !Fail;      } -     +      void operator -- () {        if (Fail)          return; @@ -1699,7 +1700,7 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) {    //      /    \    //    [f(1)] [if]    //      |     | \ -  //      |     |  \ +  //      |     |  |    //      |  [f(2)]|    //       \    | /    //        [ end ] @@ -1737,7 +1738,7 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) {    }    if (UnconditionalPreds.size() < 2)      return false; -   +    bool Changed = false;    // We take a two-step approach to tail sinking. First we scan from the end of    // each block upwards in lockstep. If the n'th instruction from the end of each @@ -1767,7 +1768,7 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) {      unsigned NumPHIInsts = NumPHIdValues / UnconditionalPreds.size();      if ((NumPHIdValues % UnconditionalPreds.size()) != 0)          NumPHIInsts++; -     +      return NumPHIInsts <= 1;    }; @@ -1790,7 +1791,7 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) {      }      if (!Profitable)        return false; -     +      DEBUG(dbgs() << "SINK: Splitting edge\n");      // We have a conditional edge and we're going to sink some instructions.      // Insert a new block postdominating all blocks we're going to sink from. @@ -1800,7 +1801,7 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) {        return false;      Changed = true;    } -   +    // Now that we've analyzed all potential sinking candidates, perform the    // actual sink. We iteratively sink the last non-terminator of the source    // blocks into their common successor unless doing so would require too @@ -1826,7 +1827,7 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) {        DEBUG(dbgs() << "SINK: stopping here, too many PHIs would be created!\n");        break;      } -     +      if (!sinkLastInstruction(UnconditionalPreds))        return Changed;      NumSinkCommons++; @@ -2078,6 +2079,9 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB,      Value *S = Builder.CreateSelect(          BrCond, TrueV, FalseV, TrueV->getName() + "." + FalseV->getName(), BI);      SpeculatedStore->setOperand(0, S); +    SpeculatedStore->setDebugLoc( +        DILocation::getMergedLocation( +          BI->getDebugLoc(), SpeculatedStore->getDebugLoc()));    }    // Metadata can be dependent on the condition we are hoisting above. @@ -2147,7 +2151,8 @@ static bool BlockIsSimpleEnoughToThreadThrough(BasicBlock *BB) {  /// If we have a conditional branch on a PHI node value that is defined in the  /// same block as the branch and if any PHI entries are constants, thread edges  /// corresponding to that entry to be branches to their ultimate destination. -static bool FoldCondBranchOnPHI(BranchInst *BI, const DataLayout &DL) { +static bool FoldCondBranchOnPHI(BranchInst *BI, const DataLayout &DL, +                                AssumptionCache *AC) {    BasicBlock *BB = BI->getParent();    PHINode *PN = dyn_cast<PHINode>(BI->getCondition());    // NOTE: we currently cannot transform this case if the PHI node is used @@ -2239,6 +2244,11 @@ static bool FoldCondBranchOnPHI(BranchInst *BI, const DataLayout &DL) {        // Insert the new instruction into its new home.        if (N)          EdgeBB->getInstList().insert(InsertPt, N); + +      // Register the new instruction with the assumption cache if necessary. +      if (auto *II = dyn_cast_or_null<IntrinsicInst>(N)) +        if (II->getIntrinsicID() == Intrinsic::assume) +          AC->registerAssumption(II);      }      // Loop over all of the edges from PredBB to BB, changing them to branch @@ -2251,7 +2261,7 @@ static bool FoldCondBranchOnPHI(BranchInst *BI, const DataLayout &DL) {        }      // Recurse, simplifying any other constants. -    return FoldCondBranchOnPHI(BI, DL) | true; +    return FoldCondBranchOnPHI(BI, DL, AC) | true;    }    return false; @@ -3433,8 +3443,8 @@ static bool SimplifySwitchOnSelect(SwitchInst *SI, SelectInst *Select) {    // Find the relevant condition and destinations.    Value *Condition = Select->getCondition(); -  BasicBlock *TrueBB = SI->findCaseValue(TrueVal).getCaseSuccessor(); -  BasicBlock *FalseBB = SI->findCaseValue(FalseVal).getCaseSuccessor(); +  BasicBlock *TrueBB = SI->findCaseValue(TrueVal)->getCaseSuccessor(); +  BasicBlock *FalseBB = SI->findCaseValue(FalseVal)->getCaseSuccessor();    // Get weight for TrueBB and FalseBB.    uint32_t TrueWeight = 0, FalseWeight = 0; @@ -3444,9 +3454,9 @@ static bool SimplifySwitchOnSelect(SwitchInst *SI, SelectInst *Select) {      GetBranchWeights(SI, Weights);      if (Weights.size() == 1 + SI->getNumCases()) {        TrueWeight = -          (uint32_t)Weights[SI->findCaseValue(TrueVal).getSuccessorIndex()]; +          (uint32_t)Weights[SI->findCaseValue(TrueVal)->getSuccessorIndex()];        FalseWeight = -          (uint32_t)Weights[SI->findCaseValue(FalseVal).getSuccessorIndex()]; +          (uint32_t)Weights[SI->findCaseValue(FalseVal)->getSuccessorIndex()];      }    } @@ -4148,15 +4158,16 @@ bool SimplifyCFGOpt::SimplifyUnreachable(UnreachableInst *UI) {          }        }      } else if (auto *SI = dyn_cast<SwitchInst>(TI)) { -      for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; -           ++i) -        if (i.getCaseSuccessor() == BB) { -          BB->removePredecessor(SI->getParent()); -          SI->removeCase(i); -          --i; -          --e; -          Changed = true; +      for (auto i = SI->case_begin(), e = SI->case_end(); i != e;) { +        if (i->getCaseSuccessor() != BB) { +          ++i; +          continue;          } +        BB->removePredecessor(SI->getParent()); +        i = SI->removeCase(i); +        e = SI->case_end(); +        Changed = true; +      }      } else if (auto *II = dyn_cast<InvokeInst>(TI)) {        if (II->getUnwindDest() == BB) {          removeUnwindEdge(TI->getParent()); @@ -4239,18 +4250,18 @@ static bool TurnSwitchRangeIntoICmp(SwitchInst *SI, IRBuilder<> &Builder) {    SmallVector<ConstantInt *, 16> CasesA;    SmallVector<ConstantInt *, 16> CasesB; -  for (SwitchInst::CaseIt I : SI->cases()) { -    BasicBlock *Dest = I.getCaseSuccessor(); +  for (auto Case : SI->cases()) { +    BasicBlock *Dest = Case.getCaseSuccessor();      if (!DestA)        DestA = Dest;      if (Dest == DestA) { -      CasesA.push_back(I.getCaseValue()); +      CasesA.push_back(Case.getCaseValue());        continue;      }      if (!DestB)        DestB = Dest;      if (Dest == DestB) { -      CasesB.push_back(I.getCaseValue()); +      CasesB.push_back(Case.getCaseValue());        continue;      }      return false; // More than two destinations. @@ -4375,7 +4386,7 @@ static bool EliminateDeadSwitchCases(SwitchInst *SI, AssumptionCache *AC,    bool HasDefault =        !isa<UnreachableInst>(SI->getDefaultDest()->getFirstNonPHIOrDbg());    const unsigned NumUnknownBits = -      Bits - (KnownZero.Or(KnownOne)).countPopulation(); +      Bits - (KnownZero | KnownOne).countPopulation();    assert(NumUnknownBits <= Bits);    if (HasDefault && DeadCases.empty() &&        NumUnknownBits < 64 /* avoid overflow */ && @@ -4400,17 +4411,17 @@ static bool EliminateDeadSwitchCases(SwitchInst *SI, AssumptionCache *AC,    // Remove dead cases from the switch.    for (ConstantInt *DeadCase : DeadCases) { -    SwitchInst::CaseIt Case = SI->findCaseValue(DeadCase); -    assert(Case != SI->case_default() && +    SwitchInst::CaseIt CaseI = SI->findCaseValue(DeadCase); +    assert(CaseI != SI->case_default() &&             "Case was not found. Probably mistake in DeadCases forming.");      if (HasWeight) { -      std::swap(Weights[Case.getCaseIndex() + 1], Weights.back()); +      std::swap(Weights[CaseI->getCaseIndex() + 1], Weights.back());        Weights.pop_back();      }      // Prune unused values from PHI nodes. -    Case.getCaseSuccessor()->removePredecessor(SI->getParent()); -    SI->removeCase(Case); +    CaseI->getCaseSuccessor()->removePredecessor(SI->getParent()); +    SI->removeCase(CaseI);    }    if (HasWeight && Weights.size() >= 2) {      SmallVector<uint32_t, 8> MDWeights(Weights.begin(), Weights.end()); @@ -4464,10 +4475,9 @@ static bool ForwardSwitchConditionToPHI(SwitchInst *SI) {    typedef DenseMap<PHINode *, SmallVector<int, 4>> ForwardingNodesMap;    ForwardingNodesMap ForwardingNodes; -  for (SwitchInst::CaseIt I = SI->case_begin(), E = SI->case_end(); I != E; -       ++I) { -    ConstantInt *CaseValue = I.getCaseValue(); -    BasicBlock *CaseDest = I.getCaseSuccessor(); +  for (auto Case : SI->cases()) { +    ConstantInt *CaseValue = Case.getCaseValue(); +    BasicBlock *CaseDest = Case.getCaseSuccessor();      int PhiIndex;      PHINode *PHI = @@ -5202,8 +5212,8 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,    // common destination, as well as the min and max case values.    assert(SI->case_begin() != SI->case_end());    SwitchInst::CaseIt CI = SI->case_begin(); -  ConstantInt *MinCaseVal = CI.getCaseValue(); -  ConstantInt *MaxCaseVal = CI.getCaseValue(); +  ConstantInt *MinCaseVal = CI->getCaseValue(); +  ConstantInt *MaxCaseVal = CI->getCaseValue();    BasicBlock *CommonDest = nullptr;    typedef SmallVector<std::pair<ConstantInt *, Constant *>, 4> ResultListTy; @@ -5213,7 +5223,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,    SmallVector<PHINode *, 4> PHIs;    for (SwitchInst::CaseIt E = SI->case_end(); CI != E; ++CI) { -    ConstantInt *CaseVal = CI.getCaseValue(); +    ConstantInt *CaseVal = CI->getCaseValue();      if (CaseVal->getValue().slt(MinCaseVal->getValue()))        MinCaseVal = CaseVal;      if (CaseVal->getValue().sgt(MaxCaseVal->getValue())) @@ -5222,7 +5232,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,      // Resulting value at phi nodes for this case value.      typedef SmallVector<std::pair<PHINode *, Constant *>, 4> ResultsTy;      ResultsTy Results; -    if (!GetCaseResults(SI, CaseVal, CI.getCaseSuccessor(), &CommonDest, +    if (!GetCaseResults(SI, CaseVal, CI->getCaseSuccessor(), &CommonDest,                          Results, DL, TTI))        return false; @@ -5503,11 +5513,10 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder,    auto *Rot = Builder.CreateOr(LShr, Shl);    SI->replaceUsesOfWith(SI->getCondition(), Rot); -  for (SwitchInst::CaseIt C = SI->case_begin(), E = SI->case_end(); C != E; -       ++C) { -    auto *Orig = C.getCaseValue(); +  for (auto Case : SI->cases()) { +    auto *Orig = Case.getCaseValue();      auto Sub = Orig->getValue() - APInt(Ty->getBitWidth(), Base); -    C.setValue( +    Case.setValue(          cast<ConstantInt>(ConstantInt::get(Ty, Sub.lshr(ShiftC->getValue()))));    }    return true; @@ -5553,7 +5562,12 @@ bool SimplifyCFGOpt::SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {    if (ForwardSwitchConditionToPHI(SI))      return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; -  if (SwitchToLookupTable(SI, Builder, DL, TTI)) +  // The conversion from switch to lookup tables results in difficult +  // to analyze code and makes pruning branches much harder. +  // This is a problem of the switch expression itself can still be +  // restricted as a result of inlining or CVP. There only apply this +  // transformation during late steps of the optimisation chain. +  if (LateSimplifyCFG && SwitchToLookupTable(SI, Builder, DL, TTI))      return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true;    if (ReduceSwitchRange(SI, Builder, DL, TTI)) @@ -5833,7 +5847,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) {    // through this block if any PHI node entries are constants.    if (PHINode *PN = dyn_cast<PHINode>(BI->getCondition()))      if (PN->getParent() == BI->getParent()) -      if (FoldCondBranchOnPHI(BI, DL)) +      if (FoldCondBranchOnPHI(BI, DL, AC))          return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true;    // Scan predecessor blocks for conditional branches. @@ -6012,8 +6026,9 @@ bool SimplifyCFGOpt::run(BasicBlock *BB) {  ///  bool llvm::SimplifyCFG(BasicBlock *BB, const TargetTransformInfo &TTI,                         unsigned BonusInstThreshold, AssumptionCache *AC, -                       SmallPtrSetImpl<BasicBlock *> *LoopHeaders) { +                       SmallPtrSetImpl<BasicBlock *> *LoopHeaders, +                       bool LateSimplifyCFG) {    return SimplifyCFGOpt(TTI, BB->getModule()->getDataLayout(), -                        BonusInstThreshold, AC, LoopHeaders) +                        BonusInstThreshold, AC, LoopHeaders, LateSimplifyCFG)        .run(BB);  } diff --git a/lib/Transforms/Utils/SimplifyIndVar.cpp b/lib/Transforms/Utils/SimplifyIndVar.cpp index 6b1d3dc41330..a4cc6a031ad4 100644 --- a/lib/Transforms/Utils/SimplifyIndVar.cpp +++ b/lib/Transforms/Utils/SimplifyIndVar.cpp @@ -35,6 +35,9 @@ using namespace llvm;  STATISTIC(NumElimIdentity, "Number of IV identities eliminated");  STATISTIC(NumElimOperand,  "Number of IV operands folded into a use");  STATISTIC(NumElimRem     , "Number of IV remainder operations eliminated"); +STATISTIC( +    NumSimplifiedSDiv, +    "Number of IV signed division operations converted to unsigned division");  STATISTIC(NumElimCmp     , "Number of IV comparisons eliminated");  namespace { @@ -75,6 +78,7 @@ namespace {      void eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand);      void eliminateIVRemainder(BinaryOperator *Rem, Value *IVOperand,                                bool IsSigned); +    bool eliminateSDiv(BinaryOperator *SDiv);      bool strengthenOverflowingOperation(BinaryOperator *OBO, Value *IVOperand);    };  } @@ -265,6 +269,33 @@ void SimplifyIndvar::eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand) {    Changed = true;  } +bool SimplifyIndvar::eliminateSDiv(BinaryOperator *SDiv) { +  // Get the SCEVs for the ICmp operands. +  auto *N = SE->getSCEV(SDiv->getOperand(0)); +  auto *D = SE->getSCEV(SDiv->getOperand(1)); + +  // Simplify unnecessary loops away. +  const Loop *L = LI->getLoopFor(SDiv->getParent()); +  N = SE->getSCEVAtScope(N, L); +  D = SE->getSCEVAtScope(D, L); + +  // Replace sdiv by udiv if both of the operands are non-negative +  if (SE->isKnownNonNegative(N) && SE->isKnownNonNegative(D)) { +    auto *UDiv = BinaryOperator::Create( +        BinaryOperator::UDiv, SDiv->getOperand(0), SDiv->getOperand(1), +        SDiv->getName() + ".udiv", SDiv); +    UDiv->setIsExact(SDiv->isExact()); +    SDiv->replaceAllUsesWith(UDiv); +    DEBUG(dbgs() << "INDVARS: Simplified sdiv: " << *SDiv << '\n'); +    ++NumSimplifiedSDiv; +    Changed = true; +    DeadInsts.push_back(SDiv); +    return true; +  } + +  return false; +} +  /// SimplifyIVUsers helper for eliminating useless  /// remainder operations operating on an induction variable.  void SimplifyIndvar::eliminateIVRemainder(BinaryOperator *Rem, @@ -426,12 +457,15 @@ bool SimplifyIndvar::eliminateIVUser(Instruction *UseInst,      eliminateIVComparison(ICmp, IVOperand);      return true;    } -  if (BinaryOperator *Rem = dyn_cast<BinaryOperator>(UseInst)) { -    bool IsSigned = Rem->getOpcode() == Instruction::SRem; -    if (IsSigned || Rem->getOpcode() == Instruction::URem) { -      eliminateIVRemainder(Rem, IVOperand, IsSigned); +  if (BinaryOperator *Bin = dyn_cast<BinaryOperator>(UseInst)) { +    bool IsSRem = Bin->getOpcode() == Instruction::SRem; +    if (IsSRem || Bin->getOpcode() == Instruction::URem) { +      eliminateIVRemainder(Bin, IVOperand, IsSRem);        return true;      } + +    if (Bin->getOpcode() == Instruction::SDiv) +      return eliminateSDiv(Bin);    }    if (auto *CI = dyn_cast<CallInst>(UseInst)) diff --git a/lib/Transforms/Utils/SimplifyInstructions.cpp b/lib/Transforms/Utils/SimplifyInstructions.cpp index 1220490123ce..f6070868de44 100644 --- a/lib/Transforms/Utils/SimplifyInstructions.cpp +++ b/lib/Transforms/Utils/SimplifyInstructions.cpp @@ -20,6 +20,7 @@  #include "llvm/ADT/Statistic.h"  #include "llvm/Analysis/AssumptionCache.h"  #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/OptimizationDiagnosticInfo.h"  #include "llvm/Analysis/TargetLibraryInfo.h"  #include "llvm/IR/DataLayout.h"  #include "llvm/IR/Dominators.h" @@ -35,7 +36,8 @@ using namespace llvm;  STATISTIC(NumSimplified, "Number of redundant instructions removed");  static bool runImpl(Function &F, const DominatorTree *DT, -                    const TargetLibraryInfo *TLI, AssumptionCache *AC) { +                    const TargetLibraryInfo *TLI, AssumptionCache *AC, +                    OptimizationRemarkEmitter *ORE) {    const DataLayout &DL = F.getParent()->getDataLayout();    SmallPtrSet<const Instruction *, 8> S1, S2, *ToSimplify = &S1, *Next = &S2;    bool Changed = false; @@ -54,7 +56,7 @@ static bool runImpl(Function &F, const DominatorTree *DT,          // Don't waste time simplifying unused instructions.          if (!I->use_empty()) { -          if (Value *V = SimplifyInstruction(I, DL, TLI, DT, AC)) { +          if (Value *V = SimplifyInstruction(I, DL, TLI, DT, AC, ORE)) {              // Mark all uses for resimplification next time round the loop.              for (User *U : I->users())                Next->insert(cast<Instruction>(U)); @@ -95,6 +97,7 @@ namespace {        AU.addRequired<DominatorTreeWrapperPass>();        AU.addRequired<AssumptionCacheTracker>();        AU.addRequired<TargetLibraryInfoWrapperPass>(); +      AU.addRequired<OptimizationRemarkEmitterWrapperPass>();      }      /// runOnFunction - Remove instructions that simplify. @@ -108,7 +111,10 @@ namespace {            &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();        AssumptionCache *AC =            &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); -      return runImpl(F, DT, TLI, AC); +      OptimizationRemarkEmitter *ORE = +          &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); + +      return runImpl(F, DT, TLI, AC, ORE);      }    };  } @@ -119,6 +125,7 @@ INITIALIZE_PASS_BEGIN(InstSimplifier, "instsimplify",  INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)  INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)  INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)  INITIALIZE_PASS_END(InstSimplifier, "instsimplify",                      "Remove redundant instructions", false, false)  char &llvm::InstructionSimplifierID = InstSimplifier::ID; @@ -133,9 +140,12 @@ PreservedAnalyses InstSimplifierPass::run(Function &F,    auto &DT = AM.getResult<DominatorTreeAnalysis>(F);    auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);    auto &AC = AM.getResult<AssumptionAnalysis>(F); -  bool Changed = runImpl(F, &DT, &TLI, &AC); +  auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); +  bool Changed = runImpl(F, &DT, &TLI, &AC, &ORE);    if (!Changed)      return PreservedAnalyses::all(); -  // FIXME: This should also 'preserve the CFG'. -  return PreservedAnalyses::none(); + +  PreservedAnalyses PA; +  PA.preserveSet<CFGAnalyses>(); +  return PA;  } diff --git a/lib/Transforms/Utils/SimplifyLibCalls.cpp b/lib/Transforms/Utils/SimplifyLibCalls.cpp index 8eaeb1073a76..aa71e3669ea2 100644 --- a/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -51,9 +51,9 @@ static cl::opt<bool>  // Helper Functions  //===----------------------------------------------------------------------===// -static bool ignoreCallingConv(LibFunc::Func Func) { -  return Func == LibFunc::abs || Func == LibFunc::labs || -         Func == LibFunc::llabs || Func == LibFunc::strlen; +static bool ignoreCallingConv(LibFunc Func) { +  return Func == LibFunc_abs || Func == LibFunc_labs || +         Func == LibFunc_llabs || Func == LibFunc_strlen;  }  static bool isCallingConvCCompatible(CallInst *CI) { @@ -123,8 +123,8 @@ static bool callHasFloatingPointArgument(const CallInst *CI) {  /// \brief Check whether the overloaded unary floating point function  /// corresponding to \a Ty is available.  static bool hasUnaryFloatFn(const TargetLibraryInfo *TLI, Type *Ty, -                            LibFunc::Func DoubleFn, LibFunc::Func FloatFn, -                            LibFunc::Func LongDoubleFn) { +                            LibFunc DoubleFn, LibFunc FloatFn, +                            LibFunc LongDoubleFn) {    switch (Ty->getTypeID()) {    case Type::FloatTyID:      return TLI->has(FloatFn); @@ -809,9 +809,9 @@ Value *LibCallSimplifier::optimizeMemMove(CallInst *CI, IRBuilder<> &B) {  // TODO: Does this belong in BuildLibCalls or should all of those similar  // functions be moved here? -static Value *emitCalloc(Value *Num, Value *Size, const AttributeSet &Attrs, +static Value *emitCalloc(Value *Num, Value *Size, const AttributeList &Attrs,                           IRBuilder<> &B, const TargetLibraryInfo &TLI) { -  LibFunc::Func Func; +  LibFunc Func;    if (!TLI.getLibFunc("calloc", Func) || !TLI.has(Func))      return nullptr; @@ -819,7 +819,7 @@ static Value *emitCalloc(Value *Num, Value *Size, const AttributeSet &Attrs,    const DataLayout &DL = M->getDataLayout();    IntegerType *PtrType = DL.getIntPtrType((B.GetInsertBlock()->getContext()));    Value *Calloc = M->getOrInsertFunction("calloc", Attrs, B.getInt8PtrTy(), -                                         PtrType, PtrType, nullptr); +                                         PtrType, PtrType);    CallInst *CI = B.CreateCall(Calloc, { Num, Size }, "calloc");    if (const auto *F = dyn_cast<Function>(Calloc->stripPointerCasts())) @@ -846,9 +846,9 @@ static Value *foldMallocMemset(CallInst *Memset, IRBuilder<> &B,    // Is the inner call really malloc()?    Function *InnerCallee = Malloc->getCalledFunction(); -  LibFunc::Func Func; +  LibFunc Func;    if (!TLI.getLibFunc(*InnerCallee, Func) || !TLI.has(Func) || -      Func != LibFunc::malloc) +      Func != LibFunc_malloc)      return nullptr;    // The memset must cover the same number of bytes that are malloc'd. @@ -948,6 +948,20 @@ static Value *optimizeUnaryDoubleFP(CallInst *CI, IRBuilder<> &B,    return B.CreateFPExt(V, B.getDoubleTy());  } +// Replace a libcall \p CI with a call to intrinsic \p IID +static Value *replaceUnaryCall(CallInst *CI, IRBuilder<> &B, Intrinsic::ID IID) { +  // Propagate fast-math flags from the existing call to the new call. +  IRBuilder<>::FastMathFlagGuard Guard(B); +  B.setFastMathFlags(CI->getFastMathFlags()); + +  Module *M = CI->getModule(); +  Value *V = CI->getArgOperand(0); +  Function *F = Intrinsic::getDeclaration(M, IID, CI->getType()); +  CallInst *NewCall = B.CreateCall(F, V); +  NewCall->takeName(CI); +  return NewCall; +} +  /// Shrink double -> float for binary functions like 'fmin/fmax'.  static Value *optimizeBinaryDoubleFP(CallInst *CI, IRBuilder<> &B) {    Function *Callee = CI->getCalledFunction(); @@ -1041,9 +1055,9 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {    if (ConstantFP *Op1C = dyn_cast<ConstantFP>(Op1)) {      // pow(10.0, x) -> exp10(x)      if (Op1C->isExactlyValue(10.0) && -        hasUnaryFloatFn(TLI, Op1->getType(), LibFunc::exp10, LibFunc::exp10f, -                        LibFunc::exp10l)) -      return emitUnaryFloatFnCall(Op2, TLI->getName(LibFunc::exp10), B, +        hasUnaryFloatFn(TLI, Op1->getType(), LibFunc_exp10, LibFunc_exp10f, +                        LibFunc_exp10l)) +      return emitUnaryFloatFnCall(Op2, TLI->getName(LibFunc_exp10), B,                                    Callee->getAttributes());    } @@ -1055,10 +1069,10 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {    // pow(exp(x), y) = pow(inf, 0.001) = inf, whereas exp(x*y) = exp(1).    auto *OpC = dyn_cast<CallInst>(Op1);    if (OpC && OpC->hasUnsafeAlgebra() && CI->hasUnsafeAlgebra()) { -    LibFunc::Func Func; +    LibFunc Func;      Function *OpCCallee = OpC->getCalledFunction();      if (OpCCallee && TLI->getLibFunc(OpCCallee->getName(), Func) && -        TLI->has(Func) && (Func == LibFunc::exp || Func == LibFunc::exp2)) { +        TLI->has(Func) && (Func == LibFunc_exp || Func == LibFunc_exp2)) {        IRBuilder<>::FastMathFlagGuard Guard(B);        B.setFastMathFlags(CI->getFastMathFlags());        Value *FMul = B.CreateFMul(OpC->getArgOperand(0), Op2, "mul"); @@ -1075,17 +1089,20 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {      return ConstantFP::get(CI->getType(), 1.0);    if (Op2C->isExactlyValue(-0.5) && -      hasUnaryFloatFn(TLI, Op2->getType(), LibFunc::sqrt, LibFunc::sqrtf, -                      LibFunc::sqrtl)) { +      hasUnaryFloatFn(TLI, Op2->getType(), LibFunc_sqrt, LibFunc_sqrtf, +                      LibFunc_sqrtl)) {      // If -ffast-math:      // pow(x, -0.5) -> 1.0 / sqrt(x)      if (CI->hasUnsafeAlgebra()) {        IRBuilder<>::FastMathFlagGuard Guard(B);        B.setFastMathFlags(CI->getFastMathFlags()); -      // Here we cannot lower to an intrinsic because C99 sqrt() and llvm.sqrt -      // are not guaranteed to have the same semantics. -      Value *Sqrt = emitUnaryFloatFnCall(Op1, TLI->getName(LibFunc::sqrt), B, +      // TODO: If the pow call is an intrinsic, we should lower to the sqrt +      // intrinsic, so we match errno semantics.  We also should check that the +      // target can in fact lower the sqrt intrinsic -- we currently have no way +      // to ask this question other than asking whether the target has a sqrt +      // libcall, which is a sufficient but not necessary condition. +      Value *Sqrt = emitUnaryFloatFnCall(Op1, TLI->getName(LibFunc_sqrt), B,                                           Callee->getAttributes());        return B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), Sqrt, "sqrtrecip"); @@ -1093,19 +1110,17 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {    }    if (Op2C->isExactlyValue(0.5) && -      hasUnaryFloatFn(TLI, Op2->getType(), LibFunc::sqrt, LibFunc::sqrtf, -                      LibFunc::sqrtl) && -      hasUnaryFloatFn(TLI, Op2->getType(), LibFunc::fabs, LibFunc::fabsf, -                      LibFunc::fabsl)) { +      hasUnaryFloatFn(TLI, Op2->getType(), LibFunc_sqrt, LibFunc_sqrtf, +                      LibFunc_sqrtl)) {      // In -ffast-math, pow(x, 0.5) -> sqrt(x).      if (CI->hasUnsafeAlgebra()) {        IRBuilder<>::FastMathFlagGuard Guard(B);        B.setFastMathFlags(CI->getFastMathFlags()); -      // Unlike other math intrinsics, sqrt has differerent semantics -      // from the libc function. See LangRef for details. -      return emitUnaryFloatFnCall(Op1, TLI->getName(LibFunc::sqrt), B, +      // TODO: As above, we should lower to the sqrt intrinsic if the pow is an +      // intrinsic, to match errno semantics. +      return emitUnaryFloatFnCall(Op1, TLI->getName(LibFunc_sqrt), B,                                    Callee->getAttributes());      } @@ -1115,9 +1130,16 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {      // TODO: In finite-only mode, this could be just fabs(sqrt(x)).      Value *Inf = ConstantFP::getInfinity(CI->getType());      Value *NegInf = ConstantFP::getInfinity(CI->getType(), true); + +    // TODO: As above, we should lower to the sqrt intrinsic if the pow is an +    // intrinsic, to match errno semantics.      Value *Sqrt = emitUnaryFloatFnCall(Op1, "sqrt", B, Callee->getAttributes()); -    Value *FAbs = -        emitUnaryFloatFnCall(Sqrt, "fabs", B, Callee->getAttributes()); + +    Module *M = Callee->getParent(); +    Function *FabsF = Intrinsic::getDeclaration(M, Intrinsic::fabs, +                                                CI->getType()); +    Value *FAbs = B.CreateCall(FabsF, Sqrt); +      Value *FCmp = B.CreateFCmpOEQ(Op1, NegInf);      Value *Sel = B.CreateSelect(FCmp, Inf, FAbs);      return Sel; @@ -1173,11 +1195,11 @@ Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) {    Value *Op = CI->getArgOperand(0);    // Turn exp2(sitofp(x)) -> ldexp(1.0, sext(x))  if sizeof(x) <= 32    // Turn exp2(uitofp(x)) -> ldexp(1.0, zext(x))  if sizeof(x) < 32 -  LibFunc::Func LdExp = LibFunc::ldexpl; +  LibFunc LdExp = LibFunc_ldexpl;    if (Op->getType()->isFloatTy()) -    LdExp = LibFunc::ldexpf; +    LdExp = LibFunc_ldexpf;    else if (Op->getType()->isDoubleTy()) -    LdExp = LibFunc::ldexp; +    LdExp = LibFunc_ldexp;    if (TLI->has(LdExp)) {      Value *LdExpArg = nullptr; @@ -1197,7 +1219,7 @@ Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) {        Module *M = CI->getModule();        Value *NewCallee =            M->getOrInsertFunction(TLI->getName(LdExp), Op->getType(), -                                 Op->getType(), B.getInt32Ty(), nullptr); +                                 Op->getType(), B.getInt32Ty());        CallInst *CI = B.CreateCall(NewCallee, {One, LdExpArg});        if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts()))          CI->setCallingConv(F->getCallingConv()); @@ -1208,15 +1230,6 @@ Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) {    return Ret;  } -Value *LibCallSimplifier::optimizeFabs(CallInst *CI, IRBuilder<> &B) { -  Function *Callee = CI->getCalledFunction(); -  StringRef Name = Callee->getName(); -  if (Name == "fabs" && hasFloatVersion(Name)) -    return optimizeUnaryDoubleFP(CI, B, false); - -  return nullptr; -} -  Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilder<> &B) {    Function *Callee = CI->getCalledFunction();    // If we can shrink the call to a float function rather than a double @@ -1280,17 +1293,17 @@ Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) {    FMF.setUnsafeAlgebra();    B.setFastMathFlags(FMF); -  LibFunc::Func Func; +  LibFunc Func;    Function *F = OpC->getCalledFunction();    if (F && ((TLI->getLibFunc(F->getName(), Func) && TLI->has(Func) && -      Func == LibFunc::pow) || F->getIntrinsicID() == Intrinsic::pow)) +      Func == LibFunc_pow) || F->getIntrinsicID() == Intrinsic::pow))      return B.CreateFMul(OpC->getArgOperand(1),        emitUnaryFloatFnCall(OpC->getOperand(0), Callee->getName(), B,                             Callee->getAttributes()), "mul");    // log(exp2(y)) -> y*log(2)    if (F && Name == "log" && TLI->getLibFunc(F->getName(), Func) && -      TLI->has(Func) && Func == LibFunc::exp2) +      TLI->has(Func) && Func == LibFunc_exp2)      return B.CreateFMul(          OpC->getArgOperand(0),          emitUnaryFloatFnCall(ConstantFP::get(CI->getType(), 2.0), @@ -1302,8 +1315,11 @@ Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) {  Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) {    Function *Callee = CI->getCalledFunction();    Value *Ret = nullptr; -  if (TLI->has(LibFunc::sqrtf) && (Callee->getName() == "sqrt" || -                                   Callee->getIntrinsicID() == Intrinsic::sqrt)) +  // TODO: Once we have a way (other than checking for the existince of the +  // libcall) to tell whether our target can lower @llvm.sqrt, relax the +  // condition below. +  if (TLI->has(LibFunc_sqrtf) && (Callee->getName() == "sqrt" || +                                  Callee->getIntrinsicID() == Intrinsic::sqrt))      Ret = optimizeUnaryDoubleFP(CI, B, true);    if (!CI->hasUnsafeAlgebra()) @@ -1385,12 +1401,12 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilder<> &B) {    // tan(atan(x)) -> x    // tanf(atanf(x)) -> x    // tanl(atanl(x)) -> x -  LibFunc::Func Func; +  LibFunc Func;    Function *F = OpC->getCalledFunction();    if (F && TLI->getLibFunc(F->getName(), Func) && TLI->has(Func) && -      ((Func == LibFunc::atan && Callee->getName() == "tan") || -       (Func == LibFunc::atanf && Callee->getName() == "tanf") || -       (Func == LibFunc::atanl && Callee->getName() == "tanl"))) +      ((Func == LibFunc_atan && Callee->getName() == "tan") || +       (Func == LibFunc_atanf && Callee->getName() == "tanf") || +       (Func == LibFunc_atanl && Callee->getName() == "tanl")))      Ret = OpC->getArgOperand(0);    return Ret;  } @@ -1427,7 +1443,7 @@ static void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg,    Module *M = OrigCallee->getParent();    Value *Callee = M->getOrInsertFunction(Name, OrigCallee->getAttributes(), -                                         ResTy, ArgTy, nullptr); +                                         ResTy, ArgTy);    if (Instruction *ArgInst = dyn_cast<Instruction>(Arg)) {      // If the argument is an instruction, it must dominate all uses so put our @@ -1508,24 +1524,24 @@ void LibCallSimplifier::classifyArgUse(      return;    Function *Callee = CI->getCalledFunction(); -  LibFunc::Func Func; +  LibFunc Func;    if (!Callee || !TLI->getLibFunc(*Callee, Func) || !TLI->has(Func) ||        !isTrigLibCall(CI))      return;    if (IsFloat) { -    if (Func == LibFunc::sinpif) +    if (Func == LibFunc_sinpif)        SinCalls.push_back(CI); -    else if (Func == LibFunc::cospif) +    else if (Func == LibFunc_cospif)        CosCalls.push_back(CI); -    else if (Func == LibFunc::sincospif_stret) +    else if (Func == LibFunc_sincospif_stret)        SinCosCalls.push_back(CI);    } else { -    if (Func == LibFunc::sinpi) +    if (Func == LibFunc_sinpi)        SinCalls.push_back(CI); -    else if (Func == LibFunc::cospi) +    else if (Func == LibFunc_cospi)        CosCalls.push_back(CI); -    else if (Func == LibFunc::sincospi_stret) +    else if (Func == LibFunc_sincospi_stret)        SinCosCalls.push_back(CI);    }  } @@ -1609,7 +1625,7 @@ Value *LibCallSimplifier::optimizeErrorReporting(CallInst *CI, IRBuilder<> &B,    // Proceedings of PACT'98, Oct. 1998, IEEE    if (!CI->hasFnAttr(Attribute::Cold) &&        isReportingError(Callee, CI, StreamArg)) { -    CI->addAttribute(AttributeSet::FunctionIndex, Attribute::Cold); +    CI->addAttribute(AttributeList::FunctionIndex, Attribute::Cold);    }    return nullptr; @@ -1699,7 +1715,7 @@ Value *LibCallSimplifier::optimizePrintF(CallInst *CI, IRBuilder<> &B) {    // printf(format, ...) -> iprintf(format, ...) if no floating point    // arguments. -  if (TLI->has(LibFunc::iprintf) && !callHasFloatingPointArgument(CI)) { +  if (TLI->has(LibFunc_iprintf) && !callHasFloatingPointArgument(CI)) {      Module *M = B.GetInsertBlock()->getParent()->getParent();      Constant *IPrintFFn =          M->getOrInsertFunction("iprintf", FT, Callee->getAttributes()); @@ -1780,7 +1796,7 @@ Value *LibCallSimplifier::optimizeSPrintF(CallInst *CI, IRBuilder<> &B) {    // sprintf(str, format, ...) -> siprintf(str, format, ...) if no floating    // point arguments. -  if (TLI->has(LibFunc::siprintf) && !callHasFloatingPointArgument(CI)) { +  if (TLI->has(LibFunc_siprintf) && !callHasFloatingPointArgument(CI)) {      Module *M = B.GetInsertBlock()->getParent()->getParent();      Constant *SIPrintFFn =          M->getOrInsertFunction("siprintf", FT, Callee->getAttributes()); @@ -1850,7 +1866,7 @@ Value *LibCallSimplifier::optimizeFPrintF(CallInst *CI, IRBuilder<> &B) {    // fprintf(stream, format, ...) -> fiprintf(stream, format, ...) if no    // floating point arguments. -  if (TLI->has(LibFunc::fiprintf) && !callHasFloatingPointArgument(CI)) { +  if (TLI->has(LibFunc_fiprintf) && !callHasFloatingPointArgument(CI)) {      Module *M = B.GetInsertBlock()->getParent()->getParent();      Constant *FIPrintFFn =          M->getOrInsertFunction("fiprintf", FT, Callee->getAttributes()); @@ -1929,7 +1945,7 @@ Value *LibCallSimplifier::optimizePuts(CallInst *CI, IRBuilder<> &B) {  }  bool LibCallSimplifier::hasFloatVersion(StringRef FuncName) { -  LibFunc::Func Func; +  LibFunc Func;    SmallString<20> FloatFuncName = FuncName;    FloatFuncName += 'f';    if (TLI->getLibFunc(FloatFuncName, Func)) @@ -1939,7 +1955,7 @@ bool LibCallSimplifier::hasFloatVersion(StringRef FuncName) {  Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI,                                                        IRBuilder<> &Builder) { -  LibFunc::Func Func; +  LibFunc Func;    Function *Callee = CI->getCalledFunction();    // Check for string/memory library functions.    if (TLI->getLibFunc(*Callee, Func) && TLI->has(Func)) { @@ -1948,51 +1964,51 @@ Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI,              isCallingConvCCompatible(CI)) &&        "Optimizing string/memory libcall would change the calling convention");      switch (Func) { -    case LibFunc::strcat: +    case LibFunc_strcat:        return optimizeStrCat(CI, Builder); -    case LibFunc::strncat: +    case LibFunc_strncat:        return optimizeStrNCat(CI, Builder); -    case LibFunc::strchr: +    case LibFunc_strchr:        return optimizeStrChr(CI, Builder); -    case LibFunc::strrchr: +    case LibFunc_strrchr:        return optimizeStrRChr(CI, Builder); -    case LibFunc::strcmp: +    case LibFunc_strcmp:        return optimizeStrCmp(CI, Builder); -    case LibFunc::strncmp: +    case LibFunc_strncmp:        return optimizeStrNCmp(CI, Builder); -    case LibFunc::strcpy: +    case LibFunc_strcpy:        return optimizeStrCpy(CI, Builder); -    case LibFunc::stpcpy: +    case LibFunc_stpcpy:        return optimizeStpCpy(CI, Builder); -    case LibFunc::strncpy: +    case LibFunc_strncpy:        return optimizeStrNCpy(CI, Builder); -    case LibFunc::strlen: +    case LibFunc_strlen:        return optimizeStrLen(CI, Builder); -    case LibFunc::strpbrk: +    case LibFunc_strpbrk:        return optimizeStrPBrk(CI, Builder); -    case LibFunc::strtol: -    case LibFunc::strtod: -    case LibFunc::strtof: -    case LibFunc::strtoul: -    case LibFunc::strtoll: -    case LibFunc::strtold: -    case LibFunc::strtoull: +    case LibFunc_strtol: +    case LibFunc_strtod: +    case LibFunc_strtof: +    case LibFunc_strtoul: +    case LibFunc_strtoll: +    case LibFunc_strtold: +    case LibFunc_strtoull:        return optimizeStrTo(CI, Builder); -    case LibFunc::strspn: +    case LibFunc_strspn:        return optimizeStrSpn(CI, Builder); -    case LibFunc::strcspn: +    case LibFunc_strcspn:        return optimizeStrCSpn(CI, Builder); -    case LibFunc::strstr: +    case LibFunc_strstr:        return optimizeStrStr(CI, Builder); -    case LibFunc::memchr: +    case LibFunc_memchr:        return optimizeMemChr(CI, Builder); -    case LibFunc::memcmp: +    case LibFunc_memcmp:        return optimizeMemCmp(CI, Builder); -    case LibFunc::memcpy: +    case LibFunc_memcpy:        return optimizeMemCpy(CI, Builder); -    case LibFunc::memmove: +    case LibFunc_memmove:        return optimizeMemMove(CI, Builder); -    case LibFunc::memset: +    case LibFunc_memset:        return optimizeMemSet(CI, Builder);      default:        break; @@ -2005,7 +2021,7 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {    if (CI->isNoBuiltin())      return nullptr; -  LibFunc::Func Func; +  LibFunc Func;    Function *Callee = CI->getCalledFunction();    StringRef FuncName = Callee->getName(); @@ -2029,8 +2045,6 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {        return optimizePow(CI, Builder);      case Intrinsic::exp2:        return optimizeExp2(CI, Builder); -    case Intrinsic::fabs: -      return optimizeFabs(CI, Builder);      case Intrinsic::log:        return optimizeLog(CI, Builder);      case Intrinsic::sqrt: @@ -2067,114 +2081,117 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {      if (Value *V = optimizeStringMemoryLibCall(CI, Builder))        return V;      switch (Func) { -    case LibFunc::cosf: -    case LibFunc::cos: -    case LibFunc::cosl: +    case LibFunc_cosf: +    case LibFunc_cos: +    case LibFunc_cosl:        return optimizeCos(CI, Builder); -    case LibFunc::sinpif: -    case LibFunc::sinpi: -    case LibFunc::cospif: -    case LibFunc::cospi: +    case LibFunc_sinpif: +    case LibFunc_sinpi: +    case LibFunc_cospif: +    case LibFunc_cospi:        return optimizeSinCosPi(CI, Builder); -    case LibFunc::powf: -    case LibFunc::pow: -    case LibFunc::powl: +    case LibFunc_powf: +    case LibFunc_pow: +    case LibFunc_powl:        return optimizePow(CI, Builder); -    case LibFunc::exp2l: -    case LibFunc::exp2: -    case LibFunc::exp2f: +    case LibFunc_exp2l: +    case LibFunc_exp2: +    case LibFunc_exp2f:        return optimizeExp2(CI, Builder); -    case LibFunc::fabsf: -    case LibFunc::fabs: -    case LibFunc::fabsl: -      return optimizeFabs(CI, Builder); -    case LibFunc::sqrtf: -    case LibFunc::sqrt: -    case LibFunc::sqrtl: +    case LibFunc_fabsf: +    case LibFunc_fabs: +    case LibFunc_fabsl: +      return replaceUnaryCall(CI, Builder, Intrinsic::fabs); +    case LibFunc_sqrtf: +    case LibFunc_sqrt: +    case LibFunc_sqrtl:        return optimizeSqrt(CI, Builder); -    case LibFunc::ffs: -    case LibFunc::ffsl: -    case LibFunc::ffsll: +    case LibFunc_ffs: +    case LibFunc_ffsl: +    case LibFunc_ffsll:        return optimizeFFS(CI, Builder); -    case LibFunc::fls: -    case LibFunc::flsl: -    case LibFunc::flsll: +    case LibFunc_fls: +    case LibFunc_flsl: +    case LibFunc_flsll:        return optimizeFls(CI, Builder); -    case LibFunc::abs: -    case LibFunc::labs: -    case LibFunc::llabs: +    case LibFunc_abs: +    case LibFunc_labs: +    case LibFunc_llabs:        return optimizeAbs(CI, Builder); -    case LibFunc::isdigit: +    case LibFunc_isdigit:        return optimizeIsDigit(CI, Builder); -    case LibFunc::isascii: +    case LibFunc_isascii:        return optimizeIsAscii(CI, Builder); -    case LibFunc::toascii: +    case LibFunc_toascii:        return optimizeToAscii(CI, Builder); -    case LibFunc::printf: +    case LibFunc_printf:        return optimizePrintF(CI, Builder); -    case LibFunc::sprintf: +    case LibFunc_sprintf:        return optimizeSPrintF(CI, Builder); -    case LibFunc::fprintf: +    case LibFunc_fprintf:        return optimizeFPrintF(CI, Builder); -    case LibFunc::fwrite: +    case LibFunc_fwrite:        return optimizeFWrite(CI, Builder); -    case LibFunc::fputs: +    case LibFunc_fputs:        return optimizeFPuts(CI, Builder); -    case LibFunc::log: -    case LibFunc::log10: -    case LibFunc::log1p: -    case LibFunc::log2: -    case LibFunc::logb: +    case LibFunc_log: +    case LibFunc_log10: +    case LibFunc_log1p: +    case LibFunc_log2: +    case LibFunc_logb:        return optimizeLog(CI, Builder); -    case LibFunc::puts: +    case LibFunc_puts:        return optimizePuts(CI, Builder); -    case LibFunc::tan: -    case LibFunc::tanf: -    case LibFunc::tanl: +    case LibFunc_tan: +    case LibFunc_tanf: +    case LibFunc_tanl:        return optimizeTan(CI, Builder); -    case LibFunc::perror: +    case LibFunc_perror:        return optimizeErrorReporting(CI, Builder); -    case LibFunc::vfprintf: -    case LibFunc::fiprintf: +    case LibFunc_vfprintf: +    case LibFunc_fiprintf:        return optimizeErrorReporting(CI, Builder, 0); -    case LibFunc::fputc: +    case LibFunc_fputc:        return optimizeErrorReporting(CI, Builder, 1); -    case LibFunc::ceil: -    case LibFunc::floor: -    case LibFunc::rint: -    case LibFunc::round: -    case LibFunc::nearbyint: -    case LibFunc::trunc: -      if (hasFloatVersion(FuncName)) -        return optimizeUnaryDoubleFP(CI, Builder, false); -      return nullptr; -    case LibFunc::acos: -    case LibFunc::acosh: -    case LibFunc::asin: -    case LibFunc::asinh: -    case LibFunc::atan: -    case LibFunc::atanh: -    case LibFunc::cbrt: -    case LibFunc::cosh: -    case LibFunc::exp: -    case LibFunc::exp10: -    case LibFunc::expm1: -    case LibFunc::sin: -    case LibFunc::sinh: -    case LibFunc::tanh: +    case LibFunc_ceil: +      return replaceUnaryCall(CI, Builder, Intrinsic::ceil); +    case LibFunc_floor: +      return replaceUnaryCall(CI, Builder, Intrinsic::floor); +    case LibFunc_round: +      return replaceUnaryCall(CI, Builder, Intrinsic::round); +    case LibFunc_nearbyint: +      return replaceUnaryCall(CI, Builder, Intrinsic::nearbyint); +    case LibFunc_rint: +      return replaceUnaryCall(CI, Builder, Intrinsic::rint); +    case LibFunc_trunc: +      return replaceUnaryCall(CI, Builder, Intrinsic::trunc); +    case LibFunc_acos: +    case LibFunc_acosh: +    case LibFunc_asin: +    case LibFunc_asinh: +    case LibFunc_atan: +    case LibFunc_atanh: +    case LibFunc_cbrt: +    case LibFunc_cosh: +    case LibFunc_exp: +    case LibFunc_exp10: +    case LibFunc_expm1: +    case LibFunc_sin: +    case LibFunc_sinh: +    case LibFunc_tanh:        if (UnsafeFPShrink && hasFloatVersion(FuncName))          return optimizeUnaryDoubleFP(CI, Builder, true);        return nullptr; -    case LibFunc::copysign: +    case LibFunc_copysign:        if (hasFloatVersion(FuncName))          return optimizeBinaryDoubleFP(CI, Builder);        return nullptr; -    case LibFunc::fminf: -    case LibFunc::fmin: -    case LibFunc::fminl: -    case LibFunc::fmaxf: -    case LibFunc::fmax: -    case LibFunc::fmaxl: +    case LibFunc_fminf: +    case LibFunc_fmin: +    case LibFunc_fminl: +    case LibFunc_fmaxf: +    case LibFunc_fmax: +    case LibFunc_fmaxl:        return optimizeFMinFMax(CI, Builder);      default:        return nullptr; @@ -2211,16 +2228,10 @@ void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) {  //   * log(exp10(y)) -> y*log(10)  //   * log(sqrt(x))  -> 0.5*log(x)  // -// lround, lroundf, lroundl: -//   * lround(cnst) -> cnst' -//  // pow, powf, powl:  //   * pow(sqrt(x),y) -> pow(x,y*0.5)  //   * pow(pow(x,y),z)-> pow(x,y*z)  // -// round, roundf, roundl: -//   * round(cnst) -> cnst' -//  // signbit:  //   * signbit(cnst) -> cnst'  //   * signbit(nncst) -> 0 (if pstv is a non-negative constant) @@ -2230,10 +2241,6 @@ void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) {  //   * sqrt(Nroot(x)) -> pow(x,1/(2*N))  //   * sqrt(pow(x,y)) -> pow(|x|,y*0.5)  // -// trunc, truncf, truncl: -//   * trunc(cnst) -> cnst' -// -//  //===----------------------------------------------------------------------===//  // Fortified Library Call Optimizations @@ -2300,7 +2307,7 @@ Value *FortifiedLibCallSimplifier::optimizeMemSetChk(CallInst *CI,  Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI,                                                        IRBuilder<> &B, -                                                      LibFunc::Func Func) { +                                                      LibFunc Func) {    Function *Callee = CI->getCalledFunction();    StringRef Name = Callee->getName();    const DataLayout &DL = CI->getModule()->getDataLayout(); @@ -2308,7 +2315,7 @@ Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI,          *ObjSize = CI->getArgOperand(2);    // __stpcpy_chk(x,x,...)  -> x+strlen(x) -  if (Func == LibFunc::stpcpy_chk && !OnlyLowerUnknownSize && Dst == Src) { +  if (Func == LibFunc_stpcpy_chk && !OnlyLowerUnknownSize && Dst == Src) {      Value *StrLen = emitStrLen(Src, B, DL, TLI);      return StrLen ? B.CreateInBoundsGEP(B.getInt8Ty(), Dst, StrLen) : nullptr;    } @@ -2334,14 +2341,14 @@ Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI,    Value *Ret = emitMemCpyChk(Dst, Src, LenV, ObjSize, B, DL, TLI);    // If the function was an __stpcpy_chk, and we were able to fold it into    // a __memcpy_chk, we still need to return the correct end pointer. -  if (Ret && Func == LibFunc::stpcpy_chk) +  if (Ret && Func == LibFunc_stpcpy_chk)      return B.CreateGEP(B.getInt8Ty(), Dst, ConstantInt::get(SizeTTy, Len - 1));    return Ret;  }  Value *FortifiedLibCallSimplifier::optimizeStrpNCpyChk(CallInst *CI,                                                         IRBuilder<> &B, -                                                       LibFunc::Func Func) { +                                                       LibFunc Func) {    Function *Callee = CI->getCalledFunction();    StringRef Name = Callee->getName();    if (isFortifiedCallFoldable(CI, 3, 2, false)) { @@ -2366,7 +2373,7 @@ Value *FortifiedLibCallSimplifier::optimizeCall(CallInst *CI) {    //    // PR23093. -  LibFunc::Func Func; +  LibFunc Func;    Function *Callee = CI->getCalledFunction();    SmallVector<OperandBundleDef, 2> OpBundles; @@ -2384,17 +2391,17 @@ Value *FortifiedLibCallSimplifier::optimizeCall(CallInst *CI) {      return nullptr;    switch (Func) { -  case LibFunc::memcpy_chk: +  case LibFunc_memcpy_chk:      return optimizeMemCpyChk(CI, Builder); -  case LibFunc::memmove_chk: +  case LibFunc_memmove_chk:      return optimizeMemMoveChk(CI, Builder); -  case LibFunc::memset_chk: +  case LibFunc_memset_chk:      return optimizeMemSetChk(CI, Builder); -  case LibFunc::stpcpy_chk: -  case LibFunc::strcpy_chk: +  case LibFunc_stpcpy_chk: +  case LibFunc_strcpy_chk:      return optimizeStrpCpyChk(CI, Builder, Func); -  case LibFunc::stpncpy_chk: -  case LibFunc::strncpy_chk: +  case LibFunc_stpncpy_chk: +  case LibFunc_strncpy_chk:      return optimizeStrpNCpyChk(CI, Builder, Func);    default:      break; diff --git a/lib/Transforms/Utils/Utils.cpp b/lib/Transforms/Utils/Utils.cpp index 7b9de2eadc61..7106483c3bd2 100644 --- a/lib/Transforms/Utils/Utils.cpp +++ b/lib/Transforms/Utils/Utils.cpp @@ -35,9 +35,8 @@ void llvm::initializeTransformUtils(PassRegistry &Registry) {    initializeUnifyFunctionExitNodesPass(Registry);    initializeInstSimplifierPass(Registry);    initializeMetaRenamerPass(Registry); -  initializeMemorySSAWrapperPassPass(Registry); -  initializeMemorySSAPrinterLegacyPassPass(Registry);    initializeStripGCRelocatesPass(Registry); +  initializePredicateInfoPrinterLegacyPassPass(Registry);  }  /// LLVMInitializeTransformUtils - C binding for initializeTransformUtilsPasses. diff --git a/lib/Transforms/Utils/VNCoercion.cpp b/lib/Transforms/Utils/VNCoercion.cpp new file mode 100644 index 000000000000..4aeea02b1b1b --- /dev/null +++ b/lib/Transforms/Utils/VNCoercion.cpp @@ -0,0 +1,482 @@ +#include "llvm/Transforms/Utils/VNCoercion.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/MemoryDependenceAnalysis.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "vncoerce" +namespace llvm { +namespace VNCoercion { + +/// Return true if coerceAvailableValueToLoadType will succeed. +bool canCoerceMustAliasedValueToLoad(Value *StoredVal, Type *LoadTy, +                                     const DataLayout &DL) { +  // If the loaded or stored value is an first class array or struct, don't try +  // to transform them.  We need to be able to bitcast to integer. +  if (LoadTy->isStructTy() || LoadTy->isArrayTy() || +      StoredVal->getType()->isStructTy() || StoredVal->getType()->isArrayTy()) +    return false; + +  // The store has to be at least as big as the load. +  if (DL.getTypeSizeInBits(StoredVal->getType()) < DL.getTypeSizeInBits(LoadTy)) +    return false; + +  return true; +} + +template <class T, class HelperClass> +static T *coerceAvailableValueToLoadTypeHelper(T *StoredVal, Type *LoadedTy, +                                               HelperClass &Helper, +                                               const DataLayout &DL) { +  assert(canCoerceMustAliasedValueToLoad(StoredVal, LoadedTy, DL) && +         "precondition violation - materialization can't fail"); +  if (auto *C = dyn_cast<Constant>(StoredVal)) +    if (auto *FoldedStoredVal = ConstantFoldConstant(C, DL)) +      StoredVal = FoldedStoredVal; + +  // If this is already the right type, just return it. +  Type *StoredValTy = StoredVal->getType(); + +  uint64_t StoredValSize = DL.getTypeSizeInBits(StoredValTy); +  uint64_t LoadedValSize = DL.getTypeSizeInBits(LoadedTy); + +  // If the store and reload are the same size, we can always reuse it. +  if (StoredValSize == LoadedValSize) { +    // Pointer to Pointer -> use bitcast. +    if (StoredValTy->getScalarType()->isPointerTy() && +        LoadedTy->getScalarType()->isPointerTy()) { +      StoredVal = Helper.CreateBitCast(StoredVal, LoadedTy); +    } else { +      // Convert source pointers to integers, which can be bitcast. +      if (StoredValTy->getScalarType()->isPointerTy()) { +        StoredValTy = DL.getIntPtrType(StoredValTy); +        StoredVal = Helper.CreatePtrToInt(StoredVal, StoredValTy); +      } + +      Type *TypeToCastTo = LoadedTy; +      if (TypeToCastTo->getScalarType()->isPointerTy()) +        TypeToCastTo = DL.getIntPtrType(TypeToCastTo); + +      if (StoredValTy != TypeToCastTo) +        StoredVal = Helper.CreateBitCast(StoredVal, TypeToCastTo); + +      // Cast to pointer if the load needs a pointer type. +      if (LoadedTy->getScalarType()->isPointerTy()) +        StoredVal = Helper.CreateIntToPtr(StoredVal, LoadedTy); +    } + +    if (auto *C = dyn_cast<ConstantExpr>(StoredVal)) +      if (auto *FoldedStoredVal = ConstantFoldConstant(C, DL)) +        StoredVal = FoldedStoredVal; + +    return StoredVal; +  } +  // If the loaded value is smaller than the available value, then we can +  // extract out a piece from it.  If the available value is too small, then we +  // can't do anything. +  assert(StoredValSize >= LoadedValSize && +         "canCoerceMustAliasedValueToLoad fail"); + +  // Convert source pointers to integers, which can be manipulated. +  if (StoredValTy->getScalarType()->isPointerTy()) { +    StoredValTy = DL.getIntPtrType(StoredValTy); +    StoredVal = Helper.CreatePtrToInt(StoredVal, StoredValTy); +  } + +  // Convert vectors and fp to integer, which can be manipulated. +  if (!StoredValTy->isIntegerTy()) { +    StoredValTy = IntegerType::get(StoredValTy->getContext(), StoredValSize); +    StoredVal = Helper.CreateBitCast(StoredVal, StoredValTy); +  } + +  // If this is a big-endian system, we need to shift the value down to the low +  // bits so that a truncate will work. +  if (DL.isBigEndian()) { +    uint64_t ShiftAmt = DL.getTypeStoreSizeInBits(StoredValTy) - +                        DL.getTypeStoreSizeInBits(LoadedTy); +    StoredVal = Helper.CreateLShr( +        StoredVal, ConstantInt::get(StoredVal->getType(), ShiftAmt)); +  } + +  // Truncate the integer to the right size now. +  Type *NewIntTy = IntegerType::get(StoredValTy->getContext(), LoadedValSize); +  StoredVal = Helper.CreateTruncOrBitCast(StoredVal, NewIntTy); + +  if (LoadedTy != NewIntTy) { +    // If the result is a pointer, inttoptr. +    if (LoadedTy->getScalarType()->isPointerTy()) +      StoredVal = Helper.CreateIntToPtr(StoredVal, LoadedTy); +    else +      // Otherwise, bitcast. +      StoredVal = Helper.CreateBitCast(StoredVal, LoadedTy); +  } + +  if (auto *C = dyn_cast<Constant>(StoredVal)) +    if (auto *FoldedStoredVal = ConstantFoldConstant(C, DL)) +      StoredVal = FoldedStoredVal; + +  return StoredVal; +} + +/// If we saw a store of a value to memory, and +/// then a load from a must-aliased pointer of a different type, try to coerce +/// the stored value.  LoadedTy is the type of the load we want to replace. +/// IRB is IRBuilder used to insert new instructions. +/// +/// If we can't do it, return null. +Value *coerceAvailableValueToLoadType(Value *StoredVal, Type *LoadedTy, +                                      IRBuilder<> &IRB, const DataLayout &DL) { +  return coerceAvailableValueToLoadTypeHelper(StoredVal, LoadedTy, IRB, DL); +} + +/// This function is called when we have a memdep query of a load that ends up +/// being a clobbering memory write (store, memset, memcpy, memmove).  This +/// means that the write *may* provide bits used by the load but we can't be +/// sure because the pointers don't must-alias. +/// +/// Check this case to see if there is anything more we can do before we give +/// up.  This returns -1 if we have to give up, or a byte number in the stored +/// value of the piece that feeds the load. +static int analyzeLoadFromClobberingWrite(Type *LoadTy, Value *LoadPtr, +                                          Value *WritePtr, +                                          uint64_t WriteSizeInBits, +                                          const DataLayout &DL) { +  // If the loaded or stored value is a first class array or struct, don't try +  // to transform them.  We need to be able to bitcast to integer. +  if (LoadTy->isStructTy() || LoadTy->isArrayTy()) +    return -1; + +  int64_t StoreOffset = 0, LoadOffset = 0; +  Value *StoreBase = +      GetPointerBaseWithConstantOffset(WritePtr, StoreOffset, DL); +  Value *LoadBase = GetPointerBaseWithConstantOffset(LoadPtr, LoadOffset, DL); +  if (StoreBase != LoadBase) +    return -1; + +  // If the load and store are to the exact same address, they should have been +  // a must alias.  AA must have gotten confused. +  // FIXME: Study to see if/when this happens.  One case is forwarding a memset +  // to a load from the base of the memset. + +  // If the load and store don't overlap at all, the store doesn't provide +  // anything to the load.  In this case, they really don't alias at all, AA +  // must have gotten confused. +  uint64_t LoadSize = DL.getTypeSizeInBits(LoadTy); + +  if ((WriteSizeInBits & 7) | (LoadSize & 7)) +    return -1; +  uint64_t StoreSize = WriteSizeInBits / 8; // Convert to bytes. +  LoadSize /= 8; + +  bool isAAFailure = false; +  if (StoreOffset < LoadOffset) +    isAAFailure = StoreOffset + int64_t(StoreSize) <= LoadOffset; +  else +    isAAFailure = LoadOffset + int64_t(LoadSize) <= StoreOffset; + +  if (isAAFailure) +    return -1; + +  // If the Load isn't completely contained within the stored bits, we don't +  // have all the bits to feed it.  We could do something crazy in the future +  // (issue a smaller load then merge the bits in) but this seems unlikely to be +  // valuable. +  if (StoreOffset > LoadOffset || +      StoreOffset + StoreSize < LoadOffset + LoadSize) +    return -1; + +  // Okay, we can do this transformation.  Return the number of bytes into the +  // store that the load is. +  return LoadOffset - StoreOffset; +} + +/// This function is called when we have a +/// memdep query of a load that ends up being a clobbering store. +int analyzeLoadFromClobberingStore(Type *LoadTy, Value *LoadPtr, +                                   StoreInst *DepSI, const DataLayout &DL) { +  // Cannot handle reading from store of first-class aggregate yet. +  if (DepSI->getValueOperand()->getType()->isStructTy() || +      DepSI->getValueOperand()->getType()->isArrayTy()) +    return -1; + +  Value *StorePtr = DepSI->getPointerOperand(); +  uint64_t StoreSize = +      DL.getTypeSizeInBits(DepSI->getValueOperand()->getType()); +  return analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, StorePtr, StoreSize, +                                        DL); +} + +/// This function is called when we have a +/// memdep query of a load that ends up being clobbered by another load.  See if +/// the other load can feed into the second load. +int analyzeLoadFromClobberingLoad(Type *LoadTy, Value *LoadPtr, LoadInst *DepLI, +                                  const DataLayout &DL) { +  // Cannot handle reading from store of first-class aggregate yet. +  if (DepLI->getType()->isStructTy() || DepLI->getType()->isArrayTy()) +    return -1; + +  Value *DepPtr = DepLI->getPointerOperand(); +  uint64_t DepSize = DL.getTypeSizeInBits(DepLI->getType()); +  int R = analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, DepPtr, DepSize, DL); +  if (R != -1) +    return R; + +  // If we have a load/load clobber an DepLI can be widened to cover this load, +  // then we should widen it! +  int64_t LoadOffs = 0; +  const Value *LoadBase = +      GetPointerBaseWithConstantOffset(LoadPtr, LoadOffs, DL); +  unsigned LoadSize = DL.getTypeStoreSize(LoadTy); + +  unsigned Size = MemoryDependenceResults::getLoadLoadClobberFullWidthSize( +      LoadBase, LoadOffs, LoadSize, DepLI); +  if (Size == 0) +    return -1; + +  // Check non-obvious conditions enforced by MDA which we rely on for being +  // able to materialize this potentially available value +  assert(DepLI->isSimple() && "Cannot widen volatile/atomic load!"); +  assert(DepLI->getType()->isIntegerTy() && "Can't widen non-integer load"); + +  return analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, DepPtr, Size * 8, DL); +} + +int analyzeLoadFromClobberingMemInst(Type *LoadTy, Value *LoadPtr, +                                     MemIntrinsic *MI, const DataLayout &DL) { +  // If the mem operation is a non-constant size, we can't handle it. +  ConstantInt *SizeCst = dyn_cast<ConstantInt>(MI->getLength()); +  if (!SizeCst) +    return -1; +  uint64_t MemSizeInBits = SizeCst->getZExtValue() * 8; + +  // If this is memset, we just need to see if the offset is valid in the size +  // of the memset.. +  if (MI->getIntrinsicID() == Intrinsic::memset) +    return analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, MI->getDest(), +                                          MemSizeInBits, DL); + +  // If we have a memcpy/memmove, the only case we can handle is if this is a +  // copy from constant memory.  In that case, we can read directly from the +  // constant memory. +  MemTransferInst *MTI = cast<MemTransferInst>(MI); + +  Constant *Src = dyn_cast<Constant>(MTI->getSource()); +  if (!Src) +    return -1; + +  GlobalVariable *GV = dyn_cast<GlobalVariable>(GetUnderlyingObject(Src, DL)); +  if (!GV || !GV->isConstant()) +    return -1; + +  // See if the access is within the bounds of the transfer. +  int Offset = analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, MI->getDest(), +                                              MemSizeInBits, DL); +  if (Offset == -1) +    return Offset; + +  unsigned AS = Src->getType()->getPointerAddressSpace(); +  // Otherwise, see if we can constant fold a load from the constant with the +  // offset applied as appropriate. +  Src = +      ConstantExpr::getBitCast(Src, Type::getInt8PtrTy(Src->getContext(), AS)); +  Constant *OffsetCst = +      ConstantInt::get(Type::getInt64Ty(Src->getContext()), (unsigned)Offset); +  Src = ConstantExpr::getGetElementPtr(Type::getInt8Ty(Src->getContext()), Src, +                                       OffsetCst); +  Src = ConstantExpr::getBitCast(Src, PointerType::get(LoadTy, AS)); +  if (ConstantFoldLoadFromConstPtr(Src, LoadTy, DL)) +    return Offset; +  return -1; +} + +template <class T, class HelperClass> +static T *getStoreValueForLoadHelper(T *SrcVal, unsigned Offset, Type *LoadTy, +                                     HelperClass &Helper, +                                     const DataLayout &DL) { +  LLVMContext &Ctx = SrcVal->getType()->getContext(); + +  uint64_t StoreSize = (DL.getTypeSizeInBits(SrcVal->getType()) + 7) / 8; +  uint64_t LoadSize = (DL.getTypeSizeInBits(LoadTy) + 7) / 8; +  // Compute which bits of the stored value are being used by the load.  Convert +  // to an integer type to start with. +  if (SrcVal->getType()->getScalarType()->isPointerTy()) +    SrcVal = Helper.CreatePtrToInt(SrcVal, DL.getIntPtrType(SrcVal->getType())); +  if (!SrcVal->getType()->isIntegerTy()) +    SrcVal = Helper.CreateBitCast(SrcVal, IntegerType::get(Ctx, StoreSize * 8)); + +  // Shift the bits to the least significant depending on endianness. +  unsigned ShiftAmt; +  if (DL.isLittleEndian()) +    ShiftAmt = Offset * 8; +  else +    ShiftAmt = (StoreSize - LoadSize - Offset) * 8; +  if (ShiftAmt) +    SrcVal = Helper.CreateLShr(SrcVal, +                               ConstantInt::get(SrcVal->getType(), ShiftAmt)); + +  if (LoadSize != StoreSize) +    SrcVal = Helper.CreateTruncOrBitCast(SrcVal, +                                         IntegerType::get(Ctx, LoadSize * 8)); +  return SrcVal; +} + +/// This function is called when we have a memdep query of a load that ends up +/// being a clobbering store.  This means that the store provides bits used by +/// the load but the pointers don't must-alias.  Check this case to see if +/// there is anything more we can do before we give up. +Value *getStoreValueForLoad(Value *SrcVal, unsigned Offset, Type *LoadTy, +                            Instruction *InsertPt, const DataLayout &DL) { + +  IRBuilder<> Builder(InsertPt); +  SrcVal = getStoreValueForLoadHelper(SrcVal, Offset, LoadTy, Builder, DL); +  return coerceAvailableValueToLoadTypeHelper(SrcVal, LoadTy, Builder, DL); +} + +Constant *getConstantStoreValueForLoad(Constant *SrcVal, unsigned Offset, +                                       Type *LoadTy, const DataLayout &DL) { +  ConstantFolder F; +  SrcVal = getStoreValueForLoadHelper(SrcVal, Offset, LoadTy, F, DL); +  return coerceAvailableValueToLoadTypeHelper(SrcVal, LoadTy, F, DL); +} + +/// This function is called when we have a memdep query of a load that ends up +/// being a clobbering load.  This means that the load *may* provide bits used +/// by the load but we can't be sure because the pointers don't must-alias. +/// Check this case to see if there is anything more we can do before we give +/// up. +Value *getLoadValueForLoad(LoadInst *SrcVal, unsigned Offset, Type *LoadTy, +                           Instruction *InsertPt, const DataLayout &DL) { +  // If Offset+LoadTy exceeds the size of SrcVal, then we must be wanting to +  // widen SrcVal out to a larger load. +  unsigned SrcValStoreSize = DL.getTypeStoreSize(SrcVal->getType()); +  unsigned LoadSize = DL.getTypeStoreSize(LoadTy); +  if (Offset + LoadSize > SrcValStoreSize) { +    assert(SrcVal->isSimple() && "Cannot widen volatile/atomic load!"); +    assert(SrcVal->getType()->isIntegerTy() && "Can't widen non-integer load"); +    // If we have a load/load clobber an DepLI can be widened to cover this +    // load, then we should widen it to the next power of 2 size big enough! +    unsigned NewLoadSize = Offset + LoadSize; +    if (!isPowerOf2_32(NewLoadSize)) +      NewLoadSize = NextPowerOf2(NewLoadSize); + +    Value *PtrVal = SrcVal->getPointerOperand(); +    // Insert the new load after the old load.  This ensures that subsequent +    // memdep queries will find the new load.  We can't easily remove the old +    // load completely because it is already in the value numbering table. +    IRBuilder<> Builder(SrcVal->getParent(), ++BasicBlock::iterator(SrcVal)); +    Type *DestPTy = IntegerType::get(LoadTy->getContext(), NewLoadSize * 8); +    DestPTy = +        PointerType::get(DestPTy, PtrVal->getType()->getPointerAddressSpace()); +    Builder.SetCurrentDebugLocation(SrcVal->getDebugLoc()); +    PtrVal = Builder.CreateBitCast(PtrVal, DestPTy); +    LoadInst *NewLoad = Builder.CreateLoad(PtrVal); +    NewLoad->takeName(SrcVal); +    NewLoad->setAlignment(SrcVal->getAlignment()); + +    DEBUG(dbgs() << "GVN WIDENED LOAD: " << *SrcVal << "\n"); +    DEBUG(dbgs() << "TO: " << *NewLoad << "\n"); + +    // Replace uses of the original load with the wider load.  On a big endian +    // system, we need to shift down to get the relevant bits. +    Value *RV = NewLoad; +    if (DL.isBigEndian()) +      RV = Builder.CreateLShr(RV, (NewLoadSize - SrcValStoreSize) * 8); +    RV = Builder.CreateTrunc(RV, SrcVal->getType()); +    SrcVal->replaceAllUsesWith(RV); + +    SrcVal = NewLoad; +  } + +  return getStoreValueForLoad(SrcVal, Offset, LoadTy, InsertPt, DL); +} + +Constant *getConstantLoadValueForLoad(Constant *SrcVal, unsigned Offset, +                                      Type *LoadTy, const DataLayout &DL) { +  unsigned SrcValStoreSize = DL.getTypeStoreSize(SrcVal->getType()); +  unsigned LoadSize = DL.getTypeStoreSize(LoadTy); +  if (Offset + LoadSize > SrcValStoreSize) +    return nullptr; +  return getConstantStoreValueForLoad(SrcVal, Offset, LoadTy, DL); +} + +template <class T, class HelperClass> +T *getMemInstValueForLoadHelper(MemIntrinsic *SrcInst, unsigned Offset, +                                Type *LoadTy, HelperClass &Helper, +                                const DataLayout &DL) { +  LLVMContext &Ctx = LoadTy->getContext(); +  uint64_t LoadSize = DL.getTypeSizeInBits(LoadTy) / 8; + +  // We know that this method is only called when the mem transfer fully +  // provides the bits for the load. +  if (MemSetInst *MSI = dyn_cast<MemSetInst>(SrcInst)) { +    // memset(P, 'x', 1234) -> splat('x'), even if x is a variable, and +    // independently of what the offset is. +    T *Val = cast<T>(MSI->getValue()); +    if (LoadSize != 1) +      Val = +          Helper.CreateZExtOrBitCast(Val, IntegerType::get(Ctx, LoadSize * 8)); +    T *OneElt = Val; + +    // Splat the value out to the right number of bits. +    for (unsigned NumBytesSet = 1; NumBytesSet != LoadSize;) { +      // If we can double the number of bytes set, do it. +      if (NumBytesSet * 2 <= LoadSize) { +        T *ShVal = Helper.CreateShl( +            Val, ConstantInt::get(Val->getType(), NumBytesSet * 8)); +        Val = Helper.CreateOr(Val, ShVal); +        NumBytesSet <<= 1; +        continue; +      } + +      // Otherwise insert one byte at a time. +      T *ShVal = Helper.CreateShl(Val, ConstantInt::get(Val->getType(), 1 * 8)); +      Val = Helper.CreateOr(OneElt, ShVal); +      ++NumBytesSet; +    } + +    return coerceAvailableValueToLoadTypeHelper(Val, LoadTy, Helper, DL); +  } + +  // Otherwise, this is a memcpy/memmove from a constant global. +  MemTransferInst *MTI = cast<MemTransferInst>(SrcInst); +  Constant *Src = cast<Constant>(MTI->getSource()); +  unsigned AS = Src->getType()->getPointerAddressSpace(); + +  // Otherwise, see if we can constant fold a load from the constant with the +  // offset applied as appropriate. +  Src = +      ConstantExpr::getBitCast(Src, Type::getInt8PtrTy(Src->getContext(), AS)); +  Constant *OffsetCst = +      ConstantInt::get(Type::getInt64Ty(Src->getContext()), (unsigned)Offset); +  Src = ConstantExpr::getGetElementPtr(Type::getInt8Ty(Src->getContext()), Src, +                                       OffsetCst); +  Src = ConstantExpr::getBitCast(Src, PointerType::get(LoadTy, AS)); +  return ConstantFoldLoadFromConstPtr(Src, LoadTy, DL); +} + +/// This function is called when we have a +/// memdep query of a load that ends up being a clobbering mem intrinsic. +Value *getMemInstValueForLoad(MemIntrinsic *SrcInst, unsigned Offset, +                              Type *LoadTy, Instruction *InsertPt, +                              const DataLayout &DL) { +  IRBuilder<> Builder(InsertPt); +  return getMemInstValueForLoadHelper<Value, IRBuilder<>>(SrcInst, Offset, +                                                          LoadTy, Builder, DL); +} + +Constant *getConstantMemInstValueForLoad(MemIntrinsic *SrcInst, unsigned Offset, +                                         Type *LoadTy, const DataLayout &DL) { +  // The only case analyzeLoadFromClobberingMemInst cannot be converted to a +  // constant is when it's a memset of a non-constant. +  if (auto *MSI = dyn_cast<MemSetInst>(SrcInst)) +    if (!isa<Constant>(MSI->getValue())) +      return nullptr; +  ConstantFolder F; +  return getMemInstValueForLoadHelper<Constant, ConstantFolder>(SrcInst, Offset, +                                                                LoadTy, F, DL); +} +} // namespace VNCoercion +} // namespace llvm diff --git a/lib/Transforms/Utils/ValueMapper.cpp b/lib/Transforms/Utils/ValueMapper.cpp index 0e9baaf8649d..f77c10b6dd47 100644 --- a/lib/Transforms/Utils/ValueMapper.cpp +++ b/lib/Transforms/Utils/ValueMapper.cpp @@ -681,6 +681,7 @@ void MDNodeMapper::mapNodesInPOT(UniquedGraph &G) {      remapOperands(*ClonedN, [this, &D, &G](Metadata *Old) {        if (Optional<Metadata *> MappedOp = getMappedOp(Old))          return *MappedOp; +      (void)D;        assert(G.Info[Old].ID > D.ID && "Expected a forward reference");        return &G.getFwdReference(*cast<MDNode>(Old));      }); diff --git a/lib/Transforms/Vectorize/BBVectorize.cpp b/lib/Transforms/Vectorize/BBVectorize.cpp index c01740b27d59..c83b3f7b225b 100644 --- a/lib/Transforms/Vectorize/BBVectorize.cpp +++ b/lib/Transforms/Vectorize/BBVectorize.cpp @@ -494,13 +494,13 @@ namespace {        if (StoreInst *SI = dyn_cast<StoreInst>(I)) {          // For stores, it is the value type, not the pointer type that matters          // because the value is what will come from a vector register. -   +          Value *IVal = SI->getValueOperand();          T1 = IVal->getType();        } else {          T1 = I->getType();        } -   +        if (CastInst *CI = dyn_cast<CastInst>(I))          T2 = CI->getSrcTy();        else @@ -547,10 +547,11 @@ namespace {      // Returns the cost of the provided instruction using TTI.      // This does not handle loads and stores.      unsigned getInstrCost(unsigned Opcode, Type *T1, Type *T2, -                          TargetTransformInfo::OperandValueKind Op1VK =  +                          TargetTransformInfo::OperandValueKind Op1VK =                                TargetTransformInfo::OK_AnyValue,                            TargetTransformInfo::OperandValueKind Op2VK = -                              TargetTransformInfo::OK_AnyValue) { +                              TargetTransformInfo::OK_AnyValue, +                          const Instruction *I = nullptr) {        switch (Opcode) {        default: break;        case Instruction::GetElementPtr: @@ -584,7 +585,7 @@ namespace {        case Instruction::Select:        case Instruction::ICmp:        case Instruction::FCmp: -        return TTI->getCmpSelInstrCost(Opcode, T1, T2); +        return TTI->getCmpSelInstrCost(Opcode, T1, T2, I);        case Instruction::ZExt:        case Instruction::SExt:        case Instruction::FPToUI: @@ -598,7 +599,7 @@ namespace {        case Instruction::FPTrunc:        case Instruction::BitCast:        case Instruction::ShuffleVector: -        return TTI->getCastInstrCost(Opcode, T1, T2); +        return TTI->getCastInstrCost(Opcode, T1, T2, I);        }        return 1; @@ -894,7 +895,7 @@ namespace {        // vectors that has a scalar condition results in a malformed select.        // FIXME: We could probably be smarter about this by rewriting the select        // with different types instead. -      return (SI->getCondition()->getType()->isVectorTy() ==  +      return (SI->getCondition()->getType()->isVectorTy() ==                SI->getTrueValue()->getType()->isVectorTy());      } else if (isa<CmpInst>(I)) {        if (!Config.VectorizeCmp) @@ -1044,14 +1045,14 @@ namespace {          return false;        }      } else if (TTI) { -      unsigned ICost = getInstrCost(I->getOpcode(), IT1, IT2); -      unsigned JCost = getInstrCost(J->getOpcode(), JT1, JT2); -      Type *VT1 = getVecTypeForPair(IT1, JT1), -           *VT2 = getVecTypeForPair(IT2, JT2);        TargetTransformInfo::OperandValueKind Op1VK =            TargetTransformInfo::OK_AnyValue;        TargetTransformInfo::OperandValueKind Op2VK =            TargetTransformInfo::OK_AnyValue; +      unsigned ICost = getInstrCost(I->getOpcode(), IT1, IT2, Op1VK, Op2VK, I); +      unsigned JCost = getInstrCost(J->getOpcode(), JT1, JT2, Op1VK, Op2VK, J); +      Type *VT1 = getVecTypeForPair(IT1, JT1), +           *VT2 = getVecTypeForPair(IT2, JT2);        // On some targets (example X86) the cost of a vector shift may vary        // depending on whether the second operand is a Uniform or @@ -1090,7 +1091,7 @@ namespace {        // but this cost is ignored (because insert and extract element        // instructions are assigned a zero depth factor and are not really        // fused in general). -      unsigned VCost = getInstrCost(I->getOpcode(), VT1, VT2, Op1VK, Op2VK); +      unsigned VCost = getInstrCost(I->getOpcode(), VT1, VT2, Op1VK, Op2VK, I);        if (VCost > ICost + JCost)          return false; @@ -1127,39 +1128,51 @@ namespace {          FastMathFlags FMFCI;          if (auto *FPMOCI = dyn_cast<FPMathOperator>(CI))            FMFCI = FPMOCI->getFastMathFlags(); +        SmallVector<Value *, 4> IArgs(CI->arg_operands()); +        unsigned ICost = TTI->getIntrinsicInstrCost(IID, IT1, IArgs, FMFCI); -        SmallVector<Type*, 4> Tys; -        for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i) -          Tys.push_back(CI->getArgOperand(i)->getType()); -        unsigned ICost = TTI->getIntrinsicInstrCost(IID, IT1, Tys, FMFCI); - -        Tys.clear();          CallInst *CJ = cast<CallInst>(J);          FastMathFlags FMFCJ;          if (auto *FPMOCJ = dyn_cast<FPMathOperator>(CJ))            FMFCJ = FPMOCJ->getFastMathFlags(); -        for (unsigned i = 0, ie = CJ->getNumArgOperands(); i != ie; ++i) -          Tys.push_back(CJ->getArgOperand(i)->getType()); -        unsigned JCost = TTI->getIntrinsicInstrCost(IID, JT1, Tys, FMFCJ); +        SmallVector<Value *, 4> JArgs(CJ->arg_operands()); +        unsigned JCost = TTI->getIntrinsicInstrCost(IID, JT1, JArgs, FMFCJ); -        Tys.clear();          assert(CI->getNumArgOperands() == CJ->getNumArgOperands() &&                 "Intrinsic argument counts differ"); +        SmallVector<Type*, 4> Tys; +        SmallVector<Value *, 4> VecArgs;          for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i) {            if ((IID == Intrinsic::powi || IID == Intrinsic::ctlz || -               IID == Intrinsic::cttz) && i == 1) +               IID == Intrinsic::cttz) && i == 1) {              Tys.push_back(CI->getArgOperand(i)->getType()); -          else +            VecArgs.push_back(CI->getArgOperand(i)); +          } +          else {              Tys.push_back(getVecTypeForPair(CI->getArgOperand(i)->getType(),                                              CJ->getArgOperand(i)->getType())); +            // Add both operands, and then count their scalarization overhead +            // with VF 1. +            VecArgs.push_back(CI->getArgOperand(i)); +            VecArgs.push_back(CJ->getArgOperand(i)); +          }          } +        // Compute the scalarization cost here with the original operands (to +        // check for uniqueness etc), and then call getIntrinsicInstrCost() +        // with the constructed vector types. +        Type *RetTy = getVecTypeForPair(IT1, JT1); +        unsigned ScalarizationCost = 0; +        if (!RetTy->isVoidTy()) +          ScalarizationCost += TTI->getScalarizationOverhead(RetTy, true, false); +        ScalarizationCost += TTI->getOperandsScalarizationOverhead(VecArgs, 1); +          FastMathFlags FMFV = FMFCI;          FMFV &= FMFCJ; -        Type *RetTy = getVecTypeForPair(IT1, JT1); -        unsigned VCost = TTI->getIntrinsicInstrCost(IID, RetTy, Tys, FMFV); +        unsigned VCost = TTI->getIntrinsicInstrCost(IID, RetTy, Tys, FMFV, +                                                    ScalarizationCost);          if (VCost > ICost + JCost)            return false; @@ -2502,7 +2515,7 @@ namespace {          if (I2 == I1 || isa<UndefValue>(I2))            I2 = nullptr;        } -   +        if (HEE) {          Value *I3 = HEE->getOperand(0);          if (!I2 && I3 != I1) @@ -2693,14 +2706,14 @@ namespace {          // so extend the smaller vector to be the same length as the larger one.          Instruction *NLOp;          if (numElemL > 1) { -   +            std::vector<Constant *> Mask(numElemH);            unsigned v = 0;            for (; v < numElemL; ++v)              Mask[v] = ConstantInt::get(Type::getInt32Ty(Context), v);            for (; v < numElemH; ++v)              Mask[v] = UndefValue::get(Type::getInt32Ty(Context)); -     +            NLOp = new ShuffleVectorInst(LOp, UndefValue::get(ArgTypeL),                                         ConstantVector::get(Mask),                                         getReplacementName(IBeforeJ ? I : J, @@ -2710,7 +2723,7 @@ namespace {                                             getReplacementName(IBeforeJ ? I : J,                                                                true, o, 1));          } -   +          NLOp->insertBefore(IBeforeJ ? J : I);          LOp = NLOp;        } @@ -2720,7 +2733,7 @@ namespace {        if (numElemH == 1 && expandIEChain(Context, I, J, o, LOp, numElemL,                                           ArgTypeH, VArgType, IBeforeJ)) {          Instruction *S = -          InsertElementInst::Create(LOp, HOp,  +          InsertElementInst::Create(LOp, HOp,                                      ConstantInt::get(Type::getInt32Ty(Context),                                                       numElemL),                                      getReplacementName(IBeforeJ ? I : J, @@ -2737,7 +2750,7 @@ namespace {              Mask[v] = ConstantInt::get(Type::getInt32Ty(Context), v);            for (; v < numElemL; ++v)              Mask[v] = UndefValue::get(Type::getInt32Ty(Context)); -     +            NHOp = new ShuffleVectorInst(HOp, UndefValue::get(ArgTypeH),                                         ConstantVector::get(Mask),                                         getReplacementName(IBeforeJ ? I : J, diff --git a/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp index c44a393cf846..4409d7a404f8 100644 --- a/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ b/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -432,9 +432,12 @@ Vectorizer::splitOddVectorElts(ArrayRef<Instruction *> Chain,    unsigned ElementSizeBytes = ElementSizeBits / 8;    unsigned SizeBytes = ElementSizeBytes * Chain.size();    unsigned NumLeft = (SizeBytes - (SizeBytes % 4)) / ElementSizeBytes; -  if (NumLeft == Chain.size()) -    --NumLeft; -  else if (NumLeft == 0) +  if (NumLeft == Chain.size()) { +    if ((NumLeft & 1) == 0) +      NumLeft /= 2; // Split even in half +    else +      --NumLeft;    // Split off last element +  } else if (NumLeft == 0)      NumLeft = 1;    return std::make_pair(Chain.slice(0, NumLeft), Chain.slice(NumLeft));  } @@ -588,7 +591,7 @@ Vectorizer::collectInstructions(BasicBlock *BB) {          continue;        // Make sure all the users of a vector are constant-index extracts. -      if (isa<VectorType>(Ty) && !all_of(LI->users(), [LI](const User *U) { +      if (isa<VectorType>(Ty) && !all_of(LI->users(), [](const User *U) {              const ExtractElementInst *EEI = dyn_cast<ExtractElementInst>(U);              return EEI && isa<ConstantInt>(EEI->getOperand(1));            })) @@ -622,7 +625,7 @@ Vectorizer::collectInstructions(BasicBlock *BB) {        if (TySize > VecRegSize / 2)          continue; -      if (isa<VectorType>(Ty) && !all_of(SI->users(), [SI](const User *U) { +      if (isa<VectorType>(Ty) && !all_of(SI->users(), [](const User *U) {              const ExtractElementInst *EEI = dyn_cast<ExtractElementInst>(U);              return EEI && isa<ConstantInt>(EEI->getOperand(1));            })) diff --git a/lib/Transforms/Vectorize/LoopVectorize.cpp b/lib/Transforms/Vectorize/LoopVectorize.cpp index dac7032fa08f..595b2ec88943 100644 --- a/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -50,6 +50,7 @@  #include "llvm/ADT/DenseMap.h"  #include "llvm/ADT/Hashing.h"  #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/Optional.h"  #include "llvm/ADT/SCCIterator.h"  #include "llvm/ADT/SetVector.h"  #include "llvm/ADT/SmallPtrSet.h" @@ -92,6 +93,7 @@  #include "llvm/Transforms/Scalar.h"  #include "llvm/Transforms/Utils/BasicBlockUtils.h"  #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopSimplify.h"  #include "llvm/Transforms/Utils/LoopUtils.h"  #include "llvm/Transforms/Utils/LoopVersioning.h"  #include "llvm/Transforms/Vectorize.h" @@ -266,21 +268,6 @@ static bool hasCyclesInLoopBody(const Loop &L) {    return false;  } -/// \brief This modifies LoopAccessReport to initialize message with -/// loop-vectorizer-specific part. -class VectorizationReport : public LoopAccessReport { -public: -  VectorizationReport(Instruction *I = nullptr) -      : LoopAccessReport("loop not vectorized: ", I) {} - -  /// \brief This allows promotion of the loop-access analysis report into the -  /// loop-vectorizer report.  It modifies the message to add the -  /// loop-vectorizer-specific part of the message. -  explicit VectorizationReport(const LoopAccessReport &R) -      : LoopAccessReport(Twine("loop not vectorized: ") + R.str(), -                         R.getInstr()) {} -}; -  /// A helper function for converting Scalar types to vector types.  /// If the incoming type is void, we return void. If the VF is 1, we return  /// the scalar type. @@ -290,31 +277,9 @@ static Type *ToVectorTy(Type *Scalar, unsigned VF) {    return VectorType::get(Scalar, VF);  } -/// A helper function that returns GEP instruction and knows to skip a -/// 'bitcast'. The 'bitcast' may be skipped if the source and the destination -/// pointee types of the 'bitcast' have the same size. -/// For example: -///   bitcast double** %var to i64* - can be skipped -///   bitcast double** %var to i8*  - can not -static GetElementPtrInst *getGEPInstruction(Value *Ptr) { - -  if (isa<GetElementPtrInst>(Ptr)) -    return cast<GetElementPtrInst>(Ptr); - -  if (isa<BitCastInst>(Ptr) && -      isa<GetElementPtrInst>(cast<BitCastInst>(Ptr)->getOperand(0))) { -    Type *BitcastTy = Ptr->getType(); -    Type *GEPTy = cast<BitCastInst>(Ptr)->getSrcTy(); -    if (!isa<PointerType>(BitcastTy) || !isa<PointerType>(GEPTy)) -      return nullptr; -    Type *Pointee1Ty = cast<PointerType>(BitcastTy)->getPointerElementType(); -    Type *Pointee2Ty = cast<PointerType>(GEPTy)->getPointerElementType(); -    const DataLayout &DL = cast<BitCastInst>(Ptr)->getModule()->getDataLayout(); -    if (DL.getTypeSizeInBits(Pointee1Ty) == DL.getTypeSizeInBits(Pointee2Ty)) -      return cast<GetElementPtrInst>(cast<BitCastInst>(Ptr)->getOperand(0)); -  } -  return nullptr; -} +// FIXME: The following helper functions have multiple implementations +// in the project. They can be effectively organized in a common Load/Store +// utilities unit.  /// A helper function that returns the pointer operand of a load or store  /// instruction. @@ -326,6 +291,34 @@ static Value *getPointerOperand(Value *I) {    return nullptr;  } +/// A helper function that returns the type of loaded or stored value. +static Type *getMemInstValueType(Value *I) { +  assert((isa<LoadInst>(I) || isa<StoreInst>(I)) && +         "Expected Load or Store instruction"); +  if (auto *LI = dyn_cast<LoadInst>(I)) +    return LI->getType(); +  return cast<StoreInst>(I)->getValueOperand()->getType(); +} + +/// A helper function that returns the alignment of load or store instruction. +static unsigned getMemInstAlignment(Value *I) { +  assert((isa<LoadInst>(I) || isa<StoreInst>(I)) && +         "Expected Load or Store instruction"); +  if (auto *LI = dyn_cast<LoadInst>(I)) +    return LI->getAlignment(); +  return cast<StoreInst>(I)->getAlignment(); +} + +/// A helper function that returns the address space of the pointer operand of +/// load or store instruction. +static unsigned getMemInstAddressSpace(Value *I) { +  assert((isa<LoadInst>(I) || isa<StoreInst>(I)) && +         "Expected Load or Store instruction"); +  if (auto *LI = dyn_cast<LoadInst>(I)) +    return LI->getPointerAddressSpace(); +  return cast<StoreInst>(I)->getPointerAddressSpace(); +} +  /// A helper function that returns true if the given type is irregular. The  /// type is irregular if its allocated size doesn't equal the store size of an  /// element of the corresponding vector type at the given vectorization factor. @@ -351,6 +344,23 @@ static bool hasIrregularType(Type *Ty, const DataLayout &DL, unsigned VF) {  ///       we always assume predicated blocks have a 50% chance of executing.  static unsigned getReciprocalPredBlockProb() { return 2; } +/// A helper function that adds a 'fast' flag to floating-point operations. +static Value *addFastMathFlag(Value *V) { +  if (isa<FPMathOperator>(V)) { +    FastMathFlags Flags; +    Flags.setUnsafeAlgebra(); +    cast<Instruction>(V)->setFastMathFlags(Flags); +  } +  return V; +} + +/// A helper function that returns an integer or floating-point constant with +/// value C. +static Constant *getSignedIntOrFpConstant(Type *Ty, int64_t C) { +  return Ty->isIntegerTy() ? ConstantInt::getSigned(Ty, C) +                           : ConstantFP::get(Ty, C); +} +  /// InnerLoopVectorizer vectorizes loops which contain only one basic  /// block to a specified vectorization factor (VF).  /// This class performs the widening of scalars into vectors, or multiple @@ -428,10 +438,17 @@ protected:    /// Copy and widen the instructions from the old loop.    virtual void vectorizeLoop(); +  /// Handle all cross-iteration phis in the header. +  void fixCrossIterationPHIs(); +    /// Fix a first-order recurrence. This is the second phase of vectorizing    /// this phi node.    void fixFirstOrderRecurrence(PHINode *Phi); +  /// Fix a reduction cross-iteration phi. This is the second phase of +  /// vectorizing this phi node. +  void fixReduction(PHINode *Phi); +    /// \brief The Loop exit block may have single value PHI nodes where the    /// incoming value is 'Undef'. While vectorizing we only handled real values    /// that were defined inside the loop. Here we fix the 'undef case'. @@ -448,7 +465,8 @@ protected:    /// Collect the instructions from the original loop that would be trivially    /// dead in the vectorized loop if generated. -  void collectTriviallyDeadInstructions(); +  void collectTriviallyDeadInstructions( +      SmallPtrSetImpl<Instruction *> &DeadInstructions);    /// Shrinks vector element sizes to the smallest bitwidth they can be legally    /// represented as. @@ -462,14 +480,14 @@ protected:    /// and DST.    VectorParts createEdgeMask(BasicBlock *Src, BasicBlock *Dst); -  /// A helper function to vectorize a single BB within the innermost loop. -  void vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV); +  /// A helper function to vectorize a single instruction within the innermost +  /// loop. +  void vectorizeInstruction(Instruction &I);    /// Vectorize a single PHINode in a block. This method handles the induction    /// variable canonicalization. It supports both VF = 1 for unrolled loops and    /// arbitrary length vectors. -  void widenPHIInstruction(Instruction *PN, unsigned UF, unsigned VF, -                           PhiVector *PV); +  void widenPHIInstruction(Instruction *PN, unsigned UF, unsigned VF);    /// Insert the new loop to the loop hierarchy and pass manager    /// and update the analysis passes. @@ -504,20 +522,21 @@ protected:    /// \p EntryVal is the value from the original loop that maps to the steps.    /// Note that \p EntryVal doesn't have to be an induction variable (e.g., it    /// can be a truncate instruction). -  void buildScalarSteps(Value *ScalarIV, Value *Step, Value *EntryVal); - -  /// Create a vector induction phi node based on an existing scalar one. This -  /// currently only works for integer induction variables with a constant -  /// step. \p EntryVal is the value from the original loop that maps to the -  /// vector phi node. If \p EntryVal is a truncate instruction, instead of -  /// widening the original IV, we widen a version of the IV truncated to \p -  /// EntryVal's type. -  void createVectorIntInductionPHI(const InductionDescriptor &II, -                                   Instruction *EntryVal); - -  /// Widen an integer induction variable \p IV. If \p Trunc is provided, the -  /// induction variable will first be truncated to the corresponding type. -  void widenIntInduction(PHINode *IV, TruncInst *Trunc = nullptr); +  void buildScalarSteps(Value *ScalarIV, Value *Step, Value *EntryVal, +                        const InductionDescriptor &ID); + +  /// Create a vector induction phi node based on an existing scalar one. \p +  /// EntryVal is the value from the original loop that maps to the vector phi +  /// node, and \p Step is the loop-invariant step. If \p EntryVal is a +  /// truncate instruction, instead of widening the original IV, we widen a +  /// version of the IV truncated to \p EntryVal's type. +  void createVectorIntOrFpInductionPHI(const InductionDescriptor &II, +                                       Value *Step, Instruction *EntryVal); + +  /// Widen an integer or floating-point induction variable \p IV. If \p Trunc +  /// is provided, the integer induction variable will first be truncated to +  /// the corresponding type. +  void widenIntOrFpInduction(PHINode *IV, TruncInst *Trunc = nullptr);    /// Returns true if an instruction \p I should be scalarized instead of    /// vectorized for the chosen vectorization factor. @@ -583,6 +602,10 @@ protected:    /// vector of instructions.    void addMetadata(ArrayRef<Value *> To, Instruction *From); +  /// \brief Set the debug location in the builder using the debug location in +  /// the instruction. +  void setDebugLocFromInst(IRBuilder<> &B, const Value *Ptr); +    /// This is a helper class for maintaining vectorization state. It's used for    /// mapping values from the original loop to their corresponding values in    /// the new loop. Two mappings are maintained: one for vectorized values and @@ -777,14 +800,6 @@ protected:    // Record whether runtime checks are added.    bool AddedSafetyChecks; -  // Holds instructions from the original loop whose counterparts in the -  // vectorized loop would be trivially dead if generated. For example, -  // original induction update instructions can become dead because we -  // separately emit induction "steps" when generating code for the new loop. -  // Similarly, we create a new latch condition when setting up the structure -  // of the new loop, so the old one can become dead. -  SmallPtrSet<Instruction *, 4> DeadInstructions; -    // Holds the end values for each induction variable. We save the end values    // so we can later fix-up the external users of the induction variables.    DenseMap<PHINode *, Value *> IVEndValues; @@ -803,8 +818,6 @@ public:                              UnrollFactor, LVL, CM) {}  private: -  void scalarizeInstruction(Instruction *Instr, -                            bool IfPredicateInstr = false) override;    void vectorizeMemoryInstruction(Instruction *Instr) override;    Value *getBroadcastInstrs(Value *V) override;    Value *getStepVector(Value *Val, int StartIdx, Value *Step, @@ -832,12 +845,14 @@ static Instruction *getDebugLocFromInstOrOperands(Instruction *I) {    return I;  } -/// \brief Set the debug location in the builder using the debug location in the -/// instruction. -static void setDebugLocFromInst(IRBuilder<> &B, const Value *Ptr) { -  if (const Instruction *Inst = dyn_cast_or_null<Instruction>(Ptr)) -    B.SetCurrentDebugLocation(Inst->getDebugLoc()); -  else +void InnerLoopVectorizer::setDebugLocFromInst(IRBuilder<> &B, const Value *Ptr) { +  if (const Instruction *Inst = dyn_cast_or_null<Instruction>(Ptr)) { +    const DILocation *DIL = Inst->getDebugLoc(); +    if (DIL && Inst->getFunction()->isDebugInfoForProfiling()) +      B.SetCurrentDebugLocation(DIL->cloneWithDuplicationFactor(UF * VF)); +    else +      B.SetCurrentDebugLocation(DIL); +  } else      B.SetCurrentDebugLocation(DebugLoc());  } @@ -1497,14 +1512,6 @@ private:    OptimizationRemarkEmitter &ORE;  }; -static void emitAnalysisDiag(const Loop *TheLoop, -                             const LoopVectorizeHints &Hints, -                             OptimizationRemarkEmitter &ORE, -                             const LoopAccessReport &Message) { -  const char *Name = Hints.vectorizeAnalysisPassName(); -  LoopAccessReport::emitAnalysis(Message, TheLoop, Name, ORE); -} -  static void emitMissedWarning(Function *F, Loop *L,                                const LoopVectorizeHints &LH,                                OptimizationRemarkEmitter *ORE) { @@ -1512,13 +1519,17 @@ static void emitMissedWarning(Function *F, Loop *L,    if (LH.getForce() == LoopVectorizeHints::FK_Enabled) {      if (LH.getWidth() != 1) -      emitLoopVectorizeWarning( -          F->getContext(), *F, L->getStartLoc(), -          "failed explicitly specified loop vectorization"); +      ORE->emit(DiagnosticInfoOptimizationFailure( +                    DEBUG_TYPE, "FailedRequestedVectorization", +                    L->getStartLoc(), L->getHeader()) +                << "loop not vectorized: " +                << "failed explicitly specified loop vectorization");      else if (LH.getInterleave() != 1) -      emitLoopInterleaveWarning( -          F->getContext(), *F, L->getStartLoc(), -          "failed explicitly specified loop interleaving"); +      ORE->emit(DiagnosticInfoOptimizationFailure( +                    DEBUG_TYPE, "FailedRequestedInterleaving", L->getStartLoc(), +                    L->getHeader()) +                << "loop not interleaved: " +                << "failed explicitly specified loop interleaving");    }  } @@ -1546,7 +1557,7 @@ public:        LoopVectorizeHints *H)        : NumPredStores(0), TheLoop(L), PSE(PSE), TLI(TLI), TTI(TTI), DT(DT),          GetLAA(GetLAA), LAI(nullptr), ORE(ORE), InterleaveInfo(PSE, L, DT, LI), -        Induction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false), +        PrimaryInduction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false),          Requirements(R), Hints(H) {}    /// ReductionList contains the reduction descriptors for all @@ -1566,8 +1577,8 @@ public:    /// loop, only that it is legal to do so.    bool canVectorize(); -  /// Returns the Induction variable. -  PHINode *getInduction() { return Induction; } +  /// Returns the primary induction variable. +  PHINode *getPrimaryInduction() { return PrimaryInduction; }    /// Returns the reduction variables found in the loop.    ReductionList *getReductionVars() { return &Reductions; } @@ -1607,12 +1618,6 @@ public:    /// Returns true if the value V is uniform within the loop.    bool isUniform(Value *V); -  /// Returns true if \p I is known to be uniform after vectorization. -  bool isUniformAfterVectorization(Instruction *I) { return Uniforms.count(I); } - -  /// Returns true if \p I is known to be scalar after vectorization. -  bool isScalarAfterVectorization(Instruction *I) { return Scalars.count(I); } -    /// Returns the information that we collected about runtime memory check.    const RuntimePointerChecking *getRuntimePointerChecking() const {      return LAI->getRuntimePointerChecking(); @@ -1689,15 +1694,9 @@ public:    /// instructions that may divide by zero.    bool isScalarWithPredication(Instruction *I); -  /// Returns true if \p I is a memory instruction that has a consecutive or -  /// consecutive-like pointer operand. Consecutive-like pointers are pointers -  /// that are treated like consecutive pointers during vectorization. The -  /// pointer operands of interleaved accesses are an example. -  bool hasConsecutiveLikePtrOperand(Instruction *I); - -  /// Returns true if \p I is a memory instruction that must be scalarized -  /// during vectorization. -  bool memoryInstructionMustBeScalarized(Instruction *I, unsigned VF = 1); +  /// Returns true if \p I is a memory instruction with consecutive memory +  /// access that can be widened. +  bool memoryInstructionCanBeWidened(Instruction *I, unsigned VF = 1);  private:    /// Check if a single basic block loop is vectorizable. @@ -1715,24 +1714,6 @@ private:    /// transformation.    bool canVectorizeWithIfConvert(); -  /// Collect the instructions that are uniform after vectorization. An -  /// instruction is uniform if we represent it with a single scalar value in -  /// the vectorized loop corresponding to each vector iteration. Examples of -  /// uniform instructions include pointer operands of consecutive or -  /// interleaved memory accesses. Note that although uniformity implies an -  /// instruction will be scalar, the reverse is not true. In general, a -  /// scalarized instruction will be represented by VF scalar values in the -  /// vectorized loop, each corresponding to an iteration of the original -  /// scalar loop. -  void collectLoopUniforms(); - -  /// Collect the instructions that are scalar after vectorization. An -  /// instruction is scalar if it is known to be uniform or will be scalarized -  /// during vectorization. Non-uniform scalarized instructions will be -  /// represented by VF values in the vectorized loop, each corresponding to an -  /// iteration of the original scalar loop. -  void collectLoopScalars(); -    /// Return true if all of the instructions in the block can be speculatively    /// executed. \p SafePtrs is a list of addresses that are known to be legal    /// and we know that we can read from them without segfault. @@ -1744,14 +1725,6 @@ private:    void addInductionPhi(PHINode *Phi, const InductionDescriptor &ID,                         SmallPtrSetImpl<Value *> &AllowedExit); -  /// Report an analysis message to assist the user in diagnosing loops that are -  /// not vectorized.  These are handled as LoopAccessReport rather than -  /// VectorizationReport because the << operator of VectorizationReport returns -  /// LoopAccessReport. -  void emitAnalysis(const LoopAccessReport &Message) const { -    emitAnalysisDiag(TheLoop, *Hints, *ORE, Message); -  } -    /// Create an analysis remark that explains why vectorization failed    ///    /// \p RemarkName is the identifier for the remark.  If \p I is passed it is @@ -1804,9 +1777,9 @@ private:    //  ---  vectorization state --- // -  /// Holds the integer induction variable. This is the counter of the +  /// Holds the primary induction variable. This is the counter of the    /// loop. -  PHINode *Induction; +  PHINode *PrimaryInduction;    /// Holds the reduction variables.    ReductionList Reductions;    /// Holds all of the induction variables that we found in the loop. @@ -1822,12 +1795,6 @@ private:    /// vars which can be accessed from outside the loop.    SmallPtrSet<Value *, 4> AllowedExit; -  /// Holds the instructions known to be uniform after vectorization. -  SmallPtrSet<Instruction *, 4> Uniforms; - -  /// Holds the instructions known to be scalar after vectorization. -  SmallPtrSet<Instruction *, 4> Scalars; -    /// Can we assume the absence of NaNs.    bool HasFunNoNaNAttr; @@ -1861,16 +1828,26 @@ public:        : TheLoop(L), PSE(PSE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), DB(DB),          AC(AC), ORE(ORE), TheFunction(F), Hints(Hints) {} +  /// \return An upper bound for the vectorization factor, or None if +  /// vectorization should be avoided up front. +  Optional<unsigned> computeMaxVF(bool OptForSize); +    /// Information about vectorization costs    struct VectorizationFactor {      unsigned Width; // Vector width with best cost      unsigned Cost;  // Cost of the loop with that width    };    /// \return The most profitable vectorization factor and the cost of that VF. -  /// This method checks every power of two up to VF. If UserVF is not ZERO +  /// This method checks every power of two up to MaxVF. If UserVF is not ZERO    /// then this vectorization factor will be selected if vectorization is    /// possible. -  VectorizationFactor selectVectorizationFactor(bool OptForSize); +  VectorizationFactor selectVectorizationFactor(unsigned MaxVF); + +  /// Setup cost-based decisions for user vectorization factor. +  void selectUserVectorizationFactor(unsigned UserVF) { +    collectUniformsAndScalars(UserVF); +    collectInstsToScalarize(UserVF); +  }    /// \return The size (in bits) of the smallest and widest types in the code    /// that needs to be vectorized. We ignore values that remain scalar such as @@ -1884,6 +1861,15 @@ public:    unsigned selectInterleaveCount(bool OptForSize, unsigned VF,                                   unsigned LoopCost); +  /// Memory access instruction may be vectorized in more than one way. +  /// Form of instruction after vectorization depends on cost. +  /// This function takes cost-based decisions for Load/Store instructions +  /// and collects them in a map. This decisions map is used for building +  /// the lists of loop-uniform and loop-scalar instructions. +  /// The calculated cost is saved with widening decision in order to +  /// avoid redundant calculations. +  void setCostBasedWideningDecision(unsigned VF); +    /// \brief A struct that represents some properties of the register usage    /// of a loop.    struct RegisterUsage { @@ -1918,14 +1904,118 @@ public:      return Scalars->second.count(I);    } +  /// Returns true if \p I is known to be uniform after vectorization. +  bool isUniformAfterVectorization(Instruction *I, unsigned VF) const { +    if (VF == 1) +      return true; +    assert(Uniforms.count(VF) && "VF not yet analyzed for uniformity"); +    auto UniformsPerVF = Uniforms.find(VF); +    return UniformsPerVF->second.count(I); +  } + +  /// Returns true if \p I is known to be scalar after vectorization. +  bool isScalarAfterVectorization(Instruction *I, unsigned VF) const { +    if (VF == 1) +      return true; +    assert(Scalars.count(VF) && "Scalar values are not calculated for VF"); +    auto ScalarsPerVF = Scalars.find(VF); +    return ScalarsPerVF->second.count(I); +  } +    /// \returns True if instruction \p I can be truncated to a smaller bitwidth    /// for vectorization factor \p VF.    bool canTruncateToMinimalBitwidth(Instruction *I, unsigned VF) const {      return VF > 1 && MinBWs.count(I) && !isProfitableToScalarize(I, VF) && -           !Legal->isScalarAfterVectorization(I); +           !isScalarAfterVectorization(I, VF); +  } + +  /// Decision that was taken during cost calculation for memory instruction. +  enum InstWidening { +    CM_Unknown, +    CM_Widen, +    CM_Interleave, +    CM_GatherScatter, +    CM_Scalarize +  }; + +  /// Save vectorization decision \p W and \p Cost taken by the cost model for +  /// instruction \p I and vector width \p VF. +  void setWideningDecision(Instruction *I, unsigned VF, InstWidening W, +                           unsigned Cost) { +    assert(VF >= 2 && "Expected VF >=2"); +    WideningDecisions[std::make_pair(I, VF)] = std::make_pair(W, Cost); +  } + +  /// Save vectorization decision \p W and \p Cost taken by the cost model for +  /// interleaving group \p Grp and vector width \p VF. +  void setWideningDecision(const InterleaveGroup *Grp, unsigned VF, +                           InstWidening W, unsigned Cost) { +    assert(VF >= 2 && "Expected VF >=2"); +    /// Broadcast this decicion to all instructions inside the group. +    /// But the cost will be assigned to one instruction only. +    for (unsigned i = 0; i < Grp->getFactor(); ++i) { +      if (auto *I = Grp->getMember(i)) { +        if (Grp->getInsertPos() == I) +          WideningDecisions[std::make_pair(I, VF)] = std::make_pair(W, Cost); +        else +          WideningDecisions[std::make_pair(I, VF)] = std::make_pair(W, 0); +      } +    } +  } + +  /// Return the cost model decision for the given instruction \p I and vector +  /// width \p VF. Return CM_Unknown if this instruction did not pass +  /// through the cost modeling. +  InstWidening getWideningDecision(Instruction *I, unsigned VF) { +    assert(VF >= 2 && "Expected VF >=2"); +    std::pair<Instruction *, unsigned> InstOnVF = std::make_pair(I, VF); +    auto Itr = WideningDecisions.find(InstOnVF); +    if (Itr == WideningDecisions.end()) +      return CM_Unknown; +    return Itr->second.first; +  } + +  /// Return the vectorization cost for the given instruction \p I and vector +  /// width \p VF. +  unsigned getWideningCost(Instruction *I, unsigned VF) { +    assert(VF >= 2 && "Expected VF >=2"); +    std::pair<Instruction *, unsigned> InstOnVF = std::make_pair(I, VF); +    assert(WideningDecisions.count(InstOnVF) && "The cost is not calculated"); +    return WideningDecisions[InstOnVF].second; +  } + +  /// Return True if instruction \p I is an optimizable truncate whose operand +  /// is an induction variable. Such a truncate will be removed by adding a new +  /// induction variable with the destination type. +  bool isOptimizableIVTruncate(Instruction *I, unsigned VF) { + +    // If the instruction is not a truncate, return false. +    auto *Trunc = dyn_cast<TruncInst>(I); +    if (!Trunc) +      return false; + +    // Get the source and destination types of the truncate. +    Type *SrcTy = ToVectorTy(cast<CastInst>(I)->getSrcTy(), VF); +    Type *DestTy = ToVectorTy(cast<CastInst>(I)->getDestTy(), VF); + +    // If the truncate is free for the given types, return false. Replacing a +    // free truncate with an induction variable would add an induction variable +    // update instruction to each iteration of the loop. We exclude from this +    // check the primary induction variable since it will need an update +    // instruction regardless. +    Value *Op = Trunc->getOperand(0); +    if (Op != Legal->getPrimaryInduction() && TTI.isTruncateFree(SrcTy, DestTy)) +      return false; + +    // If the truncated value is not an induction variable, return false. +    return Legal->isInductionVariable(Op);    }  private: +  /// \return An upper bound for the vectorization factor, larger than zero. +  /// One is returned if vectorization should best be avoided due to cost. +  unsigned computeFeasibleMaxVF(bool OptForSize); +    /// The vectorization cost is a combination of the cost itself and a boolean    /// indicating whether any of the contributing operations will actually    /// operate on @@ -1949,6 +2039,26 @@ private:    /// the vector type as an output parameter.    unsigned getInstructionCost(Instruction *I, unsigned VF, Type *&VectorTy); +  /// Calculate vectorization cost of memory instruction \p I. +  unsigned getMemoryInstructionCost(Instruction *I, unsigned VF); + +  /// The cost computation for scalarized memory instruction. +  unsigned getMemInstScalarizationCost(Instruction *I, unsigned VF); + +  /// The cost computation for interleaving group of memory instructions. +  unsigned getInterleaveGroupCost(Instruction *I, unsigned VF); + +  /// The cost computation for Gather/Scatter instruction. +  unsigned getGatherScatterCost(Instruction *I, unsigned VF); + +  /// The cost computation for widening instruction \p I with consecutive +  /// memory access. +  unsigned getConsecutiveMemOpCost(Instruction *I, unsigned VF); + +  /// The cost calculation for Load instruction \p I with uniform pointer - +  /// scalar load + broadcast. +  unsigned getUniformMemOpCost(Instruction *I, unsigned VF); +    /// Returns whether the instruction is a load or store and will be a emitted    /// as a vector operation.    bool isConsecutiveLoadOrStore(Instruction *I); @@ -1972,12 +2082,24 @@ private:    /// pairs.    typedef DenseMap<Instruction *, unsigned> ScalarCostsTy; +  /// A set containing all BasicBlocks that are known to present after +  /// vectorization as a predicated block. +  SmallPtrSet<BasicBlock *, 4> PredicatedBBsAfterVectorization; +    /// A map holding scalar costs for different vectorization factors. The    /// presence of a cost for an instruction in the mapping indicates that the    /// instruction will be scalarized when vectorizing with the associated    /// vectorization factor. The entries are VF-ScalarCostTy pairs.    DenseMap<unsigned, ScalarCostsTy> InstsToScalarize; +  /// Holds the instructions known to be uniform after vectorization. +  /// The data is collected per VF. +  DenseMap<unsigned, SmallPtrSet<Instruction *, 4>> Uniforms; + +  /// Holds the instructions known to be scalar after vectorization. +  /// The data is collected per VF. +  DenseMap<unsigned, SmallPtrSet<Instruction *, 4>> Scalars; +    /// Returns the expected difference in cost from scalarizing the expression    /// feeding a predicated instruction \p PredInst. The instructions to    /// scalarize and their scalar costs are collected in \p ScalarCosts. A @@ -1990,6 +2112,44 @@ private:    /// the loop.    void collectInstsToScalarize(unsigned VF); +  /// Collect the instructions that are uniform after vectorization. An +  /// instruction is uniform if we represent it with a single scalar value in +  /// the vectorized loop corresponding to each vector iteration. Examples of +  /// uniform instructions include pointer operands of consecutive or +  /// interleaved memory accesses. Note that although uniformity implies an +  /// instruction will be scalar, the reverse is not true. In general, a +  /// scalarized instruction will be represented by VF scalar values in the +  /// vectorized loop, each corresponding to an iteration of the original +  /// scalar loop. +  void collectLoopUniforms(unsigned VF); + +  /// Collect the instructions that are scalar after vectorization. An +  /// instruction is scalar if it is known to be uniform or will be scalarized +  /// during vectorization. Non-uniform scalarized instructions will be +  /// represented by VF values in the vectorized loop, each corresponding to an +  /// iteration of the original scalar loop. +  void collectLoopScalars(unsigned VF); + +  /// Collect Uniform and Scalar values for the given \p VF. +  /// The sets depend on CM decision for Load/Store instructions +  /// that may be vectorized as interleave, gather-scatter or scalarized. +  void collectUniformsAndScalars(unsigned VF) { +    // Do the analysis once. +    if (VF == 1 || Uniforms.count(VF)) +      return; +    setCostBasedWideningDecision(VF); +    collectLoopUniforms(VF); +    collectLoopScalars(VF); +  } + +  /// Keeps cost model vectorization decision and cost for instructions. +  /// Right now it is used for memory instructions only. +  typedef DenseMap<std::pair<Instruction *, unsigned>, +                   std::pair<InstWidening, unsigned>> +      DecisionList; + +  DecisionList WideningDecisions; +  public:    /// The loop that we evaluate.    Loop *TheLoop; @@ -2019,6 +2179,23 @@ public:    SmallPtrSet<const Value *, 16> VecValuesToIgnore;  }; +/// LoopVectorizationPlanner - drives the vectorization process after having +/// passed Legality checks. +class LoopVectorizationPlanner { +public: +  LoopVectorizationPlanner(LoopVectorizationCostModel &CM) : CM(CM) {} + +  ~LoopVectorizationPlanner() {} + +  /// Plan how to best vectorize, return the best VF and its cost. +  LoopVectorizationCostModel::VectorizationFactor plan(bool OptForSize, +                                                       unsigned UserVF); + +private: +  /// The profitablity analysis. +  LoopVectorizationCostModel &CM; +}; +  /// \brief This holds vectorization requirements that must be verified late in  /// the process. The requirements are set by legalize and costmodel. Once  /// vectorization has been determined to be possible and profitable the @@ -2134,8 +2311,6 @@ struct LoopVectorize : public FunctionPass {    void getAnalysisUsage(AnalysisUsage &AU) const override {      AU.addRequired<AssumptionCacheTracker>(); -    AU.addRequiredID(LoopSimplifyID); -    AU.addRequiredID(LCSSAID);      AU.addRequired<BlockFrequencyInfoWrapperPass>();      AU.addRequired<DominatorTreeWrapperPass>();      AU.addRequired<LoopInfoWrapperPass>(); @@ -2156,7 +2331,7 @@ struct LoopVectorize : public FunctionPass {  //===----------------------------------------------------------------------===//  // Implementation of LoopVectorizationLegality, InnerLoopVectorizer and -// LoopVectorizationCostModel. +// LoopVectorizationCostModel and LoopVectorizationPlanner.  //===----------------------------------------------------------------------===//  Value *InnerLoopVectorizer::getBroadcastInstrs(Value *V) { @@ -2176,27 +2351,51 @@ Value *InnerLoopVectorizer::getBroadcastInstrs(Value *V) {    return Shuf;  } -void InnerLoopVectorizer::createVectorIntInductionPHI( -    const InductionDescriptor &II, Instruction *EntryVal) { +void InnerLoopVectorizer::createVectorIntOrFpInductionPHI( +    const InductionDescriptor &II, Value *Step, Instruction *EntryVal) {    Value *Start = II.getStartValue(); -  ConstantInt *Step = II.getConstIntStepValue(); -  assert(Step && "Can not widen an IV with a non-constant step");    // Construct the initial value of the vector IV in the vector loop preheader    auto CurrIP = Builder.saveIP();    Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator());    if (isa<TruncInst>(EntryVal)) { +    assert(Start->getType()->isIntegerTy() && +           "Truncation requires an integer type");      auto *TruncType = cast<IntegerType>(EntryVal->getType()); -    Step = ConstantInt::getSigned(TruncType, Step->getSExtValue()); +    Step = Builder.CreateTrunc(Step, TruncType);      Start = Builder.CreateCast(Instruction::Trunc, Start, TruncType);    }    Value *SplatStart = Builder.CreateVectorSplat(VF, Start); -  Value *SteppedStart = getStepVector(SplatStart, 0, Step); +  Value *SteppedStart = +      getStepVector(SplatStart, 0, Step, II.getInductionOpcode()); + +  // We create vector phi nodes for both integer and floating-point induction +  // variables. Here, we determine the kind of arithmetic we will perform. +  Instruction::BinaryOps AddOp; +  Instruction::BinaryOps MulOp; +  if (Step->getType()->isIntegerTy()) { +    AddOp = Instruction::Add; +    MulOp = Instruction::Mul; +  } else { +    AddOp = II.getInductionOpcode(); +    MulOp = Instruction::FMul; +  } + +  // Multiply the vectorization factor by the step using integer or +  // floating-point arithmetic as appropriate. +  Value *ConstVF = getSignedIntOrFpConstant(Step->getType(), VF); +  Value *Mul = addFastMathFlag(Builder.CreateBinOp(MulOp, Step, ConstVF)); + +  // Create a vector splat to use in the induction update. +  // +  // FIXME: If the step is non-constant, we create the vector splat with +  //        IRBuilder. IRBuilder can constant-fold the multiply, but it doesn't +  //        handle a constant vector splat. +  Value *SplatVF = isa<Constant>(Mul) +                       ? ConstantVector::getSplat(VF, cast<Constant>(Mul)) +                       : Builder.CreateVectorSplat(VF, Mul);    Builder.restoreIP(CurrIP); -  Value *SplatVF = -      ConstantVector::getSplat(VF, ConstantInt::getSigned(Start->getType(), -                               VF * Step->getSExtValue()));    // We may need to add the step a number of times, depending on the unroll    // factor. The last of those goes into the PHI.    PHINode *VecInd = PHINode::Create(SteppedStart->getType(), 2, "vec.ind", @@ -2205,8 +2404,8 @@ void InnerLoopVectorizer::createVectorIntInductionPHI(    VectorParts Entry(UF);    for (unsigned Part = 0; Part < UF; ++Part) {      Entry[Part] = LastInduction; -    LastInduction = cast<Instruction>( -        Builder.CreateAdd(LastInduction, SplatVF, "step.add")); +    LastInduction = cast<Instruction>(addFastMathFlag( +        Builder.CreateBinOp(AddOp, LastInduction, SplatVF, "step.add")));    }    VectorLoopValueMap.initVector(EntryVal, Entry);    if (isa<TruncInst>(EntryVal)) @@ -2225,7 +2424,7 @@ void InnerLoopVectorizer::createVectorIntInductionPHI(  }  bool InnerLoopVectorizer::shouldScalarizeInstruction(Instruction *I) const { -  return Legal->isScalarAfterVectorization(I) || +  return Cost->isScalarAfterVectorization(I, VF) ||           Cost->isProfitableToScalarize(I, VF);  } @@ -2239,7 +2438,10 @@ bool InnerLoopVectorizer::needsScalarInduction(Instruction *IV) const {    return any_of(IV->users(), isScalarInst);  } -void InnerLoopVectorizer::widenIntInduction(PHINode *IV, TruncInst *Trunc) { +void InnerLoopVectorizer::widenIntOrFpInduction(PHINode *IV, TruncInst *Trunc) { + +  assert((IV->getType()->isIntegerTy() || IV != OldInduction) && +         "Primary induction variable must have an integer type");    auto II = Legal->getInductionVars()->find(IV);    assert(II != Legal->getInductionVars()->end() && "IV is not an induction"); @@ -2251,9 +2453,6 @@ void InnerLoopVectorizer::widenIntInduction(PHINode *IV, TruncInst *Trunc) {    // induction variable.    Value *ScalarIV = nullptr; -  // The step of the induction. -  Value *Step = nullptr; -    // The value from the original loop to which we are mapping the new induction    // variable.    Instruction *EntryVal = Trunc ? cast<Instruction>(Trunc) : IV; @@ -2266,45 +2465,49 @@ void InnerLoopVectorizer::widenIntInduction(PHINode *IV, TruncInst *Trunc) {    // least one user in the loop that is not widened.    auto NeedsScalarIV = VF > 1 && needsScalarInduction(EntryVal); -  // If the induction variable has a constant integer step value, go ahead and -  // get it now. -  if (ID.getConstIntStepValue()) -    Step = ID.getConstIntStepValue(); +  // Generate code for the induction step. Note that induction steps are +  // required to be loop-invariant +  assert(PSE.getSE()->isLoopInvariant(ID.getStep(), OrigLoop) && +         "Induction step should be loop invariant"); +  auto &DL = OrigLoop->getHeader()->getModule()->getDataLayout(); +  Value *Step = nullptr; +  if (PSE.getSE()->isSCEVable(IV->getType())) { +    SCEVExpander Exp(*PSE.getSE(), DL, "induction"); +    Step = Exp.expandCodeFor(ID.getStep(), ID.getStep()->getType(), +                             LoopVectorPreHeader->getTerminator()); +  } else { +    Step = cast<SCEVUnknown>(ID.getStep())->getValue(); +  }    // Try to create a new independent vector induction variable. If we can't    // create the phi node, we will splat the scalar induction variable in each    // loop iteration. -  if (VF > 1 && IV->getType() == Induction->getType() && Step && -      !shouldScalarizeInstruction(EntryVal)) { -    createVectorIntInductionPHI(ID, EntryVal); +  if (VF > 1 && !shouldScalarizeInstruction(EntryVal)) { +    createVectorIntOrFpInductionPHI(ID, Step, EntryVal);      VectorizedIV = true;    }    // If we haven't yet vectorized the induction variable, or if we will create    // a scalar one, we need to define the scalar induction variable and step    // values. If we were given a truncation type, truncate the canonical -  // induction variable and constant step. Otherwise, derive these values from -  // the induction descriptor. +  // induction variable and step. Otherwise, derive these values from the +  // induction descriptor.    if (!VectorizedIV || NeedsScalarIV) { +    ScalarIV = Induction; +    if (IV != OldInduction) { +      ScalarIV = IV->getType()->isIntegerTy() +                     ? Builder.CreateSExtOrTrunc(Induction, IV->getType()) +                     : Builder.CreateCast(Instruction::SIToFP, Induction, +                                          IV->getType()); +      ScalarIV = ID.transform(Builder, ScalarIV, PSE.getSE(), DL); +      ScalarIV->setName("offset.idx"); +    }      if (Trunc) {        auto *TruncType = cast<IntegerType>(Trunc->getType()); -      assert(Step && "Truncation requires constant integer step"); -      auto StepInt = cast<ConstantInt>(Step)->getSExtValue(); -      ScalarIV = Builder.CreateCast(Instruction::Trunc, Induction, TruncType); -      Step = ConstantInt::getSigned(TruncType, StepInt); -    } else { -      ScalarIV = Induction; -      auto &DL = OrigLoop->getHeader()->getModule()->getDataLayout(); -      if (IV != OldInduction) { -        ScalarIV = Builder.CreateSExtOrTrunc(ScalarIV, IV->getType()); -        ScalarIV = ID.transform(Builder, ScalarIV, PSE.getSE(), DL); -        ScalarIV->setName("offset.idx"); -      } -      if (!Step) { -        SCEVExpander Exp(*PSE.getSE(), DL, "induction"); -        Step = Exp.expandCodeFor(ID.getStep(), ID.getStep()->getType(), -                                 &*Builder.GetInsertPoint()); -      } +      assert(Step->getType()->isIntegerTy() && +             "Truncation requires an integer step"); +      ScalarIV = Builder.CreateTrunc(ScalarIV, TruncType); +      Step = Builder.CreateTrunc(Step, TruncType);      }    } @@ -2314,7 +2517,8 @@ void InnerLoopVectorizer::widenIntInduction(PHINode *IV, TruncInst *Trunc) {      Value *Broadcasted = getBroadcastInstrs(ScalarIV);      VectorParts Entry(UF);      for (unsigned Part = 0; Part < UF; ++Part) -      Entry[Part] = getStepVector(Broadcasted, VF * Part, Step); +      Entry[Part] = +          getStepVector(Broadcasted, VF * Part, Step, ID.getInductionOpcode());      VectorLoopValueMap.initVector(EntryVal, Entry);      if (Trunc)        addMetadata(Entry, Trunc); @@ -2327,7 +2531,7 @@ void InnerLoopVectorizer::widenIntInduction(PHINode *IV, TruncInst *Trunc) {    // in the loop in the common case prior to InstCombine. We will be trading    // one vector extract for each scalar step.    if (NeedsScalarIV) -    buildScalarSteps(ScalarIV, Step, EntryVal); +    buildScalarSteps(ScalarIV, Step, EntryVal, ID);  }  Value *InnerLoopVectorizer::getStepVector(Value *Val, int StartIdx, Value *Step, @@ -2387,30 +2591,43 @@ Value *InnerLoopVectorizer::getStepVector(Value *Val, int StartIdx, Value *Step,  }  void InnerLoopVectorizer::buildScalarSteps(Value *ScalarIV, Value *Step, -                                           Value *EntryVal) { +                                           Value *EntryVal, +                                           const InductionDescriptor &ID) {    // We shouldn't have to build scalar steps if we aren't vectorizing.    assert(VF > 1 && "VF should be greater than one");    // Get the value type and ensure it and the step have the same integer type.    Type *ScalarIVTy = ScalarIV->getType()->getScalarType(); -  assert(ScalarIVTy->isIntegerTy() && ScalarIVTy == Step->getType() && -         "Val and Step should have the same integer type"); +  assert(ScalarIVTy == Step->getType() && +         "Val and Step should have the same type"); + +  // We build scalar steps for both integer and floating-point induction +  // variables. Here, we determine the kind of arithmetic we will perform. +  Instruction::BinaryOps AddOp; +  Instruction::BinaryOps MulOp; +  if (ScalarIVTy->isIntegerTy()) { +    AddOp = Instruction::Add; +    MulOp = Instruction::Mul; +  } else { +    AddOp = ID.getInductionOpcode(); +    MulOp = Instruction::FMul; +  }    // Determine the number of scalars we need to generate for each unroll    // iteration. If EntryVal is uniform, we only need to generate the first    // lane. Otherwise, we generate all VF values.    unsigned Lanes = -      Legal->isUniformAfterVectorization(cast<Instruction>(EntryVal)) ? 1 : VF; +    Cost->isUniformAfterVectorization(cast<Instruction>(EntryVal), VF) ? 1 : VF;    // Compute the scalar steps and save the results in VectorLoopValueMap.    ScalarParts Entry(UF);    for (unsigned Part = 0; Part < UF; ++Part) {      Entry[Part].resize(VF);      for (unsigned Lane = 0; Lane < Lanes; ++Lane) { -      auto *StartIdx = ConstantInt::get(ScalarIVTy, VF * Part + Lane); -      auto *Mul = Builder.CreateMul(StartIdx, Step); -      auto *Add = Builder.CreateAdd(ScalarIV, Mul); +      auto *StartIdx = getSignedIntOrFpConstant(ScalarIVTy, VF * Part + Lane); +      auto *Mul = addFastMathFlag(Builder.CreateBinOp(MulOp, StartIdx, Step)); +      auto *Add = addFastMathFlag(Builder.CreateBinOp(AddOp, ScalarIV, Mul));        Entry[Part][Lane] = Add;      }    } @@ -2469,7 +2686,7 @@ InnerLoopVectorizer::getVectorValue(Value *V) {      // known to be uniform after vectorization, this corresponds to lane zero      // of the last unroll iteration. Otherwise, the last instruction is the one      // we created for the last vector lane of the last unroll iteration. -    unsigned LastLane = Legal->isUniformAfterVectorization(I) ? 0 : VF - 1; +    unsigned LastLane = Cost->isUniformAfterVectorization(I, VF) ? 0 : VF - 1;      auto *LastInst = cast<Instruction>(getScalarValue(V, UF - 1, LastLane));      // Set the insert point after the last scalarized instruction. This ensures @@ -2486,7 +2703,7 @@ InnerLoopVectorizer::getVectorValue(Value *V) {      // VectorLoopValueMap, we will only generate the insertelements once.      for (unsigned Part = 0; Part < UF; ++Part) {        Value *VectorValue = nullptr; -      if (Legal->isUniformAfterVectorization(I)) { +      if (Cost->isUniformAfterVectorization(I, VF)) {          VectorValue = getBroadcastInstrs(getScalarValue(V, Part, 0));        } else {          VectorValue = UndefValue::get(VectorType::get(V->getType(), VF)); @@ -2515,8 +2732,9 @@ Value *InnerLoopVectorizer::getScalarValue(Value *V, unsigned Part,    if (OrigLoop->isLoopInvariant(V))      return V; -  assert(Lane > 0 ? !Legal->isUniformAfterVectorization(cast<Instruction>(V)) -                  : true && "Uniform values only have lane zero"); +  assert(Lane > 0 ? +         !Cost->isUniformAfterVectorization(cast<Instruction>(V), VF) +         : true && "Uniform values only have lane zero");    // If the value from the original loop has not been vectorized, it is    // represented by UF x VF scalar values in the new loop. Return the requested @@ -2551,102 +2769,6 @@ Value *InnerLoopVectorizer::reverseVector(Value *Vec) {                                       "reverse");  } -// Get a mask to interleave \p NumVec vectors into a wide vector. -// I.e.  <0, VF, VF*2, ..., VF*(NumVec-1), 1, VF+1, VF*2+1, ...> -// E.g. For 2 interleaved vectors, if VF is 4, the mask is: -//      <0, 4, 1, 5, 2, 6, 3, 7> -static Constant *getInterleavedMask(IRBuilder<> &Builder, unsigned VF, -                                    unsigned NumVec) { -  SmallVector<Constant *, 16> Mask; -  for (unsigned i = 0; i < VF; i++) -    for (unsigned j = 0; j < NumVec; j++) -      Mask.push_back(Builder.getInt32(j * VF + i)); - -  return ConstantVector::get(Mask); -} - -// Get the strided mask starting from index \p Start. -// I.e.  <Start, Start + Stride, ..., Start + Stride*(VF-1)> -static Constant *getStridedMask(IRBuilder<> &Builder, unsigned Start, -                                unsigned Stride, unsigned VF) { -  SmallVector<Constant *, 16> Mask; -  for (unsigned i = 0; i < VF; i++) -    Mask.push_back(Builder.getInt32(Start + i * Stride)); - -  return ConstantVector::get(Mask); -} - -// Get a mask of two parts: The first part consists of sequential integers -// starting from 0, The second part consists of UNDEFs. -// I.e. <0, 1, 2, ..., NumInt - 1, undef, ..., undef> -static Constant *getSequentialMask(IRBuilder<> &Builder, unsigned NumInt, -                                   unsigned NumUndef) { -  SmallVector<Constant *, 16> Mask; -  for (unsigned i = 0; i < NumInt; i++) -    Mask.push_back(Builder.getInt32(i)); - -  Constant *Undef = UndefValue::get(Builder.getInt32Ty()); -  for (unsigned i = 0; i < NumUndef; i++) -    Mask.push_back(Undef); - -  return ConstantVector::get(Mask); -} - -// Concatenate two vectors with the same element type. The 2nd vector should -// not have more elements than the 1st vector. If the 2nd vector has less -// elements, extend it with UNDEFs. -static Value *ConcatenateTwoVectors(IRBuilder<> &Builder, Value *V1, -                                    Value *V2) { -  VectorType *VecTy1 = dyn_cast<VectorType>(V1->getType()); -  VectorType *VecTy2 = dyn_cast<VectorType>(V2->getType()); -  assert(VecTy1 && VecTy2 && -         VecTy1->getScalarType() == VecTy2->getScalarType() && -         "Expect two vectors with the same element type"); - -  unsigned NumElts1 = VecTy1->getNumElements(); -  unsigned NumElts2 = VecTy2->getNumElements(); -  assert(NumElts1 >= NumElts2 && "Unexpect the first vector has less elements"); - -  if (NumElts1 > NumElts2) { -    // Extend with UNDEFs. -    Constant *ExtMask = -        getSequentialMask(Builder, NumElts2, NumElts1 - NumElts2); -    V2 = Builder.CreateShuffleVector(V2, UndefValue::get(VecTy2), ExtMask); -  } - -  Constant *Mask = getSequentialMask(Builder, NumElts1 + NumElts2, 0); -  return Builder.CreateShuffleVector(V1, V2, Mask); -} - -// Concatenate vectors in the given list. All vectors have the same type. -static Value *ConcatenateVectors(IRBuilder<> &Builder, -                                 ArrayRef<Value *> InputList) { -  unsigned NumVec = InputList.size(); -  assert(NumVec > 1 && "Should be at least two vectors"); - -  SmallVector<Value *, 8> ResList; -  ResList.append(InputList.begin(), InputList.end()); -  do { -    SmallVector<Value *, 8> TmpList; -    for (unsigned i = 0; i < NumVec - 1; i += 2) { -      Value *V0 = ResList[i], *V1 = ResList[i + 1]; -      assert((V0->getType() == V1->getType() || i == NumVec - 2) && -             "Only the last vector may have a different type"); - -      TmpList.push_back(ConcatenateTwoVectors(Builder, V0, V1)); -    } - -    // Push the last vector if the total number of vectors is odd. -    if (NumVec % 2 != 0) -      TmpList.push_back(ResList[NumVec - 1]); - -    ResList = TmpList; -    NumVec = ResList.size(); -  } while (NumVec > 1); - -  return ResList[0]; -} -  // Try to vectorize the interleave group that \p Instr belongs to.  //  // E.g. Translate following interleaved load group (factor = 3): @@ -2683,15 +2805,13 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) {    if (Instr != Group->getInsertPos())      return; -  LoadInst *LI = dyn_cast<LoadInst>(Instr); -  StoreInst *SI = dyn_cast<StoreInst>(Instr);    Value *Ptr = getPointerOperand(Instr);    // Prepare for the vector type of the interleaved load/store. -  Type *ScalarTy = LI ? LI->getType() : SI->getValueOperand()->getType(); +  Type *ScalarTy = getMemInstValueType(Instr);    unsigned InterleaveFactor = Group->getFactor();    Type *VecTy = VectorType::get(ScalarTy, InterleaveFactor * VF); -  Type *PtrTy = VecTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); +  Type *PtrTy = VecTy->getPointerTo(getMemInstAddressSpace(Instr));    // Prepare for the new pointers.    setDebugLocFromInst(Builder, Ptr); @@ -2731,7 +2851,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) {    Value *UndefVec = UndefValue::get(VecTy);    // Vectorize the interleaved load group. -  if (LI) { +  if (isa<LoadInst>(Instr)) {      // For each unroll part, create a wide load for the group.      SmallVector<Value *, 2> NewLoads; @@ -2752,7 +2872,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) {          continue;        VectorParts Entry(UF); -      Constant *StrideMask = getStridedMask(Builder, I, InterleaveFactor, VF); +      Constant *StrideMask = createStrideMask(Builder, I, InterleaveFactor, VF);        for (unsigned Part = 0; Part < UF; Part++) {          Value *StridedVec = Builder.CreateShuffleVector(              NewLoads[Part], UndefVec, StrideMask, "strided.vec"); @@ -2796,10 +2916,10 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) {      }      // Concatenate all vectors into a wide vector. -    Value *WideVec = ConcatenateVectors(Builder, StoredVecs); +    Value *WideVec = concatenateVectors(Builder, StoredVecs);      // Interleave the elements in the wide vector. -    Constant *IMask = getInterleavedMask(Builder, VF, InterleaveFactor); +    Constant *IMask = createInterleaveMask(Builder, VF, InterleaveFactor);      Value *IVec = Builder.CreateShuffleVector(WideVec, UndefVec, IMask,                                                "interleaved.vec"); @@ -2816,103 +2936,44 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) {    assert((LI || SI) && "Invalid Load/Store instruction"); -  // Try to vectorize the interleave group if this access is interleaved. -  if (Legal->isAccessInterleaved(Instr)) +  LoopVectorizationCostModel::InstWidening Decision = +      Cost->getWideningDecision(Instr, VF); +  assert(Decision != LoopVectorizationCostModel::CM_Unknown && +         "CM decision should be taken at this point"); +  if (Decision == LoopVectorizationCostModel::CM_Interleave)      return vectorizeInterleaveGroup(Instr); -  Type *ScalarDataTy = LI ? LI->getType() : SI->getValueOperand()->getType(); +  Type *ScalarDataTy = getMemInstValueType(Instr);    Type *DataTy = VectorType::get(ScalarDataTy, VF);    Value *Ptr = getPointerOperand(Instr); -  unsigned Alignment = LI ? LI->getAlignment() : SI->getAlignment(); +  unsigned Alignment = getMemInstAlignment(Instr);    // An alignment of 0 means target abi alignment. We need to use the scalar's    // target abi alignment in such a case.    const DataLayout &DL = Instr->getModule()->getDataLayout();    if (!Alignment)      Alignment = DL.getABITypeAlignment(ScalarDataTy); -  unsigned AddressSpace = Ptr->getType()->getPointerAddressSpace(); +  unsigned AddressSpace = getMemInstAddressSpace(Instr);    // Scalarize the memory instruction if necessary. -  if (Legal->memoryInstructionMustBeScalarized(Instr, VF)) +  if (Decision == LoopVectorizationCostModel::CM_Scalarize)      return scalarizeInstruction(Instr, Legal->isScalarWithPredication(Instr));    // Determine if the pointer operand of the access is either consecutive or    // reverse consecutive.    int ConsecutiveStride = Legal->isConsecutivePtr(Ptr);    bool Reverse = ConsecutiveStride < 0; - -  // Determine if either a gather or scatter operation is legal.    bool CreateGatherScatter = -      !ConsecutiveStride && Legal->isLegalGatherOrScatter(Instr); +      (Decision == LoopVectorizationCostModel::CM_GatherScatter);    VectorParts VectorGep;    // Handle consecutive loads/stores. -  GetElementPtrInst *Gep = getGEPInstruction(Ptr);    if (ConsecutiveStride) { -    if (Gep) { -      unsigned NumOperands = Gep->getNumOperands(); -#ifndef NDEBUG -      // The original GEP that identified as a consecutive memory access -      // should have only one loop-variant operand. -      unsigned NumOfLoopVariantOps = 0; -      for (unsigned i = 0; i < NumOperands; ++i) -        if (!PSE.getSE()->isLoopInvariant(PSE.getSCEV(Gep->getOperand(i)), -                                          OrigLoop)) -          NumOfLoopVariantOps++; -      assert(NumOfLoopVariantOps == 1 && -             "Consecutive GEP should have only one loop-variant operand"); -#endif -      GetElementPtrInst *Gep2 = cast<GetElementPtrInst>(Gep->clone()); -      Gep2->setName("gep.indvar"); - -      // A new GEP is created for a 0-lane value of the first unroll iteration. -      // The GEPs for the rest of the unroll iterations are computed below as an -      // offset from this GEP. -      for (unsigned i = 0; i < NumOperands; ++i) -        // We can apply getScalarValue() for all GEP indices. It returns an -        // original value for loop-invariant operand and 0-lane for consecutive -        // operand. -        Gep2->setOperand(i, getScalarValue(Gep->getOperand(i), -                                           0, /* First unroll iteration */ -                                           0  /* 0-lane of the vector */ )); -      setDebugLocFromInst(Builder, Gep); -      Ptr = Builder.Insert(Gep2); - -    } else { // No GEP -      setDebugLocFromInst(Builder, Ptr); -      Ptr = getScalarValue(Ptr, 0, 0); -    } +    Ptr = getScalarValue(Ptr, 0, 0);    } else {      // At this point we should vector version of GEP for Gather or Scatter      assert(CreateGatherScatter && "The instruction should be scalarized"); -    if (Gep) { -      // Vectorizing GEP, across UF parts. We want to get a vector value for base -      // and each index that's defined inside the loop, even if it is -      // loop-invariant but wasn't hoisted out. Otherwise we want to keep them -      // scalar. -      SmallVector<VectorParts, 4> OpsV; -      for (Value *Op : Gep->operands()) { -        Instruction *SrcInst = dyn_cast<Instruction>(Op); -        if (SrcInst && OrigLoop->contains(SrcInst)) -          OpsV.push_back(getVectorValue(Op)); -        else -          OpsV.push_back(VectorParts(UF, Op)); -      } -      for (unsigned Part = 0; Part < UF; ++Part) { -        SmallVector<Value *, 4> Ops; -        Value *GEPBasePtr = OpsV[0][Part]; -        for (unsigned i = 1; i < Gep->getNumOperands(); i++) -          Ops.push_back(OpsV[i][Part]); -        Value *NewGep =  Builder.CreateGEP(GEPBasePtr, Ops, "VectorGep"); -        cast<GetElementPtrInst>(NewGep)->setIsInBounds(Gep->isInBounds()); -        assert(NewGep->getType()->isVectorTy() && "Expected vector GEP"); - -        NewGep = -            Builder.CreateBitCast(NewGep, VectorType::get(Ptr->getType(), VF)); -        VectorGep.push_back(NewGep); -      } -    } else -      VectorGep = getVectorValue(Ptr); +    VectorGep = getVectorValue(Ptr);    }    VectorParts Mask = createBlockInMask(Instr->getParent()); @@ -3027,7 +3088,7 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr,    // Determine the number of scalars we need to generate for each unroll    // iteration. If the instruction is uniform, we only need to generate the    // first lane. Otherwise, we generate all VF values. -  unsigned Lanes = Legal->isUniformAfterVectorization(Instr) ? 1 : VF; +  unsigned Lanes = Cost->isUniformAfterVectorization(Instr, VF) ? 1 : VF;    // For each vector unroll 'part':    for (unsigned Part = 0; Part < UF; ++Part) { @@ -3038,7 +3099,9 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr,        // Start if-block.        Value *Cmp = nullptr;        if (IfPredicateInstr) { -        Cmp = Builder.CreateExtractElement(Cond[Part], Builder.getInt32(Lane)); +        Cmp = Cond[Part]; +        if (Cmp->getType()->isVectorTy()) +          Cmp = Builder.CreateExtractElement(Cmp, Builder.getInt32(Lane));          Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Cmp,                                   ConstantInt::get(Cmp->getType(), 1));        } @@ -3346,7 +3409,7 @@ void InnerLoopVectorizer::createEmptyLoop() {    //   - counts from zero, stepping by one    //   - is the size of the widest induction variable type    // then we create a new one. -  OldInduction = Legal->getInduction(); +  OldInduction = Legal->getPrimaryInduction();    Type *IdxTy = Legal->getWidestInductionType();    // Split the single block loop into the two loop structure described above. @@ -3543,7 +3606,7 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,  namespace {  struct CSEDenseMapInfo { -  static bool canHandle(Instruction *I) { +  static bool canHandle(const Instruction *I) {      return isa<InsertElementInst>(I) || isa<ExtractElementInst>(I) ||             isa<ShuffleVectorInst>(I) || isa<GetElementPtrInst>(I);    } @@ -3553,12 +3616,12 @@ struct CSEDenseMapInfo {    static inline Instruction *getTombstoneKey() {      return DenseMapInfo<Instruction *>::getTombstoneKey();    } -  static unsigned getHashValue(Instruction *I) { +  static unsigned getHashValue(const Instruction *I) {      assert(canHandle(I) && "Unknown instruction!");      return hash_combine(I->getOpcode(), hash_combine_range(I->value_op_begin(),                                                             I->value_op_end()));    } -  static bool isEqual(Instruction *LHS, Instruction *RHS) { +  static bool isEqual(const Instruction *LHS, const Instruction *RHS) {      if (LHS == getEmptyKey() || RHS == getEmptyKey() ||          LHS == getTombstoneKey() || RHS == getTombstoneKey())        return LHS == RHS; @@ -3589,51 +3652,6 @@ static void cse(BasicBlock *BB) {    }  } -/// \brief Adds a 'fast' flag to floating point operations. -static Value *addFastMathFlag(Value *V) { -  if (isa<FPMathOperator>(V)) { -    FastMathFlags Flags; -    Flags.setUnsafeAlgebra(); -    cast<Instruction>(V)->setFastMathFlags(Flags); -  } -  return V; -} - -/// \brief Estimate the overhead of scalarizing a value based on its type. -/// Insert and Extract are set if the result needs to be inserted and/or -/// extracted from vectors. -static unsigned getScalarizationOverhead(Type *Ty, bool Insert, bool Extract, -                                         const TargetTransformInfo &TTI) { -  if (Ty->isVoidTy()) -    return 0; - -  assert(Ty->isVectorTy() && "Can only scalarize vectors"); -  unsigned Cost = 0; - -  for (unsigned I = 0, E = Ty->getVectorNumElements(); I < E; ++I) { -    if (Extract) -      Cost += TTI.getVectorInstrCost(Instruction::ExtractElement, Ty, I); -    if (Insert) -      Cost += TTI.getVectorInstrCost(Instruction::InsertElement, Ty, I); -  } - -  return Cost; -} - -/// \brief Estimate the overhead of scalarizing an Instruction based on the -/// types of its operands and return value. -static unsigned getScalarizationOverhead(SmallVectorImpl<Type *> &OpTys, -                                         Type *RetTy, -                                         const TargetTransformInfo &TTI) { -  unsigned ScalarizationCost = -      getScalarizationOverhead(RetTy, true, false, TTI); - -  for (Type *Ty : OpTys) -    ScalarizationCost += getScalarizationOverhead(Ty, false, true, TTI); - -  return ScalarizationCost; -} -  /// \brief Estimate the overhead of scalarizing an instruction. This is a  /// convenience wrapper for the type-based getScalarizationOverhead API.  static unsigned getScalarizationOverhead(Instruction *I, unsigned VF, @@ -3641,14 +3659,24 @@ static unsigned getScalarizationOverhead(Instruction *I, unsigned VF,    if (VF == 1)      return 0; +  unsigned Cost = 0;    Type *RetTy = ToVectorTy(I->getType(), VF); +  if (!RetTy->isVoidTy() && +      (!isa<LoadInst>(I) || +       !TTI.supportsEfficientVectorElementLoadStore())) +    Cost += TTI.getScalarizationOverhead(RetTy, true, false); -  SmallVector<Type *, 4> OpTys; -  unsigned OperandsNum = I->getNumOperands(); -  for (unsigned OpInd = 0; OpInd < OperandsNum; ++OpInd) -    OpTys.push_back(ToVectorTy(I->getOperand(OpInd)->getType(), VF)); +  if (CallInst *CI = dyn_cast<CallInst>(I)) { +    SmallVector<const Value *, 4> Operands(CI->arg_operands()); +    Cost += TTI.getOperandsScalarizationOverhead(Operands, VF); +  } +  else if (!isa<StoreInst>(I) || +           !TTI.supportsEfficientVectorElementLoadStore()) { +    SmallVector<const Value *, 4> Operands(I->operand_values()); +    Cost += TTI.getOperandsScalarizationOverhead(Operands, VF); +  } -  return getScalarizationOverhead(OpTys, RetTy, TTI); +  return Cost;  }  // Estimate cost of a call instruction CI if it were vectorized with factor VF. @@ -3681,7 +3709,7 @@ static unsigned getVectorCallCost(CallInst *CI, unsigned VF,    // Compute costs of unpacking argument values for the scalar calls and    // packing the return values to a vector. -  unsigned ScalarizationCost = getScalarizationOverhead(Tys, RetTy, TTI); +  unsigned ScalarizationCost = getScalarizationOverhead(CI, VF, TTI);    unsigned Cost = ScalarCallCost * VF + ScalarizationCost; @@ -3709,16 +3737,12 @@ static unsigned getVectorIntrinsicCost(CallInst *CI, unsigned VF,    Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);    assert(ID && "Expected intrinsic call!"); -  Type *RetTy = ToVectorTy(CI->getType(), VF); -  SmallVector<Type *, 4> Tys; -  for (Value *ArgOperand : CI->arg_operands()) -    Tys.push_back(ToVectorTy(ArgOperand->getType(), VF)); -    FastMathFlags FMF;    if (auto *FPMO = dyn_cast<FPMathOperator>(CI))      FMF = FPMO->getFastMathFlags(); -  return TTI.getIntrinsicInstrCost(ID, RetTy, Tys, FMF); +  SmallVector<Value *, 4> Operands(CI->arg_operands()); +  return TTI.getIntrinsicInstrCost(ID, CI->getType(), Operands, FMF, VF);  }  static Type *smallestIntegerVectorType(Type *T1, Type *T2) { @@ -3861,30 +3885,27 @@ void InnerLoopVectorizer::vectorizeLoop() {    // the cost-model.    //    //===------------------------------------------------===// -  Constant *Zero = Builder.getInt32(0); -  // In order to support recurrences we need to be able to vectorize Phi nodes. -  // Phi nodes have cycles, so we need to vectorize them in two stages. First, -  // we create a new vector PHI node with no incoming edges. We use this value -  // when we vectorize all of the instructions that use the PHI. Next, after -  // all of the instructions in the block are complete we add the new incoming -  // edges to the PHI. At this point all of the instructions in the basic block -  // are vectorized, so we can use them to construct the PHI. -  PhiVector PHIsToFix; - -  // Collect instructions from the original loop that will become trivially -  // dead in the vectorized loop. We don't need to vectorize these -  // instructions. -  collectTriviallyDeadInstructions(); +  // Collect instructions from the original loop that will become trivially dead +  // in the vectorized loop. We don't need to vectorize these instructions. For +  // example, original induction update instructions can become dead because we +  // separately emit induction "steps" when generating code for the new loop. +  // Similarly, we create a new latch condition when setting up the structure +  // of the new loop, so the old one can become dead. +  SmallPtrSet<Instruction *, 4> DeadInstructions; +  collectTriviallyDeadInstructions(DeadInstructions);    // Scan the loop in a topological order to ensure that defs are vectorized    // before users.    LoopBlocksDFS DFS(OrigLoop);    DFS.perform(LI); -  // Vectorize all of the blocks in the original loop. +  // Vectorize all instructions in the original loop that will not become +  // trivially dead when vectorized.    for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) -    vectorizeBlockInLoop(BB, &PHIsToFix); +    for (Instruction &I : *BB) +      if (!DeadInstructions.count(&I)) +        vectorizeInstruction(I);    // Insert truncates and extends for any truncated instructions as hints to    // InstCombine. @@ -3892,224 +3913,10 @@ void InnerLoopVectorizer::vectorizeLoop() {      truncateToMinimalBitwidths();    // At this point every instruction in the original loop is widened to a -  // vector form. Now we need to fix the recurrences in PHIsToFix. These PHI +  // vector form. Now we need to fix the recurrences in the loop. These PHI    // nodes are currently empty because we did not want to introduce cycles.    // This is the second stage of vectorizing recurrences. -  for (PHINode *Phi : PHIsToFix) { -    assert(Phi && "Unable to recover vectorized PHI"); - -    // Handle first-order recurrences that need to be fixed. -    if (Legal->isFirstOrderRecurrence(Phi)) { -      fixFirstOrderRecurrence(Phi); -      continue; -    } - -    // If the phi node is not a first-order recurrence, it must be a reduction. -    // Get it's reduction variable descriptor. -    assert(Legal->isReductionVariable(Phi) && -           "Unable to find the reduction variable"); -    RecurrenceDescriptor RdxDesc = (*Legal->getReductionVars())[Phi]; - -    RecurrenceDescriptor::RecurrenceKind RK = RdxDesc.getRecurrenceKind(); -    TrackingVH<Value> ReductionStartValue = RdxDesc.getRecurrenceStartValue(); -    Instruction *LoopExitInst = RdxDesc.getLoopExitInstr(); -    RecurrenceDescriptor::MinMaxRecurrenceKind MinMaxKind = -        RdxDesc.getMinMaxRecurrenceKind(); -    setDebugLocFromInst(Builder, ReductionStartValue); - -    // We need to generate a reduction vector from the incoming scalar. -    // To do so, we need to generate the 'identity' vector and override -    // one of the elements with the incoming scalar reduction. We need -    // to do it in the vector-loop preheader. -    Builder.SetInsertPoint(LoopBypassBlocks[1]->getTerminator()); - -    // This is the vector-clone of the value that leaves the loop. -    const VectorParts &VectorExit = getVectorValue(LoopExitInst); -    Type *VecTy = VectorExit[0]->getType(); - -    // Find the reduction identity variable. Zero for addition, or, xor, -    // one for multiplication, -1 for And. -    Value *Identity; -    Value *VectorStart; -    if (RK == RecurrenceDescriptor::RK_IntegerMinMax || -        RK == RecurrenceDescriptor::RK_FloatMinMax) { -      // MinMax reduction have the start value as their identify. -      if (VF == 1) { -        VectorStart = Identity = ReductionStartValue; -      } else { -        VectorStart = Identity = -            Builder.CreateVectorSplat(VF, ReductionStartValue, "minmax.ident"); -      } -    } else { -      // Handle other reduction kinds: -      Constant *Iden = RecurrenceDescriptor::getRecurrenceIdentity( -          RK, VecTy->getScalarType()); -      if (VF == 1) { -        Identity = Iden; -        // This vector is the Identity vector where the first element is the -        // incoming scalar reduction. -        VectorStart = ReductionStartValue; -      } else { -        Identity = ConstantVector::getSplat(VF, Iden); - -        // This vector is the Identity vector where the first element is the -        // incoming scalar reduction. -        VectorStart = -            Builder.CreateInsertElement(Identity, ReductionStartValue, Zero); -      } -    } - -    // Fix the vector-loop phi. - -    // Reductions do not have to start at zero. They can start with -    // any loop invariant values. -    const VectorParts &VecRdxPhi = getVectorValue(Phi); -    BasicBlock *Latch = OrigLoop->getLoopLatch(); -    Value *LoopVal = Phi->getIncomingValueForBlock(Latch); -    const VectorParts &Val = getVectorValue(LoopVal); -    for (unsigned part = 0; part < UF; ++part) { -      // Make sure to add the reduction stat value only to the -      // first unroll part. -      Value *StartVal = (part == 0) ? VectorStart : Identity; -      cast<PHINode>(VecRdxPhi[part]) -          ->addIncoming(StartVal, LoopVectorPreHeader); -      cast<PHINode>(VecRdxPhi[part]) -          ->addIncoming(Val[part], LoopVectorBody); -    } - -    // Before each round, move the insertion point right between -    // the PHIs and the values we are going to write. -    // This allows us to write both PHINodes and the extractelement -    // instructions. -    Builder.SetInsertPoint(&*LoopMiddleBlock->getFirstInsertionPt()); - -    VectorParts &RdxParts = VectorLoopValueMap.getVector(LoopExitInst); -    setDebugLocFromInst(Builder, LoopExitInst); - -    // If the vector reduction can be performed in a smaller type, we truncate -    // then extend the loop exit value to enable InstCombine to evaluate the -    // entire expression in the smaller type. -    if (VF > 1 && Phi->getType() != RdxDesc.getRecurrenceType()) { -      Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), VF); -      Builder.SetInsertPoint(LoopVectorBody->getTerminator()); -      for (unsigned part = 0; part < UF; ++part) { -        Value *Trunc = Builder.CreateTrunc(RdxParts[part], RdxVecTy); -        Value *Extnd = RdxDesc.isSigned() ? Builder.CreateSExt(Trunc, VecTy) -                                          : Builder.CreateZExt(Trunc, VecTy); -        for (Value::user_iterator UI = RdxParts[part]->user_begin(); -             UI != RdxParts[part]->user_end();) -          if (*UI != Trunc) { -            (*UI++)->replaceUsesOfWith(RdxParts[part], Extnd); -            RdxParts[part] = Extnd; -          } else { -            ++UI; -          } -      } -      Builder.SetInsertPoint(&*LoopMiddleBlock->getFirstInsertionPt()); -      for (unsigned part = 0; part < UF; ++part) -        RdxParts[part] = Builder.CreateTrunc(RdxParts[part], RdxVecTy); -    } - -    // Reduce all of the unrolled parts into a single vector. -    Value *ReducedPartRdx = RdxParts[0]; -    unsigned Op = RecurrenceDescriptor::getRecurrenceBinOp(RK); -    setDebugLocFromInst(Builder, ReducedPartRdx); -    for (unsigned part = 1; part < UF; ++part) { -      if (Op != Instruction::ICmp && Op != Instruction::FCmp) -        // Floating point operations had to be 'fast' to enable the reduction. -        ReducedPartRdx = addFastMathFlag( -            Builder.CreateBinOp((Instruction::BinaryOps)Op, RdxParts[part], -                                ReducedPartRdx, "bin.rdx")); -      else -        ReducedPartRdx = RecurrenceDescriptor::createMinMaxOp( -            Builder, MinMaxKind, ReducedPartRdx, RdxParts[part]); -    } - -    if (VF > 1) { -      // VF is a power of 2 so we can emit the reduction using log2(VF) shuffles -      // and vector ops, reducing the set of values being computed by half each -      // round. -      assert(isPowerOf2_32(VF) && -             "Reduction emission only supported for pow2 vectors!"); -      Value *TmpVec = ReducedPartRdx; -      SmallVector<Constant *, 32> ShuffleMask(VF, nullptr); -      for (unsigned i = VF; i != 1; i >>= 1) { -        // Move the upper half of the vector to the lower half. -        for (unsigned j = 0; j != i / 2; ++j) -          ShuffleMask[j] = Builder.getInt32(i / 2 + j); - -        // Fill the rest of the mask with undef. -        std::fill(&ShuffleMask[i / 2], ShuffleMask.end(), -                  UndefValue::get(Builder.getInt32Ty())); - -        Value *Shuf = Builder.CreateShuffleVector( -            TmpVec, UndefValue::get(TmpVec->getType()), -            ConstantVector::get(ShuffleMask), "rdx.shuf"); - -        if (Op != Instruction::ICmp && Op != Instruction::FCmp) -          // Floating point operations had to be 'fast' to enable the reduction. -          TmpVec = addFastMathFlag(Builder.CreateBinOp( -              (Instruction::BinaryOps)Op, TmpVec, Shuf, "bin.rdx")); -        else -          TmpVec = RecurrenceDescriptor::createMinMaxOp(Builder, MinMaxKind, -                                                        TmpVec, Shuf); -      } - -      // The result is in the first element of the vector. -      ReducedPartRdx = -          Builder.CreateExtractElement(TmpVec, Builder.getInt32(0)); - -      // If the reduction can be performed in a smaller type, we need to extend -      // the reduction to the wider type before we branch to the original loop. -      if (Phi->getType() != RdxDesc.getRecurrenceType()) -        ReducedPartRdx = -            RdxDesc.isSigned() -                ? Builder.CreateSExt(ReducedPartRdx, Phi->getType()) -                : Builder.CreateZExt(ReducedPartRdx, Phi->getType()); -    } - -    // Create a phi node that merges control-flow from the backedge-taken check -    // block and the middle block. -    PHINode *BCBlockPhi = PHINode::Create(Phi->getType(), 2, "bc.merge.rdx", -                                          LoopScalarPreHeader->getTerminator()); -    for (unsigned I = 0, E = LoopBypassBlocks.size(); I != E; ++I) -      BCBlockPhi->addIncoming(ReductionStartValue, LoopBypassBlocks[I]); -    BCBlockPhi->addIncoming(ReducedPartRdx, LoopMiddleBlock); - -    // Now, we need to fix the users of the reduction variable -    // inside and outside of the scalar remainder loop. -    // We know that the loop is in LCSSA form. We need to update the -    // PHI nodes in the exit blocks. -    for (BasicBlock::iterator LEI = LoopExitBlock->begin(), -                              LEE = LoopExitBlock->end(); -         LEI != LEE; ++LEI) { -      PHINode *LCSSAPhi = dyn_cast<PHINode>(LEI); -      if (!LCSSAPhi) -        break; - -      // All PHINodes need to have a single entry edge, or two if -      // we already fixed them. -      assert(LCSSAPhi->getNumIncomingValues() < 3 && "Invalid LCSSA PHI"); - -      // We found our reduction value exit-PHI. Update it with the -      // incoming bypass edge. -      if (LCSSAPhi->getIncomingValue(0) == LoopExitInst) { -        // Add an edge coming from the bypass. -        LCSSAPhi->addIncoming(ReducedPartRdx, LoopMiddleBlock); -        break; -      } -    } // end of the LCSSA phi scan. - -    // Fix the scalar loop reduction variable with the incoming reduction sum -    // from the vector body and from the backedge value. -    int IncomingEdgeBlockIdx = -        Phi->getBasicBlockIndex(OrigLoop->getLoopLatch()); -    assert(IncomingEdgeBlockIdx >= 0 && "Invalid block index"); -    // Pick the other block. -    int SelfEdgeBlockIdx = (IncomingEdgeBlockIdx ? 0 : 1); -    Phi->setIncomingValue(SelfEdgeBlockIdx, BCBlockPhi); -    Phi->setIncomingValue(IncomingEdgeBlockIdx, LoopExitInst); -  } // end of for each Phi in PHIsToFix. +  fixCrossIterationPHIs();    // Update the dominator tree.    // @@ -4134,6 +3941,25 @@ void InnerLoopVectorizer::vectorizeLoop() {    cse(LoopVectorBody);  } +void InnerLoopVectorizer::fixCrossIterationPHIs() { +  // In order to support recurrences we need to be able to vectorize Phi nodes. +  // Phi nodes have cycles, so we need to vectorize them in two stages. This is +  // stage #2: We now need to fix the recurrences by adding incoming edges to +  // the currently empty PHI nodes. At this point every instruction in the +  // original loop is widened to a vector form so we can use them to construct +  // the incoming edges. +  for (Instruction &I : *OrigLoop->getHeader()) { +    PHINode *Phi = dyn_cast<PHINode>(&I); +    if (!Phi) +      break; +    // Handle first-order recurrences and reductions that need to be fixed. +    if (Legal->isFirstOrderRecurrence(Phi)) +      fixFirstOrderRecurrence(Phi); +    else if (Legal->isReductionVariable(Phi)) +      fixReduction(Phi); +  } +} +  void InnerLoopVectorizer::fixFirstOrderRecurrence(PHINode *Phi) {    // This is the second phase of vectorizing first-order recurrences. An @@ -4212,15 +4038,17 @@ void InnerLoopVectorizer::fixFirstOrderRecurrence(PHINode *Phi) {    auto *VecPhi = Builder.CreatePHI(VectorInit->getType(), 2, "vector.recur");    VecPhi->addIncoming(VectorInit, LoopVectorPreHeader); -  // Get the vectorized previous value. We ensured the previous values was an -  // instruction when detecting the recurrence. +  // Get the vectorized previous value.    auto &PreviousParts = getVectorValue(Previous); -  // Set the insertion point to be after this instruction. We ensured the -  // previous value dominated all uses of the phi when detecting the -  // recurrence. -  Builder.SetInsertPoint( -      &*++BasicBlock::iterator(cast<Instruction>(PreviousParts[UF - 1]))); +  // Set the insertion point after the previous value if it is an instruction. +  // Note that the previous value may have been constant-folded so it is not +  // guaranteed to be an instruction in the vector loop. +  if (LI->getLoopFor(LoopVectorBody)->isLoopInvariant(PreviousParts[UF - 1])) +    Builder.SetInsertPoint(&*LoopVectorBody->getFirstInsertionPt()); +  else +    Builder.SetInsertPoint( +        &*++BasicBlock::iterator(cast<Instruction>(PreviousParts[UF - 1])));    // We will construct a vector for the recurrence by combining the values for    // the current and previous iterations. This is the required shuffle mask. @@ -4251,18 +4079,33 @@ void InnerLoopVectorizer::fixFirstOrderRecurrence(PHINode *Phi) {    // Extract the last vector element in the middle block. This will be the    // initial value for the recurrence when jumping to the scalar loop. -  auto *Extract = Incoming; +  auto *ExtractForScalar = Incoming;    if (VF > 1) {      Builder.SetInsertPoint(LoopMiddleBlock->getTerminator()); -    Extract = Builder.CreateExtractElement(Extract, Builder.getInt32(VF - 1), -                                           "vector.recur.extract"); -  } +    ExtractForScalar = Builder.CreateExtractElement( +        ExtractForScalar, Builder.getInt32(VF - 1), "vector.recur.extract"); +  } +  // Extract the second last element in the middle block if the +  // Phi is used outside the loop. We need to extract the phi itself +  // and not the last element (the phi update in the current iteration). This +  // will be the value when jumping to the exit block from the LoopMiddleBlock, +  // when the scalar loop is not run at all. +  Value *ExtractForPhiUsedOutsideLoop = nullptr; +  if (VF > 1) +    ExtractForPhiUsedOutsideLoop = Builder.CreateExtractElement( +        Incoming, Builder.getInt32(VF - 2), "vector.recur.extract.for.phi"); +  // When loop is unrolled without vectorizing, initialize +  // ExtractForPhiUsedOutsideLoop with the value just prior to unrolled value of +  // `Incoming`. This is analogous to the vectorized case above: extracting the +  // second last element when VF > 1. +  else if (UF > 1) +    ExtractForPhiUsedOutsideLoop = PreviousParts[UF - 2];    // Fix the initial value of the original recurrence in the scalar loop.    Builder.SetInsertPoint(&*LoopScalarPreHeader->begin());    auto *Start = Builder.CreatePHI(Phi->getType(), 2, "scalar.recur.init");    for (auto *BB : predecessors(LoopScalarPreHeader)) { -    auto *Incoming = BB == LoopMiddleBlock ? Extract : ScalarInit; +    auto *Incoming = BB == LoopMiddleBlock ? ExtractForScalar : ScalarInit;      Start->addIncoming(Incoming, BB);    } @@ -4279,12 +4122,218 @@ void InnerLoopVectorizer::fixFirstOrderRecurrence(PHINode *Phi) {      if (!LCSSAPhi)        break;      if (LCSSAPhi->getIncomingValue(0) == Phi) { -      LCSSAPhi->addIncoming(Extract, LoopMiddleBlock); +      LCSSAPhi->addIncoming(ExtractForPhiUsedOutsideLoop, LoopMiddleBlock);        break;      }    }  } +void InnerLoopVectorizer::fixReduction(PHINode *Phi) { +  Constant *Zero = Builder.getInt32(0); + +  // Get it's reduction variable descriptor. +  assert(Legal->isReductionVariable(Phi) && +         "Unable to find the reduction variable"); +  RecurrenceDescriptor RdxDesc = (*Legal->getReductionVars())[Phi]; + +  RecurrenceDescriptor::RecurrenceKind RK = RdxDesc.getRecurrenceKind(); +  TrackingVH<Value> ReductionStartValue = RdxDesc.getRecurrenceStartValue(); +  Instruction *LoopExitInst = RdxDesc.getLoopExitInstr(); +  RecurrenceDescriptor::MinMaxRecurrenceKind MinMaxKind = +    RdxDesc.getMinMaxRecurrenceKind(); +  setDebugLocFromInst(Builder, ReductionStartValue); + +  // We need to generate a reduction vector from the incoming scalar. +  // To do so, we need to generate the 'identity' vector and override +  // one of the elements with the incoming scalar reduction. We need +  // to do it in the vector-loop preheader. +  Builder.SetInsertPoint(LoopBypassBlocks[1]->getTerminator()); + +  // This is the vector-clone of the value that leaves the loop. +  const VectorParts &VectorExit = getVectorValue(LoopExitInst); +  Type *VecTy = VectorExit[0]->getType(); + +  // Find the reduction identity variable. Zero for addition, or, xor, +  // one for multiplication, -1 for And. +  Value *Identity; +  Value *VectorStart; +  if (RK == RecurrenceDescriptor::RK_IntegerMinMax || +      RK == RecurrenceDescriptor::RK_FloatMinMax) { +    // MinMax reduction have the start value as their identify. +    if (VF == 1) { +      VectorStart = Identity = ReductionStartValue; +    } else { +      VectorStart = Identity = +        Builder.CreateVectorSplat(VF, ReductionStartValue, "minmax.ident"); +    } +  } else { +    // Handle other reduction kinds: +    Constant *Iden = RecurrenceDescriptor::getRecurrenceIdentity( +        RK, VecTy->getScalarType()); +    if (VF == 1) { +      Identity = Iden; +      // This vector is the Identity vector where the first element is the +      // incoming scalar reduction. +      VectorStart = ReductionStartValue; +    } else { +      Identity = ConstantVector::getSplat(VF, Iden); + +      // This vector is the Identity vector where the first element is the +      // incoming scalar reduction. +      VectorStart = +        Builder.CreateInsertElement(Identity, ReductionStartValue, Zero); +    } +  } + +  // Fix the vector-loop phi. + +  // Reductions do not have to start at zero. They can start with +  // any loop invariant values. +  const VectorParts &VecRdxPhi = getVectorValue(Phi); +  BasicBlock *Latch = OrigLoop->getLoopLatch(); +  Value *LoopVal = Phi->getIncomingValueForBlock(Latch); +  const VectorParts &Val = getVectorValue(LoopVal); +  for (unsigned part = 0; part < UF; ++part) { +    // Make sure to add the reduction stat value only to the +    // first unroll part. +    Value *StartVal = (part == 0) ? VectorStart : Identity; +    cast<PHINode>(VecRdxPhi[part]) +      ->addIncoming(StartVal, LoopVectorPreHeader); +    cast<PHINode>(VecRdxPhi[part]) +      ->addIncoming(Val[part], LI->getLoopFor(LoopVectorBody)->getLoopLatch()); +  } + +  // Before each round, move the insertion point right between +  // the PHIs and the values we are going to write. +  // This allows us to write both PHINodes and the extractelement +  // instructions. +  Builder.SetInsertPoint(&*LoopMiddleBlock->getFirstInsertionPt()); + +  VectorParts &RdxParts = VectorLoopValueMap.getVector(LoopExitInst); +  setDebugLocFromInst(Builder, LoopExitInst); + +  // If the vector reduction can be performed in a smaller type, we truncate +  // then extend the loop exit value to enable InstCombine to evaluate the +  // entire expression in the smaller type. +  if (VF > 1 && Phi->getType() != RdxDesc.getRecurrenceType()) { +    Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), VF); +    Builder.SetInsertPoint(LoopVectorBody->getTerminator()); +    for (unsigned part = 0; part < UF; ++part) { +      Value *Trunc = Builder.CreateTrunc(RdxParts[part], RdxVecTy); +      Value *Extnd = RdxDesc.isSigned() ? Builder.CreateSExt(Trunc, VecTy) +        : Builder.CreateZExt(Trunc, VecTy); +      for (Value::user_iterator UI = RdxParts[part]->user_begin(); +           UI != RdxParts[part]->user_end();) +        if (*UI != Trunc) { +          (*UI++)->replaceUsesOfWith(RdxParts[part], Extnd); +          RdxParts[part] = Extnd; +        } else { +          ++UI; +        } +    } +    Builder.SetInsertPoint(&*LoopMiddleBlock->getFirstInsertionPt()); +    for (unsigned part = 0; part < UF; ++part) +      RdxParts[part] = Builder.CreateTrunc(RdxParts[part], RdxVecTy); +  } + +  // Reduce all of the unrolled parts into a single vector. +  Value *ReducedPartRdx = RdxParts[0]; +  unsigned Op = RecurrenceDescriptor::getRecurrenceBinOp(RK); +  setDebugLocFromInst(Builder, ReducedPartRdx); +  for (unsigned part = 1; part < UF; ++part) { +    if (Op != Instruction::ICmp && Op != Instruction::FCmp) +      // Floating point operations had to be 'fast' to enable the reduction. +      ReducedPartRdx = addFastMathFlag( +          Builder.CreateBinOp((Instruction::BinaryOps)Op, RdxParts[part], +                              ReducedPartRdx, "bin.rdx")); +    else +      ReducedPartRdx = RecurrenceDescriptor::createMinMaxOp( +          Builder, MinMaxKind, ReducedPartRdx, RdxParts[part]); +  } + +  if (VF > 1) { +    // VF is a power of 2 so we can emit the reduction using log2(VF) shuffles +    // and vector ops, reducing the set of values being computed by half each +    // round. +    assert(isPowerOf2_32(VF) && +           "Reduction emission only supported for pow2 vectors!"); +    Value *TmpVec = ReducedPartRdx; +    SmallVector<Constant *, 32> ShuffleMask(VF, nullptr); +    for (unsigned i = VF; i != 1; i >>= 1) { +      // Move the upper half of the vector to the lower half. +      for (unsigned j = 0; j != i / 2; ++j) +        ShuffleMask[j] = Builder.getInt32(i / 2 + j); + +      // Fill the rest of the mask with undef. +      std::fill(&ShuffleMask[i / 2], ShuffleMask.end(), +                UndefValue::get(Builder.getInt32Ty())); + +      Value *Shuf = Builder.CreateShuffleVector( +          TmpVec, UndefValue::get(TmpVec->getType()), +          ConstantVector::get(ShuffleMask), "rdx.shuf"); + +      if (Op != Instruction::ICmp && Op != Instruction::FCmp) +        // Floating point operations had to be 'fast' to enable the reduction. +        TmpVec = addFastMathFlag(Builder.CreateBinOp( +                                     (Instruction::BinaryOps)Op, TmpVec, Shuf, "bin.rdx")); +      else +        TmpVec = RecurrenceDescriptor::createMinMaxOp(Builder, MinMaxKind, +                                                      TmpVec, Shuf); +    } + +    // The result is in the first element of the vector. +    ReducedPartRdx = +      Builder.CreateExtractElement(TmpVec, Builder.getInt32(0)); + +    // If the reduction can be performed in a smaller type, we need to extend +    // the reduction to the wider type before we branch to the original loop. +    if (Phi->getType() != RdxDesc.getRecurrenceType()) +      ReducedPartRdx = +        RdxDesc.isSigned() +        ? Builder.CreateSExt(ReducedPartRdx, Phi->getType()) +        : Builder.CreateZExt(ReducedPartRdx, Phi->getType()); +  } + +  // Create a phi node that merges control-flow from the backedge-taken check +  // block and the middle block. +  PHINode *BCBlockPhi = PHINode::Create(Phi->getType(), 2, "bc.merge.rdx", +                                        LoopScalarPreHeader->getTerminator()); +  for (unsigned I = 0, E = LoopBypassBlocks.size(); I != E; ++I) +    BCBlockPhi->addIncoming(ReductionStartValue, LoopBypassBlocks[I]); +  BCBlockPhi->addIncoming(ReducedPartRdx, LoopMiddleBlock); + +  // Now, we need to fix the users of the reduction variable +  // inside and outside of the scalar remainder loop. +  // We know that the loop is in LCSSA form. We need to update the +  // PHI nodes in the exit blocks. +  for (BasicBlock::iterator LEI = LoopExitBlock->begin(), +         LEE = LoopExitBlock->end(); +       LEI != LEE; ++LEI) { +    PHINode *LCSSAPhi = dyn_cast<PHINode>(LEI); +    if (!LCSSAPhi) +      break; + +    // All PHINodes need to have a single entry edge, or two if +    // we already fixed them. +    assert(LCSSAPhi->getNumIncomingValues() < 3 && "Invalid LCSSA PHI"); + +    // We found a reduction value exit-PHI. Update it with the +    // incoming bypass edge. +    if (LCSSAPhi->getIncomingValue(0) == LoopExitInst) +      LCSSAPhi->addIncoming(ReducedPartRdx, LoopMiddleBlock); +  } // end of the LCSSA phi scan. + +    // Fix the scalar loop reduction variable with the incoming reduction sum +    // from the vector body and from the backedge value. +  int IncomingEdgeBlockIdx = +    Phi->getBasicBlockIndex(OrigLoop->getLoopLatch()); +  assert(IncomingEdgeBlockIdx >= 0 && "Invalid block index"); +  // Pick the other block. +  int SelfEdgeBlockIdx = (IncomingEdgeBlockIdx ? 0 : 1); +  Phi->setIncomingValue(SelfEdgeBlockIdx, BCBlockPhi); +  Phi->setIncomingValue(IncomingEdgeBlockIdx, LoopExitInst); +} +  void InnerLoopVectorizer::fixLCSSAPHIs() {    for (Instruction &LEI : *LoopExitBlock) {      auto *LCSSAPhi = dyn_cast<PHINode>(&LEI); @@ -4296,7 +4345,8 @@ void InnerLoopVectorizer::fixLCSSAPHIs() {    }  } -void InnerLoopVectorizer::collectTriviallyDeadInstructions() { +void InnerLoopVectorizer::collectTriviallyDeadInstructions( +    SmallPtrSetImpl<Instruction *> &DeadInstructions) {    BasicBlock *Latch = OrigLoop->getLoopLatch();    // We create new control-flow for the vectorized loop, so the original @@ -4563,9 +4613,12 @@ InnerLoopVectorizer::createBlockInMask(BasicBlock *BB) {  }  void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, unsigned UF, -                                              unsigned VF, PhiVector *PV) { +                                              unsigned VF) {    PHINode *P = cast<PHINode>(PN); -  // Handle recurrences. +  // In order to support recurrences we need to be able to vectorize Phi nodes. +  // Phi nodes have cycles, so we need to vectorize them in two stages. This is +  // stage #1: We create a new vector PHI node with no incoming edges. We'll use +  // this value when we vectorize all of the instructions that use the PHI.    if (Legal->isReductionVariable(P) || Legal->isFirstOrderRecurrence(P)) {      VectorParts Entry(UF);      for (unsigned part = 0; part < UF; ++part) { @@ -4576,7 +4629,6 @@ void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, unsigned UF,            VecTy, 2, "vec.phi", &*LoopVectorBody->getFirstInsertionPt());      }      VectorLoopValueMap.initVector(P, Entry); -    PV->push_back(P);      return;    } @@ -4631,7 +4683,8 @@ void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, unsigned UF,    case InductionDescriptor::IK_NoInduction:      llvm_unreachable("Unknown induction");    case InductionDescriptor::IK_IntInduction: -    return widenIntInduction(P); +  case InductionDescriptor::IK_FpInduction: +    return widenIntOrFpInduction(P);    case InductionDescriptor::IK_PtrInduction: {      // Handle the pointer induction variable case.      assert(P->getType()->isPointerTy() && "Unexpected type."); @@ -4641,7 +4694,7 @@ void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, unsigned UF,      // Determine the number of scalars we need to generate for each unroll      // iteration. If the instruction is uniform, we only need to generate the      // first lane. Otherwise, we generate all VF values. -    unsigned Lanes = Legal->isUniformAfterVectorization(P) ? 1 : VF; +    unsigned Lanes = Cost->isUniformAfterVectorization(P, VF) ? 1 : VF;      // These are the scalar results. Notice that we don't generate vector GEPs      // because scalar GEPs result in better code.      ScalarParts Entry(UF); @@ -4658,30 +4711,6 @@ void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, unsigned UF,      VectorLoopValueMap.initScalar(P, Entry);      return;    } -  case InductionDescriptor::IK_FpInduction: { -    assert(P->getType() == II.getStartValue()->getType() && -           "Types must match"); -    // Handle other induction variables that are now based on the -    // canonical one. -    assert(P != OldInduction && "Primary induction can be integer only"); - -    Value *V = Builder.CreateCast(Instruction::SIToFP, Induction, P->getType()); -    V = II.transform(Builder, V, PSE.getSE(), DL); -    V->setName("fp.offset.idx"); - -    // Now we have scalar op: %fp.offset.idx = StartVal +/- Induction*StepVal - -    Value *Broadcasted = getBroadcastInstrs(V); -    // After broadcasting the induction variable we need to make the vector -    // consecutive by adding StepVal*0, StepVal*1, StepVal*2, etc. -    Value *StepVal = cast<SCEVUnknown>(II.getStep())->getValue(); -    VectorParts Entry(UF); -    for (unsigned part = 0; part < UF; ++part) -      Entry[part] = getStepVector(Broadcasted, VF * part, StepVal, -                                  II.getInductionOpcode()); -    VectorLoopValueMap.initVector(P, Entry); -    return; -  }    }  } @@ -4703,269 +4732,323 @@ static bool mayDivideByZero(Instruction &I) {    return !CInt || CInt->isZero();  } -void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { -  // For each instruction in the old loop. -  for (Instruction &I : *BB) { - -    // If the instruction will become trivially dead when vectorized, we don't -    // need to generate it. -    if (DeadInstructions.count(&I)) -      continue; +void InnerLoopVectorizer::vectorizeInstruction(Instruction &I) { +  // Scalarize instructions that should remain scalar after vectorization. +  if (VF > 1 && +      !(isa<BranchInst>(&I) || isa<PHINode>(&I) || isa<DbgInfoIntrinsic>(&I)) && +      shouldScalarizeInstruction(&I)) { +    scalarizeInstruction(&I, Legal->isScalarWithPredication(&I)); +    return; +  } -    // Scalarize instructions that should remain scalar after vectorization. -    if (VF > 1 && -        !(isa<BranchInst>(&I) || isa<PHINode>(&I) || -          isa<DbgInfoIntrinsic>(&I)) && -        shouldScalarizeInstruction(&I)) { -      scalarizeInstruction(&I, Legal->isScalarWithPredication(&I)); -      continue; -    } +  switch (I.getOpcode()) { +  case Instruction::Br: +    // Nothing to do for PHIs and BR, since we already took care of the +    // loop control flow instructions. +    break; +  case Instruction::PHI: { +    // Vectorize PHINodes. +    widenPHIInstruction(&I, UF, VF); +    break; +  } // End of PHI. +  case Instruction::GetElementPtr: { +    // Construct a vector GEP by widening the operands of the scalar GEP as +    // necessary. We mark the vector GEP 'inbounds' if appropriate. A GEP +    // results in a vector of pointers when at least one operand of the GEP +    // is vector-typed. Thus, to keep the representation compact, we only use +    // vector-typed operands for loop-varying values. +    auto *GEP = cast<GetElementPtrInst>(&I); +    VectorParts Entry(UF); -    switch (I.getOpcode()) { -    case Instruction::Br: -      // Nothing to do for PHIs and BR, since we already took care of the -      // loop control flow instructions. -      continue; -    case Instruction::PHI: { -      // Vectorize PHINodes. -      widenPHIInstruction(&I, UF, VF, PV); -      continue; -    } // End of PHI. - -    case Instruction::UDiv: -    case Instruction::SDiv: -    case Instruction::SRem: -    case Instruction::URem: -      // Scalarize with predication if this instruction may divide by zero and -      // block execution is conditional, otherwise fallthrough. -      if (Legal->isScalarWithPredication(&I)) { -        scalarizeInstruction(&I, true); -        continue; -      } -    case Instruction::Add: -    case Instruction::FAdd: -    case Instruction::Sub: -    case Instruction::FSub: -    case Instruction::Mul: -    case Instruction::FMul: -    case Instruction::FDiv: -    case Instruction::FRem: -    case Instruction::Shl: -    case Instruction::LShr: -    case Instruction::AShr: -    case Instruction::And: -    case Instruction::Or: -    case Instruction::Xor: { -      // Just widen binops. -      auto *BinOp = cast<BinaryOperator>(&I); -      setDebugLocFromInst(Builder, BinOp); -      const VectorParts &A = getVectorValue(BinOp->getOperand(0)); -      const VectorParts &B = getVectorValue(BinOp->getOperand(1)); - -      // Use this vector value for all users of the original instruction. -      VectorParts Entry(UF); +    if (VF > 1 && OrigLoop->hasLoopInvariantOperands(GEP)) { +      // If we are vectorizing, but the GEP has only loop-invariant operands, +      // the GEP we build (by only using vector-typed operands for +      // loop-varying values) would be a scalar pointer. Thus, to ensure we +      // produce a vector of pointers, we need to either arbitrarily pick an +      // operand to broadcast, or broadcast a clone of the original GEP. +      // Here, we broadcast a clone of the original. +      // +      // TODO: If at some point we decide to scalarize instructions having +      //       loop-invariant operands, this special case will no longer be +      //       required. We would add the scalarization decision to +      //       collectLoopScalars() and teach getVectorValue() to broadcast +      //       the lane-zero scalar value. +      auto *Clone = Builder.Insert(GEP->clone()); +      for (unsigned Part = 0; Part < UF; ++Part) +        Entry[Part] = Builder.CreateVectorSplat(VF, Clone); +    } else { +      // If the GEP has at least one loop-varying operand, we are sure to +      // produce a vector of pointers. But if we are only unrolling, we want +      // to produce a scalar GEP for each unroll part. Thus, the GEP we +      // produce with the code below will be scalar (if VF == 1) or vector +      // (otherwise). Note that for the unroll-only case, we still maintain +      // values in the vector mapping with initVector, as we do for other +      // instructions.        for (unsigned Part = 0; Part < UF; ++Part) { -        Value *V = Builder.CreateBinOp(BinOp->getOpcode(), A[Part], B[Part]); -        if (BinaryOperator *VecOp = dyn_cast<BinaryOperator>(V)) -          VecOp->copyIRFlags(BinOp); +        // The pointer operand of the new GEP. If it's loop-invariant, we +        // won't broadcast it. +        auto *Ptr = OrigLoop->isLoopInvariant(GEP->getPointerOperand()) +                        ? GEP->getPointerOperand() +                        : getVectorValue(GEP->getPointerOperand())[Part]; + +        // Collect all the indices for the new GEP. If any index is +        // loop-invariant, we won't broadcast it. +        SmallVector<Value *, 4> Indices; +        for (auto &U : make_range(GEP->idx_begin(), GEP->idx_end())) { +          if (OrigLoop->isLoopInvariant(U.get())) +            Indices.push_back(U.get()); +          else +            Indices.push_back(getVectorValue(U.get())[Part]); +        } -        Entry[Part] = V; +        // Create the new GEP. Note that this GEP may be a scalar if VF == 1, +        // but it should be a vector, otherwise. +        auto *NewGEP = GEP->isInBounds() +                           ? Builder.CreateInBoundsGEP(Ptr, Indices) +                           : Builder.CreateGEP(Ptr, Indices); +        assert((VF == 1 || NewGEP->getType()->isVectorTy()) && +               "NewGEP is not a pointer vector"); +        Entry[Part] = NewGEP;        } +    } -      VectorLoopValueMap.initVector(&I, Entry); -      addMetadata(Entry, BinOp); +    VectorLoopValueMap.initVector(&I, Entry); +    addMetadata(Entry, GEP); +    break; +  } +  case Instruction::UDiv: +  case Instruction::SDiv: +  case Instruction::SRem: +  case Instruction::URem: +    // Scalarize with predication if this instruction may divide by zero and +    // block execution is conditional, otherwise fallthrough. +    if (Legal->isScalarWithPredication(&I)) { +      scalarizeInstruction(&I, true);        break;      } -    case Instruction::Select: { -      // Widen selects. -      // If the selector is loop invariant we can create a select -      // instruction with a scalar condition. Otherwise, use vector-select. -      auto *SE = PSE.getSE(); -      bool InvariantCond = -          SE->isLoopInvariant(PSE.getSCEV(I.getOperand(0)), OrigLoop); -      setDebugLocFromInst(Builder, &I); - -      // The condition can be loop invariant  but still defined inside the -      // loop. This means that we can't just use the original 'cond' value. -      // We have to take the 'vectorized' value and pick the first lane. -      // Instcombine will make this a no-op. -      const VectorParts &Cond = getVectorValue(I.getOperand(0)); -      const VectorParts &Op0 = getVectorValue(I.getOperand(1)); -      const VectorParts &Op1 = getVectorValue(I.getOperand(2)); - -      auto *ScalarCond = getScalarValue(I.getOperand(0), 0, 0); +  case Instruction::Add: +  case Instruction::FAdd: +  case Instruction::Sub: +  case Instruction::FSub: +  case Instruction::Mul: +  case Instruction::FMul: +  case Instruction::FDiv: +  case Instruction::FRem: +  case Instruction::Shl: +  case Instruction::LShr: +  case Instruction::AShr: +  case Instruction::And: +  case Instruction::Or: +  case Instruction::Xor: { +    // Just widen binops. +    auto *BinOp = cast<BinaryOperator>(&I); +    setDebugLocFromInst(Builder, BinOp); +    const VectorParts &A = getVectorValue(BinOp->getOperand(0)); +    const VectorParts &B = getVectorValue(BinOp->getOperand(1)); + +    // Use this vector value for all users of the original instruction. +    VectorParts Entry(UF); +    for (unsigned Part = 0; Part < UF; ++Part) { +      Value *V = Builder.CreateBinOp(BinOp->getOpcode(), A[Part], B[Part]); + +      if (BinaryOperator *VecOp = dyn_cast<BinaryOperator>(V)) +        VecOp->copyIRFlags(BinOp); + +      Entry[Part] = V; +    } -      VectorParts Entry(UF); -      for (unsigned Part = 0; Part < UF; ++Part) { -        Entry[Part] = Builder.CreateSelect( -            InvariantCond ? ScalarCond : Cond[Part], Op0[Part], Op1[Part]); -      } +    VectorLoopValueMap.initVector(&I, Entry); +    addMetadata(Entry, BinOp); +    break; +  } +  case Instruction::Select: { +    // Widen selects. +    // If the selector is loop invariant we can create a select +    // instruction with a scalar condition. Otherwise, use vector-select. +    auto *SE = PSE.getSE(); +    bool InvariantCond = +        SE->isLoopInvariant(PSE.getSCEV(I.getOperand(0)), OrigLoop); +    setDebugLocFromInst(Builder, &I); + +    // The condition can be loop invariant  but still defined inside the +    // loop. This means that we can't just use the original 'cond' value. +    // We have to take the 'vectorized' value and pick the first lane. +    // Instcombine will make this a no-op. +    const VectorParts &Cond = getVectorValue(I.getOperand(0)); +    const VectorParts &Op0 = getVectorValue(I.getOperand(1)); +    const VectorParts &Op1 = getVectorValue(I.getOperand(2)); + +    auto *ScalarCond = getScalarValue(I.getOperand(0), 0, 0); -      VectorLoopValueMap.initVector(&I, Entry); -      addMetadata(Entry, &I); -      break; +    VectorParts Entry(UF); +    for (unsigned Part = 0; Part < UF; ++Part) { +      Entry[Part] = Builder.CreateSelect( +          InvariantCond ? ScalarCond : Cond[Part], Op0[Part], Op1[Part]);      } -    case Instruction::ICmp: -    case Instruction::FCmp: { -      // Widen compares. Generate vector compares. -      bool FCmp = (I.getOpcode() == Instruction::FCmp); -      auto *Cmp = dyn_cast<CmpInst>(&I); -      setDebugLocFromInst(Builder, Cmp); -      const VectorParts &A = getVectorValue(Cmp->getOperand(0)); -      const VectorParts &B = getVectorValue(Cmp->getOperand(1)); -      VectorParts Entry(UF); -      for (unsigned Part = 0; Part < UF; ++Part) { -        Value *C = nullptr; -        if (FCmp) { -          C = Builder.CreateFCmp(Cmp->getPredicate(), A[Part], B[Part]); -          cast<FCmpInst>(C)->copyFastMathFlags(Cmp); -        } else { -          C = Builder.CreateICmp(Cmp->getPredicate(), A[Part], B[Part]); -        } -        Entry[Part] = C; +    VectorLoopValueMap.initVector(&I, Entry); +    addMetadata(Entry, &I); +    break; +  } + +  case Instruction::ICmp: +  case Instruction::FCmp: { +    // Widen compares. Generate vector compares. +    bool FCmp = (I.getOpcode() == Instruction::FCmp); +    auto *Cmp = dyn_cast<CmpInst>(&I); +    setDebugLocFromInst(Builder, Cmp); +    const VectorParts &A = getVectorValue(Cmp->getOperand(0)); +    const VectorParts &B = getVectorValue(Cmp->getOperand(1)); +    VectorParts Entry(UF); +    for (unsigned Part = 0; Part < UF; ++Part) { +      Value *C = nullptr; +      if (FCmp) { +        C = Builder.CreateFCmp(Cmp->getPredicate(), A[Part], B[Part]); +        cast<FCmpInst>(C)->copyFastMathFlags(Cmp); +      } else { +        C = Builder.CreateICmp(Cmp->getPredicate(), A[Part], B[Part]);        } +      Entry[Part] = C; +    } + +    VectorLoopValueMap.initVector(&I, Entry); +    addMetadata(Entry, &I); +    break; +  } -      VectorLoopValueMap.initVector(&I, Entry); -      addMetadata(Entry, &I); +  case Instruction::Store: +  case Instruction::Load: +    vectorizeMemoryInstruction(&I); +    break; +  case Instruction::ZExt: +  case Instruction::SExt: +  case Instruction::FPToUI: +  case Instruction::FPToSI: +  case Instruction::FPExt: +  case Instruction::PtrToInt: +  case Instruction::IntToPtr: +  case Instruction::SIToFP: +  case Instruction::UIToFP: +  case Instruction::Trunc: +  case Instruction::FPTrunc: +  case Instruction::BitCast: { +    auto *CI = dyn_cast<CastInst>(&I); +    setDebugLocFromInst(Builder, CI); + +    // Optimize the special case where the source is a constant integer +    // induction variable. Notice that we can only optimize the 'trunc' case +    // because (a) FP conversions lose precision, (b) sext/zext may wrap, and +    // (c) other casts depend on pointer size. +    if (Cost->isOptimizableIVTruncate(CI, VF)) { +      widenIntOrFpInduction(cast<PHINode>(CI->getOperand(0)), +                            cast<TruncInst>(CI));        break;      } -    case Instruction::Store: -    case Instruction::Load: -      vectorizeMemoryInstruction(&I); +    /// Vectorize casts. +    Type *DestTy = +        (VF == 1) ? CI->getType() : VectorType::get(CI->getType(), VF); + +    const VectorParts &A = getVectorValue(CI->getOperand(0)); +    VectorParts Entry(UF); +    for (unsigned Part = 0; Part < UF; ++Part) +      Entry[Part] = Builder.CreateCast(CI->getOpcode(), A[Part], DestTy); +    VectorLoopValueMap.initVector(&I, Entry); +    addMetadata(Entry, &I); +    break; +  } + +  case Instruction::Call: { +    // Ignore dbg intrinsics. +    if (isa<DbgInfoIntrinsic>(I))        break; -    case Instruction::ZExt: -    case Instruction::SExt: -    case Instruction::FPToUI: -    case Instruction::FPToSI: -    case Instruction::FPExt: -    case Instruction::PtrToInt: -    case Instruction::IntToPtr: -    case Instruction::SIToFP: -    case Instruction::UIToFP: -    case Instruction::Trunc: -    case Instruction::FPTrunc: -    case Instruction::BitCast: { -      auto *CI = dyn_cast<CastInst>(&I); -      setDebugLocFromInst(Builder, CI); - -      // Optimize the special case where the source is a constant integer -      // induction variable. Notice that we can only optimize the 'trunc' case -      // because (a) FP conversions lose precision, (b) sext/zext may wrap, and -      // (c) other casts depend on pointer size. -      auto ID = Legal->getInductionVars()->lookup(OldInduction); -      if (isa<TruncInst>(CI) && CI->getOperand(0) == OldInduction && -          ID.getConstIntStepValue()) { -        widenIntInduction(OldInduction, cast<TruncInst>(CI)); -        break; -      } +    setDebugLocFromInst(Builder, &I); -      /// Vectorize casts. -      Type *DestTy = -          (VF == 1) ? CI->getType() : VectorType::get(CI->getType(), VF); +    Module *M = I.getParent()->getParent()->getParent(); +    auto *CI = cast<CallInst>(&I); -      const VectorParts &A = getVectorValue(CI->getOperand(0)); -      VectorParts Entry(UF); -      for (unsigned Part = 0; Part < UF; ++Part) -        Entry[Part] = Builder.CreateCast(CI->getOpcode(), A[Part], DestTy); -      VectorLoopValueMap.initVector(&I, Entry); -      addMetadata(Entry, &I); +    StringRef FnName = CI->getCalledFunction()->getName(); +    Function *F = CI->getCalledFunction(); +    Type *RetTy = ToVectorTy(CI->getType(), VF); +    SmallVector<Type *, 4> Tys; +    for (Value *ArgOperand : CI->arg_operands()) +      Tys.push_back(ToVectorTy(ArgOperand->getType(), VF)); + +    Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); +    if (ID && (ID == Intrinsic::assume || ID == Intrinsic::lifetime_end || +               ID == Intrinsic::lifetime_start)) { +      scalarizeInstruction(&I); +      break; +    } +    // The flag shows whether we use Intrinsic or a usual Call for vectorized +    // version of the instruction. +    // Is it beneficial to perform intrinsic call compared to lib call? +    bool NeedToScalarize; +    unsigned CallCost = getVectorCallCost(CI, VF, *TTI, TLI, NeedToScalarize); +    bool UseVectorIntrinsic = +        ID && getVectorIntrinsicCost(CI, VF, *TTI, TLI) <= CallCost; +    if (!UseVectorIntrinsic && NeedToScalarize) { +      scalarizeInstruction(&I);        break;      } -    case Instruction::Call: { -      // Ignore dbg intrinsics. -      if (isa<DbgInfoIntrinsic>(I)) -        break; -      setDebugLocFromInst(Builder, &I); - -      Module *M = BB->getParent()->getParent(); -      auto *CI = cast<CallInst>(&I); - -      StringRef FnName = CI->getCalledFunction()->getName(); -      Function *F = CI->getCalledFunction(); -      Type *RetTy = ToVectorTy(CI->getType(), VF); -      SmallVector<Type *, 4> Tys; -      for (Value *ArgOperand : CI->arg_operands()) -        Tys.push_back(ToVectorTy(ArgOperand->getType(), VF)); - -      Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); -      if (ID && (ID == Intrinsic::assume || ID == Intrinsic::lifetime_end || -                 ID == Intrinsic::lifetime_start)) { -        scalarizeInstruction(&I); -        break; -      } -      // The flag shows whether we use Intrinsic or a usual Call for vectorized -      // version of the instruction. -      // Is it beneficial to perform intrinsic call compared to lib call? -      bool NeedToScalarize; -      unsigned CallCost = getVectorCallCost(CI, VF, *TTI, TLI, NeedToScalarize); -      bool UseVectorIntrinsic = -          ID && getVectorIntrinsicCost(CI, VF, *TTI, TLI) <= CallCost; -      if (!UseVectorIntrinsic && NeedToScalarize) { -        scalarizeInstruction(&I); -        break; -      } - -      VectorParts Entry(UF); -      for (unsigned Part = 0; Part < UF; ++Part) { -        SmallVector<Value *, 4> Args; -        for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i) { -          Value *Arg = CI->getArgOperand(i); -          // Some intrinsics have a scalar argument - don't replace it with a -          // vector. -          if (!UseVectorIntrinsic || !hasVectorInstrinsicScalarOpd(ID, i)) { -            const VectorParts &VectorArg = getVectorValue(CI->getArgOperand(i)); -            Arg = VectorArg[Part]; -          } -          Args.push_back(Arg); +    VectorParts Entry(UF); +    for (unsigned Part = 0; Part < UF; ++Part) { +      SmallVector<Value *, 4> Args; +      for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i) { +        Value *Arg = CI->getArgOperand(i); +        // Some intrinsics have a scalar argument - don't replace it with a +        // vector. +        if (!UseVectorIntrinsic || !hasVectorInstrinsicScalarOpd(ID, i)) { +          const VectorParts &VectorArg = getVectorValue(CI->getArgOperand(i)); +          Arg = VectorArg[Part];          } +        Args.push_back(Arg); +      } -        Function *VectorF; -        if (UseVectorIntrinsic) { -          // Use vector version of the intrinsic. -          Type *TysForDecl[] = {CI->getType()}; -          if (VF > 1) -            TysForDecl[0] = VectorType::get(CI->getType()->getScalarType(), VF); -          VectorF = Intrinsic::getDeclaration(M, ID, TysForDecl); -        } else { -          // Use vector version of the library call. -          StringRef VFnName = TLI->getVectorizedFunction(FnName, VF); -          assert(!VFnName.empty() && "Vector function name is empty."); -          VectorF = M->getFunction(VFnName); -          if (!VectorF) { -            // Generate a declaration -            FunctionType *FTy = FunctionType::get(RetTy, Tys, false); -            VectorF = -                Function::Create(FTy, Function::ExternalLinkage, VFnName, M); -            VectorF->copyAttributesFrom(F); -          } +      Function *VectorF; +      if (UseVectorIntrinsic) { +        // Use vector version of the intrinsic. +        Type *TysForDecl[] = {CI->getType()}; +        if (VF > 1) +          TysForDecl[0] = VectorType::get(CI->getType()->getScalarType(), VF); +        VectorF = Intrinsic::getDeclaration(M, ID, TysForDecl); +      } else { +        // Use vector version of the library call. +        StringRef VFnName = TLI->getVectorizedFunction(FnName, VF); +        assert(!VFnName.empty() && "Vector function name is empty."); +        VectorF = M->getFunction(VFnName); +        if (!VectorF) { +          // Generate a declaration +          FunctionType *FTy = FunctionType::get(RetTy, Tys, false); +          VectorF = +              Function::Create(FTy, Function::ExternalLinkage, VFnName, M); +          VectorF->copyAttributesFrom(F);          } -        assert(VectorF && "Can't create vector function."); - -        SmallVector<OperandBundleDef, 1> OpBundles; -        CI->getOperandBundlesAsDefs(OpBundles); -        CallInst *V = Builder.CreateCall(VectorF, Args, OpBundles); +      } +      assert(VectorF && "Can't create vector function."); -        if (isa<FPMathOperator>(V)) -          V->copyFastMathFlags(CI); +      SmallVector<OperandBundleDef, 1> OpBundles; +      CI->getOperandBundlesAsDefs(OpBundles); +      CallInst *V = Builder.CreateCall(VectorF, Args, OpBundles); -        Entry[Part] = V; -      } +      if (isa<FPMathOperator>(V)) +        V->copyFastMathFlags(CI); -      VectorLoopValueMap.initVector(&I, Entry); -      addMetadata(Entry, &I); -      break; +      Entry[Part] = V;      } -    default: -      // All other instructions are unsupported. Scalarize them. -      scalarizeInstruction(&I); -      break; -    } // end of switch. -  }   // end of for_each instr. +    VectorLoopValueMap.initVector(&I, Entry); +    addMetadata(Entry, &I); +    break; +  } + +  default: +    // All other instructions are unsupported. Scalarize them. +    scalarizeInstruction(&I); +    break; +  } // end of switch.  }  void InnerLoopVectorizer::updateAnalysis() { @@ -4976,11 +5059,10 @@ void InnerLoopVectorizer::updateAnalysis() {    assert(DT->properlyDominates(LoopBypassBlocks.front(), LoopExitBlock) &&           "Entry does not dominate exit."); -  // We don't predicate stores by this point, so the vector body should be a -  // single loop. -  DT->addNewBlock(LoopVectorBody, LoopVectorPreHeader); - -  DT->addNewBlock(LoopMiddleBlock, LoopVectorBody); +  DT->addNewBlock(LI->getLoopFor(LoopVectorBody)->getHeader(), +                  LoopVectorPreHeader); +  DT->addNewBlock(LoopMiddleBlock, +                  LI->getLoopFor(LoopVectorBody)->getLoopLatch());    DT->addNewBlock(LoopScalarPreHeader, LoopBypassBlocks[0]);    DT->changeImmediateDominator(LoopScalarBody, LoopScalarPreHeader);    DT->changeImmediateDominator(LoopExitBlock, LoopBypassBlocks[0]); @@ -5145,12 +5227,6 @@ bool LoopVectorizationLegality::canVectorize() {    if (UseInterleaved)      InterleaveInfo.analyzeInterleaving(*getSymbolicStrides()); -  // Collect all instructions that are known to be uniform after vectorization. -  collectLoopUniforms(); - -  // Collect all instructions that are known to be scalar after vectorization. -  collectLoopScalars(); -    unsigned SCEVThreshold = VectorizeSCEVCheckThreshold;    if (Hints->getForce() == LoopVectorizeHints::FK_Enabled)      SCEVThreshold = PragmaVectorizeSCEVCheckThreshold; @@ -5234,8 +5310,8 @@ void LoopVectorizationLegality::addInductionPhi(      // one if there are multiple (no good reason for doing this other      // than it is expedient). We've checked that it begins at zero and      // steps by one, so this is a canonical induction variable. -    if (!Induction || PhiTy == WidestIndTy) -      Induction = Phi; +    if (!PrimaryInduction || PhiTy == WidestIndTy) +      PrimaryInduction = Phi;    }    // Both the PHI node itself, and the "post-increment" value feeding @@ -5398,7 +5474,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {      } // next instr.    } -  if (!Induction) { +  if (!PrimaryInduction) {      DEBUG(dbgs() << "LV: Did not find one integer induction var.\n");      if (Inductions.empty()) {        ORE->emit(createMissedAnalysis("NoInductionVariable") @@ -5410,46 +5486,166 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {    // Now we know the widest induction type, check if our found induction    // is the same size. If it's not, unset it here and InnerLoopVectorizer    // will create another. -  if (Induction && WidestIndTy != Induction->getType()) -    Induction = nullptr; +  if (PrimaryInduction && WidestIndTy != PrimaryInduction->getType()) +    PrimaryInduction = nullptr;    return true;  } -void LoopVectorizationLegality::collectLoopScalars() { +void LoopVectorizationCostModel::collectLoopScalars(unsigned VF) { + +  // We should not collect Scalars more than once per VF. Right now, this +  // function is called from collectUniformsAndScalars(), which already does +  // this check. Collecting Scalars for VF=1 does not make any sense. +  assert(VF >= 2 && !Scalars.count(VF) && +         "This function should not be visited twice for the same VF"); + +  SmallSetVector<Instruction *, 8> Worklist; + +  // These sets are used to seed the analysis with pointers used by memory +  // accesses that will remain scalar. +  SmallSetVector<Instruction *, 8> ScalarPtrs; +  SmallPtrSet<Instruction *, 8> PossibleNonScalarPtrs; + +  // A helper that returns true if the use of Ptr by MemAccess will be scalar. +  // The pointer operands of loads and stores will be scalar as long as the +  // memory access is not a gather or scatter operation. The value operand of a +  // store will remain scalar if the store is scalarized. +  auto isScalarUse = [&](Instruction *MemAccess, Value *Ptr) { +    InstWidening WideningDecision = getWideningDecision(MemAccess, VF); +    assert(WideningDecision != CM_Unknown && +           "Widening decision should be ready at this moment"); +    if (auto *Store = dyn_cast<StoreInst>(MemAccess)) +      if (Ptr == Store->getValueOperand()) +        return WideningDecision == CM_Scalarize; +    assert(Ptr == getPointerOperand(MemAccess) && +           "Ptr is neither a value or pointer operand"); +    return WideningDecision != CM_GatherScatter; +  }; + +  // A helper that returns true if the given value is a bitcast or +  // getelementptr instruction contained in the loop. +  auto isLoopVaryingBitCastOrGEP = [&](Value *V) { +    return ((isa<BitCastInst>(V) && V->getType()->isPointerTy()) || +            isa<GetElementPtrInst>(V)) && +           !TheLoop->isLoopInvariant(V); +  }; -  // If an instruction is uniform after vectorization, it will remain scalar. -  Scalars.insert(Uniforms.begin(), Uniforms.end()); +  // A helper that evaluates a memory access's use of a pointer. If the use +  // will be a scalar use, and the pointer is only used by memory accesses, we +  // place the pointer in ScalarPtrs. Otherwise, the pointer is placed in +  // PossibleNonScalarPtrs. +  auto evaluatePtrUse = [&](Instruction *MemAccess, Value *Ptr) { -  // Collect the getelementptr instructions that will not be vectorized. A -  // getelementptr instruction is only vectorized if it is used for a legal -  // gather or scatter operation. +    // We only care about bitcast and getelementptr instructions contained in +    // the loop. +    if (!isLoopVaryingBitCastOrGEP(Ptr)) +      return; + +    // If the pointer has already been identified as scalar (e.g., if it was +    // also identified as uniform), there's nothing to do. +    auto *I = cast<Instruction>(Ptr); +    if (Worklist.count(I)) +      return; + +    // If the use of the pointer will be a scalar use, and all users of the +    // pointer are memory accesses, place the pointer in ScalarPtrs. Otherwise, +    // place the pointer in PossibleNonScalarPtrs. +    if (isScalarUse(MemAccess, Ptr) && all_of(I->users(), [&](User *U) { +          return isa<LoadInst>(U) || isa<StoreInst>(U); +        })) +      ScalarPtrs.insert(I); +    else +      PossibleNonScalarPtrs.insert(I); +  }; + +  // We seed the scalars analysis with three classes of instructions: (1) +  // instructions marked uniform-after-vectorization, (2) bitcast and +  // getelementptr instructions used by memory accesses requiring a scalar use, +  // and (3) pointer induction variables and their update instructions (we +  // currently only scalarize these). +  // +  // (1) Add to the worklist all instructions that have been identified as +  // uniform-after-vectorization. +  Worklist.insert(Uniforms[VF].begin(), Uniforms[VF].end()); + +  // (2) Add to the worklist all bitcast and getelementptr instructions used by +  // memory accesses requiring a scalar use. The pointer operands of loads and +  // stores will be scalar as long as the memory accesses is not a gather or +  // scatter operation. The value operand of a store will remain scalar if the +  // store is scalarized.    for (auto *BB : TheLoop->blocks())      for (auto &I : *BB) { -      if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { -        Scalars.insert(GEP); -        continue; +      if (auto *Load = dyn_cast<LoadInst>(&I)) { +        evaluatePtrUse(Load, Load->getPointerOperand()); +      } else if (auto *Store = dyn_cast<StoreInst>(&I)) { +        evaluatePtrUse(Store, Store->getPointerOperand()); +        evaluatePtrUse(Store, Store->getValueOperand());        } -      auto *Ptr = getPointerOperand(&I); -      if (!Ptr) -        continue; -      auto *GEP = getGEPInstruction(Ptr); -      if (GEP && isLegalGatherOrScatter(&I)) -        Scalars.erase(GEP); +    } +  for (auto *I : ScalarPtrs) +    if (!PossibleNonScalarPtrs.count(I)) { +      DEBUG(dbgs() << "LV: Found scalar instruction: " << *I << "\n"); +      Worklist.insert(I);      } +  // (3) Add to the worklist all pointer induction variables and their update +  // instructions. +  // +  // TODO: Once we are able to vectorize pointer induction variables we should +  //       no longer insert them into the worklist here. +  auto *Latch = TheLoop->getLoopLatch(); +  for (auto &Induction : *Legal->getInductionVars()) { +    auto *Ind = Induction.first; +    auto *IndUpdate = cast<Instruction>(Ind->getIncomingValueForBlock(Latch)); +    if (Induction.second.getKind() != InductionDescriptor::IK_PtrInduction) +      continue; +    Worklist.insert(Ind); +    Worklist.insert(IndUpdate); +    DEBUG(dbgs() << "LV: Found scalar instruction: " << *Ind << "\n"); +    DEBUG(dbgs() << "LV: Found scalar instruction: " << *IndUpdate << "\n"); +  } + +  // Expand the worklist by looking through any bitcasts and getelementptr +  // instructions we've already identified as scalar. This is similar to the +  // expansion step in collectLoopUniforms(); however, here we're only +  // expanding to include additional bitcasts and getelementptr instructions. +  unsigned Idx = 0; +  while (Idx != Worklist.size()) { +    Instruction *Dst = Worklist[Idx++]; +    if (!isLoopVaryingBitCastOrGEP(Dst->getOperand(0))) +      continue; +    auto *Src = cast<Instruction>(Dst->getOperand(0)); +    if (all_of(Src->users(), [&](User *U) -> bool { +          auto *J = cast<Instruction>(U); +          return !TheLoop->contains(J) || Worklist.count(J) || +                 ((isa<LoadInst>(J) || isa<StoreInst>(J)) && +                  isScalarUse(J, Src)); +        })) { +      Worklist.insert(Src); +      DEBUG(dbgs() << "LV: Found scalar instruction: " << *Src << "\n"); +    } +  } +    // An induction variable will remain scalar if all users of the induction    // variable and induction variable update remain scalar. -  auto *Latch = TheLoop->getLoopLatch(); -  for (auto &Induction : *getInductionVars()) { +  for (auto &Induction : *Legal->getInductionVars()) {      auto *Ind = Induction.first;      auto *IndUpdate = cast<Instruction>(Ind->getIncomingValueForBlock(Latch)); +    // We already considered pointer induction variables, so there's no reason +    // to look at their users again. +    // +    // TODO: Once we are able to vectorize pointer induction variables we +    //       should no longer skip over them here. +    if (Induction.second.getKind() == InductionDescriptor::IK_PtrInduction) +      continue; +      // Determine if all users of the induction variable are scalar after      // vectorization.      auto ScalarInd = all_of(Ind->users(), [&](User *U) -> bool {        auto *I = cast<Instruction>(U); -      return I == IndUpdate || !TheLoop->contains(I) || Scalars.count(I); +      return I == IndUpdate || !TheLoop->contains(I) || Worklist.count(I);      });      if (!ScalarInd)        continue; @@ -5458,23 +5654,19 @@ void LoopVectorizationLegality::collectLoopScalars() {      // scalar after vectorization.      auto ScalarIndUpdate = all_of(IndUpdate->users(), [&](User *U) -> bool {        auto *I = cast<Instruction>(U); -      return I == Ind || !TheLoop->contains(I) || Scalars.count(I); +      return I == Ind || !TheLoop->contains(I) || Worklist.count(I);      });      if (!ScalarIndUpdate)        continue;      // The induction variable and its update instruction will remain scalar. -    Scalars.insert(Ind); -    Scalars.insert(IndUpdate); +    Worklist.insert(Ind); +    Worklist.insert(IndUpdate); +    DEBUG(dbgs() << "LV: Found scalar instruction: " << *Ind << "\n"); +    DEBUG(dbgs() << "LV: Found scalar instruction: " << *IndUpdate << "\n");    } -} -bool LoopVectorizationLegality::hasConsecutiveLikePtrOperand(Instruction *I) { -  if (isAccessInterleaved(I)) -    return true; -  if (auto *Ptr = getPointerOperand(I)) -    return isConsecutivePtr(Ptr); -  return false; +  Scalars[VF].insert(Worklist.begin(), Worklist.end());  }  bool LoopVectorizationLegality::isScalarWithPredication(Instruction *I) { @@ -5494,48 +5686,48 @@ bool LoopVectorizationLegality::isScalarWithPredication(Instruction *I) {    return false;  } -bool LoopVectorizationLegality::memoryInstructionMustBeScalarized( -    Instruction *I, unsigned VF) { - -  // If the memory instruction is in an interleaved group, it will be -  // vectorized and its pointer will remain uniform. -  if (isAccessInterleaved(I)) -    return false; - +bool LoopVectorizationLegality::memoryInstructionCanBeWidened(Instruction *I, +                                                              unsigned VF) {    // Get and ensure we have a valid memory instruction.    LoadInst *LI = dyn_cast<LoadInst>(I);    StoreInst *SI = dyn_cast<StoreInst>(I);    assert((LI || SI) && "Invalid memory instruction"); -  // If the pointer operand is uniform (loop invariant), the memory instruction -  // will be scalarized.    auto *Ptr = getPointerOperand(I); -  if (LI && isUniform(Ptr)) -    return true; -  // If the pointer operand is non-consecutive and neither a gather nor a -  // scatter operation is legal, the memory instruction will be scalarized. -  if (!isConsecutivePtr(Ptr) && !isLegalGatherOrScatter(I)) -    return true; +  // In order to be widened, the pointer should be consecutive, first of all. +  if (!isConsecutivePtr(Ptr)) +    return false;    // If the instruction is a store located in a predicated block, it will be    // scalarized.    if (isScalarWithPredication(I)) -    return true; +    return false;    // If the instruction's allocated size doesn't equal it's type size, it    // requires padding and will be scalarized.    auto &DL = I->getModule()->getDataLayout();    auto *ScalarTy = LI ? LI->getType() : SI->getValueOperand()->getType();    if (hasIrregularType(ScalarTy, DL, VF)) -    return true; +    return false; -  // Otherwise, the memory instruction should be vectorized if the rest of the -  // loop is. -  return false; +  return true;  } -void LoopVectorizationLegality::collectLoopUniforms() { +void LoopVectorizationCostModel::collectLoopUniforms(unsigned VF) { + +  // We should not collect Uniforms more than once per VF. Right now, +  // this function is called from collectUniformsAndScalars(), which  +  // already does this check. Collecting Uniforms for VF=1 does not make any +  // sense. + +  assert(VF >= 2 && !Uniforms.count(VF) && +         "This function should not be visited twice for the same VF"); + +  // Visit the list of Uniforms. If we'll not find any uniform value, we'll  +  // not analyze again.  Uniforms.count(VF) will return 1. +  Uniforms[VF].clear(); +    // We now know that the loop is vectorizable!    // Collect instructions inside the loop that will remain uniform after    // vectorization. @@ -5568,6 +5760,14 @@ void LoopVectorizationLegality::collectLoopUniforms() {    // Holds pointer operands of instructions that are possibly non-uniform.    SmallPtrSet<Instruction *, 8> PossibleNonUniformPtrs; +  auto isUniformDecision = [&](Instruction *I, unsigned VF) { +    InstWidening WideningDecision = getWideningDecision(I, VF); +    assert(WideningDecision != CM_Unknown && +           "Widening decision should be ready at this moment"); + +    return (WideningDecision == CM_Widen || +            WideningDecision == CM_Interleave); +  };    // Iterate over the instructions in the loop, and collect all    // consecutive-like pointer operands in ConsecutiveLikePtrs. If it's possible    // that a consecutive-like pointer operand will be scalarized, we collect it @@ -5590,25 +5790,18 @@ void LoopVectorizationLegality::collectLoopUniforms() {          return getPointerOperand(U) == Ptr;        }); -      // Ensure the memory instruction will not be scalarized, making its -      // pointer operand non-uniform. If the pointer operand is used by some -      // instruction other than a memory access, we're not going to check if -      // that other instruction may be scalarized here. Thus, conservatively -      // assume the pointer operand may be non-uniform. -      if (!UsersAreMemAccesses || memoryInstructionMustBeScalarized(&I)) +      // Ensure the memory instruction will not be scalarized or used by +      // gather/scatter, making its pointer operand non-uniform. If the pointer +      // operand is used by any instruction other than a memory access, we +      // conservatively assume the pointer operand may be non-uniform. +      if (!UsersAreMemAccesses || !isUniformDecision(&I, VF))          PossibleNonUniformPtrs.insert(Ptr);        // If the memory instruction will be vectorized and its pointer operand -      // is consecutive-like, the pointer operand should remain uniform. -      else if (hasConsecutiveLikePtrOperand(&I)) -        ConsecutiveLikePtrs.insert(Ptr); - -      // Otherwise, if the memory instruction will be vectorized and its -      // pointer operand is non-consecutive-like, the memory instruction should -      // be a gather or scatter operation. Its pointer operand will be -      // non-uniform. +      // is consecutive-like, or interleaving - the pointer operand should +      // remain uniform.        else -        PossibleNonUniformPtrs.insert(Ptr); +        ConsecutiveLikePtrs.insert(Ptr);      }    // Add to the Worklist all consecutive and consecutive-like pointers that @@ -5632,7 +5825,9 @@ void LoopVectorizationLegality::collectLoopUniforms() {          continue;        auto *OI = cast<Instruction>(OV);        if (all_of(OI->users(), [&](User *U) -> bool { -            return isOutOfScope(U) || Worklist.count(cast<Instruction>(U)); +            auto *J = cast<Instruction>(U); +            return !TheLoop->contains(J) || Worklist.count(J) || +                   (OI == getPointerOperand(J) && isUniformDecision(J, VF));            })) {          Worklist.insert(OI);          DEBUG(dbgs() << "LV: Found uniform instruction: " << *OI << "\n"); @@ -5643,7 +5838,7 @@ void LoopVectorizationLegality::collectLoopUniforms() {    // Returns true if Ptr is the pointer operand of a memory access instruction    // I, and I is known to not require scalarization.    auto isVectorizedMemAccessUse = [&](Instruction *I, Value *Ptr) -> bool { -    return getPointerOperand(I) == Ptr && !memoryInstructionMustBeScalarized(I); +    return getPointerOperand(I) == Ptr && isUniformDecision(I, VF);    };    // For an instruction to be added into Worklist above, all its users inside @@ -5652,7 +5847,7 @@ void LoopVectorizationLegality::collectLoopUniforms() {    // nodes separately. An induction variable will remain uniform if all users    // of the induction variable and induction variable update remain uniform.    // The code below handles both pointer and non-pointer induction variables. -  for (auto &Induction : Inductions) { +  for (auto &Induction : *Legal->getInductionVars()) {      auto *Ind = Induction.first;      auto *IndUpdate = cast<Instruction>(Ind->getIncomingValueForBlock(Latch)); @@ -5683,7 +5878,7 @@ void LoopVectorizationLegality::collectLoopUniforms() {      DEBUG(dbgs() << "LV: Found uniform instruction: " << *IndUpdate << "\n");    } -  Uniforms.insert(Worklist.begin(), Worklist.end()); +  Uniforms[VF].insert(Worklist.begin(), Worklist.end());  }  bool LoopVectorizationLegality::canVectorizeMemory() { @@ -5823,7 +6018,7 @@ void InterleavedAccessInfo::collectConstStrideAccesses(        uint64_t Size = DL.getTypeAllocSize(PtrTy->getElementType());        // An alignment of 0 means target ABI alignment. -      unsigned Align = LI ? LI->getAlignment() : SI->getAlignment(); +      unsigned Align = getMemInstAlignment(&I);        if (!Align)          Align = DL.getABITypeAlignment(PtrTy->getElementType()); @@ -5978,6 +6173,11 @@ void InterleavedAccessInfo::analyzeInterleaving(        if (DesA.Stride != DesB.Stride || DesA.Size != DesB.Size)          continue; +      // Ignore A if the memory object of A and B don't belong to the same +      // address space +      if (getMemInstAddressSpace(A) != getMemInstAddressSpace(B)) +        continue; +        // Calculate the distance from A to B.        const SCEVConstant *DistToB = dyn_cast<SCEVConstant>(            PSE.getSE()->getMinusSCEV(DesA.Scev, DesB.Scev)); @@ -6020,36 +6220,36 @@ void InterleavedAccessInfo::analyzeInterleaving(      if (Group->getNumMembers() != Group->getFactor())        releaseGroup(Group); -  // Remove interleaved groups with gaps (currently only loads) whose memory  -  // accesses may wrap around. We have to revisit the getPtrStride analysis,  -  // this time with ShouldCheckWrap=true, since collectConstStrideAccesses does  +  // Remove interleaved groups with gaps (currently only loads) whose memory +  // accesses may wrap around. We have to revisit the getPtrStride analysis, +  // this time with ShouldCheckWrap=true, since collectConstStrideAccesses does    // not check wrapping (see documentation there). -  // FORNOW we use Assume=false;  -  // TODO: Change to Assume=true but making sure we don't exceed the threshold  +  // FORNOW we use Assume=false; +  // TODO: Change to Assume=true but making sure we don't exceed the threshold    // of runtime SCEV assumptions checks (thereby potentially failing to -  // vectorize altogether).  +  // vectorize altogether).    // Additional optional optimizations: -  // TODO: If we are peeling the loop and we know that the first pointer doesn't  +  // TODO: If we are peeling the loop and we know that the first pointer doesn't    // wrap then we can deduce that all pointers in the group don't wrap. -  // This means that we can forcefully peel the loop in order to only have to  -  // check the first pointer for no-wrap. When we'll change to use Assume=true  +  // This means that we can forcefully peel the loop in order to only have to +  // check the first pointer for no-wrap. When we'll change to use Assume=true    // we'll only need at most one runtime check per interleaved group.    //    for (InterleaveGroup *Group : LoadGroups) {      // Case 1: A full group. Can Skip the checks; For full groups, if the wide -    // load would wrap around the address space we would do a memory access at  -    // nullptr even without the transformation.  -    if (Group->getNumMembers() == Group->getFactor())  +    // load would wrap around the address space we would do a memory access at +    // nullptr even without the transformation. +    if (Group->getNumMembers() == Group->getFactor())        continue; -    // Case 2: If first and last members of the group don't wrap this implies  +    // Case 2: If first and last members of the group don't wrap this implies      // that all the pointers in the group don't wrap.      // So we check only group member 0 (which is always guaranteed to exist), -    // and group member Factor - 1; If the latter doesn't exist we rely on  +    // and group member Factor - 1; If the latter doesn't exist we rely on      // peeling (if it is a non-reveresed accsess -- see Case 3).      Value *FirstMemberPtr = getPointerOperand(Group->getMember(0)); -    if (!getPtrStride(PSE, FirstMemberPtr, TheLoop, Strides, /*Assume=*/false,  +    if (!getPtrStride(PSE, FirstMemberPtr, TheLoop, Strides, /*Assume=*/false,                        /*ShouldCheckWrap=*/true)) {        DEBUG(dbgs() << "LV: Invalidate candidate interleaved group due to "                        "first group member potentially pointer-wrapping.\n"); @@ -6065,8 +6265,7 @@ void InterleavedAccessInfo::analyzeInterleaving(                          "last group member potentially pointer-wrapping.\n");          releaseGroup(Group);        } -    } -    else { +    } else {        // Case 3: A non-reversed interleaved load group with gaps: We need        // to execute at least one scalar epilogue iteration. This will ensure         // we don't speculatively access memory out-of-bounds. We only need @@ -6082,27 +6281,62 @@ void InterleavedAccessInfo::analyzeInterleaving(    }  } -LoopVectorizationCostModel::VectorizationFactor -LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) { -  // Width 1 means no vectorize -  VectorizationFactor Factor = {1U, 0U}; -  if (OptForSize && Legal->getRuntimePointerChecking()->Need) { +Optional<unsigned> LoopVectorizationCostModel::computeMaxVF(bool OptForSize) { +  if (!EnableCondStoresVectorization && Legal->getNumPredStores()) { +    ORE->emit(createMissedAnalysis("ConditionalStore") +              << "store that is conditionally executed prevents vectorization"); +    DEBUG(dbgs() << "LV: No vectorization. There are conditional stores.\n"); +    return None; +  } + +  if (!OptForSize) // Remaining checks deal with scalar loop when OptForSize. +    return computeFeasibleMaxVF(OptForSize); + +  if (Legal->getRuntimePointerChecking()->Need) {      ORE->emit(createMissedAnalysis("CantVersionLoopWithOptForSize")                << "runtime pointer checks needed. Enable vectorization of this "                   "loop with '#pragma clang loop vectorize(enable)' when "                   "compiling with -Os/-Oz");      DEBUG(dbgs()            << "LV: Aborting. Runtime ptr check is required with -Os/-Oz.\n"); -    return Factor; +    return None;    } -  if (!EnableCondStoresVectorization && Legal->getNumPredStores()) { -    ORE->emit(createMissedAnalysis("ConditionalStore") -              << "store that is conditionally executed prevents vectorization"); -    DEBUG(dbgs() << "LV: No vectorization. There are conditional stores.\n"); -    return Factor; +  // If we optimize the program for size, avoid creating the tail loop. +  unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop); +  DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n'); + +  // If we don't know the precise trip count, don't try to vectorize. +  if (TC < 2) { +    ORE->emit( +        createMissedAnalysis("UnknownLoopCountComplexCFG") +        << "unable to calculate the loop count due to complex control flow"); +    DEBUG(dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n"); +    return None;    } +  unsigned MaxVF = computeFeasibleMaxVF(OptForSize); + +  if (TC % MaxVF != 0) { +    // If the trip count that we found modulo the vectorization factor is not +    // zero then we require a tail. +    // FIXME: look for a smaller MaxVF that does divide TC rather than give up. +    // FIXME: return None if loop requiresScalarEpilog(<MaxVF>), or look for a +    //        smaller MaxVF that does not require a scalar epilog. + +    ORE->emit(createMissedAnalysis("NoTailLoopWithOptForSize") +              << "cannot optimize for size and vectorize at the " +                 "same time. Enable vectorization of this loop " +                 "with '#pragma clang loop vectorize(enable)' " +                 "when compiling with -Os/-Oz"); +    DEBUG(dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n"); +    return None; +  } + +  return MaxVF; +} + +unsigned LoopVectorizationCostModel::computeFeasibleMaxVF(bool OptForSize) {    MinBWs = computeMinimumValueSizes(TheLoop->getBlocks(), *DB, &TTI);    unsigned SmallestType, WidestType;    std::tie(SmallestType, WidestType) = getSmallestAndWidestTypes(); @@ -6136,7 +6370,7 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) {    assert(MaxVectorSize <= 64 && "Did not expect to pack so many elements"                                  " into one vector!"); -  unsigned VF = MaxVectorSize; +  unsigned MaxVF = MaxVectorSize;    if (MaximizeBandwidth && !OptForSize) {      // Collect all viable vectorization factors.      SmallVector<unsigned, 8> VFs; @@ -6152,54 +6386,16 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) {      unsigned TargetNumRegisters = TTI.getNumberOfRegisters(true);      for (int i = RUs.size() - 1; i >= 0; --i) {        if (RUs[i].MaxLocalUsers <= TargetNumRegisters) { -        VF = VFs[i]; +        MaxVF = VFs[i];          break;        }      }    } +  return MaxVF; +} -  // If we optimize the program for size, avoid creating the tail loop. -  if (OptForSize) { -    unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop); -    DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n'); - -    // If we don't know the precise trip count, don't try to vectorize. -    if (TC < 2) { -      ORE->emit( -          createMissedAnalysis("UnknownLoopCountComplexCFG") -          << "unable to calculate the loop count due to complex control flow"); -      DEBUG(dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n"); -      return Factor; -    } - -    // Find the maximum SIMD width that can fit within the trip count. -    VF = TC % MaxVectorSize; - -    if (VF == 0) -      VF = MaxVectorSize; -    else { -      // If the trip count that we found modulo the vectorization factor is not -      // zero then we require a tail. -      ORE->emit(createMissedAnalysis("NoTailLoopWithOptForSize") -                << "cannot optimize for size and vectorize at the " -                   "same time. Enable vectorization of this loop " -                   "with '#pragma clang loop vectorize(enable)' " -                   "when compiling with -Os/-Oz"); -      DEBUG(dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n"); -      return Factor; -    } -  } - -  int UserVF = Hints->getWidth(); -  if (UserVF != 0) { -    assert(isPowerOf2_32(UserVF) && "VF needs to be a power of two"); -    DEBUG(dbgs() << "LV: Using user VF " << UserVF << ".\n"); - -    Factor.Width = UserVF; -    collectInstsToScalarize(UserVF); -    return Factor; -  } - +LoopVectorizationCostModel::VectorizationFactor +LoopVectorizationCostModel::selectVectorizationFactor(unsigned MaxVF) {    float Cost = expectedCost(1).first;  #ifndef NDEBUG    const float ScalarCost = Cost; @@ -6209,12 +6405,12 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) {    bool ForceVectorization = Hints->getForce() == LoopVectorizeHints::FK_Enabled;    // Ignore scalar width, because the user explicitly wants vectorization. -  if (ForceVectorization && VF > 1) { +  if (ForceVectorization && MaxVF > 1) {      Width = 2;      Cost = expectedCost(Width).first / (float)Width;    } -  for (unsigned i = 2; i <= VF; i *= 2) { +  for (unsigned i = 2; i <= MaxVF; i *= 2) {      // Notice that the vector loop needs to be executed less times, so      // we need to divide the cost of the vector loops by the width of      // the vector elements. @@ -6238,8 +6434,7 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) {          << "LV: Vectorization seems to be not beneficial, "          << "but was forced by a user.\n");    DEBUG(dbgs() << "LV: Selecting VF: " << Width << ".\n"); -  Factor.Width = Width; -  Factor.Cost = Width * Cost; +  VectorizationFactor Factor = {Width, (unsigned)(Width * Cost)};    return Factor;  } @@ -6277,9 +6472,16 @@ LoopVectorizationCostModel::getSmallestAndWidestTypes() {          T = ST->getValueOperand()->getType();        // Ignore loaded pointer types and stored pointer types that are not -      // consecutive. However, we do want to take consecutive stores/loads of -      // pointer vectors into account. -      if (T->isPointerTy() && !isConsecutiveLoadOrStore(&I)) +      // vectorizable. +      // +      // FIXME: The check here attempts to predict whether a load or store will +      //        be vectorized. We only know this for certain after a VF has +      //        been selected. Here, we assume that if an access can be +      //        vectorized, it will be. We should also look at extending this +      //        optimization to non-pointer types. +      // +      if (T->isPointerTy() && !isConsecutiveLoadOrStore(&I) && +          !Legal->isAccessInterleaved(&I) && !Legal->isLegalGatherOrScatter(&I))          continue;        MinWidth = std::min(MinWidth, @@ -6562,12 +6764,13 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) {          MaxUsages[j] = std::max(MaxUsages[j], OpenIntervals.size());          continue;        } - +      collectUniformsAndScalars(VFs[j]);        // Count the number of live intervals.        unsigned RegUsage = 0;        for (auto Inst : OpenIntervals) {          // Skip ignored values for VF > 1. -        if (VecValuesToIgnore.count(Inst)) +        if (VecValuesToIgnore.count(Inst) || +            isScalarAfterVectorization(Inst, VFs[j]))            continue;          RegUsage += GetRegUsage(Inst->getType(), VFs[j]);        } @@ -6628,6 +6831,9 @@ void LoopVectorizationCostModel::collectInstsToScalarize(unsigned VF) {          ScalarCostsTy ScalarCosts;          if (computePredInstDiscount(&I, ScalarCosts, VF) >= 0)            ScalarCostsVF.insert(ScalarCosts.begin(), ScalarCosts.end()); + +        // Remember that BB will remain after vectorization. +        PredicatedBBsAfterVectorization.insert(BB);        }    }  } @@ -6636,7 +6842,7 @@ int LoopVectorizationCostModel::computePredInstDiscount(      Instruction *PredInst, DenseMap<Instruction *, unsigned> &ScalarCosts,      unsigned VF) { -  assert(!Legal->isUniformAfterVectorization(PredInst) && +  assert(!isUniformAfterVectorization(PredInst, VF) &&           "Instruction marked uniform-after-vectorization will be predicated");    // Initialize the discount to zero, meaning that the scalar version and the @@ -6657,7 +6863,7 @@ int LoopVectorizationCostModel::computePredInstDiscount(      // already be scalar to avoid traversing chains that are unlikely to be      // beneficial.      if (!I->hasOneUse() || PredInst->getParent() != I->getParent() || -        Legal->isScalarAfterVectorization(I)) +        isScalarAfterVectorization(I, VF))        return false;      // If the instruction is scalar with predication, it will be analyzed @@ -6677,7 +6883,7 @@ int LoopVectorizationCostModel::computePredInstDiscount(      // the lane zero values for uniforms rather than asserting.      for (Use &U : I->operands())        if (auto *J = dyn_cast<Instruction>(U.get())) -        if (Legal->isUniformAfterVectorization(J)) +        if (isUniformAfterVectorization(J, VF))            return false;      // Otherwise, we can scalarize the instruction. @@ -6690,7 +6896,7 @@ int LoopVectorizationCostModel::computePredInstDiscount(    // and their return values are inserted into vectors. Thus, an extract would    // still be required.    auto needsExtract = [&](Instruction *I) -> bool { -    return TheLoop->contains(I) && !Legal->isScalarAfterVectorization(I); +    return TheLoop->contains(I) && !isScalarAfterVectorization(I, VF);    };    // Compute the expected cost discount from scalarizing the entire expression @@ -6717,8 +6923,8 @@ int LoopVectorizationCostModel::computePredInstDiscount(      // Compute the scalarization overhead of needed insertelement instructions      // and phi nodes.      if (Legal->isScalarWithPredication(I) && !I->getType()->isVoidTy()) { -      ScalarCost += getScalarizationOverhead(ToVectorTy(I->getType(), VF), true, -                                             false, TTI); +      ScalarCost += TTI.getScalarizationOverhead(ToVectorTy(I->getType(), VF), +                                                 true, false);        ScalarCost += VF * TTI.getCFInstrCost(Instruction::PHI);      } @@ -6733,8 +6939,8 @@ int LoopVectorizationCostModel::computePredInstDiscount(          if (canBeScalarized(J))            Worklist.push_back(J);          else if (needsExtract(J)) -          ScalarCost += getScalarizationOverhead(ToVectorTy(J->getType(), VF), -                                                 false, true, TTI); +          ScalarCost += TTI.getScalarizationOverhead( +                              ToVectorTy(J->getType(),VF), false, true);        }      // Scale the total scalar cost by block probability. @@ -6753,6 +6959,9 @@ LoopVectorizationCostModel::VectorizationCostTy  LoopVectorizationCostModel::expectedCost(unsigned VF) {    VectorizationCostTy Cost; +  // Collect Uniform and Scalar instructions after vectorization with VF. +  collectUniformsAndScalars(VF); +    // Collect the instructions (and their associated costs) that will be more    // profitable to scalarize.    collectInstsToScalarize(VF); @@ -6832,11 +7041,141 @@ static bool isStrideMul(Instruction *I, LoopVectorizationLegality *Legal) {           Legal->hasStride(I->getOperand(1));  } +unsigned LoopVectorizationCostModel::getMemInstScalarizationCost(Instruction *I, +                                                                 unsigned VF) { +  Type *ValTy = getMemInstValueType(I); +  auto SE = PSE.getSE(); + +  unsigned Alignment = getMemInstAlignment(I); +  unsigned AS = getMemInstAddressSpace(I); +  Value *Ptr = getPointerOperand(I); +  Type *PtrTy = ToVectorTy(Ptr->getType(), VF); + +  // Figure out whether the access is strided and get the stride value +  // if it's known in compile time +  const SCEV *PtrSCEV = getAddressAccessSCEV(Ptr, Legal, SE, TheLoop); + +  // Get the cost of the scalar memory instruction and address computation. +  unsigned Cost = VF * TTI.getAddressComputationCost(PtrTy, SE, PtrSCEV); + +  Cost += VF * +          TTI.getMemoryOpCost(I->getOpcode(), ValTy->getScalarType(), Alignment, +                              AS, I); + +  // Get the overhead of the extractelement and insertelement instructions +  // we might create due to scalarization. +  Cost += getScalarizationOverhead(I, VF, TTI); + +  // If we have a predicated store, it may not be executed for each vector +  // lane. Scale the cost by the probability of executing the predicated +  // block. +  if (Legal->isScalarWithPredication(I)) +    Cost /= getReciprocalPredBlockProb(); + +  return Cost; +} + +unsigned LoopVectorizationCostModel::getConsecutiveMemOpCost(Instruction *I, +                                                             unsigned VF) { +  Type *ValTy = getMemInstValueType(I); +  Type *VectorTy = ToVectorTy(ValTy, VF); +  unsigned Alignment = getMemInstAlignment(I); +  Value *Ptr = getPointerOperand(I); +  unsigned AS = getMemInstAddressSpace(I); +  int ConsecutiveStride = Legal->isConsecutivePtr(Ptr); + +  assert((ConsecutiveStride == 1 || ConsecutiveStride == -1) && +         "Stride should be 1 or -1 for consecutive memory access"); +  unsigned Cost = 0; +  if (Legal->isMaskRequired(I)) +    Cost += TTI.getMaskedMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS); +  else +    Cost += TTI.getMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS, I); + +  bool Reverse = ConsecutiveStride < 0; +  if (Reverse) +    Cost += TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, VectorTy, 0); +  return Cost; +} + +unsigned LoopVectorizationCostModel::getUniformMemOpCost(Instruction *I, +                                                         unsigned VF) { +  LoadInst *LI = cast<LoadInst>(I); +  Type *ValTy = LI->getType(); +  Type *VectorTy = ToVectorTy(ValTy, VF); +  unsigned Alignment = LI->getAlignment(); +  unsigned AS = LI->getPointerAddressSpace(); + +  return TTI.getAddressComputationCost(ValTy) + +         TTI.getMemoryOpCost(Instruction::Load, ValTy, Alignment, AS) + +         TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VectorTy); +} + +unsigned LoopVectorizationCostModel::getGatherScatterCost(Instruction *I, +                                                          unsigned VF) { +  Type *ValTy = getMemInstValueType(I); +  Type *VectorTy = ToVectorTy(ValTy, VF); +  unsigned Alignment = getMemInstAlignment(I); +  Value *Ptr = getPointerOperand(I); + +  return TTI.getAddressComputationCost(VectorTy) + +         TTI.getGatherScatterOpCost(I->getOpcode(), VectorTy, Ptr, +                                    Legal->isMaskRequired(I), Alignment); +} + +unsigned LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I, +                                                            unsigned VF) { +  Type *ValTy = getMemInstValueType(I); +  Type *VectorTy = ToVectorTy(ValTy, VF); +  unsigned AS = getMemInstAddressSpace(I); + +  auto Group = Legal->getInterleavedAccessGroup(I); +  assert(Group && "Fail to get an interleaved access group."); + +  unsigned InterleaveFactor = Group->getFactor(); +  Type *WideVecTy = VectorType::get(ValTy, VF * InterleaveFactor); + +  // Holds the indices of existing members in an interleaved load group. +  // An interleaved store group doesn't need this as it doesn't allow gaps. +  SmallVector<unsigned, 4> Indices; +  if (isa<LoadInst>(I)) { +    for (unsigned i = 0; i < InterleaveFactor; i++) +      if (Group->getMember(i)) +        Indices.push_back(i); +  } + +  // Calculate the cost of the whole interleaved group. +  unsigned Cost = TTI.getInterleavedMemoryOpCost(I->getOpcode(), WideVecTy, +                                                 Group->getFactor(), Indices, +                                                 Group->getAlignment(), AS); + +  if (Group->isReverse()) +    Cost += Group->getNumMembers() * +            TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, VectorTy, 0); +  return Cost; +} + +unsigned LoopVectorizationCostModel::getMemoryInstructionCost(Instruction *I, +                                                              unsigned VF) { + +  // Calculate scalar cost only. Vectorization cost should be ready at this +  // moment. +  if (VF == 1) { +    Type *ValTy = getMemInstValueType(I); +    unsigned Alignment = getMemInstAlignment(I); +    unsigned AS = getMemInstAlignment(I); + +    return TTI.getAddressComputationCost(ValTy) + +           TTI.getMemoryOpCost(I->getOpcode(), ValTy, Alignment, AS, I); +  } +  return getWideningCost(I, VF); +} +  LoopVectorizationCostModel::VectorizationCostTy  LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) {    // If we know that this instruction will remain uniform, check the cost of    // the scalar version. -  if (Legal->isUniformAfterVectorization(I)) +  if (isUniformAfterVectorization(I, VF))      VF = 1;    if (VF > 1 && isProfitableToScalarize(I, VF)) @@ -6850,6 +7189,79 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) {    return VectorizationCostTy(C, TypeNotScalarized);  } +void LoopVectorizationCostModel::setCostBasedWideningDecision(unsigned VF) { +  if (VF == 1) +    return; +  for (BasicBlock *BB : TheLoop->blocks()) { +    // For each instruction in the old loop. +    for (Instruction &I : *BB) { +      Value *Ptr = getPointerOperand(&I); +      if (!Ptr) +        continue; + +      if (isa<LoadInst>(&I) && Legal->isUniform(Ptr)) { +        // Scalar load + broadcast +        unsigned Cost = getUniformMemOpCost(&I, VF); +        setWideningDecision(&I, VF, CM_Scalarize, Cost); +        continue; +      } + +      // We assume that widening is the best solution when possible. +      if (Legal->memoryInstructionCanBeWidened(&I, VF)) { +        unsigned Cost = getConsecutiveMemOpCost(&I, VF); +        setWideningDecision(&I, VF, CM_Widen, Cost); +        continue; +      } + +      // Choose between Interleaving, Gather/Scatter or Scalarization. +      unsigned InterleaveCost = UINT_MAX; +      unsigned NumAccesses = 1; +      if (Legal->isAccessInterleaved(&I)) { +        auto Group = Legal->getInterleavedAccessGroup(&I); +        assert(Group && "Fail to get an interleaved access group."); + +        // Make one decision for the whole group. +        if (getWideningDecision(&I, VF) != CM_Unknown) +          continue; + +        NumAccesses = Group->getNumMembers(); +        InterleaveCost = getInterleaveGroupCost(&I, VF); +      } + +      unsigned GatherScatterCost = +          Legal->isLegalGatherOrScatter(&I) +              ? getGatherScatterCost(&I, VF) * NumAccesses +              : UINT_MAX; + +      unsigned ScalarizationCost = +          getMemInstScalarizationCost(&I, VF) * NumAccesses; + +      // Choose better solution for the current VF, +      // write down this decision and use it during vectorization. +      unsigned Cost; +      InstWidening Decision; +      if (InterleaveCost <= GatherScatterCost && +          InterleaveCost < ScalarizationCost) { +        Decision = CM_Interleave; +        Cost = InterleaveCost; +      } else if (GatherScatterCost < ScalarizationCost) { +        Decision = CM_GatherScatter; +        Cost = GatherScatterCost; +      } else { +        Decision = CM_Scalarize; +        Cost = ScalarizationCost; +      } +      // If the instructions belongs to an interleave group, the whole group +      // receives the same decision. The whole group receives the cost, but +      // the cost will actually be assigned to one instruction. +      if (auto Group = Legal->getInterleavedAccessGroup(&I)) +        setWideningDecision(Group, VF, Decision, Cost); +      else +        setWideningDecision(&I, VF, Decision, Cost); +    } +  } +} +  unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I,                                                          unsigned VF,                                                          Type *&VectorTy) { @@ -6868,7 +7280,31 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I,      // instruction cost.      return 0;    case Instruction::Br: { -    return TTI.getCFInstrCost(I->getOpcode()); +    // In cases of scalarized and predicated instructions, there will be VF +    // predicated blocks in the vectorized loop. Each branch around these +    // blocks requires also an extract of its vector compare i1 element. +    bool ScalarPredicatedBB = false; +    BranchInst *BI = cast<BranchInst>(I); +    if (VF > 1 && BI->isConditional() && +        (PredicatedBBsAfterVectorization.count(BI->getSuccessor(0)) || +         PredicatedBBsAfterVectorization.count(BI->getSuccessor(1)))) +      ScalarPredicatedBB = true; + +    if (ScalarPredicatedBB) { +      // Return cost for branches around scalarized and predicated blocks. +      Type *Vec_i1Ty = +          VectorType::get(IntegerType::getInt1Ty(RetTy->getContext()), VF); +      return (TTI.getScalarizationOverhead(Vec_i1Ty, false, true) + +              (TTI.getCFInstrCost(Instruction::Br) * VF)); +    } else if (I->getParent() == TheLoop->getLoopLatch() || VF == 1) +      // The back-edge branch will remain, as will all scalar branches. +      return TTI.getCFInstrCost(Instruction::Br); +    else +      // This branch will be eliminated by if-conversion. +      return 0; +    // Note: We currently assume zero cost for an unconditional branch inside +    // a predicated block since it will become a fall-through, although we +    // may decide in the future to call TTI for all branches.    }    case Instruction::PHI: {      auto *Phi = cast<PHINode>(I); @@ -6969,7 +7405,7 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I,      if (!ScalarCond)        CondTy = VectorType::get(CondTy, VF); -    return TTI.getCmpSelInstrCost(I->getOpcode(), VectorTy, CondTy); +    return TTI.getCmpSelInstrCost(I->getOpcode(), VectorTy, CondTy, I);    }    case Instruction::ICmp:    case Instruction::FCmp: { @@ -6978,130 +7414,12 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I,      if (canTruncateToMinimalBitwidth(Op0AsInstruction, VF))        ValTy = IntegerType::get(ValTy->getContext(), MinBWs[Op0AsInstruction]);      VectorTy = ToVectorTy(ValTy, VF); -    return TTI.getCmpSelInstrCost(I->getOpcode(), VectorTy); +    return TTI.getCmpSelInstrCost(I->getOpcode(), VectorTy, nullptr, I);    }    case Instruction::Store:    case Instruction::Load: { -    StoreInst *SI = dyn_cast<StoreInst>(I); -    LoadInst *LI = dyn_cast<LoadInst>(I); -    Type *ValTy = (SI ? SI->getValueOperand()->getType() : LI->getType()); -    VectorTy = ToVectorTy(ValTy, VF); - -    unsigned Alignment = SI ? SI->getAlignment() : LI->getAlignment(); -    unsigned AS = -        SI ? SI->getPointerAddressSpace() : LI->getPointerAddressSpace(); -    Value *Ptr = getPointerOperand(I); -    // We add the cost of address computation here instead of with the gep -    // instruction because only here we know whether the operation is -    // scalarized. -    if (VF == 1) -      return TTI.getAddressComputationCost(VectorTy) + -             TTI.getMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS); - -    if (LI && Legal->isUniform(Ptr)) { -      // Scalar load + broadcast -      unsigned Cost = TTI.getAddressComputationCost(ValTy->getScalarType()); -      Cost += TTI.getMemoryOpCost(I->getOpcode(), ValTy->getScalarType(), -                                  Alignment, AS); -      return Cost + -             TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, ValTy); -    } - -    // For an interleaved access, calculate the total cost of the whole -    // interleave group. -    if (Legal->isAccessInterleaved(I)) { -      auto Group = Legal->getInterleavedAccessGroup(I); -      assert(Group && "Fail to get an interleaved access group."); - -      // Only calculate the cost once at the insert position. -      if (Group->getInsertPos() != I) -        return 0; - -      unsigned InterleaveFactor = Group->getFactor(); -      Type *WideVecTy = -          VectorType::get(VectorTy->getVectorElementType(), -                          VectorTy->getVectorNumElements() * InterleaveFactor); - -      // Holds the indices of existing members in an interleaved load group. -      // An interleaved store group doesn't need this as it doesn't allow gaps. -      SmallVector<unsigned, 4> Indices; -      if (LI) { -        for (unsigned i = 0; i < InterleaveFactor; i++) -          if (Group->getMember(i)) -            Indices.push_back(i); -      } - -      // Calculate the cost of the whole interleaved group. -      unsigned Cost = TTI.getInterleavedMemoryOpCost( -          I->getOpcode(), WideVecTy, Group->getFactor(), Indices, -          Group->getAlignment(), AS); - -      if (Group->isReverse()) -        Cost += -            Group->getNumMembers() * -            TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, VectorTy, 0); - -      // FIXME: The interleaved load group with a huge gap could be even more -      // expensive than scalar operations. Then we could ignore such group and -      // use scalar operations instead. -      return Cost; -    } - -    // Check if the memory instruction will be scalarized. -    if (Legal->memoryInstructionMustBeScalarized(I, VF)) { -      unsigned Cost = 0; -      Type *PtrTy = ToVectorTy(Ptr->getType(), VF); - -      // Figure out whether the access is strided and get the stride value -      // if it's known in compile time -      const SCEV *PtrSCEV = getAddressAccessSCEV(Ptr, Legal, SE, TheLoop);  - -      // Get the cost of the scalar memory instruction and address computation. -      Cost += VF * TTI.getAddressComputationCost(PtrTy, SE, PtrSCEV); -      Cost += VF * -              TTI.getMemoryOpCost(I->getOpcode(), ValTy->getScalarType(), -                                  Alignment, AS); - -      // Get the overhead of the extractelement and insertelement instructions -      // we might create due to scalarization. -      Cost += getScalarizationOverhead(I, VF, TTI); - -      // If we have a predicated store, it may not be executed for each vector -      // lane. Scale the cost by the probability of executing the predicated -      // block. -      if (Legal->isScalarWithPredication(I)) -        Cost /= getReciprocalPredBlockProb(); - -      return Cost; -    } - -    // Determine if the pointer operand of the access is either consecutive or -    // reverse consecutive. -    int ConsecutiveStride = Legal->isConsecutivePtr(Ptr); -    bool Reverse = ConsecutiveStride < 0; - -    // Determine if either a gather or scatter operation is legal. -    bool UseGatherOrScatter = -        !ConsecutiveStride && Legal->isLegalGatherOrScatter(I); - -    unsigned Cost = TTI.getAddressComputationCost(VectorTy); -    if (UseGatherOrScatter) { -      assert(ConsecutiveStride == 0 && -             "Gather/Scatter are not used for consecutive stride"); -      return Cost + -             TTI.getGatherScatterOpCost(I->getOpcode(), VectorTy, Ptr, -                                        Legal->isMaskRequired(I), Alignment); -    } -    // Wide load/stores. -    if (Legal->isMaskRequired(I)) -      Cost += -          TTI.getMaskedMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS); -    else -      Cost += TTI.getMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS); - -    if (Reverse) -      Cost += TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, VectorTy, 0); -    return Cost; +    VectorTy = ToVectorTy(getMemInstValueType(I), VF); +    return getMemoryInstructionCost(I, VF);    }    case Instruction::ZExt:    case Instruction::SExt: @@ -7115,12 +7433,14 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I,    case Instruction::Trunc:    case Instruction::FPTrunc:    case Instruction::BitCast: { -    // We optimize the truncation of induction variable. -    // The cost of these is the same as the scalar operation. -    if (I->getOpcode() == Instruction::Trunc && -        Legal->isInductionVariable(I->getOperand(0))) -      return TTI.getCastInstrCost(I->getOpcode(), I->getType(), -                                  I->getOperand(0)->getType()); +    // We optimize the truncation of induction variables having constant +    // integer steps. The cost of these truncations is the same as the scalar +    // operation. +    if (isOptimizableIVTruncate(I, VF)) { +      auto *Trunc = cast<TruncInst>(I); +      return TTI.getCastInstrCost(Instruction::Trunc, Trunc->getDestTy(), +                                  Trunc->getSrcTy(), Trunc); +    }      Type *SrcScalarTy = I->getOperand(0)->getType();      Type *SrcVecTy = ToVectorTy(SrcScalarTy, VF); @@ -7143,7 +7463,7 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I,        }      } -    return TTI.getCastInstrCost(I->getOpcode(), VectorTy, SrcVecTy); +    return TTI.getCastInstrCost(I->getOpcode(), VectorTy, SrcVecTy, I);    }    case Instruction::Call: {      bool NeedToScalarize; @@ -7172,9 +7492,7 @@ INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)  INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass)  INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)  INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass)  INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopSimplify)  INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis)  INITIALIZE_PASS_DEPENDENCY(DemandedBitsWrapperPass)  INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) @@ -7206,81 +7524,34 @@ void LoopVectorizationCostModel::collectValuesToIgnore() {      SmallPtrSetImpl<Instruction *> &Casts = RedDes.getCastInsts();      VecValuesToIgnore.insert(Casts.begin(), Casts.end());    } - -  // Insert values known to be scalar into VecValuesToIgnore. This is a -  // conservative estimation of the values that will later be scalarized. -  // -  // FIXME: Even though an instruction is not scalar-after-vectoriztion, it may -  //        still be scalarized. For example, we may find an instruction to be -  //        more profitable for a given vectorization factor if it were to be -  //        scalarized. But at this point, we haven't yet computed the -  //        vectorization factor. -  for (auto *BB : TheLoop->getBlocks()) -    for (auto &I : *BB) -      if (Legal->isScalarAfterVectorization(&I)) -        VecValuesToIgnore.insert(&I);  } -void InnerLoopUnroller::scalarizeInstruction(Instruction *Instr, -                                             bool IfPredicateInstr) { -  assert(!Instr->getType()->isAggregateType() && "Can't handle vectors"); -  // Holds vector parameters or scalars, in case of uniform vals. -  SmallVector<VectorParts, 4> Params; - -  setDebugLocFromInst(Builder, Instr); - -  // Does this instruction return a value ? -  bool IsVoidRetTy = Instr->getType()->isVoidTy(); - -  // Initialize a new scalar map entry. -  ScalarParts Entry(UF); - -  VectorParts Cond; -  if (IfPredicateInstr) -    Cond = createBlockInMask(Instr->getParent()); - -  // For each vector unroll 'part': -  for (unsigned Part = 0; Part < UF; ++Part) { -    Entry[Part].resize(1); -    // For each scalar that we create: - -    // Start an "if (pred) a[i] = ..." block. -    Value *Cmp = nullptr; -    if (IfPredicateInstr) { -      if (Cond[Part]->getType()->isVectorTy()) -        Cond[Part] = -            Builder.CreateExtractElement(Cond[Part], Builder.getInt32(0)); -      Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Cond[Part], -                               ConstantInt::get(Cond[Part]->getType(), 1)); -    } - -    Instruction *Cloned = Instr->clone(); -    if (!IsVoidRetTy) -      Cloned->setName(Instr->getName() + ".cloned"); - -    // Replace the operands of the cloned instructions with their scalar -    // equivalents in the new loop. -    for (unsigned op = 0, e = Instr->getNumOperands(); op != e; ++op) { -      auto *NewOp = getScalarValue(Instr->getOperand(op), Part, 0); -      Cloned->setOperand(op, NewOp); -    } +LoopVectorizationCostModel::VectorizationFactor +LoopVectorizationPlanner::plan(bool OptForSize, unsigned UserVF) { -    // Place the cloned scalar in the new loop. -    Builder.Insert(Cloned); +  // Width 1 means no vectorize, cost 0 means uncomputed cost. +  const LoopVectorizationCostModel::VectorizationFactor NoVectorization = {1U, +                                                                           0U}; +  Optional<unsigned> MaybeMaxVF = CM.computeMaxVF(OptForSize); +  if (!MaybeMaxVF.hasValue()) // Cases considered too costly to vectorize. +    return NoVectorization; -    // Add the cloned scalar to the scalar map entry. -    Entry[Part][0] = Cloned; +  if (UserVF) { +    DEBUG(dbgs() << "LV: Using user VF " << UserVF << ".\n"); +    assert(isPowerOf2_32(UserVF) && "VF needs to be a power of two"); +    // Collect the instructions (and their associated costs) that will be more +    // profitable to scalarize. +    CM.selectUserVectorizationFactor(UserVF); +    return {UserVF, 0}; +  } -    // If we just cloned a new assumption, add it the assumption cache. -    if (auto *II = dyn_cast<IntrinsicInst>(Cloned)) -      if (II->getIntrinsicID() == Intrinsic::assume) -        AC->registerAssumption(II); +  unsigned MaxVF = MaybeMaxVF.getValue(); +  assert(MaxVF != 0 && "MaxVF is zero."); +  if (MaxVF == 1) +    return NoVectorization; -    // End if-block. -    if (IfPredicateInstr) -      PredicatedInstructions.push_back(std::make_pair(Cloned, Cmp)); -  } -  VectorLoopValueMap.initScalar(Instr, Entry); +  // Select the optimal vectorization factor. +  return CM.selectVectorizationFactor(MaxVF);  }  void InnerLoopUnroller::vectorizeMemoryInstruction(Instruction *Instr) { @@ -7414,11 +7685,6 @@ bool LoopVectorizePass::processLoop(Loop *L) {      return false;    } -  // Use the cost model. -  LoopVectorizationCostModel CM(L, PSE, LI, &LVL, *TTI, TLI, DB, AC, ORE, F, -                                &Hints); -  CM.collectValuesToIgnore(); -    // Check the function attributes to find out if this function should be    // optimized for size.    bool OptForSize = @@ -7464,9 +7730,20 @@ bool LoopVectorizePass::processLoop(Loop *L) {      return false;    } -  // Select the optimal vectorization factor. -  const LoopVectorizationCostModel::VectorizationFactor VF = -      CM.selectVectorizationFactor(OptForSize); +  // Use the cost model. +  LoopVectorizationCostModel CM(L, PSE, LI, &LVL, *TTI, TLI, DB, AC, ORE, F, +                                &Hints); +  CM.collectValuesToIgnore(); + +  // Use the planner for vectorization. +  LoopVectorizationPlanner LVP(CM); + +  // Get user vectorization factor. +  unsigned UserVF = Hints.getWidth(); + +  // Plan how to best vectorize, return the best VF and its cost. +  LoopVectorizationCostModel::VectorizationFactor VF = +      LVP.plan(OptForSize, UserVF);    // Select the interleave count.    unsigned IC = CM.selectInterleaveCount(OptForSize, VF.Width, VF.Cost); @@ -7522,10 +7799,10 @@ bool LoopVectorizePass::processLoop(Loop *L) {    const char *VAPassName = Hints.vectorizeAnalysisPassName();    if (!VectorizeLoop && !InterleaveLoop) {      // Do not vectorize or interleaving the loop. -    ORE->emit(OptimizationRemarkAnalysis(VAPassName, VecDiagMsg.first, +    ORE->emit(OptimizationRemarkMissed(VAPassName, VecDiagMsg.first,                                           L->getStartLoc(), L->getHeader())                << VecDiagMsg.second); -    ORE->emit(OptimizationRemarkAnalysis(LV_NAME, IntDiagMsg.first, +    ORE->emit(OptimizationRemarkMissed(LV_NAME, IntDiagMsg.first,                                           L->getStartLoc(), L->getHeader())                << IntDiagMsg.second);      return false; @@ -7621,6 +7898,16 @@ bool LoopVectorizePass::runImpl(    if (!TTI->getNumberOfRegisters(true) && TTI->getMaxInterleaveFactor(1) < 2)      return false; +  bool Changed = false; + +  // The vectorizer requires loops to be in simplified form. +  // Since simplification may add new inner loops, it has to run before the +  // legality and profitability checks. This means running the loop vectorizer +  // will simplify all loops, regardless of whether anything end up being +  // vectorized. +  for (auto &L : *LI) +    Changed |= simplifyLoop(L, DT, LI, SE, AC, false /* PreserveLCSSA */); +    // Build up a worklist of inner-loops to vectorize. This is necessary as    // the act of vectorizing or partially unrolling a loop creates new loops    // and can invalidate iterators across the loops. @@ -7632,9 +7919,15 @@ bool LoopVectorizePass::runImpl(    LoopsAnalyzed += Worklist.size();    // Now walk the identified inner loops. -  bool Changed = false; -  while (!Worklist.empty()) -    Changed |= processLoop(Worklist.pop_back_val()); +  while (!Worklist.empty()) { +    Loop *L = Worklist.pop_back_val(); + +    // For the inner loops we actually process, form LCSSA to simplify the +    // transform. +    Changed |= formLCSSARecursively(*L, *DT, LI, SE); + +    Changed |= processLoop(L); +  }    // Process each loop nest in the function.    return Changed; diff --git a/lib/Transforms/Vectorize/SLPVectorizer.cpp b/lib/Transforms/Vectorize/SLPVectorizer.cpp index 328f27002960..da3ac06ab464 100644 --- a/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -39,6 +39,7 @@  #include "llvm/Pass.h"  #include "llvm/Support/CommandLine.h"  #include "llvm/Support/Debug.h" +#include "llvm/Support/GraphWriter.h"  #include "llvm/Support/raw_ostream.h"  #include "llvm/Transforms/Vectorize.h"  #include <algorithm> @@ -90,6 +91,10 @@ static cl::opt<unsigned> MinTreeSize(      "slp-min-tree-size", cl::init(3), cl::Hidden,      cl::desc("Only vectorize small trees if they are fully vectorizable")); +static cl::opt<bool> +    ViewSLPTree("view-slp-tree", cl::Hidden, +                cl::desc("Display the SLP trees with Graphviz")); +  // Limit the number of alias checks. The limit is chosen so that  // it has no negative effect on the llvm benchmarks.  static const unsigned AliasedCheckLimit = 10; @@ -212,14 +217,14 @@ static unsigned getSameOpcode(ArrayRef<Value *> VL) {  /// Flag set: NSW, NUW, exact, and all of fast-math.  static void propagateIRFlags(Value *I, ArrayRef<Value *> VL) {    if (auto *VecOp = dyn_cast<Instruction>(I)) { -    if (auto *Intersection = dyn_cast<Instruction>(VL[0])) { -      // Intersection is initialized to the 0th scalar, -      // so start counting from index '1'. +    if (auto *I0 = dyn_cast<Instruction>(VL[0])) { +      // VecOVp is initialized to the 0th scalar, so start counting from index +      // '1'. +      VecOp->copyIRFlags(I0);        for (int i = 1, e = VL.size(); i < e; ++i) {          if (auto *Scalar = dyn_cast<Instruction>(VL[i])) -          Intersection->andIRFlags(Scalar); +          VecOp->andIRFlags(Scalar);        } -      VecOp->copyIRFlags(Intersection);      }    }  } @@ -304,6 +309,8 @@ public:    typedef SmallVector<Instruction *, 16> InstrList;    typedef SmallPtrSet<Value *, 16> ValueSet;    typedef SmallVector<StoreInst *, 8> StoreList; +  typedef MapVector<Value *, SmallVector<Instruction *, 2>> +      ExtraValueToDebugLocsMap;    BoUpSLP(Function *Func, ScalarEvolution *Se, TargetTransformInfo *Tti,            TargetLibraryInfo *TLi, AliasAnalysis *Aa, LoopInfo *Li, @@ -330,6 +337,10 @@ public:    /// \brief Vectorize the tree that starts with the elements in \p VL.    /// Returns the vectorized root.    Value *vectorizeTree(); +  /// Vectorize the tree but with the list of externally used values \p +  /// ExternallyUsedValues. Values in this MapVector can be replaced but the +  /// generated extractvalue instructions. +  Value *vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues);    /// \returns the cost incurred by unwanted spills and fills, caused by    /// holding live values over call sites. @@ -343,6 +354,13 @@ public:    /// the purpose of scheduling and extraction in the \p UserIgnoreLst.    void buildTree(ArrayRef<Value *> Roots,                   ArrayRef<Value *> UserIgnoreLst = None); +  /// Construct a vectorizable tree that starts at \p Roots, ignoring users for +  /// the purpose of scheduling and extraction in the \p UserIgnoreLst taking +  /// into account (anf updating it, if required) list of externally used +  /// values stored in \p ExternallyUsedValues. +  void buildTree(ArrayRef<Value *> Roots, +                 ExtraValueToDebugLocsMap &ExternallyUsedValues, +                 ArrayRef<Value *> UserIgnoreLst = None);    /// Clear the internal data structures that are created by 'buildTree'.    void deleteTree() { @@ -404,7 +422,7 @@ private:    int getEntryCost(TreeEntry *E);    /// This is the recursive part of buildTree. -  void buildTree_rec(ArrayRef<Value *> Roots, unsigned Depth); +  void buildTree_rec(ArrayRef<Value *> Roots, unsigned Depth, int);    /// \returns True if the ExtractElement/ExtractValue instructions in VL can    /// be vectorized to use the original vector (or aggregate "bitcast" to a vector). @@ -451,8 +469,9 @@ private:                                        SmallVectorImpl<Value *> &Left,                                        SmallVectorImpl<Value *> &Right);    struct TreeEntry { -    TreeEntry() : Scalars(), VectorizedValue(nullptr), -    NeedToGather(0) {} +    TreeEntry(std::vector<TreeEntry> &Container) +        : Scalars(), VectorizedValue(nullptr), NeedToGather(0), +          Container(Container) {}      /// \returns true if the scalars in VL are equal to this entry.      bool isSame(ArrayRef<Value *> VL) const { @@ -468,11 +487,24 @@ private:      /// Do we need to gather this sequence ?      bool NeedToGather; + +    /// Points back to the VectorizableTree. +    /// +    /// Only used for Graphviz right now.  Unfortunately GraphTrait::NodeRef has +    /// to be a pointer and needs to be able to initialize the child iterator. +    /// Thus we need a reference back to the container to translate the indices +    /// to entries. +    std::vector<TreeEntry> &Container; + +    /// The TreeEntry index containing the user of this entry.  We can actually +    /// have multiple users so the data structure is not truly a tree. +    SmallVector<int, 1> UserTreeIndices;    };    /// Create a new VectorizableTree entry. -  TreeEntry *newTreeEntry(ArrayRef<Value *> VL, bool Vectorized) { -    VectorizableTree.emplace_back(); +  TreeEntry *newTreeEntry(ArrayRef<Value *> VL, bool Vectorized, +                          int &UserTreeIdx) { +    VectorizableTree.emplace_back(VectorizableTree);      int idx = VectorizableTree.size() - 1;      TreeEntry *Last = &VectorizableTree[idx];      Last->Scalars.insert(Last->Scalars.begin(), VL.begin(), VL.end()); @@ -485,6 +517,10 @@ private:      } else {        MustGather.insert(VL.begin(), VL.end());      } + +    if (UserTreeIdx >= 0) +      Last->UserTreeIndices.push_back(UserTreeIdx); +    UserTreeIdx = idx;      return Last;    } @@ -558,7 +594,9 @@ private:    SmallVector<std::unique_ptr<Instruction>, 8> DeletedInstructions;    /// A list of values that need to extracted out of the tree. -  /// This list holds pairs of (Internal Scalar : External User). +  /// This list holds pairs of (Internal Scalar : External User). External User +  /// can be nullptr, it means that this Internal Scalar will be used later, +  /// after vectorization.    UserList ExternalUses;    /// Values used only by @llvm.assume calls. @@ -706,6 +744,8 @@ private:      return os;    }  #endif +  friend struct GraphTraits<BoUpSLP *>; +  friend struct DOTGraphTraits<BoUpSLP *>;    /// Contains all scheduling data for a basic block.    /// @@ -916,17 +956,98 @@ private:    /// original width.    MapVector<Value *, std::pair<uint64_t, bool>> MinBWs;  }; +} // end namespace slpvectorizer + +template <> struct GraphTraits<BoUpSLP *> { +  typedef BoUpSLP::TreeEntry TreeEntry; + +  /// NodeRef has to be a pointer per the GraphWriter. +  typedef TreeEntry *NodeRef; + +  /// \brief Add the VectorizableTree to the index iterator to be able to return +  /// TreeEntry pointers. +  struct ChildIteratorType +      : public iterator_adaptor_base<ChildIteratorType, +                                     SmallVector<int, 1>::iterator> { + +    std::vector<TreeEntry> &VectorizableTree; + +    ChildIteratorType(SmallVector<int, 1>::iterator W, +                      std::vector<TreeEntry> &VT) +        : ChildIteratorType::iterator_adaptor_base(W), VectorizableTree(VT) {} + +    NodeRef operator*() { return &VectorizableTree[*I]; } +  }; + +  static NodeRef getEntryNode(BoUpSLP &R) { return &R.VectorizableTree[0]; } + +  static ChildIteratorType child_begin(NodeRef N) { +    return {N->UserTreeIndices.begin(), N->Container}; +  } +  static ChildIteratorType child_end(NodeRef N) { +    return {N->UserTreeIndices.end(), N->Container}; +  } + +  /// For the node iterator we just need to turn the TreeEntry iterator into a +  /// TreeEntry* iterator so that it dereferences to NodeRef. +  typedef pointer_iterator<std::vector<TreeEntry>::iterator> nodes_iterator; + +  static nodes_iterator nodes_begin(BoUpSLP *R) { +    return nodes_iterator(R->VectorizableTree.begin()); +  } +  static nodes_iterator nodes_end(BoUpSLP *R) { +    return nodes_iterator(R->VectorizableTree.end()); +  } + +  static unsigned size(BoUpSLP *R) { return R->VectorizableTree.size(); } +}; + +template <> struct DOTGraphTraits<BoUpSLP *> : public DefaultDOTGraphTraits { +  typedef BoUpSLP::TreeEntry TreeEntry; + +  DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {} + +  std::string getNodeLabel(const TreeEntry *Entry, const BoUpSLP *R) { +    std::string Str; +    raw_string_ostream OS(Str); +    if (isSplat(Entry->Scalars)) { +      OS << "<splat> " << *Entry->Scalars[0]; +      return Str; +    } +    for (auto V : Entry->Scalars) { +      OS << *V; +      if (std::any_of( +              R->ExternalUses.begin(), R->ExternalUses.end(), +              [&](const BoUpSLP::ExternalUser &EU) { return EU.Scalar == V; })) +        OS << " <extract>"; +      OS << "\n"; +    } +    return Str; +  } + +  static std::string getNodeAttributes(const TreeEntry *Entry, +                                       const BoUpSLP *) { +    if (Entry->NeedToGather) +      return "color=red"; +    return ""; +  } +};  } // end namespace llvm -} // end namespace slpvectorizer  void BoUpSLP::buildTree(ArrayRef<Value *> Roots,                          ArrayRef<Value *> UserIgnoreLst) { +  ExtraValueToDebugLocsMap ExternallyUsedValues; +  buildTree(Roots, ExternallyUsedValues, UserIgnoreLst); +} +void BoUpSLP::buildTree(ArrayRef<Value *> Roots, +                        ExtraValueToDebugLocsMap &ExternallyUsedValues, +                        ArrayRef<Value *> UserIgnoreLst) {    deleteTree();    UserIgnoreList = UserIgnoreLst;    if (!allSameType(Roots))      return; -  buildTree_rec(Roots, 0); +  buildTree_rec(Roots, 0, -1);    // Collect the values that we need to extract from the tree.    for (TreeEntry &EIdx : VectorizableTree) { @@ -940,6 +1061,14 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots,        if (Entry->NeedToGather)          continue; +      // Check if the scalar is externally used as an extra arg. +      auto ExtI = ExternallyUsedValues.find(Scalar); +      if (ExtI != ExternallyUsedValues.end()) { +        DEBUG(dbgs() << "SLP: Need to extract: Extra arg from lane " << +              Lane << " from " << *Scalar << ".\n"); +        ExternalUses.emplace_back(Scalar, nullptr, Lane); +        continue; +      }        for (User *U : Scalar->users()) {          DEBUG(dbgs() << "SLP: Checking user:" << *U << ".\n"); @@ -976,28 +1105,28 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots,    }  } - -void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { +void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, +                            int UserTreeIdx) {    bool isAltShuffle = false;    assert((allConstant(VL) || allSameType(VL)) && "Invalid types!");    if (Depth == RecursionMaxDepth) {      DEBUG(dbgs() << "SLP: Gathering due to max recursion depth.\n"); -    newTreeEntry(VL, false); +    newTreeEntry(VL, false, UserTreeIdx);      return;    }    // Don't handle vectors.    if (VL[0]->getType()->isVectorTy()) {      DEBUG(dbgs() << "SLP: Gathering due to vector type.\n"); -    newTreeEntry(VL, false); +    newTreeEntry(VL, false, UserTreeIdx);      return;    }    if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))      if (SI->getValueOperand()->getType()->isVectorTy()) {        DEBUG(dbgs() << "SLP: Gathering due to store vector type.\n"); -      newTreeEntry(VL, false); +      newTreeEntry(VL, false, UserTreeIdx);        return;      }    unsigned Opcode = getSameOpcode(VL); @@ -1014,7 +1143,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {    // If all of the operands are identical or constant we have a simple solution.    if (allConstant(VL) || isSplat(VL) || !allSameBlock(VL) || !Opcode) {      DEBUG(dbgs() << "SLP: Gathering due to C,S,B,O. \n"); -    newTreeEntry(VL, false); +    newTreeEntry(VL, false, UserTreeIdx);      return;    } @@ -1026,7 +1155,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {      if (EphValues.count(VL[i])) {        DEBUG(dbgs() << "SLP: The instruction (" << *VL[i] <<              ") is ephemeral.\n"); -      newTreeEntry(VL, false); +      newTreeEntry(VL, false, UserTreeIdx);        return;      }    } @@ -1039,10 +1168,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {        DEBUG(dbgs() << "SLP: \tChecking bundle: " << *VL[i] << ".\n");        if (E->Scalars[i] != VL[i]) {          DEBUG(dbgs() << "SLP: Gathering due to partial overlap.\n"); -        newTreeEntry(VL, false); +        newTreeEntry(VL, false, UserTreeIdx);          return;        }      } +    // Record the reuse of the tree node.  FIXME, currently this is only used to +    // properly draw the graph rather than for the actual vectorization. +    E->UserTreeIndices.push_back(UserTreeIdx);      DEBUG(dbgs() << "SLP: Perfect diamond merge at " << *VL[0] << ".\n");      return;    } @@ -1052,7 +1184,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {      if (ScalarToTreeEntry.count(VL[i])) {        DEBUG(dbgs() << "SLP: The instruction (" << *VL[i] <<              ") is already in tree.\n"); -      newTreeEntry(VL, false); +      newTreeEntry(VL, false, UserTreeIdx);        return;      }    } @@ -1062,7 +1194,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {    for (unsigned i = 0, e = VL.size(); i != e; ++i) {      if (MustGather.count(VL[i])) {        DEBUG(dbgs() << "SLP: Gathering due to gathered scalar.\n"); -      newTreeEntry(VL, false); +      newTreeEntry(VL, false, UserTreeIdx);        return;      }    } @@ -1076,7 +1208,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {      // Don't go into unreachable blocks. They may contain instructions with      // dependency cycles which confuse the final scheduling.      DEBUG(dbgs() << "SLP: bundle in unreachable block.\n"); -    newTreeEntry(VL, false); +    newTreeEntry(VL, false, UserTreeIdx);      return;    } @@ -1085,7 +1217,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {      for (unsigned j = i+1; j < e; ++j)        if (VL[i] == VL[j]) {          DEBUG(dbgs() << "SLP: Scalar used twice in bundle.\n"); -        newTreeEntry(VL, false); +        newTreeEntry(VL, false, UserTreeIdx);          return;        } @@ -1100,7 +1232,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {      assert((!BS.getScheduleData(VL[0]) ||              !BS.getScheduleData(VL[0])->isPartOfBundle()) &&             "tryScheduleBundle should cancelScheduling on failure"); -    newTreeEntry(VL, false); +    newTreeEntry(VL, false, UserTreeIdx);      return;    }    DEBUG(dbgs() << "SLP: We are able to schedule this bundle.\n"); @@ -1117,12 +1249,12 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {            if (Term) {              DEBUG(dbgs() << "SLP: Need to swizzle PHINodes (TerminatorInst use).\n");              BS.cancelScheduling(VL); -            newTreeEntry(VL, false); +            newTreeEntry(VL, false, UserTreeIdx);              return;            }          } -      newTreeEntry(VL, true); +      newTreeEntry(VL, true, UserTreeIdx);        DEBUG(dbgs() << "SLP: added a vector of PHINodes.\n");        for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) { @@ -1132,7 +1264,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {            Operands.push_back(cast<PHINode>(j)->getIncomingValueForBlock(                PH->getIncomingBlock(i))); -        buildTree_rec(Operands, Depth + 1); +        buildTree_rec(Operands, Depth + 1, UserTreeIdx);        }        return;      } @@ -1144,7 +1276,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {        } else {          BS.cancelScheduling(VL);        } -      newTreeEntry(VL, Reuse); +      newTreeEntry(VL, Reuse, UserTreeIdx);        return;      }      case Instruction::Load: { @@ -1160,7 +1292,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {        if (DL->getTypeSizeInBits(ScalarTy) !=            DL->getTypeAllocSizeInBits(ScalarTy)) {          BS.cancelScheduling(VL); -        newTreeEntry(VL, false); +        newTreeEntry(VL, false, UserTreeIdx);          DEBUG(dbgs() << "SLP: Gathering loads of non-packed type.\n");          return;        } @@ -1171,7 +1303,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {          LoadInst *L = cast<LoadInst>(VL[i]);          if (!L->isSimple()) {            BS.cancelScheduling(VL); -          newTreeEntry(VL, false); +          newTreeEntry(VL, false, UserTreeIdx);            DEBUG(dbgs() << "SLP: Gathering non-simple loads.\n");            return;          } @@ -1193,7 +1325,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {        if (Consecutive) {          ++NumLoadsWantToKeepOrder; -        newTreeEntry(VL, true); +        newTreeEntry(VL, true, UserTreeIdx);          DEBUG(dbgs() << "SLP: added a vector of loads.\n");          return;        } @@ -1208,7 +1340,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {            }        BS.cancelScheduling(VL); -      newTreeEntry(VL, false); +      newTreeEntry(VL, false, UserTreeIdx);        if (ReverseConsecutive) {          ++NumLoadsWantToChangeOrder; @@ -1235,12 +1367,12 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {          Type *Ty = cast<Instruction>(VL[i])->getOperand(0)->getType();          if (Ty != SrcTy || !isValidElementType(Ty)) {            BS.cancelScheduling(VL); -          newTreeEntry(VL, false); +          newTreeEntry(VL, false, UserTreeIdx);            DEBUG(dbgs() << "SLP: Gathering casts with different src types.\n");            return;          }        } -      newTreeEntry(VL, true); +      newTreeEntry(VL, true, UserTreeIdx);        DEBUG(dbgs() << "SLP: added a vector of casts.\n");        for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) { @@ -1249,7 +1381,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {          for (Value *j : VL)            Operands.push_back(cast<Instruction>(j)->getOperand(i)); -        buildTree_rec(Operands, Depth+1); +        buildTree_rec(Operands, Depth + 1, UserTreeIdx);        }        return;      } @@ -1263,13 +1395,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {          if (Cmp->getPredicate() != P0 ||              Cmp->getOperand(0)->getType() != ComparedTy) {            BS.cancelScheduling(VL); -          newTreeEntry(VL, false); +          newTreeEntry(VL, false, UserTreeIdx);            DEBUG(dbgs() << "SLP: Gathering cmp with different predicate.\n");            return;          }        } -      newTreeEntry(VL, true); +      newTreeEntry(VL, true, UserTreeIdx);        DEBUG(dbgs() << "SLP: added a vector of compares.\n");        for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) { @@ -1278,7 +1410,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {          for (Value *j : VL)            Operands.push_back(cast<Instruction>(j)->getOperand(i)); -        buildTree_rec(Operands, Depth+1); +        buildTree_rec(Operands, Depth + 1, UserTreeIdx);        }        return;      } @@ -1301,7 +1433,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {      case Instruction::And:      case Instruction::Or:      case Instruction::Xor: { -      newTreeEntry(VL, true); +      newTreeEntry(VL, true, UserTreeIdx);        DEBUG(dbgs() << "SLP: added a vector of bin op.\n");        // Sort operands of the instructions so that each side is more likely to @@ -1309,8 +1441,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {        if (isa<BinaryOperator>(VL0) && VL0->isCommutative()) {          ValueList Left, Right;          reorderInputsAccordingToOpcode(VL, Left, Right); -        buildTree_rec(Left, Depth + 1); -        buildTree_rec(Right, Depth + 1); +        buildTree_rec(Left, Depth + 1, UserTreeIdx); +        buildTree_rec(Right, Depth + 1, UserTreeIdx);          return;        } @@ -1320,7 +1452,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {          for (Value *j : VL)            Operands.push_back(cast<Instruction>(j)->getOperand(i)); -        buildTree_rec(Operands, Depth+1); +        buildTree_rec(Operands, Depth + 1, UserTreeIdx);        }        return;      } @@ -1330,7 +1462,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {          if (cast<Instruction>(VL[j])->getNumOperands() != 2) {            DEBUG(dbgs() << "SLP: not-vectorizable GEP (nested indexes).\n");            BS.cancelScheduling(VL); -          newTreeEntry(VL, false); +          newTreeEntry(VL, false, UserTreeIdx);            return;          }        } @@ -1343,7 +1475,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {          if (Ty0 != CurTy) {            DEBUG(dbgs() << "SLP: not-vectorizable GEP (different types).\n");            BS.cancelScheduling(VL); -          newTreeEntry(VL, false); +          newTreeEntry(VL, false, UserTreeIdx);            return;          }        } @@ -1355,12 +1487,12 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {            DEBUG(                dbgs() << "SLP: not-vectorizable GEP (non-constant indexes).\n");            BS.cancelScheduling(VL); -          newTreeEntry(VL, false); +          newTreeEntry(VL, false, UserTreeIdx);            return;          }        } -      newTreeEntry(VL, true); +      newTreeEntry(VL, true, UserTreeIdx);        DEBUG(dbgs() << "SLP: added a vector of GEPs.\n");        for (unsigned i = 0, e = 2; i < e; ++i) {          ValueList Operands; @@ -1368,7 +1500,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {          for (Value *j : VL)            Operands.push_back(cast<Instruction>(j)->getOperand(i)); -        buildTree_rec(Operands, Depth + 1); +        buildTree_rec(Operands, Depth + 1, UserTreeIdx);        }        return;      } @@ -1377,19 +1509,19 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {        for (unsigned i = 0, e = VL.size() - 1; i < e; ++i)          if (!isConsecutiveAccess(VL[i], VL[i + 1], *DL, *SE)) {            BS.cancelScheduling(VL); -          newTreeEntry(VL, false); +          newTreeEntry(VL, false, UserTreeIdx);            DEBUG(dbgs() << "SLP: Non-consecutive store.\n");            return;          } -      newTreeEntry(VL, true); +      newTreeEntry(VL, true, UserTreeIdx);        DEBUG(dbgs() << "SLP: added a vector of stores.\n");        ValueList Operands;        for (Value *j : VL)          Operands.push_back(cast<Instruction>(j)->getOperand(0)); -      buildTree_rec(Operands, Depth + 1); +      buildTree_rec(Operands, Depth + 1, UserTreeIdx);        return;      }      case Instruction::Call: { @@ -1400,7 +1532,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {        Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);        if (!isTriviallyVectorizable(ID)) {          BS.cancelScheduling(VL); -        newTreeEntry(VL, false); +        newTreeEntry(VL, false, UserTreeIdx);          DEBUG(dbgs() << "SLP: Non-vectorizable call.\n");          return;        } @@ -1414,7 +1546,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {              getVectorIntrinsicIDForCall(CI2, TLI) != ID ||              !CI->hasIdenticalOperandBundleSchema(*CI2)) {            BS.cancelScheduling(VL); -          newTreeEntry(VL, false); +          newTreeEntry(VL, false, UserTreeIdx);            DEBUG(dbgs() << "SLP: mismatched calls:" << *CI << "!=" << *VL[i]                         << "\n");            return; @@ -1425,7 +1557,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {            Value *A1J = CI2->getArgOperand(1);            if (A1I != A1J) {              BS.cancelScheduling(VL); -            newTreeEntry(VL, false); +            newTreeEntry(VL, false, UserTreeIdx);              DEBUG(dbgs() << "SLP: mismatched arguments in call:" << *CI                           << " argument "<< A1I<<"!=" << A1J                           << "\n"); @@ -1438,14 +1570,14 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {                          CI->op_begin() + CI->getBundleOperandsEndIndex(),                          CI2->op_begin() + CI2->getBundleOperandsStartIndex())) {            BS.cancelScheduling(VL); -          newTreeEntry(VL, false); +          newTreeEntry(VL, false, UserTreeIdx);            DEBUG(dbgs() << "SLP: mismatched bundle operands in calls:" << *CI << "!="                         << *VL[i] << '\n');            return;          }        } -      newTreeEntry(VL, true); +      newTreeEntry(VL, true, UserTreeIdx);        for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i) {          ValueList Operands;          // Prepare the operand vector. @@ -1453,7 +1585,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {            CallInst *CI2 = dyn_cast<CallInst>(j);            Operands.push_back(CI2->getArgOperand(i));          } -        buildTree_rec(Operands, Depth + 1); +        buildTree_rec(Operands, Depth + 1, UserTreeIdx);        }        return;      } @@ -1462,19 +1594,19 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {        // then do not vectorize this instruction.        if (!isAltShuffle) {          BS.cancelScheduling(VL); -        newTreeEntry(VL, false); +        newTreeEntry(VL, false, UserTreeIdx);          DEBUG(dbgs() << "SLP: ShuffleVector are not vectorized.\n");          return;        } -      newTreeEntry(VL, true); +      newTreeEntry(VL, true, UserTreeIdx);        DEBUG(dbgs() << "SLP: added a ShuffleVector op.\n");        // Reorder operands if reordering would enable vectorization.        if (isa<BinaryOperator>(VL0)) {          ValueList Left, Right;          reorderAltShuffleOperands(VL, Left, Right); -        buildTree_rec(Left, Depth + 1); -        buildTree_rec(Right, Depth + 1); +        buildTree_rec(Left, Depth + 1, UserTreeIdx); +        buildTree_rec(Right, Depth + 1, UserTreeIdx);          return;        } @@ -1484,13 +1616,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {          for (Value *j : VL)            Operands.push_back(cast<Instruction>(j)->getOperand(i)); -        buildTree_rec(Operands, Depth + 1); +        buildTree_rec(Operands, Depth + 1, UserTreeIdx);        }        return;      }      default:        BS.cancelScheduling(VL); -      newTreeEntry(VL, false); +      newTreeEntry(VL, false, UserTreeIdx);        DEBUG(dbgs() << "SLP: Gathering unknown instruction.\n");        return;    } @@ -1570,6 +1702,8 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {    Type *ScalarTy = VL[0]->getType();    if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))      ScalarTy = SI->getValueOperand()->getType(); +  else if (CmpInst *CI = dyn_cast<CmpInst>(VL[0])) +    ScalarTy = CI->getOperand(0)->getType();    VectorType *VecTy = VectorType::get(ScalarTy, VL.size());    // If we have computed a smaller type for the expression, update VecTy so @@ -1599,7 +1733,13 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {          int DeadCost = 0;          for (unsigned i = 0, e = VL.size(); i < e; ++i) {            Instruction *E = cast<Instruction>(VL[i]); -          if (E->hasOneUse()) +          // If all users are going to be vectorized, instruction can be +          // considered as dead. +          // The same, if have only one user, it will be vectorized for sure. +          if (E->hasOneUse() || +              std::all_of(E->user_begin(), E->user_end(), [this](User *U) { +                return ScalarToTreeEntry.count(U) > 0; +              }))              // Take credit for instruction that will become dead.              DeadCost +=                  TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, i); @@ -1624,10 +1764,10 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {        // Calculate the cost of this instruction.        int ScalarCost = VL.size() * TTI->getCastInstrCost(VL0->getOpcode(), -                                                         VL0->getType(), SrcTy); +                                                         VL0->getType(), SrcTy, VL0);        VectorType *SrcVecTy = VectorType::get(SrcTy, VL.size()); -      int VecCost = TTI->getCastInstrCost(VL0->getOpcode(), VecTy, SrcVecTy); +      int VecCost = TTI->getCastInstrCost(VL0->getOpcode(), VecTy, SrcVecTy, VL0);        return VecCost - ScalarCost;      }      case Instruction::FCmp: @@ -1636,8 +1776,8 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {        // Calculate the cost of this instruction.        VectorType *MaskTy = VectorType::get(Builder.getInt1Ty(), VL.size());        int ScalarCost = VecTy->getNumElements() * -          TTI->getCmpSelInstrCost(Opcode, ScalarTy, Builder.getInt1Ty()); -      int VecCost = TTI->getCmpSelInstrCost(Opcode, VecTy, MaskTy); +          TTI->getCmpSelInstrCost(Opcode, ScalarTy, Builder.getInt1Ty(), VL0); +      int VecCost = TTI->getCmpSelInstrCost(Opcode, VecTy, MaskTy, VL0);        return VecCost - ScalarCost;      }      case Instruction::Add: @@ -1720,18 +1860,18 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {        // Cost of wide load - cost of scalar loads.        unsigned alignment = dyn_cast<LoadInst>(VL0)->getAlignment();        int ScalarLdCost = VecTy->getNumElements() * -            TTI->getMemoryOpCost(Instruction::Load, ScalarTy, alignment, 0); +          TTI->getMemoryOpCost(Instruction::Load, ScalarTy, alignment, 0, VL0);        int VecLdCost = TTI->getMemoryOpCost(Instruction::Load, -                                           VecTy, alignment, 0); +                                           VecTy, alignment, 0, VL0);        return VecLdCost - ScalarLdCost;      }      case Instruction::Store: {        // We know that we can merge the stores. Calculate the cost.        unsigned alignment = dyn_cast<StoreInst>(VL0)->getAlignment();        int ScalarStCost = VecTy->getNumElements() * -            TTI->getMemoryOpCost(Instruction::Store, ScalarTy, alignment, 0); +          TTI->getMemoryOpCost(Instruction::Store, ScalarTy, alignment, 0, VL0);        int VecStCost = TTI->getMemoryOpCost(Instruction::Store, -                                           VecTy, alignment, 0); +                                           VecTy, alignment, 0, VL0);        return VecStCost - ScalarStCost;      }      case Instruction::Call: { @@ -1739,12 +1879,9 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {        Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);        // Calculate the cost of the scalar and vector calls. -      SmallVector<Type*, 4> ScalarTys, VecTys; -      for (unsigned op = 0, opc = CI->getNumArgOperands(); op!= opc; ++op) { +      SmallVector<Type*, 4> ScalarTys; +      for (unsigned op = 0, opc = CI->getNumArgOperands(); op!= opc; ++op)          ScalarTys.push_back(CI->getArgOperand(op)->getType()); -        VecTys.push_back(VectorType::get(CI->getArgOperand(op)->getType(), -                                         VecTy->getNumElements())); -      }        FastMathFlags FMF;        if (auto *FPMO = dyn_cast<FPMathOperator>(CI)) @@ -1753,7 +1890,9 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {        int ScalarCallCost = VecTy->getNumElements() *            TTI->getIntrinsicInstrCost(ID, ScalarTy, ScalarTys, FMF); -      int VecCallCost = TTI->getIntrinsicInstrCost(ID, VecTy, VecTys, FMF); +      SmallVector<Value *, 4> Args(CI->arg_operands()); +      int VecCallCost = TTI->getIntrinsicInstrCost(ID, CI->getType(), Args, FMF, +                                                   VecTy->getNumElements());        DEBUG(dbgs() << "SLP: Call cost "<< VecCallCost - ScalarCallCost              << " (" << VecCallCost  << "-" <<  ScalarCallCost << ")" @@ -1947,9 +2086,18 @@ int BoUpSLP::getTreeCost() {    int SpillCost = getSpillCost();    Cost += SpillCost + ExtractCost; -  DEBUG(dbgs() << "SLP: Spill Cost = " << SpillCost << ".\n" -               << "SLP: Extract Cost = " << ExtractCost << ".\n" -               << "SLP: Total Cost = " << Cost << ".\n"); +  std::string Str; +  { +    raw_string_ostream OS(Str); +    OS << "SLP: Spill Cost = " << SpillCost << ".\n" +       << "SLP: Extract Cost = " << ExtractCost << ".\n" +       << "SLP: Total Cost = " << Cost << ".\n"; +  } +  DEBUG(dbgs() << Str); + +  if (ViewSLPTree) +    ViewGraph(this, "SLP" + F->getName(), false, Str); +    return Cost;  } @@ -2702,6 +2850,12 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {  }  Value *BoUpSLP::vectorizeTree() { +  ExtraValueToDebugLocsMap ExternallyUsedValues; +  return vectorizeTree(ExternallyUsedValues); +} + +Value * +BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) {    // All blocks must be scheduled before any instructions are inserted.    for (auto &BSIter : BlocksSchedules) { @@ -2744,7 +2898,7 @@ Value *BoUpSLP::vectorizeTree() {      // Skip users that we already RAUW. This happens when one instruction      // has multiple uses of the same value. -    if (!is_contained(Scalar->users(), User)) +    if (User && !is_contained(Scalar->users(), User))        continue;      assert(ScalarToTreeEntry.count(Scalar) && "Invalid scalar"); @@ -2756,6 +2910,28 @@ Value *BoUpSLP::vectorizeTree() {      assert(Vec && "Can't find vectorizable value");      Value *Lane = Builder.getInt32(ExternalUse.Lane); +    // If User == nullptr, the Scalar is used as extra arg. Generate +    // ExtractElement instruction and update the record for this scalar in +    // ExternallyUsedValues. +    if (!User) { +      assert(ExternallyUsedValues.count(Scalar) && +             "Scalar with nullptr as an external user must be registered in " +             "ExternallyUsedValues map"); +      if (auto *VecI = dyn_cast<Instruction>(Vec)) { +        Builder.SetInsertPoint(VecI->getParent(), +                               std::next(VecI->getIterator())); +      } else { +        Builder.SetInsertPoint(&F->getEntryBlock().front()); +      } +      Value *Ex = Builder.CreateExtractElement(Vec, Lane); +      Ex = extend(ScalarRoot, Ex, Scalar->getType()); +      CSEBlocks.insert(cast<Instruction>(Scalar)->getParent()); +      auto &Locs = ExternallyUsedValues[Scalar]; +      ExternallyUsedValues.insert({Ex, Locs}); +      ExternallyUsedValues.erase(Scalar); +      continue; +    } +      // Generate extracts for out-of-tree users.      // Find the insertion point for the extractelement lane.      if (auto *VecI = dyn_cast<Instruction>(Vec)) { @@ -3264,7 +3440,7 @@ void BoUpSLP::scheduleBlock(BlockScheduling *BS) {    // sorted by the original instruction location. This lets the final schedule    // be as  close as possible to the original instruction order.    struct ScheduleDataCompare { -    bool operator()(ScheduleData *SD1, ScheduleData *SD2) { +    bool operator()(ScheduleData *SD1, ScheduleData *SD2) const {        return SD2->SchedulingPriority < SD1->SchedulingPriority;      }    }; @@ -3645,9 +3821,9 @@ PreservedAnalyses SLPVectorizerPass::run(Function &F, FunctionAnalysisManager &A    bool Changed = runImpl(F, SE, TTI, TLI, AA, LI, DT, AC, DB);    if (!Changed)      return PreservedAnalyses::all(); +    PreservedAnalyses PA; -  PA.preserve<LoopAnalysis>(); -  PA.preserve<DominatorTreeAnalysis>(); +  PA.preserveSet<CFGAnalyses>();    PA.preserve<AAManager>();    PA.preserve<GlobalsAA>();    return PA; @@ -4026,36 +4202,40 @@ bool SLPVectorizerPass::tryToVectorize(BinaryOperator *V, BoUpSLP &R) {    if (!V)      return false; +  Value *P = V->getParent(); + +  // Vectorize in current basic block only. +  auto *Op0 = dyn_cast<Instruction>(V->getOperand(0)); +  auto *Op1 = dyn_cast<Instruction>(V->getOperand(1)); +  if (!Op0 || !Op1 || Op0->getParent() != P || Op1->getParent() != P) +    return false; +    // Try to vectorize V. -  if (tryToVectorizePair(V->getOperand(0), V->getOperand(1), R)) +  if (tryToVectorizePair(Op0, Op1, R))      return true; -  BinaryOperator *A = dyn_cast<BinaryOperator>(V->getOperand(0)); -  BinaryOperator *B = dyn_cast<BinaryOperator>(V->getOperand(1)); +  auto *A = dyn_cast<BinaryOperator>(Op0); +  auto *B = dyn_cast<BinaryOperator>(Op1);    // Try to skip B.    if (B && B->hasOneUse()) { -    BinaryOperator *B0 = dyn_cast<BinaryOperator>(B->getOperand(0)); -    BinaryOperator *B1 = dyn_cast<BinaryOperator>(B->getOperand(1)); -    if (tryToVectorizePair(A, B0, R)) { +    auto *B0 = dyn_cast<BinaryOperator>(B->getOperand(0)); +    auto *B1 = dyn_cast<BinaryOperator>(B->getOperand(1)); +    if (B0 && B0->getParent() == P && tryToVectorizePair(A, B0, R))        return true; -    } -    if (tryToVectorizePair(A, B1, R)) { +    if (B1 && B1->getParent() == P && tryToVectorizePair(A, B1, R))        return true; -    }    }    // Try to skip A.    if (A && A->hasOneUse()) { -    BinaryOperator *A0 = dyn_cast<BinaryOperator>(A->getOperand(0)); -    BinaryOperator *A1 = dyn_cast<BinaryOperator>(A->getOperand(1)); -    if (tryToVectorizePair(A0, B, R)) { +    auto *A0 = dyn_cast<BinaryOperator>(A->getOperand(0)); +    auto *A1 = dyn_cast<BinaryOperator>(A->getOperand(1)); +    if (A0 && A0->getParent() == P && tryToVectorizePair(A0, B, R))        return true; -    } -    if (tryToVectorizePair(A1, B, R)) { +    if (A1 && A1->getParent() == P && tryToVectorizePair(A1, B, R))        return true; -    }    } -  return 0; +  return false;  }  /// \brief Generate a shuffle mask to be used in a reduction tree. @@ -4119,37 +4299,41 @@ namespace {  class HorizontalReduction {    SmallVector<Value *, 16> ReductionOps;    SmallVector<Value *, 32> ReducedVals; +  // Use map vector to make stable output. +  MapVector<Instruction *, Value *> ExtraArgs; -  BinaryOperator *ReductionRoot; -  // After successfull horizontal reduction vectorization attempt for PHI node -  // vectorizer tries to update root binary op by combining vectorized tree and -  // the ReductionPHI node. But during vectorization this ReductionPHI can be -  // vectorized itself and replaced by the undef value, while the instruction -  // itself is marked for deletion. This 'marked for deletion' PHI node then can -  // be used in new binary operation, causing "Use still stuck around after Def -  // is destroyed" crash upon PHI node deletion. -  WeakVH ReductionPHI; +  BinaryOperator *ReductionRoot = nullptr;    /// The opcode of the reduction. -  unsigned ReductionOpcode; +  Instruction::BinaryOps ReductionOpcode = Instruction::BinaryOpsEnd;    /// The opcode of the values we perform a reduction on. -  unsigned ReducedValueOpcode; +  unsigned ReducedValueOpcode = 0;    /// Should we model this reduction as a pairwise reduction tree or a tree that    /// splits the vector in halves and adds those halves. -  bool IsPairwiseReduction; +  bool IsPairwiseReduction = false; + +  /// Checks if the ParentStackElem.first should be marked as a reduction +  /// operation with an extra argument or as extra argument itself. +  void markExtraArg(std::pair<Instruction *, unsigned> &ParentStackElem, +                    Value *ExtraArg) { +    if (ExtraArgs.count(ParentStackElem.first)) { +      ExtraArgs[ParentStackElem.first] = nullptr; +      // We ran into something like: +      // ParentStackElem.first = ExtraArgs[ParentStackElem.first] + ExtraArg. +      // The whole ParentStackElem.first should be considered as an extra value +      // in this case. +      // Do not perform analysis of remaining operands of ParentStackElem.first +      // instruction, this whole instruction is an extra argument. +      ParentStackElem.second = ParentStackElem.first->getNumOperands(); +    } else { +      // We ran into something like: +      // ParentStackElem.first += ... + ExtraArg + ... +      ExtraArgs[ParentStackElem.first] = ExtraArg; +    } +  }  public: -  /// The width of one full horizontal reduction operation. -  unsigned ReduxWidth; - -  /// Minimal width of available vector registers. It's used to determine -  /// ReduxWidth. -  unsigned MinVecRegSize; - -  HorizontalReduction(unsigned MinVecRegSize) -      : ReductionRoot(nullptr), ReductionOpcode(0), ReducedValueOpcode(0), -        IsPairwiseReduction(false), ReduxWidth(0), -        MinVecRegSize(MinVecRegSize) {} +  HorizontalReduction() = default;    /// \brief Try to find a reduction tree.    bool matchAssociativeReduction(PHINode *Phi, BinaryOperator *B) { @@ -4176,21 +4360,14 @@ public:      if (!isValidElementType(Ty))        return false; -    const DataLayout &DL = B->getModule()->getDataLayout();      ReductionOpcode = B->getOpcode();      ReducedValueOpcode = 0; -    // FIXME: Register size should be a parameter to this function, so we can -    // try different vectorization factors. -    ReduxWidth = MinVecRegSize / DL.getTypeSizeInBits(Ty);      ReductionRoot = B; -    ReductionPHI = Phi; - -    if (ReduxWidth < 4) -      return false;      // We currently only support adds. -    if (ReductionOpcode != Instruction::Add && -        ReductionOpcode != Instruction::FAdd) +    if ((ReductionOpcode != Instruction::Add && +         ReductionOpcode != Instruction::FAdd) || +        !B->isAssociative())        return false;      // Post order traverse the reduction tree starting at B. We only handle true @@ -4202,30 +4379,26 @@ public:        unsigned EdgeToVist = Stack.back().second++;        bool IsReducedValue = TreeN->getOpcode() != ReductionOpcode; -      // Only handle trees in the current basic block. -      if (TreeN->getParent() != B->getParent()) -        return false; - -      // Each tree node needs to have one user except for the ultimate -      // reduction. -      if (!TreeN->hasOneUse() && TreeN != B) -        return false; -        // Postorder vist.        if (EdgeToVist == 2 || IsReducedValue) { -        if (IsReducedValue) { -          // Make sure that the opcodes of the operations that we are going to -          // reduce match. -          if (!ReducedValueOpcode) -            ReducedValueOpcode = TreeN->getOpcode(); -          else if (ReducedValueOpcode != TreeN->getOpcode()) -            return false; +        if (IsReducedValue)            ReducedVals.push_back(TreeN); -        } else { -          // We need to be able to reassociate the adds. -          if (!TreeN->isAssociative()) -            return false; -          ReductionOps.push_back(TreeN); +        else { +          auto I = ExtraArgs.find(TreeN); +          if (I != ExtraArgs.end() && !I->second) { +            // Check if TreeN is an extra argument of its parent operation. +            if (Stack.size() <= 1) { +              // TreeN can't be an extra argument as it is a root reduction +              // operation. +              return false; +            } +            // Yes, TreeN is an extra argument, do not add it to a list of +            // reduction operations. +            // Stack[Stack.size() - 2] always points to the parent operation. +            markExtraArg(Stack[Stack.size() - 2], TreeN); +            ExtraArgs.erase(TreeN); +          } else +            ReductionOps.push_back(TreeN);          }          // Retract.          Stack.pop_back(); @@ -4242,13 +4415,44 @@ public:          // reduced value class.          if (I && (!ReducedValueOpcode || I->getOpcode() == ReducedValueOpcode ||                    I->getOpcode() == ReductionOpcode)) { -          if (!ReducedValueOpcode && I->getOpcode() != ReductionOpcode) +          // Only handle trees in the current basic block. +          if (I->getParent() != B->getParent()) { +            // I is an extra argument for TreeN (its parent operation). +            markExtraArg(Stack.back(), I); +            continue; +          } + +          // Each tree node needs to have one user except for the ultimate +          // reduction. +          if (!I->hasOneUse() && I != B) { +            // I is an extra argument for TreeN (its parent operation). +            markExtraArg(Stack.back(), I); +            continue; +          } + +          if (I->getOpcode() == ReductionOpcode) { +            // We need to be able to reassociate the reduction operations. +            if (!I->isAssociative()) { +              // I is an extra argument for TreeN (its parent operation). +              markExtraArg(Stack.back(), I); +              continue; +            } +          } else if (ReducedValueOpcode && +                     ReducedValueOpcode != I->getOpcode()) { +            // Make sure that the opcodes of the operations that we are going to +            // reduce match. +            // I is an extra argument for TreeN (its parent operation). +            markExtraArg(Stack.back(), I); +            continue; +          } else if (!ReducedValueOpcode)              ReducedValueOpcode = I->getOpcode(); +            Stack.push_back(std::make_pair(I, 0));            continue;          } -        return false;        } +      // NextV is an extra argument for TreeN (its parent operation). +      markExtraArg(Stack.back(), NextV);      }      return true;    } @@ -4259,10 +4463,15 @@ public:      if (ReducedVals.empty())        return false; +    // If there is a sufficient number of reduction values, reduce +    // to a nearby power-of-2. Can safely generate oversized +    // vectors and rely on the backend to split them to legal sizes.      unsigned NumReducedVals = ReducedVals.size(); -    if (NumReducedVals < ReduxWidth) +    if (NumReducedVals < 4)        return false; +    unsigned ReduxWidth = PowerOf2Floor(NumReducedVals); +      Value *VectorizedTree = nullptr;      IRBuilder<> Builder(ReductionRoot);      FastMathFlags Unsafe; @@ -4270,20 +4479,26 @@ public:      Builder.setFastMathFlags(Unsafe);      unsigned i = 0; -    for (; i < NumReducedVals - ReduxWidth + 1; i += ReduxWidth) { +    BoUpSLP::ExtraValueToDebugLocsMap ExternallyUsedValues; +    // The same extra argument may be used several time, so log each attempt +    // to use it. +    for (auto &Pair : ExtraArgs) +      ExternallyUsedValues[Pair.second].push_back(Pair.first); +    while (i < NumReducedVals - ReduxWidth + 1 && ReduxWidth > 2) {        auto VL = makeArrayRef(&ReducedVals[i], ReduxWidth); -      V.buildTree(VL, ReductionOps); +      V.buildTree(VL, ExternallyUsedValues, ReductionOps);        if (V.shouldReorder()) {          SmallVector<Value *, 8> Reversed(VL.rbegin(), VL.rend()); -        V.buildTree(Reversed, ReductionOps); +        V.buildTree(Reversed, ExternallyUsedValues, ReductionOps);        }        if (V.isTreeTinyAndNotFullyVectorizable()) -        continue; +        break;        V.computeMinimumValueSizes();        // Estimate cost. -      int Cost = V.getTreeCost() + getReductionCost(TTI, ReducedVals[i]); +      int Cost = +          V.getTreeCost() + getReductionCost(TTI, ReducedVals[i], ReduxWidth);        if (Cost >= -SLPCostThreshold)          break; @@ -4292,33 +4507,44 @@ public:        // Vectorize a tree.        DebugLoc Loc = cast<Instruction>(ReducedVals[i])->getDebugLoc(); -      Value *VectorizedRoot = V.vectorizeTree(); +      Value *VectorizedRoot = V.vectorizeTree(ExternallyUsedValues);        // Emit a reduction. -      Value *ReducedSubTree = emitReduction(VectorizedRoot, Builder); +      Value *ReducedSubTree = +          emitReduction(VectorizedRoot, Builder, ReduxWidth, ReductionOps);        if (VectorizedTree) {          Builder.SetCurrentDebugLocation(Loc); -        VectorizedTree = createBinOp(Builder, ReductionOpcode, VectorizedTree, -                                     ReducedSubTree, "bin.rdx"); +        VectorizedTree = Builder.CreateBinOp(ReductionOpcode, VectorizedTree, +                                             ReducedSubTree, "bin.rdx"); +        propagateIRFlags(VectorizedTree, ReductionOps);        } else          VectorizedTree = ReducedSubTree; +      i += ReduxWidth; +      ReduxWidth = PowerOf2Floor(NumReducedVals - i);      }      if (VectorizedTree) {        // Finish the reduction.        for (; i < NumReducedVals; ++i) { -        Builder.SetCurrentDebugLocation( -          cast<Instruction>(ReducedVals[i])->getDebugLoc()); -        VectorizedTree = createBinOp(Builder, ReductionOpcode, VectorizedTree, -                                     ReducedVals[i]); +        auto *I = cast<Instruction>(ReducedVals[i]); +        Builder.SetCurrentDebugLocation(I->getDebugLoc()); +        VectorizedTree = +            Builder.CreateBinOp(ReductionOpcode, VectorizedTree, I); +        propagateIRFlags(VectorizedTree, ReductionOps); +      } +      for (auto &Pair : ExternallyUsedValues) { +        assert(!Pair.second.empty() && +               "At least one DebugLoc must be inserted"); +        // Add each externally used value to the final reduction. +        for (auto *I : Pair.second) { +          Builder.SetCurrentDebugLocation(I->getDebugLoc()); +          VectorizedTree = Builder.CreateBinOp(ReductionOpcode, VectorizedTree, +                                               Pair.first, "bin.extra"); +          propagateIRFlags(VectorizedTree, I); +        }        }        // Update users. -      if (ReductionPHI && !isa<UndefValue>(ReductionPHI)) { -        assert(ReductionRoot && "Need a reduction operation"); -        ReductionRoot->setOperand(0, VectorizedTree); -        ReductionRoot->setOperand(1, ReductionPHI); -      } else -        ReductionRoot->replaceAllUsesWith(VectorizedTree); +      ReductionRoot->replaceAllUsesWith(VectorizedTree);      }      return VectorizedTree != nullptr;    } @@ -4329,7 +4555,8 @@ public:  private:    /// \brief Calculate the cost of a reduction. -  int getReductionCost(TargetTransformInfo *TTI, Value *FirstReducedVal) { +  int getReductionCost(TargetTransformInfo *TTI, Value *FirstReducedVal, +                       unsigned ReduxWidth) {      Type *ScalarTy = FirstReducedVal->getType();      Type *VecTy = VectorType::get(ScalarTy, ReduxWidth); @@ -4352,15 +4579,9 @@ private:      return VecReduxCost - ScalarReduxCost;    } -  static Value *createBinOp(IRBuilder<> &Builder, unsigned Opcode, Value *L, -                            Value *R, const Twine &Name = "") { -    if (Opcode == Instruction::FAdd) -      return Builder.CreateFAdd(L, R, Name); -    return Builder.CreateBinOp((Instruction::BinaryOps)Opcode, L, R, Name); -  } -    /// \brief Emit a horizontal reduction of the vectorized value. -  Value *emitReduction(Value *VectorizedValue, IRBuilder<> &Builder) { +  Value *emitReduction(Value *VectorizedValue, IRBuilder<> &Builder, +                       unsigned ReduxWidth, ArrayRef<Value *> RedOps) {      assert(VectorizedValue && "Need to have a vectorized tree node");      assert(isPowerOf2_32(ReduxWidth) &&             "We only handle power-of-two reductions for now"); @@ -4378,15 +4599,16 @@ private:          Value *RightShuf = Builder.CreateShuffleVector(            TmpVec, UndefValue::get(TmpVec->getType()), (RightMask),            "rdx.shuf.r"); -        TmpVec = createBinOp(Builder, ReductionOpcode, LeftShuf, RightShuf, -                             "bin.rdx"); +        TmpVec = Builder.CreateBinOp(ReductionOpcode, LeftShuf, RightShuf, +                                     "bin.rdx");        } else {          Value *UpperHalf =            createRdxShuffleMask(ReduxWidth, i, false, false, Builder);          Value *Shuf = Builder.CreateShuffleVector(            TmpVec, UndefValue::get(TmpVec->getType()), UpperHalf, "rdx.shuf"); -        TmpVec = createBinOp(Builder, ReductionOpcode, TmpVec, Shuf, "bin.rdx"); +        TmpVec = Builder.CreateBinOp(ReductionOpcode, TmpVec, Shuf, "bin.rdx");        } +      propagateIRFlags(TmpVec, RedOps);      }      // The result is in the first element of the vector. @@ -4438,16 +4660,19 @@ static bool findBuildVector(InsertElementInst *FirstInsertElem,  static bool findBuildAggregate(InsertValueInst *IV,                                 SmallVectorImpl<Value *> &BuildVector,                                 SmallVectorImpl<Value *> &BuildVectorOpds) { -  if (!IV->hasOneUse()) -    return false; -  Value *V = IV->getAggregateOperand(); -  if (!isa<UndefValue>(V)) { -    InsertValueInst *I = dyn_cast<InsertValueInst>(V); -    if (!I || !findBuildAggregate(I, BuildVector, BuildVectorOpds)) +  Value *V; +  do { +    BuildVector.push_back(IV); +    BuildVectorOpds.push_back(IV->getInsertedValueOperand()); +    V = IV->getAggregateOperand(); +    if (isa<UndefValue>(V)) +      break; +    IV = dyn_cast<InsertValueInst>(V); +    if (!IV || !IV->hasOneUse())        return false; -  } -  BuildVector.push_back(IV); -  BuildVectorOpds.push_back(IV->getInsertedValueOperand()); +  } while (true); +  std::reverse(BuildVector.begin(), BuildVector.end()); +  std::reverse(BuildVectorOpds.begin(), BuildVectorOpds.end());    return true;  } @@ -4507,29 +4732,137 @@ static Value *getReductionValue(const DominatorTree *DT, PHINode *P,    return nullptr;  } +namespace { +/// Tracks instructons and its children. +class WeakVHWithLevel final : public CallbackVH { +  /// Operand index of the instruction currently beeing analized. +  unsigned Level = 0; +  /// Is this the instruction that should be vectorized, or are we now +  /// processing children (i.e. operands of this instruction) for potential +  /// vectorization? +  bool IsInitial = true; + +public: +  explicit WeakVHWithLevel() = default; +  WeakVHWithLevel(Value *V) : CallbackVH(V){}; +  /// Restart children analysis each time it is repaced by the new instruction. +  void allUsesReplacedWith(Value *New) override { +    setValPtr(New); +    Level = 0; +    IsInitial = true; +  } +  /// Check if the instruction was not deleted during vectorization. +  bool isValid() const { return !getValPtr(); } +  /// Is the istruction itself must be vectorized? +  bool isInitial() const { return IsInitial; } +  /// Try to vectorize children. +  void clearInitial() { IsInitial = false; } +  /// Are all children processed already? +  bool isFinal() const { +    assert(getValPtr() && +           (isa<Instruction>(getValPtr()) && +            cast<Instruction>(getValPtr())->getNumOperands() >= Level)); +    return getValPtr() && +           cast<Instruction>(getValPtr())->getNumOperands() == Level; +  } +  /// Get next child operation. +  Value *nextOperand() { +    assert(getValPtr() && isa<Instruction>(getValPtr()) && +           cast<Instruction>(getValPtr())->getNumOperands() > Level); +    return cast<Instruction>(getValPtr())->getOperand(Level++); +  } +  virtual ~WeakVHWithLevel() = default; +}; +} // namespace +  /// \brief Attempt to reduce a horizontal reduction.  /// If it is legal to match a horizontal reduction feeding -/// the phi node P with reduction operators BI, then check if it -/// can be done. +/// the phi node P with reduction operators Root in a basic block BB, then check +/// if it can be done.  /// \returns true if a horizontal reduction was matched and reduced.  /// \returns false if a horizontal reduction was not matched. -static bool canMatchHorizontalReduction(PHINode *P, BinaryOperator *BI, -                                        BoUpSLP &R, TargetTransformInfo *TTI, -                                        unsigned MinRegSize) { +static bool canBeVectorized( +    PHINode *P, Instruction *Root, BasicBlock *BB, BoUpSLP &R, +    TargetTransformInfo *TTI, +    const function_ref<bool(BinaryOperator *, BoUpSLP &)> Vectorize) {    if (!ShouldVectorizeHor)      return false; -  HorizontalReduction HorRdx(MinRegSize); -  if (!HorRdx.matchAssociativeReduction(P, BI)) +  if (!Root)      return false; -  // If there is a sufficient number of reduction values, reduce -  // to a nearby power-of-2. Can safely generate oversized -  // vectors and rely on the backend to split them to legal sizes. -  HorRdx.ReduxWidth = -    std::max((uint64_t)4, PowerOf2Floor(HorRdx.numReductionValues())); +  if (Root->getParent() != BB) +    return false; +  SmallVector<WeakVHWithLevel, 8> Stack(1, Root); +  SmallSet<Value *, 8> VisitedInstrs; +  bool Res = false; +  while (!Stack.empty()) { +    Value *V = Stack.back(); +    if (!V) { +      Stack.pop_back(); +      continue; +    } +    auto *Inst = dyn_cast<Instruction>(V); +    if (!Inst || isa<PHINode>(Inst)) { +      Stack.pop_back(); +      continue; +    } +    if (Stack.back().isInitial()) { +      Stack.back().clearInitial(); +      if (auto *BI = dyn_cast<BinaryOperator>(Inst)) { +        HorizontalReduction HorRdx; +        if (HorRdx.matchAssociativeReduction(P, BI)) { +          if (HorRdx.tryToReduce(R, TTI)) { +            Res = true; +            P = nullptr; +            continue; +          } +        } +        if (P) { +          Inst = dyn_cast<Instruction>(BI->getOperand(0)); +          if (Inst == P) +            Inst = dyn_cast<Instruction>(BI->getOperand(1)); +          if (!Inst) { +            P = nullptr; +            continue; +          } +        } +      } +      P = nullptr; +      if (Vectorize(dyn_cast<BinaryOperator>(Inst), R)) { +        Res = true; +        continue; +      } +    } +    if (Stack.back().isFinal()) { +      Stack.pop_back(); +      continue; +    } -  return HorRdx.tryToReduce(R, TTI); +    if (auto *NextV = dyn_cast<Instruction>(Stack.back().nextOperand())) +      if (NextV->getParent() == BB && VisitedInstrs.insert(NextV).second && +          Stack.size() < RecursionMaxDepth) +        Stack.push_back(NextV); +  } +  return Res; +} + +bool SLPVectorizerPass::vectorizeRootInstruction(PHINode *P, Value *V, +                                                 BasicBlock *BB, BoUpSLP &R, +                                                 TargetTransformInfo *TTI) { +  if (!V) +    return false; +  auto *I = dyn_cast<Instruction>(V); +  if (!I) +    return false; + +  if (!isa<BinaryOperator>(I)) +    P = nullptr; +  // Try to match and vectorize a horizontal reduction. +  return canBeVectorized(P, I, BB, R, TTI, +                         [this](BinaryOperator *BI, BoUpSLP &R) -> bool { +                           return tryToVectorize(BI, R); +                         });  }  bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { @@ -4599,67 +4932,42 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {        if (P->getNumIncomingValues() != 2)          return Changed; -      Value *Rdx = getReductionValue(DT, P, BB, LI); - -      // Check if this is a Binary Operator. -      BinaryOperator *BI = dyn_cast_or_null<BinaryOperator>(Rdx); -      if (!BI) -        continue; -        // Try to match and vectorize a horizontal reduction. -      if (canMatchHorizontalReduction(P, BI, R, TTI, R.getMinVecRegSize())) { +      if (vectorizeRootInstruction(P, getReductionValue(DT, P, BB, LI), BB, R, +                                   TTI)) {          Changed = true;          it = BB->begin();          e = BB->end();          continue;        } - -     Value *Inst = BI->getOperand(0); -      if (Inst == P) -        Inst = BI->getOperand(1); - -      if (tryToVectorize(dyn_cast<BinaryOperator>(Inst), R)) { -        // We would like to start over since some instructions are deleted -        // and the iterator may become invalid value. -        Changed = true; -        it = BB->begin(); -        e = BB->end(); -        continue; -      } -        continue;      } -    if (ShouldStartVectorizeHorAtStore) -      if (StoreInst *SI = dyn_cast<StoreInst>(it)) -        if (BinaryOperator *BinOp = -                dyn_cast<BinaryOperator>(SI->getValueOperand())) { -          if (canMatchHorizontalReduction(nullptr, BinOp, R, TTI, -                                          R.getMinVecRegSize()) || -              tryToVectorize(BinOp, R)) { -            Changed = true; -            it = BB->begin(); -            e = BB->end(); -            continue; -          } +    if (ShouldStartVectorizeHorAtStore) { +      if (StoreInst *SI = dyn_cast<StoreInst>(it)) { +        // Try to match and vectorize a horizontal reduction. +        if (vectorizeRootInstruction(nullptr, SI->getValueOperand(), BB, R, +                                     TTI)) { +          Changed = true; +          it = BB->begin(); +          e = BB->end(); +          continue;          } +      } +    }      // Try to vectorize horizontal reductions feeding into a return. -    if (ReturnInst *RI = dyn_cast<ReturnInst>(it)) -      if (RI->getNumOperands() != 0) -        if (BinaryOperator *BinOp = -                dyn_cast<BinaryOperator>(RI->getOperand(0))) { -          DEBUG(dbgs() << "SLP: Found a return to vectorize.\n"); -          if (canMatchHorizontalReduction(nullptr, BinOp, R, TTI, -                                          R.getMinVecRegSize()) || -              tryToVectorizePair(BinOp->getOperand(0), BinOp->getOperand(1), -                                 R)) { -            Changed = true; -            it = BB->begin(); -            e = BB->end(); -            continue; -          } +    if (ReturnInst *RI = dyn_cast<ReturnInst>(it)) { +      if (RI->getNumOperands() != 0) { +        // Try to match and vectorize a horizontal reduction. +        if (vectorizeRootInstruction(nullptr, RI->getOperand(0), BB, R, TTI)) { +          Changed = true; +          it = BB->begin(); +          e = BB->end(); +          continue;          } +      } +    }      // Try to vectorize trees that start at compare instructions.      if (CmpInst *CI = dyn_cast<CmpInst>(it)) { @@ -4672,16 +4980,14 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {          continue;        } -      for (int i = 0; i < 2; ++i) { -        if (BinaryOperator *BI = dyn_cast<BinaryOperator>(CI->getOperand(i))) { -          if (tryToVectorizePair(BI->getOperand(0), BI->getOperand(1), R)) { -            Changed = true; -            // We would like to start over since some instructions are deleted -            // and the iterator may become invalid value. -            it = BB->begin(); -            e = BB->end(); -            break; -          } +      for (int I = 0; I < 2; ++I) { +        if (vectorizeRootInstruction(nullptr, CI->getOperand(I), BB, R, TTI)) { +          Changed = true; +          // We would like to start over since some instructions are deleted +          // and the iterator may become invalid value. +          it = BB->begin(); +          e = BB->end(); +          break;          }        }        continue;  | 
